xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision 66fd12cf4896eb08ad8e7a2627537f84ead84dd3)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "llvm/Analysis/AssumptionCache.h"
24 #include "llvm/Analysis/DependenceAnalysis.h"
25 #include "llvm/Analysis/DomTreeUpdater.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/LoopIterator.h"
28 #include "llvm/Analysis/MustExecute.h"
29 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
30 #include "llvm/Analysis/ScalarEvolution.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/DebugInfoMetadata.h"
33 #include "llvm/IR/DebugLoc.h"
34 #include "llvm/IR/DiagnosticInfo.h"
35 #include "llvm/IR/Dominators.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/Instruction.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/IntrinsicInst.h"
40 #include "llvm/IR/User.h"
41 #include "llvm/IR/Value.h"
42 #include "llvm/IR/ValueHandle.h"
43 #include "llvm/IR/ValueMap.h"
44 #include "llvm/Support/Casting.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/ErrorHandling.h"
47 #include "llvm/Support/GenericDomTree.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
50 #include "llvm/Transforms/Utils/Cloning.h"
51 #include "llvm/Transforms/Utils/LoopUtils.h"
52 #include "llvm/Transforms/Utils/UnrollLoop.h"
53 #include "llvm/Transforms/Utils/ValueMapper.h"
54 #include <assert.h>
55 #include <memory>
56 #include <type_traits>
57 #include <vector>
58 
59 using namespace llvm;
60 
61 #define DEBUG_TYPE "loop-unroll-and-jam"
62 
63 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
64 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
65 
66 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
67 
68 // Partition blocks in an outer/inner loop pair into blocks before and after
69 // the loop
70 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
71                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
72   Loop *SubLoop = L.getSubLoops()[0];
73   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
74 
75   for (BasicBlock *BB : L.blocks()) {
76     if (!SubLoop->contains(BB)) {
77       if (DT.dominates(SubLoopLatch, BB))
78         AftBlocks.insert(BB);
79       else
80         ForeBlocks.insert(BB);
81     }
82   }
83 
84   // Check that all blocks in ForeBlocks together dominate the subloop
85   // TODO: This might ideally be done better with a dominator/postdominators.
86   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
87   for (BasicBlock *BB : ForeBlocks) {
88     if (BB == SubLoopPreHeader)
89       continue;
90     Instruction *TI = BB->getTerminator();
91     for (BasicBlock *Succ : successors(TI))
92       if (!ForeBlocks.count(Succ))
93         return false;
94   }
95 
96   return true;
97 }
98 
99 /// Partition blocks in a loop nest into blocks before and after each inner
100 /// loop.
101 static bool partitionOuterLoopBlocks(
102     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
103     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
104     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
105   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
106 
107   for (Loop *L : Root.getLoopsInPreorder()) {
108     if (L == &JamLoop)
109       break;
110 
111     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
112       return false;
113   }
114 
115   return true;
116 }
117 
118 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
119 // than 2 levels loop.
120 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
121                                      BasicBlockSet &ForeBlocks,
122                                      BasicBlockSet &SubLoopBlocks,
123                                      BasicBlockSet &AftBlocks,
124                                      DominatorTree *DT) {
125   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
126   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
127 }
128 
129 // Looks at the phi nodes in Header for values coming from Latch. For these
130 // instructions and all their operands calls Visit on them, keeping going for
131 // all the operands in AftBlocks. Returns false if Visit returns false,
132 // otherwise returns true. This is used to process the instructions in the
133 // Aft blocks that need to be moved before the subloop. It is used in two
134 // places. One to check that the required set of instructions can be moved
135 // before the loop. Then to collect the instructions to actually move in
136 // moveHeaderPhiOperandsToForeBlocks.
137 template <typename T>
138 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
139                                      BasicBlockSet &AftBlocks, T Visit) {
140   SmallPtrSet<Instruction *, 8> VisitedInstr;
141 
142   std::function<bool(Instruction * I)> ProcessInstr = [&](Instruction *I) {
143     if (VisitedInstr.count(I))
144       return true;
145 
146     VisitedInstr.insert(I);
147 
148     if (AftBlocks.count(I->getParent()))
149       for (auto &U : I->operands())
150         if (Instruction *II = dyn_cast<Instruction>(U))
151           if (!ProcessInstr(II))
152             return false;
153 
154     return Visit(I);
155   };
156 
157   for (auto &Phi : Header->phis()) {
158     Value *V = Phi.getIncomingValueForBlock(Latch);
159     if (Instruction *I = dyn_cast<Instruction>(V))
160       if (!ProcessInstr(I))
161         return false;
162   }
163 
164   return true;
165 }
166 
167 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
168 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
169                                               BasicBlock *Latch,
170                                               Instruction *InsertLoc,
171                                               BasicBlockSet &AftBlocks) {
172   // We need to ensure we move the instructions in the correct order,
173   // starting with the earliest required instruction and moving forward.
174   processHeaderPhiOperands(Header, Latch, AftBlocks,
175                            [&AftBlocks, &InsertLoc](Instruction *I) {
176                              if (AftBlocks.count(I->getParent()))
177                                I->moveBefore(InsertLoc);
178                              return true;
179                            });
180 }
181 
182 /*
183   This method performs Unroll and Jam. For a simple loop like:
184   for (i = ..)
185     Fore(i)
186     for (j = ..)
187       SubLoop(i, j)
188     Aft(i)
189 
190   Instead of doing normal inner or outer unrolling, we do:
191   for (i = .., i+=2)
192     Fore(i)
193     Fore(i+1)
194     for (j = ..)
195       SubLoop(i, j)
196       SubLoop(i+1, j)
197     Aft(i)
198     Aft(i+1)
199 
200   So the outer loop is essetially unrolled and then the inner loops are fused
201   ("jammed") together into a single loop. This can increase speed when there
202   are loads in SubLoop that are invariant to i, as they become shared between
203   the now jammed inner loops.
204 
205   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
206   Fore blocks are those before the inner loop, Aft are those after. Normal
207   Unroll code is used to copy each of these sets of blocks and the results are
208   combined together into the final form above.
209 
210   isSafeToUnrollAndJam should be used prior to calling this to make sure the
211   unrolling will be valid. Checking profitablility is also advisable.
212 
213   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
214   necessary to create one and not fully unrolled).
215 */
216 LoopUnrollResult
217 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
218                        unsigned TripMultiple, bool UnrollRemainder,
219                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
220                        AssumptionCache *AC, const TargetTransformInfo *TTI,
221                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
222 
223   // When we enter here we should have already checked that it is safe
224   BasicBlock *Header = L->getHeader();
225   assert(Header && "No header.");
226   assert(L->getSubLoops().size() == 1);
227   Loop *SubLoop = *L->begin();
228 
229   // Don't enter the unroll code if there is nothing to do.
230   if (TripCount == 0 && Count < 2) {
231     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
232     return LoopUnrollResult::Unmodified;
233   }
234 
235   assert(Count > 0);
236   assert(TripMultiple > 0);
237   assert(TripCount == 0 || TripCount % TripMultiple == 0);
238 
239   // Are we eliminating the loop control altogether?
240   bool CompletelyUnroll = (Count == TripCount);
241 
242   // We use the runtime remainder in cases where we don't know trip multiple
243   if (TripMultiple % Count != 0) {
244     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
245                                     /*UseEpilogRemainder*/ true,
246                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
247                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
248       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
249                            "generated when assuming runtime trip count\n");
250       return LoopUnrollResult::Unmodified;
251     }
252   }
253 
254   // Notify ScalarEvolution that the loop will be substantially changed,
255   // if not outright eliminated.
256   if (SE) {
257     SE->forgetLoop(L);
258     SE->forgetBlockAndLoopDispositions();
259   }
260 
261   using namespace ore;
262   // Report the unrolling decision.
263   if (CompletelyUnroll) {
264     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
265                       << Header->getName() << " with trip count " << TripCount
266                       << "!\n");
267     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
268                                  L->getHeader())
269               << "completely unroll and jammed loop with "
270               << NV("UnrollCount", TripCount) << " iterations");
271   } else {
272     auto DiagBuilder = [&]() {
273       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
274                               L->getHeader());
275       return Diag << "unroll and jammed loop by a factor of "
276                   << NV("UnrollCount", Count);
277     };
278 
279     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
280                       << " by " << Count);
281     if (TripMultiple != 1) {
282       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
283       ORE->emit([&]() {
284         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
285                              << " trips per branch";
286       });
287     } else {
288       LLVM_DEBUG(dbgs() << " with run-time trip count");
289       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
290     }
291     LLVM_DEBUG(dbgs() << "!\n");
292   }
293 
294   BasicBlock *Preheader = L->getLoopPreheader();
295   BasicBlock *LatchBlock = L->getLoopLatch();
296   assert(Preheader && "No preheader");
297   assert(LatchBlock && "No latch block");
298   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
299   assert(BI && !BI->isUnconditional());
300   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
301   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
302   bool SubLoopContinueOnTrue = SubLoop->contains(
303       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
304 
305   // Partition blocks in an outer/inner loop pair into blocks before and after
306   // the loop
307   BasicBlockSet SubLoopBlocks;
308   BasicBlockSet ForeBlocks;
309   BasicBlockSet AftBlocks;
310   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
311                            DT);
312 
313   // We keep track of the entering/first and exiting/last block of each of
314   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
315   // blocks easier.
316   std::vector<BasicBlock *> ForeBlocksFirst;
317   std::vector<BasicBlock *> ForeBlocksLast;
318   std::vector<BasicBlock *> SubLoopBlocksFirst;
319   std::vector<BasicBlock *> SubLoopBlocksLast;
320   std::vector<BasicBlock *> AftBlocksFirst;
321   std::vector<BasicBlock *> AftBlocksLast;
322   ForeBlocksFirst.push_back(Header);
323   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
324   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
325   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
326   AftBlocksFirst.push_back(SubLoop->getExitBlock());
327   AftBlocksLast.push_back(L->getExitingBlock());
328   // Maps Blocks[0] -> Blocks[It]
329   ValueToValueMapTy LastValueMap;
330 
331   // Move any instructions from fore phi operands from AftBlocks into Fore.
332   moveHeaderPhiOperandsToForeBlocks(
333       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
334 
335   // The current on-the-fly SSA update requires blocks to be processed in
336   // reverse postorder so that LastValueMap contains the correct value at each
337   // exit.
338   LoopBlocksDFS DFS(L);
339   DFS.perform(LI);
340   // Stash the DFS iterators before adding blocks to the loop.
341   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
342   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
343 
344   // When a FSDiscriminator is enabled, we don't need to add the multiply
345   // factors to the discriminators.
346   if (Header->getParent()->shouldEmitDebugInfoForProfiling() &&
347       !EnableFSDiscriminator)
348     for (BasicBlock *BB : L->getBlocks())
349       for (Instruction &I : *BB)
350         if (!isa<DbgInfoIntrinsic>(&I))
351           if (const DILocation *DIL = I.getDebugLoc()) {
352             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
353             if (NewDIL)
354               I.setDebugLoc(*NewDIL);
355             else
356               LLVM_DEBUG(dbgs()
357                          << "Failed to create new discriminator: "
358                          << DIL->getFilename() << " Line: " << DIL->getLine());
359           }
360 
361   // Copy all blocks
362   for (unsigned It = 1; It != Count; ++It) {
363     SmallVector<BasicBlock *, 8> NewBlocks;
364     // Maps Blocks[It] -> Blocks[It-1]
365     DenseMap<Value *, Value *> PrevItValueMap;
366     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
367     NewLoops[L] = L;
368     NewLoops[SubLoop] = SubLoop;
369 
370     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
371       ValueToValueMapTy VMap;
372       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
373       Header->getParent()->insert(Header->getParent()->end(), New);
374 
375       // Tell LI about New.
376       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
377 
378       if (ForeBlocks.count(*BB)) {
379         if (*BB == ForeBlocksFirst[0])
380           ForeBlocksFirst.push_back(New);
381         if (*BB == ForeBlocksLast[0])
382           ForeBlocksLast.push_back(New);
383       } else if (SubLoopBlocks.count(*BB)) {
384         if (*BB == SubLoopBlocksFirst[0])
385           SubLoopBlocksFirst.push_back(New);
386         if (*BB == SubLoopBlocksLast[0])
387           SubLoopBlocksLast.push_back(New);
388       } else if (AftBlocks.count(*BB)) {
389         if (*BB == AftBlocksFirst[0])
390           AftBlocksFirst.push_back(New);
391         if (*BB == AftBlocksLast[0])
392           AftBlocksLast.push_back(New);
393       } else {
394         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
395       }
396 
397       // Update our running maps of newest clones
398       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
399       LastValueMap[*BB] = New;
400       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
401            VI != VE; ++VI) {
402         PrevItValueMap[VI->second] =
403             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
404         LastValueMap[VI->first] = VI->second;
405       }
406 
407       NewBlocks.push_back(New);
408 
409       // Update DomTree:
410       if (*BB == ForeBlocksFirst[0])
411         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
412       else if (*BB == SubLoopBlocksFirst[0])
413         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
414       else if (*BB == AftBlocksFirst[0])
415         DT->addNewBlock(New, AftBlocksLast[It - 1]);
416       else {
417         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
418         // structure.
419         auto BBDomNode = DT->getNode(*BB);
420         auto BBIDom = BBDomNode->getIDom();
421         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
422         assert(OriginalBBIDom);
423         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
424         DT->addNewBlock(
425             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
426       }
427     }
428 
429     // Remap all instructions in the most recent iteration
430     remapInstructionsInBlocks(NewBlocks, LastValueMap);
431     for (BasicBlock *NewBlock : NewBlocks) {
432       for (Instruction &I : *NewBlock) {
433         if (auto *II = dyn_cast<AssumeInst>(&I))
434           AC->registerAssumption(II);
435       }
436     }
437 
438     // Alter the ForeBlocks phi's, pointing them at the latest version of the
439     // value from the previous iteration's phis
440     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
441       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
442       assert(OldValue && "should have incoming edge from Aft[It]");
443       Value *NewValue = OldValue;
444       if (Value *PrevValue = PrevItValueMap[OldValue])
445         NewValue = PrevValue;
446 
447       assert(Phi.getNumOperands() == 2);
448       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
449       Phi.setIncomingValue(0, NewValue);
450       Phi.removeIncomingValue(1);
451     }
452   }
453 
454   // Now that all the basic blocks for the unrolled iterations are in place,
455   // finish up connecting the blocks and phi nodes. At this point LastValueMap
456   // is the last unrolled iterations values.
457 
458   // Update Phis in BB from OldBB to point to NewBB and use the latest value
459   // from LastValueMap
460   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
461                                      BasicBlock *NewBB,
462                                      ValueToValueMapTy &LastValueMap) {
463     for (PHINode &Phi : BB->phis()) {
464       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
465         if (Phi.getIncomingBlock(b) == OldBB) {
466           Value *OldValue = Phi.getIncomingValue(b);
467           if (Value *LastValue = LastValueMap[OldValue])
468             Phi.setIncomingValue(b, LastValue);
469           Phi.setIncomingBlock(b, NewBB);
470           break;
471         }
472       }
473     }
474   };
475   // Move all the phis from Src into Dest
476   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
477     Instruction *insertPoint = Dest->getFirstNonPHI();
478     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
479       Phi->moveBefore(insertPoint);
480   };
481 
482   // Update the PHI values outside the loop to point to the last block
483   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
484                            LastValueMap);
485 
486   // Update ForeBlocks successors and phi nodes
487   BranchInst *ForeTerm =
488       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
489   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
490   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
491 
492   if (CompletelyUnroll) {
493     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
494       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
495       Phi->eraseFromParent();
496     }
497   } else {
498     // Update the PHI values to point to the last aft block
499     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
500                              AftBlocksLast.back(), LastValueMap);
501   }
502 
503   for (unsigned It = 1; It != Count; It++) {
504     // Remap ForeBlock successors from previous iteration to this
505     BranchInst *ForeTerm =
506         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
507     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
508     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
509   }
510 
511   // Subloop successors and phis
512   BranchInst *SubTerm =
513       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
514   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
515   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
516   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
517                                             ForeBlocksLast.back());
518   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
519                                             SubLoopBlocksLast.back());
520 
521   for (unsigned It = 1; It != Count; It++) {
522     // Replace the conditional branch of the previous iteration subloop with an
523     // unconditional one to this one
524     BranchInst *SubTerm =
525         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
526     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
527     SubTerm->eraseFromParent();
528 
529     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
530                                                ForeBlocksLast.back());
531     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
532                                                SubLoopBlocksLast.back());
533     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
534   }
535 
536   // Aft blocks successors and phis
537   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
538   if (CompletelyUnroll) {
539     BranchInst::Create(LoopExit, AftTerm);
540     AftTerm->eraseFromParent();
541   } else {
542     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
543     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
544            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
545   }
546   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
547                                         SubLoopBlocksLast.back());
548 
549   for (unsigned It = 1; It != Count; It++) {
550     // Replace the conditional branch of the previous iteration subloop with an
551     // unconditional one to this one
552     BranchInst *AftTerm =
553         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
554     BranchInst::Create(AftBlocksFirst[It], AftTerm);
555     AftTerm->eraseFromParent();
556 
557     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
558                                            SubLoopBlocksLast.back());
559     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
560   }
561 
562   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
563   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
564   // new ones required.
565   if (Count != 1) {
566     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
567     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
568                            SubLoopBlocksFirst[0]);
569     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
570                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
571 
572     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
573                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
574     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
575                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
576     DTU.applyUpdatesPermissive(DTUpdates);
577   }
578 
579   // Merge adjacent basic blocks, if possible.
580   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
581   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
582   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
583   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
584 
585   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
586 
587   // Apply updates to the DomTree.
588   DT = &DTU.getDomTree();
589 
590   // At this point, the code is well formed.  We now do a quick sweep over the
591   // inserted code, doing constant propagation and dead code elimination as we
592   // go.
593   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
594   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
595                           TTI);
596 
597   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
598   ++NumUnrolledAndJammed;
599 
600   // Update LoopInfo if the loop is completely removed.
601   if (CompletelyUnroll)
602     LI->erase(L);
603 
604 #ifndef NDEBUG
605   // We shouldn't have done anything to break loop simplify form or LCSSA.
606   Loop *OutestLoop = SubLoop->getParentLoop()
607                          ? SubLoop->getParentLoop()->getParentLoop()
608                                ? SubLoop->getParentLoop()->getParentLoop()
609                                : SubLoop->getParentLoop()
610                          : SubLoop;
611   assert(DT->verify());
612   LI->verify(*DT);
613   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
614   if (!CompletelyUnroll)
615     assert(L->isLoopSimplifyForm());
616   assert(SubLoop->isLoopSimplifyForm());
617   SE->verify();
618 #endif
619 
620   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
621                           : LoopUnrollResult::PartiallyUnrolled;
622 }
623 
624 static bool getLoadsAndStores(BasicBlockSet &Blocks,
625                               SmallVector<Instruction *, 4> &MemInstr) {
626   // Scan the BBs and collect legal loads and stores.
627   // Returns false if non-simple loads/stores are found.
628   for (BasicBlock *BB : Blocks) {
629     for (Instruction &I : *BB) {
630       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
631         if (!Ld->isSimple())
632           return false;
633         MemInstr.push_back(&I);
634       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
635         if (!St->isSimple())
636           return false;
637         MemInstr.push_back(&I);
638       } else if (I.mayReadOrWriteMemory()) {
639         return false;
640       }
641     }
642   }
643   return true;
644 }
645 
646 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
647                                        unsigned UnrollLevel, unsigned JamLevel,
648                                        bool Sequentialized, Dependence *D) {
649   // UnrollLevel might carry the dependency Src --> Dst
650   // Does a different loop after unrolling?
651   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
652        ++CurLoopDepth) {
653     auto JammedDir = D->getDirection(CurLoopDepth);
654     if (JammedDir == Dependence::DVEntry::LT)
655       return true;
656 
657     if (JammedDir & Dependence::DVEntry::GT)
658       return false;
659   }
660 
661   return true;
662 }
663 
664 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
665                                         unsigned UnrollLevel, unsigned JamLevel,
666                                         bool Sequentialized, Dependence *D) {
667   // UnrollLevel might carry the dependency Dst --> Src
668   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
669        ++CurLoopDepth) {
670     auto JammedDir = D->getDirection(CurLoopDepth);
671     if (JammedDir == Dependence::DVEntry::GT)
672       return true;
673 
674     if (JammedDir & Dependence::DVEntry::LT)
675       return false;
676   }
677 
678   // Backward dependencies are only preserved if not interleaved.
679   return Sequentialized;
680 }
681 
682 // Check whether it is semantically safe Src and Dst considering any potential
683 // dependency between them.
684 //
685 // @param UnrollLevel The level of the loop being unrolled
686 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
687 // different levels, the outermost common loop counts as jammed level
688 //
689 // @return true if is safe and false if there is a dependency violation.
690 static bool checkDependency(Instruction *Src, Instruction *Dst,
691                             unsigned UnrollLevel, unsigned JamLevel,
692                             bool Sequentialized, DependenceInfo &DI) {
693   assert(UnrollLevel <= JamLevel &&
694          "Expecting JamLevel to be at least UnrollLevel");
695 
696   if (Src == Dst)
697     return true;
698   // Ignore Input dependencies.
699   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
700     return true;
701 
702   // Check whether unroll-and-jam may violate a dependency.
703   // By construction, every dependency will be lexicographically non-negative
704   // (if it was, it would violate the current execution order), such as
705   //   (0,0,>,*,*)
706   // Unroll-and-jam changes the GT execution of two executions to the same
707   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
708   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
709   //   (0,0,>=,*,*)
710   // Now, the dependency is not necessarily non-negative anymore, i.e.
711   // unroll-and-jam may violate correctness.
712   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
713   if (!D)
714     return true;
715   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
716 
717   if (D->isConfused()) {
718     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
719                       << "  " << *Src << "\n"
720                       << "  " << *Dst << "\n");
721     return false;
722   }
723 
724   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
725   // non-equal direction, then the locations accessed in the inner levels cannot
726   // overlap in memory. We assumes the indexes never overlap into neighboring
727   // dimensions.
728   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
729     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
730       return true;
731 
732   auto UnrollDirection = D->getDirection(UnrollLevel);
733 
734   // If the distance carried by the unrolled loop is 0, then after unrolling
735   // that distance will become non-zero resulting in non-overlapping accesses in
736   // the inner loops.
737   if (UnrollDirection == Dependence::DVEntry::EQ)
738     return true;
739 
740   if (UnrollDirection & Dependence::DVEntry::LT &&
741       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
742                                   Sequentialized, D.get()))
743     return false;
744 
745   if (UnrollDirection & Dependence::DVEntry::GT &&
746       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
747                                    Sequentialized, D.get()))
748     return false;
749 
750   return true;
751 }
752 
753 static bool
754 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
755                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
756                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
757                   DependenceInfo &DI, LoopInfo &LI) {
758   SmallVector<BasicBlockSet, 8> AllBlocks;
759   for (Loop *L : Root.getLoopsInPreorder())
760     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
761       AllBlocks.push_back(ForeBlocksMap.lookup(L));
762   AllBlocks.push_back(SubLoopBlocks);
763   for (Loop *L : Root.getLoopsInPreorder())
764     if (AftBlocksMap.find(L) != AftBlocksMap.end())
765       AllBlocks.push_back(AftBlocksMap.lookup(L));
766 
767   unsigned LoopDepth = Root.getLoopDepth();
768   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
769   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
770   for (BasicBlockSet &Blocks : AllBlocks) {
771     CurrentLoadsAndStores.clear();
772     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
773       return false;
774 
775     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
776     unsigned CurLoopDepth = CurLoop->getLoopDepth();
777 
778     for (auto *Earlier : EarlierLoadsAndStores) {
779       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
780       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
781       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
782       for (auto *Later : CurrentLoadsAndStores) {
783         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
784                              DI))
785           return false;
786       }
787     }
788 
789     size_t NumInsts = CurrentLoadsAndStores.size();
790     for (size_t I = 0; I < NumInsts; ++I) {
791       for (size_t J = I; J < NumInsts; ++J) {
792         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
793                              LoopDepth, CurLoopDepth, true, DI))
794           return false;
795       }
796     }
797 
798     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
799                                  CurrentLoadsAndStores.end());
800   }
801   return true;
802 }
803 
804 static bool isEligibleLoopForm(const Loop &Root) {
805   // Root must have a child.
806   if (Root.getSubLoops().size() != 1)
807     return false;
808 
809   const Loop *L = &Root;
810   do {
811     // All loops in Root need to be in simplify and rotated form.
812     if (!L->isLoopSimplifyForm())
813       return false;
814 
815     if (!L->isRotatedForm())
816       return false;
817 
818     if (L->getHeader()->hasAddressTaken()) {
819       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
820       return false;
821     }
822 
823     unsigned SubLoopsSize = L->getSubLoops().size();
824     if (SubLoopsSize == 0)
825       return true;
826 
827     // Only one child is allowed.
828     if (SubLoopsSize != 1)
829       return false;
830 
831     // Only loops with a single exit block can be unrolled and jammed.
832     // The function getExitBlock() is used for this check, rather than
833     // getUniqueExitBlock() to ensure loops with mulitple exit edges are
834     // disallowed.
835     if (!L->getExitBlock()) {
836       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single exit "
837                            "blocks can be unrolled and jammed.\n");
838       return false;
839     }
840 
841     // Only loops with a single exiting block can be unrolled and jammed.
842     if (!L->getExitingBlock()) {
843       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single "
844                            "exiting blocks can be unrolled and jammed.\n");
845       return false;
846     }
847 
848     L = L->getSubLoops()[0];
849   } while (L);
850 
851   return true;
852 }
853 
854 static Loop *getInnerMostLoop(Loop *L) {
855   while (!L->getSubLoops().empty())
856     L = L->getSubLoops()[0];
857   return L;
858 }
859 
860 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
861                                 DependenceInfo &DI, LoopInfo &LI) {
862   if (!isEligibleLoopForm(*L)) {
863     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
864     return false;
865   }
866 
867   /* We currently handle outer loops like this:
868         |
869     ForeFirst    <------\   }
870      Blocks             |   } ForeBlocks of L
871     ForeLast            |   }
872         |               |
873        ...              |
874         |               |
875     ForeFirst    <----\ |   }
876      Blocks           | |   } ForeBlocks of a inner loop of L
877     ForeLast          | |   }
878         |             | |
879     JamLoopFirst  <\  | |   }
880      Blocks        |  | |   } JamLoopBlocks of the innermost loop
881     JamLoopLast   -/  | |   }
882         |             | |
883     AftFirst          | |   }
884      Blocks           | |   } AftBlocks of a inner loop of L
885     AftLast     ------/ |   }
886         |               |
887        ...              |
888         |               |
889     AftFirst            |   }
890      Blocks             |   } AftBlocks of L
891     AftLast     --------/   }
892         |
893 
894     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
895     and AftBlocks, providing that there is one edge from Fores to SubLoops,
896     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
897     In practice we currently limit Aft blocks to a single block, and limit
898     things further in the profitablility checks of the unroll and jam pass.
899 
900     Because of the way we rearrange basic blocks, we also require that
901     the Fore blocks of L on all unrolled iterations are safe to move before the
902     blocks of the direct child of L of all iterations. So we require that the
903     phi node looping operands of ForeHeader can be moved to at least the end of
904     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
905     match up Phi's correctly.
906 
907     i.e. The old order of blocks used to be
908            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
909          It needs to be safe to transform this to
910            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
911 
912     There are then a number of checks along the lines of no calls, no
913     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
914     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
915     UnrollAndJamLoop if the trip count cannot be easily calculated.
916   */
917 
918   // Split blocks into Fore/SubLoop/Aft based on dominators
919   Loop *JamLoop = getInnerMostLoop(L);
920   BasicBlockSet SubLoopBlocks;
921   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
922   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
923   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
924                                 AftBlocksMap, DT)) {
925     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
926     return false;
927   }
928 
929   // Aft blocks may need to move instructions to fore blocks, which becomes more
930   // difficult if there are multiple (potentially conditionally executed)
931   // blocks. For now we just exclude loops with multiple aft blocks.
932   if (AftBlocksMap[L].size() != 1) {
933     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
934                          "multiple blocks after the loop\n");
935     return false;
936   }
937 
938   // Check inner loop backedge count is consistent on all iterations of the
939   // outer loop
940   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
941         return !hasIterationCountInvariantInParent(SubLoop, SE);
942       })) {
943     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
944                          "not consistent on each iteration\n");
945     return false;
946   }
947 
948   // Check the loop safety info for exceptions.
949   SimpleLoopSafetyInfo LSI;
950   LSI.computeLoopSafetyInfo(L);
951   if (LSI.anyBlockMayThrow()) {
952     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
953     return false;
954   }
955 
956   // We've ruled out the easy stuff and now need to check that there are no
957   // interdependencies which may prevent us from moving the:
958   //  ForeBlocks before Subloop and AftBlocks.
959   //  Subloop before AftBlocks.
960   //  ForeBlock phi operands before the subloop
961 
962   // Make sure we can move all instructions we need to before the subloop
963   BasicBlock *Header = L->getHeader();
964   BasicBlock *Latch = L->getLoopLatch();
965   BasicBlockSet AftBlocks = AftBlocksMap[L];
966   Loop *SubLoop = L->getSubLoops()[0];
967   if (!processHeaderPhiOperands(
968           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
969             if (SubLoop->contains(I->getParent()))
970               return false;
971             if (AftBlocks.count(I->getParent())) {
972               // If we hit a phi node in afts we know we are done (probably
973               // LCSSA)
974               if (isa<PHINode>(I))
975                 return false;
976               // Can't move instructions with side effects or memory
977               // reads/writes
978               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
979                 return false;
980             }
981             // Keep going
982             return true;
983           })) {
984     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
985                          "instructions after subloop to before it\n");
986     return false;
987   }
988 
989   // Check for memory dependencies which prohibit the unrolling we are doing.
990   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
991   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
992   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
993                          LI)) {
994     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
995     return false;
996   }
997 
998   return true;
999 }
1000