xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/Analysis/AssumptionCache.h"
23 #include "llvm/Analysis/DependenceAnalysis.h"
24 #include "llvm/Analysis/DomTreeUpdater.h"
25 #include "llvm/Analysis/LoopInfo.h"
26 #include "llvm/Analysis/LoopIterator.h"
27 #include "llvm/Analysis/MustExecute.h"
28 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
29 #include "llvm/Analysis/ScalarEvolution.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/DebugInfoMetadata.h"
32 #include "llvm/IR/DebugLoc.h"
33 #include "llvm/IR/DiagnosticInfo.h"
34 #include "llvm/IR/Dominators.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/IntrinsicInst.h"
39 #include "llvm/IR/User.h"
40 #include "llvm/IR/Value.h"
41 #include "llvm/IR/ValueHandle.h"
42 #include "llvm/IR/ValueMap.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/ErrorHandling.h"
46 #include "llvm/Support/GenericDomTree.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
49 #include "llvm/Transforms/Utils/Cloning.h"
50 #include "llvm/Transforms/Utils/LoopUtils.h"
51 #include "llvm/Transforms/Utils/UnrollLoop.h"
52 #include "llvm/Transforms/Utils/ValueMapper.h"
53 #include <assert.h>
54 #include <memory>
55 #include <type_traits>
56 #include <vector>
57 
58 using namespace llvm;
59 
60 #define DEBUG_TYPE "loop-unroll-and-jam"
61 
62 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
63 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
64 
65 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
66 
67 // Partition blocks in an outer/inner loop pair into blocks before and after
68 // the loop
69 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
70                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
71   Loop *SubLoop = L.getSubLoops()[0];
72   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
73 
74   for (BasicBlock *BB : L.blocks()) {
75     if (!SubLoop->contains(BB)) {
76       if (DT.dominates(SubLoopLatch, BB))
77         AftBlocks.insert(BB);
78       else
79         ForeBlocks.insert(BB);
80     }
81   }
82 
83   // Check that all blocks in ForeBlocks together dominate the subloop
84   // TODO: This might ideally be done better with a dominator/postdominators.
85   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
86   for (BasicBlock *BB : ForeBlocks) {
87     if (BB == SubLoopPreHeader)
88       continue;
89     Instruction *TI = BB->getTerminator();
90     for (BasicBlock *Succ : successors(TI))
91       if (!ForeBlocks.count(Succ))
92         return false;
93   }
94 
95   return true;
96 }
97 
98 /// Partition blocks in a loop nest into blocks before and after each inner
99 /// loop.
100 static bool partitionOuterLoopBlocks(
101     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
102     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
103     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
104   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
105 
106   for (Loop *L : Root.getLoopsInPreorder()) {
107     if (L == &JamLoop)
108       break;
109 
110     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
111       return false;
112   }
113 
114   return true;
115 }
116 
117 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
118 // than 2 levels loop.
119 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
120                                      BasicBlockSet &ForeBlocks,
121                                      BasicBlockSet &SubLoopBlocks,
122                                      BasicBlockSet &AftBlocks,
123                                      DominatorTree *DT) {
124   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
125   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
126 }
127 
128 // Looks at the phi nodes in Header for values coming from Latch. For these
129 // instructions and all their operands calls Visit on them, keeping going for
130 // all the operands in AftBlocks. Returns false if Visit returns false,
131 // otherwise returns true. This is used to process the instructions in the
132 // Aft blocks that need to be moved before the subloop. It is used in two
133 // places. One to check that the required set of instructions can be moved
134 // before the loop. Then to collect the instructions to actually move in
135 // moveHeaderPhiOperandsToForeBlocks.
136 template <typename T>
137 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
138                                      BasicBlockSet &AftBlocks, T Visit) {
139   SmallPtrSet<Instruction *, 8> VisitedInstr;
140 
141   std::function<bool(Instruction * I)> ProcessInstr = [&](Instruction *I) {
142     if (VisitedInstr.count(I))
143       return true;
144 
145     VisitedInstr.insert(I);
146 
147     if (AftBlocks.count(I->getParent()))
148       for (auto &U : I->operands())
149         if (Instruction *II = dyn_cast<Instruction>(U))
150           if (!ProcessInstr(II))
151             return false;
152 
153     return Visit(I);
154   };
155 
156   for (auto &Phi : Header->phis()) {
157     Value *V = Phi.getIncomingValueForBlock(Latch);
158     if (Instruction *I = dyn_cast<Instruction>(V))
159       if (!ProcessInstr(I))
160         return false;
161   }
162 
163   return true;
164 }
165 
166 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
167 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
168                                               BasicBlock *Latch,
169                                               Instruction *InsertLoc,
170                                               BasicBlockSet &AftBlocks) {
171   // We need to ensure we move the instructions in the correct order,
172   // starting with the earliest required instruction and moving forward.
173   processHeaderPhiOperands(Header, Latch, AftBlocks,
174                            [&AftBlocks, &InsertLoc](Instruction *I) {
175                              if (AftBlocks.count(I->getParent()))
176                                I->moveBefore(InsertLoc);
177                              return true;
178                            });
179 }
180 
181 /*
182   This method performs Unroll and Jam. For a simple loop like:
183   for (i = ..)
184     Fore(i)
185     for (j = ..)
186       SubLoop(i, j)
187     Aft(i)
188 
189   Instead of doing normal inner or outer unrolling, we do:
190   for (i = .., i+=2)
191     Fore(i)
192     Fore(i+1)
193     for (j = ..)
194       SubLoop(i, j)
195       SubLoop(i+1, j)
196     Aft(i)
197     Aft(i+1)
198 
199   So the outer loop is essetially unrolled and then the inner loops are fused
200   ("jammed") together into a single loop. This can increase speed when there
201   are loads in SubLoop that are invariant to i, as they become shared between
202   the now jammed inner loops.
203 
204   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
205   Fore blocks are those before the inner loop, Aft are those after. Normal
206   Unroll code is used to copy each of these sets of blocks and the results are
207   combined together into the final form above.
208 
209   isSafeToUnrollAndJam should be used prior to calling this to make sure the
210   unrolling will be valid. Checking profitablility is also advisable.
211 
212   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
213   necessary to create one and not fully unrolled).
214 */
215 LoopUnrollResult
216 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
217                        unsigned TripMultiple, bool UnrollRemainder,
218                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
219                        AssumptionCache *AC, const TargetTransformInfo *TTI,
220                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
221 
222   // When we enter here we should have already checked that it is safe
223   BasicBlock *Header = L->getHeader();
224   assert(Header && "No header.");
225   assert(L->getSubLoops().size() == 1);
226   Loop *SubLoop = *L->begin();
227 
228   // Don't enter the unroll code if there is nothing to do.
229   if (TripCount == 0 && Count < 2) {
230     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
231     return LoopUnrollResult::Unmodified;
232   }
233 
234   assert(Count > 0);
235   assert(TripMultiple > 0);
236   assert(TripCount == 0 || TripCount % TripMultiple == 0);
237 
238   // Are we eliminating the loop control altogether?
239   bool CompletelyUnroll = (Count == TripCount);
240 
241   // We use the runtime remainder in cases where we don't know trip multiple
242   if (TripMultiple % Count != 0) {
243     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
244                                     /*UseEpilogRemainder*/ true,
245                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
246                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
247       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
248                            "generated when assuming runtime trip count\n");
249       return LoopUnrollResult::Unmodified;
250     }
251   }
252 
253   // Notify ScalarEvolution that the loop will be substantially changed,
254   // if not outright eliminated.
255   if (SE) {
256     SE->forgetLoop(L);
257     SE->forgetBlockAndLoopDispositions();
258   }
259 
260   using namespace ore;
261   // Report the unrolling decision.
262   if (CompletelyUnroll) {
263     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
264                       << Header->getName() << " with trip count " << TripCount
265                       << "!\n");
266     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
267                                  L->getHeader())
268               << "completely unroll and jammed loop with "
269               << NV("UnrollCount", TripCount) << " iterations");
270   } else {
271     auto DiagBuilder = [&]() {
272       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
273                               L->getHeader());
274       return Diag << "unroll and jammed loop by a factor of "
275                   << NV("UnrollCount", Count);
276     };
277 
278     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
279                       << " by " << Count);
280     if (TripMultiple != 1) {
281       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
282       ORE->emit([&]() {
283         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
284                              << " trips per branch";
285       });
286     } else {
287       LLVM_DEBUG(dbgs() << " with run-time trip count");
288       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
289     }
290     LLVM_DEBUG(dbgs() << "!\n");
291   }
292 
293   BasicBlock *Preheader = L->getLoopPreheader();
294   BasicBlock *LatchBlock = L->getLoopLatch();
295   assert(Preheader && "No preheader");
296   assert(LatchBlock && "No latch block");
297   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
298   assert(BI && !BI->isUnconditional());
299   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
300   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
301   bool SubLoopContinueOnTrue = SubLoop->contains(
302       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
303 
304   // Partition blocks in an outer/inner loop pair into blocks before and after
305   // the loop
306   BasicBlockSet SubLoopBlocks;
307   BasicBlockSet ForeBlocks;
308   BasicBlockSet AftBlocks;
309   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
310                            DT);
311 
312   // We keep track of the entering/first and exiting/last block of each of
313   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
314   // blocks easier.
315   std::vector<BasicBlock *> ForeBlocksFirst;
316   std::vector<BasicBlock *> ForeBlocksLast;
317   std::vector<BasicBlock *> SubLoopBlocksFirst;
318   std::vector<BasicBlock *> SubLoopBlocksLast;
319   std::vector<BasicBlock *> AftBlocksFirst;
320   std::vector<BasicBlock *> AftBlocksLast;
321   ForeBlocksFirst.push_back(Header);
322   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
323   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
324   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
325   AftBlocksFirst.push_back(SubLoop->getExitBlock());
326   AftBlocksLast.push_back(L->getExitingBlock());
327   // Maps Blocks[0] -> Blocks[It]
328   ValueToValueMapTy LastValueMap;
329 
330   // Move any instructions from fore phi operands from AftBlocks into Fore.
331   moveHeaderPhiOperandsToForeBlocks(
332       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
333 
334   // The current on-the-fly SSA update requires blocks to be processed in
335   // reverse postorder so that LastValueMap contains the correct value at each
336   // exit.
337   LoopBlocksDFS DFS(L);
338   DFS.perform(LI);
339   // Stash the DFS iterators before adding blocks to the loop.
340   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
341   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
342 
343   // When a FSDiscriminator is enabled, we don't need to add the multiply
344   // factors to the discriminators.
345   if (Header->getParent()->shouldEmitDebugInfoForProfiling() &&
346       !EnableFSDiscriminator)
347     for (BasicBlock *BB : L->getBlocks())
348       for (Instruction &I : *BB)
349         if (!I.isDebugOrPseudoInst())
350           if (const DILocation *DIL = I.getDebugLoc()) {
351             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
352             if (NewDIL)
353               I.setDebugLoc(*NewDIL);
354             else
355               LLVM_DEBUG(dbgs()
356                          << "Failed to create new discriminator: "
357                          << DIL->getFilename() << " Line: " << DIL->getLine());
358           }
359 
360   // Copy all blocks
361   for (unsigned It = 1; It != Count; ++It) {
362     SmallVector<BasicBlock *, 8> NewBlocks;
363     // Maps Blocks[It] -> Blocks[It-1]
364     DenseMap<Value *, Value *> PrevItValueMap;
365     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
366     NewLoops[L] = L;
367     NewLoops[SubLoop] = SubLoop;
368 
369     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
370       ValueToValueMapTy VMap;
371       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
372       Header->getParent()->insert(Header->getParent()->end(), New);
373 
374       // Tell LI about New.
375       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
376 
377       if (ForeBlocks.count(*BB)) {
378         if (*BB == ForeBlocksFirst[0])
379           ForeBlocksFirst.push_back(New);
380         if (*BB == ForeBlocksLast[0])
381           ForeBlocksLast.push_back(New);
382       } else if (SubLoopBlocks.count(*BB)) {
383         if (*BB == SubLoopBlocksFirst[0])
384           SubLoopBlocksFirst.push_back(New);
385         if (*BB == SubLoopBlocksLast[0])
386           SubLoopBlocksLast.push_back(New);
387       } else if (AftBlocks.count(*BB)) {
388         if (*BB == AftBlocksFirst[0])
389           AftBlocksFirst.push_back(New);
390         if (*BB == AftBlocksLast[0])
391           AftBlocksLast.push_back(New);
392       } else {
393         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
394       }
395 
396       // Update our running maps of newest clones
397       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
398       LastValueMap[*BB] = New;
399       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
400            VI != VE; ++VI) {
401         PrevItValueMap[VI->second] =
402             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
403         LastValueMap[VI->first] = VI->second;
404       }
405 
406       NewBlocks.push_back(New);
407 
408       // Update DomTree:
409       if (*BB == ForeBlocksFirst[0])
410         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
411       else if (*BB == SubLoopBlocksFirst[0])
412         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
413       else if (*BB == AftBlocksFirst[0])
414         DT->addNewBlock(New, AftBlocksLast[It - 1]);
415       else {
416         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
417         // structure.
418         auto BBDomNode = DT->getNode(*BB);
419         auto BBIDom = BBDomNode->getIDom();
420         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
421         assert(OriginalBBIDom);
422         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
423         DT->addNewBlock(
424             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
425       }
426     }
427 
428     // Remap all instructions in the most recent iteration
429     remapInstructionsInBlocks(NewBlocks, LastValueMap);
430     for (BasicBlock *NewBlock : NewBlocks) {
431       for (Instruction &I : *NewBlock) {
432         if (auto *II = dyn_cast<AssumeInst>(&I))
433           AC->registerAssumption(II);
434       }
435     }
436 
437     // Alter the ForeBlocks phi's, pointing them at the latest version of the
438     // value from the previous iteration's phis
439     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
440       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
441       assert(OldValue && "should have incoming edge from Aft[It]");
442       Value *NewValue = OldValue;
443       if (Value *PrevValue = PrevItValueMap[OldValue])
444         NewValue = PrevValue;
445 
446       assert(Phi.getNumOperands() == 2);
447       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
448       Phi.setIncomingValue(0, NewValue);
449       Phi.removeIncomingValue(1);
450     }
451   }
452 
453   // Now that all the basic blocks for the unrolled iterations are in place,
454   // finish up connecting the blocks and phi nodes. At this point LastValueMap
455   // is the last unrolled iterations values.
456 
457   // Update Phis in BB from OldBB to point to NewBB and use the latest value
458   // from LastValueMap
459   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
460                                      BasicBlock *NewBB,
461                                      ValueToValueMapTy &LastValueMap) {
462     for (PHINode &Phi : BB->phis()) {
463       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
464         if (Phi.getIncomingBlock(b) == OldBB) {
465           Value *OldValue = Phi.getIncomingValue(b);
466           if (Value *LastValue = LastValueMap[OldValue])
467             Phi.setIncomingValue(b, LastValue);
468           Phi.setIncomingBlock(b, NewBB);
469           break;
470         }
471       }
472     }
473   };
474   // Move all the phis from Src into Dest
475   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
476     BasicBlock::iterator insertPoint = Dest->getFirstNonPHIIt();
477     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
478       Phi->moveBefore(*Dest, insertPoint);
479   };
480 
481   // Update the PHI values outside the loop to point to the last block
482   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
483                            LastValueMap);
484 
485   // Update ForeBlocks successors and phi nodes
486   BranchInst *ForeTerm =
487       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
488   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
489   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
490 
491   if (CompletelyUnroll) {
492     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
493       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
494       Phi->eraseFromParent();
495     }
496   } else {
497     // Update the PHI values to point to the last aft block
498     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
499                              AftBlocksLast.back(), LastValueMap);
500   }
501 
502   for (unsigned It = 1; It != Count; It++) {
503     // Remap ForeBlock successors from previous iteration to this
504     BranchInst *ForeTerm =
505         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
506     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
507     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
508   }
509 
510   // Subloop successors and phis
511   BranchInst *SubTerm =
512       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
513   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
514   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
515   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
516                                             ForeBlocksLast.back());
517   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
518                                             SubLoopBlocksLast.back());
519 
520   for (unsigned It = 1; It != Count; It++) {
521     // Replace the conditional branch of the previous iteration subloop with an
522     // unconditional one to this one
523     BranchInst *SubTerm =
524         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
525     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm->getIterator());
526     SubTerm->eraseFromParent();
527 
528     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
529                                                ForeBlocksLast.back());
530     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
531                                                SubLoopBlocksLast.back());
532     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
533   }
534 
535   // Aft blocks successors and phis
536   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
537   if (CompletelyUnroll) {
538     BranchInst::Create(LoopExit, AftTerm->getIterator());
539     AftTerm->eraseFromParent();
540   } else {
541     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
542     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
543            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
544   }
545   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
546                                         SubLoopBlocksLast.back());
547 
548   for (unsigned It = 1; It != Count; It++) {
549     // Replace the conditional branch of the previous iteration subloop with an
550     // unconditional one to this one
551     BranchInst *AftTerm =
552         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
553     BranchInst::Create(AftBlocksFirst[It], AftTerm->getIterator());
554     AftTerm->eraseFromParent();
555 
556     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
557                                            SubLoopBlocksLast.back());
558     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
559   }
560 
561   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
562   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
563   // new ones required.
564   if (Count != 1) {
565     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
566     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
567                            SubLoopBlocksFirst[0]);
568     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
569                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
570 
571     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
572                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
573     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
574                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
575     DTU.applyUpdatesPermissive(DTUpdates);
576   }
577 
578   // Merge adjacent basic blocks, if possible.
579   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
580   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
581   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
582   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
583 
584   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
585 
586   // Apply updates to the DomTree.
587   DT = &DTU.getDomTree();
588 
589   // At this point, the code is well formed.  We now do a quick sweep over the
590   // inserted code, doing constant propagation and dead code elimination as we
591   // go.
592   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
593   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
594                           TTI);
595 
596   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
597   ++NumUnrolledAndJammed;
598 
599   // Update LoopInfo if the loop is completely removed.
600   if (CompletelyUnroll)
601     LI->erase(L);
602 
603 #ifndef NDEBUG
604   // We shouldn't have done anything to break loop simplify form or LCSSA.
605   Loop *OutestLoop = SubLoop->getParentLoop()
606                          ? SubLoop->getParentLoop()->getParentLoop()
607                                ? SubLoop->getParentLoop()->getParentLoop()
608                                : SubLoop->getParentLoop()
609                          : SubLoop;
610   assert(DT->verify());
611   LI->verify(*DT);
612   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
613   if (!CompletelyUnroll)
614     assert(L->isLoopSimplifyForm());
615   assert(SubLoop->isLoopSimplifyForm());
616   SE->verify();
617 #endif
618 
619   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
620                           : LoopUnrollResult::PartiallyUnrolled;
621 }
622 
623 static bool getLoadsAndStores(BasicBlockSet &Blocks,
624                               SmallVector<Instruction *, 4> &MemInstr) {
625   // Scan the BBs and collect legal loads and stores.
626   // Returns false if non-simple loads/stores are found.
627   for (BasicBlock *BB : Blocks) {
628     for (Instruction &I : *BB) {
629       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
630         if (!Ld->isSimple())
631           return false;
632         MemInstr.push_back(&I);
633       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
634         if (!St->isSimple())
635           return false;
636         MemInstr.push_back(&I);
637       } else if (I.mayReadOrWriteMemory()) {
638         return false;
639       }
640     }
641   }
642   return true;
643 }
644 
645 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
646                                        unsigned UnrollLevel, unsigned JamLevel,
647                                        bool Sequentialized, Dependence *D) {
648   // UnrollLevel might carry the dependency Src --> Dst
649   // Does a different loop after unrolling?
650   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
651        ++CurLoopDepth) {
652     auto JammedDir = D->getDirection(CurLoopDepth);
653     if (JammedDir == Dependence::DVEntry::LT)
654       return true;
655 
656     if (JammedDir & Dependence::DVEntry::GT)
657       return false;
658   }
659 
660   return true;
661 }
662 
663 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
664                                         unsigned UnrollLevel, unsigned JamLevel,
665                                         bool Sequentialized, Dependence *D) {
666   // UnrollLevel might carry the dependency Dst --> Src
667   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
668        ++CurLoopDepth) {
669     auto JammedDir = D->getDirection(CurLoopDepth);
670     if (JammedDir == Dependence::DVEntry::GT)
671       return true;
672 
673     if (JammedDir & Dependence::DVEntry::LT)
674       return false;
675   }
676 
677   // Backward dependencies are only preserved if not interleaved.
678   return Sequentialized;
679 }
680 
681 // Check whether it is semantically safe Src and Dst considering any potential
682 // dependency between them.
683 //
684 // @param UnrollLevel The level of the loop being unrolled
685 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
686 // different levels, the outermost common loop counts as jammed level
687 //
688 // @return true if is safe and false if there is a dependency violation.
689 static bool checkDependency(Instruction *Src, Instruction *Dst,
690                             unsigned UnrollLevel, unsigned JamLevel,
691                             bool Sequentialized, DependenceInfo &DI) {
692   assert(UnrollLevel <= JamLevel &&
693          "Expecting JamLevel to be at least UnrollLevel");
694 
695   if (Src == Dst)
696     return true;
697   // Ignore Input dependencies.
698   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
699     return true;
700 
701   // Check whether unroll-and-jam may violate a dependency.
702   // By construction, every dependency will be lexicographically non-negative
703   // (if it was, it would violate the current execution order), such as
704   //   (0,0,>,*,*)
705   // Unroll-and-jam changes the GT execution of two executions to the same
706   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
707   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
708   //   (0,0,>=,*,*)
709   // Now, the dependency is not necessarily non-negative anymore, i.e.
710   // unroll-and-jam may violate correctness.
711   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
712   if (!D)
713     return true;
714   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
715 
716   if (D->isConfused()) {
717     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
718                       << "  " << *Src << "\n"
719                       << "  " << *Dst << "\n");
720     return false;
721   }
722 
723   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
724   // non-equal direction, then the locations accessed in the inner levels cannot
725   // overlap in memory. We assumes the indexes never overlap into neighboring
726   // dimensions.
727   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
728     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
729       return true;
730 
731   auto UnrollDirection = D->getDirection(UnrollLevel);
732 
733   // If the distance carried by the unrolled loop is 0, then after unrolling
734   // that distance will become non-zero resulting in non-overlapping accesses in
735   // the inner loops.
736   if (UnrollDirection == Dependence::DVEntry::EQ)
737     return true;
738 
739   if (UnrollDirection & Dependence::DVEntry::LT &&
740       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
741                                   Sequentialized, D.get()))
742     return false;
743 
744   if (UnrollDirection & Dependence::DVEntry::GT &&
745       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
746                                    Sequentialized, D.get()))
747     return false;
748 
749   return true;
750 }
751 
752 static bool
753 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
754                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
755                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
756                   DependenceInfo &DI, LoopInfo &LI) {
757   SmallVector<BasicBlockSet, 8> AllBlocks;
758   for (Loop *L : Root.getLoopsInPreorder())
759     if (ForeBlocksMap.contains(L))
760       AllBlocks.push_back(ForeBlocksMap.lookup(L));
761   AllBlocks.push_back(SubLoopBlocks);
762   for (Loop *L : Root.getLoopsInPreorder())
763     if (AftBlocksMap.contains(L))
764       AllBlocks.push_back(AftBlocksMap.lookup(L));
765 
766   unsigned LoopDepth = Root.getLoopDepth();
767   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
768   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
769   for (BasicBlockSet &Blocks : AllBlocks) {
770     CurrentLoadsAndStores.clear();
771     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
772       return false;
773 
774     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
775     unsigned CurLoopDepth = CurLoop->getLoopDepth();
776 
777     for (auto *Earlier : EarlierLoadsAndStores) {
778       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
779       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
780       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
781       for (auto *Later : CurrentLoadsAndStores) {
782         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
783                              DI))
784           return false;
785       }
786     }
787 
788     size_t NumInsts = CurrentLoadsAndStores.size();
789     for (size_t I = 0; I < NumInsts; ++I) {
790       for (size_t J = I; J < NumInsts; ++J) {
791         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
792                              LoopDepth, CurLoopDepth, true, DI))
793           return false;
794       }
795     }
796 
797     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
798                                  CurrentLoadsAndStores.end());
799   }
800   return true;
801 }
802 
803 static bool isEligibleLoopForm(const Loop &Root) {
804   // Root must have a child.
805   if (Root.getSubLoops().size() != 1)
806     return false;
807 
808   const Loop *L = &Root;
809   do {
810     // All loops in Root need to be in simplify and rotated form.
811     if (!L->isLoopSimplifyForm())
812       return false;
813 
814     if (!L->isRotatedForm())
815       return false;
816 
817     if (L->getHeader()->hasAddressTaken()) {
818       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
819       return false;
820     }
821 
822     unsigned SubLoopsSize = L->getSubLoops().size();
823     if (SubLoopsSize == 0)
824       return true;
825 
826     // Only one child is allowed.
827     if (SubLoopsSize != 1)
828       return false;
829 
830     // Only loops with a single exit block can be unrolled and jammed.
831     // The function getExitBlock() is used for this check, rather than
832     // getUniqueExitBlock() to ensure loops with mulitple exit edges are
833     // disallowed.
834     if (!L->getExitBlock()) {
835       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single exit "
836                            "blocks can be unrolled and jammed.\n");
837       return false;
838     }
839 
840     // Only loops with a single exiting block can be unrolled and jammed.
841     if (!L->getExitingBlock()) {
842       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single "
843                            "exiting blocks can be unrolled and jammed.\n");
844       return false;
845     }
846 
847     L = L->getSubLoops()[0];
848   } while (L);
849 
850   return true;
851 }
852 
853 static Loop *getInnerMostLoop(Loop *L) {
854   while (!L->getSubLoops().empty())
855     L = L->getSubLoops()[0];
856   return L;
857 }
858 
859 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
860                                 DependenceInfo &DI, LoopInfo &LI) {
861   if (!isEligibleLoopForm(*L)) {
862     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
863     return false;
864   }
865 
866   /* We currently handle outer loops like this:
867         |
868     ForeFirst    <------\   }
869      Blocks             |   } ForeBlocks of L
870     ForeLast            |   }
871         |               |
872        ...              |
873         |               |
874     ForeFirst    <----\ |   }
875      Blocks           | |   } ForeBlocks of a inner loop of L
876     ForeLast          | |   }
877         |             | |
878     JamLoopFirst  <\  | |   }
879      Blocks        |  | |   } JamLoopBlocks of the innermost loop
880     JamLoopLast   -/  | |   }
881         |             | |
882     AftFirst          | |   }
883      Blocks           | |   } AftBlocks of a inner loop of L
884     AftLast     ------/ |   }
885         |               |
886        ...              |
887         |               |
888     AftFirst            |   }
889      Blocks             |   } AftBlocks of L
890     AftLast     --------/   }
891         |
892 
893     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
894     and AftBlocks, providing that there is one edge from Fores to SubLoops,
895     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
896     In practice we currently limit Aft blocks to a single block, and limit
897     things further in the profitablility checks of the unroll and jam pass.
898 
899     Because of the way we rearrange basic blocks, we also require that
900     the Fore blocks of L on all unrolled iterations are safe to move before the
901     blocks of the direct child of L of all iterations. So we require that the
902     phi node looping operands of ForeHeader can be moved to at least the end of
903     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
904     match up Phi's correctly.
905 
906     i.e. The old order of blocks used to be
907            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
908          It needs to be safe to transform this to
909            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
910 
911     There are then a number of checks along the lines of no calls, no
912     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
913     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
914     UnrollAndJamLoop if the trip count cannot be easily calculated.
915   */
916 
917   // Split blocks into Fore/SubLoop/Aft based on dominators
918   Loop *JamLoop = getInnerMostLoop(L);
919   BasicBlockSet SubLoopBlocks;
920   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
921   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
922   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
923                                 AftBlocksMap, DT)) {
924     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
925     return false;
926   }
927 
928   // Aft blocks may need to move instructions to fore blocks, which becomes more
929   // difficult if there are multiple (potentially conditionally executed)
930   // blocks. For now we just exclude loops with multiple aft blocks.
931   if (AftBlocksMap[L].size() != 1) {
932     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
933                          "multiple blocks after the loop\n");
934     return false;
935   }
936 
937   // Check inner loop backedge count is consistent on all iterations of the
938   // outer loop
939   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
940         return !hasIterationCountInvariantInParent(SubLoop, SE);
941       })) {
942     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
943                          "not consistent on each iteration\n");
944     return false;
945   }
946 
947   // Check the loop safety info for exceptions.
948   SimpleLoopSafetyInfo LSI;
949   LSI.computeLoopSafetyInfo(L);
950   if (LSI.anyBlockMayThrow()) {
951     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
952     return false;
953   }
954 
955   // We've ruled out the easy stuff and now need to check that there are no
956   // interdependencies which may prevent us from moving the:
957   //  ForeBlocks before Subloop and AftBlocks.
958   //  Subloop before AftBlocks.
959   //  ForeBlock phi operands before the subloop
960 
961   // Make sure we can move all instructions we need to before the subloop
962   BasicBlock *Header = L->getHeader();
963   BasicBlock *Latch = L->getLoopLatch();
964   BasicBlockSet AftBlocks = AftBlocksMap[L];
965   Loop *SubLoop = L->getSubLoops()[0];
966   if (!processHeaderPhiOperands(
967           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
968             if (SubLoop->contains(I->getParent()))
969               return false;
970             if (AftBlocks.count(I->getParent())) {
971               // If we hit a phi node in afts we know we are done (probably
972               // LCSSA)
973               if (isa<PHINode>(I))
974                 return false;
975               // Can't move instructions with side effects or memory
976               // reads/writes
977               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
978                 return false;
979             }
980             // Keep going
981             return true;
982           })) {
983     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
984                          "instructions after subloop to before it\n");
985     return false;
986   }
987 
988   // Check for memory dependencies which prohibit the unrolling we are doing.
989   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
990   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
991   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
992                          LI)) {
993     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
994     return false;
995   }
996 
997   return true;
998 }
999