xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DependencyGraph.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Utils.h"
13 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
14 
15 namespace llvm::sandboxir {
16 
skipBadIt(User::op_iterator OpIt,User::op_iterator OpItE,const DependencyGraph & DAG)17 User::op_iterator PredIterator::skipBadIt(User::op_iterator OpIt,
18                                           User::op_iterator OpItE,
19                                           const DependencyGraph &DAG) {
20   auto Skip = [&DAG](auto OpIt) {
21     auto *I = dyn_cast<Instruction>((*OpIt).get());
22     return I == nullptr || DAG.getNode(I) == nullptr;
23   };
24   while (OpIt != OpItE && Skip(OpIt))
25     ++OpIt;
26   return OpIt;
27 }
28 
operator *()29 PredIterator::value_type PredIterator::operator*() {
30   // If it's a DGNode then we dereference the operand iterator.
31   if (!isa<MemDGNode>(N)) {
32     assert(OpIt != OpItE && "Can't dereference end iterator!");
33     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
34   }
35   // It's a MemDGNode, so we check if we return either the use-def operand,
36   // or a mem predecessor.
37   if (OpIt != OpItE)
38     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
39   // It's a MemDGNode with OpIt == end, so we need to use MemIt.
40   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
41          "Cant' dereference end iterator!");
42   return *MemIt;
43 }
44 
operator ++()45 PredIterator &PredIterator::operator++() {
46   // If it's a DGNode then we increment the use-def iterator.
47   if (!isa<MemDGNode>(N)) {
48     assert(OpIt != OpItE && "Already at end!");
49     ++OpIt;
50     // Skip operands that are not instructions or are outside the DAG.
51     OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
52     return *this;
53   }
54   // It's a MemDGNode, so if we are not at the end of the use-def iterator we
55   // need to first increment that.
56   if (OpIt != OpItE) {
57     ++OpIt;
58     // Skip operands that are not instructions or are outside the DAG.
59     OpIt = PredIterator::skipBadIt(OpIt, OpItE, *DAG);
60     return *this;
61   }
62   // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
63   assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
64   ++MemIt;
65   return *this;
66 }
67 
operator ==(const PredIterator & Other) const68 bool PredIterator::operator==(const PredIterator &Other) const {
69   assert(DAG == Other.DAG && "Iterators of different DAGs!");
70   assert(N == Other.N && "Iterators of different nodes!");
71   return OpIt == Other.OpIt && MemIt == Other.MemIt;
72 }
73 
setSchedBundle(SchedBundle & SB)74 void DGNode::setSchedBundle(SchedBundle &SB) {
75   if (this->SB != nullptr)
76     this->SB->eraseFromBundle(this);
77   this->SB = &SB;
78 }
79 
~DGNode()80 DGNode::~DGNode() {
81   if (SB == nullptr)
82     return;
83   SB->eraseFromBundle(this);
84 }
85 
86 #ifndef NDEBUG
print(raw_ostream & OS,bool PrintDeps) const87 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
88   OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";
89 }
dump() const90 void DGNode::dump() const { print(dbgs()); }
print(raw_ostream & OS,bool PrintDeps) const91 void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
92   DGNode::print(OS, false);
93   if (PrintDeps) {
94     // Print memory preds.
95     static constexpr const unsigned Indent = 4;
96     for (auto *Pred : MemPreds)
97       OS.indent(Indent) << "<-" << *Pred->getInstruction() << "\n";
98   }
99 }
100 #endif // NDEBUG
101 
102 MemDGNode *
getTopMemDGNode(const Interval<Instruction> & Intvl,const DependencyGraph & DAG)103 MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
104                                           const DependencyGraph &DAG) {
105   Instruction *I = Intvl.top();
106   Instruction *BeforeI = Intvl.bottom();
107   // Walk down the chain looking for a mem-dep candidate instruction.
108   while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
109     I = I->getNextNode();
110   if (!DGNode::isMemDepNodeCandidate(I))
111     return nullptr;
112   return cast<MemDGNode>(DAG.getNode(I));
113 }
114 
115 MemDGNode *
getBotMemDGNode(const Interval<Instruction> & Intvl,const DependencyGraph & DAG)116 MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
117                                           const DependencyGraph &DAG) {
118   Instruction *I = Intvl.bottom();
119   Instruction *AfterI = Intvl.top();
120   // Walk up the chain looking for a mem-dep candidate instruction.
121   while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
122     I = I->getPrevNode();
123   if (!DGNode::isMemDepNodeCandidate(I))
124     return nullptr;
125   return cast<MemDGNode>(DAG.getNode(I));
126 }
127 
128 Interval<MemDGNode>
make(const Interval<Instruction> & Instrs,DependencyGraph & DAG)129 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
130                                DependencyGraph &DAG) {
131   if (Instrs.empty())
132     return {};
133   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
134   // If we couldn't find a mem node in range TopN - BotN then it's empty.
135   if (TopMemN == nullptr)
136     return {};
137   auto *BotMemN = getBotMemDGNode(Instrs, DAG);
138   assert(BotMemN != nullptr && "TopMemN should be null too!");
139   // Now that we have the mem-dep nodes, create and return the range.
140   return Interval<MemDGNode>(TopMemN, BotMemN);
141 }
142 
143 DependencyGraph::DependencyType
getRoughDepType(Instruction * FromI,Instruction * ToI)144 DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) {
145   // TODO: Perhaps compile-time improvement by skipping if neither is mem?
146   if (FromI->mayWriteToMemory()) {
147     if (ToI->mayReadFromMemory())
148       return DependencyType::ReadAfterWrite;
149     if (ToI->mayWriteToMemory())
150       return DependencyType::WriteAfterWrite;
151   } else if (FromI->mayReadFromMemory()) {
152     if (ToI->mayWriteToMemory())
153       return DependencyType::WriteAfterRead;
154   }
155   if (isa<sandboxir::PHINode>(FromI) || isa<sandboxir::PHINode>(ToI))
156     return DependencyType::Control;
157   if (ToI->isTerminator())
158     return DependencyType::Control;
159   if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) ||
160       DGNode::isStackSaveOrRestoreIntrinsic(ToI))
161     return DependencyType::Other;
162   return DependencyType::None;
163 }
164 
isOrdered(Instruction * I)165 static bool isOrdered(Instruction *I) {
166   auto IsOrdered = [](Instruction *I) {
167     if (auto *LI = dyn_cast<LoadInst>(I))
168       return !LI->isUnordered();
169     if (auto *SI = dyn_cast<StoreInst>(I))
170       return !SI->isUnordered();
171     if (DGNode::isFenceLike(I))
172       return true;
173     return false;
174   };
175   bool Is = IsOrdered(I);
176   assert((!Is || DGNode::isMemDepCandidate(I)) &&
177          "An ordered instruction must be a MemDepCandidate!");
178   return Is;
179 }
180 
alias(Instruction * SrcI,Instruction * DstI,DependencyType DepType)181 bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI,
182                             DependencyType DepType) {
183   std::optional<MemoryLocation> DstLocOpt =
184       Utils::memoryLocationGetOrNone(DstI);
185   if (!DstLocOpt)
186     return true;
187   // Check aliasing.
188   assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) &&
189          "Expected a mem instr");
190   // TODO: Check AABudget
191   ModRefInfo SrcModRef =
192       isOrdered(SrcI)
193           ? ModRefInfo::ModRef
194           : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt);
195   switch (DepType) {
196   case DependencyType::ReadAfterWrite:
197   case DependencyType::WriteAfterWrite:
198     return isModSet(SrcModRef);
199   case DependencyType::WriteAfterRead:
200     return isRefSet(SrcModRef);
201   default:
202     llvm_unreachable("Expected only RAW, WAW and WAR!");
203   }
204 }
205 
hasDep(Instruction * SrcI,Instruction * DstI)206 bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
207   DependencyType RoughDepType = getRoughDepType(SrcI, DstI);
208   switch (RoughDepType) {
209   case DependencyType::ReadAfterWrite:
210   case DependencyType::WriteAfterWrite:
211   case DependencyType::WriteAfterRead:
212     return alias(SrcI, DstI, RoughDepType);
213   case DependencyType::Control:
214     // Adding actual dep edges from PHIs/to terminator would just create too
215     // many edges, which would be bad for compile-time.
216     // So we ignore them in the DAG formation but handle them in the
217     // scheduler, while sorting the ready list.
218     return false;
219   case DependencyType::Other:
220     return true;
221   case DependencyType::None:
222     return false;
223   }
224   llvm_unreachable("Unknown DependencyType enum");
225 }
226 
scanAndAddDeps(MemDGNode & DstN,const Interval<MemDGNode> & SrcScanRange)227 void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
228                                      const Interval<MemDGNode> &SrcScanRange) {
229   assert(isa<MemDGNode>(DstN) &&
230          "DstN is the mem dep destination, so it must be mem");
231   Instruction *DstI = DstN.getInstruction();
232   // Walk up the instruction chain from ScanRange bottom to top, looking for
233   // memory instrs that may alias.
234   for (MemDGNode &SrcN : reverse(SrcScanRange)) {
235     Instruction *SrcI = SrcN.getInstruction();
236     if (hasDep(SrcI, DstI))
237       DstN.addMemPred(&SrcN);
238   }
239 }
240 
setDefUseUnscheduledSuccs(const Interval<Instruction> & NewInterval)241 void DependencyGraph::setDefUseUnscheduledSuccs(
242     const Interval<Instruction> &NewInterval) {
243   // +---+
244   // |   |  Def
245   // |   |   |
246   // |   |   v
247   // |   |  Use
248   // +---+
249   // Set the intra-interval counters in NewInterval.
250   for (Instruction &I : NewInterval) {
251     for (Value *Op : I.operands()) {
252       auto *OpI = dyn_cast<Instruction>(Op);
253       if (OpI == nullptr)
254         continue;
255       // TODO: For now don't cross BBs.
256       if (OpI->getParent() != I.getParent())
257         continue;
258       if (!NewInterval.contains(OpI))
259         continue;
260       auto *OpN = getNode(OpI);
261       if (OpN == nullptr)
262         continue;
263       ++OpN->UnscheduledSuccs;
264     }
265   }
266 
267   // Now handle the cross-interval edges.
268   bool NewIsAbove = DAGInterval.empty() || NewInterval.comesBefore(DAGInterval);
269   const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
270   const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
271   // +---+
272   // |Top|
273   // |   |  Def
274   // +---+   |
275   // |   |   v
276   // |Bot|  Use
277   // |   |
278   // +---+
279   // Walk over all instructions in "BotInterval" and update the counter
280   // of operands that are in "TopInterval".
281   for (Instruction &BotI : BotInterval) {
282     auto *BotN = getNode(&BotI);
283     // Skip scheduled nodes.
284     if (BotN->scheduled())
285       continue;
286     for (Value *Op : BotI.operands()) {
287       auto *OpI = dyn_cast<Instruction>(Op);
288       if (OpI == nullptr)
289         continue;
290       auto *OpN = getNode(OpI);
291       if (OpN == nullptr)
292         continue;
293       if (!TopInterval.contains(OpI))
294         continue;
295       ++OpN->UnscheduledSuccs;
296     }
297   }
298 }
299 
createNewNodes(const Interval<Instruction> & NewInterval)300 void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
301   // Create Nodes only for the new sections of the DAG.
302   DGNode *LastN = getOrCreateNode(NewInterval.top());
303   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
304   for (Instruction &I : drop_begin(NewInterval)) {
305     auto *N = getOrCreateNode(&I);
306     // Build the Mem node chain.
307     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
308       MemN->setPrevNode(LastMemN);
309       LastMemN = MemN;
310     }
311   }
312   // Link new MemDGNode chain with the old one, if any.
313   if (!DAGInterval.empty()) {
314     bool NewIsAbove = NewInterval.comesBefore(DAGInterval);
315     const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
316     const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
317     MemDGNode *LinkTopN =
318         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
319     MemDGNode *LinkBotN =
320         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
321     assert((LinkTopN == nullptr || LinkBotN == nullptr ||
322             LinkTopN->comesBefore(LinkBotN)) &&
323            "Wrong order!");
324     if (LinkTopN != nullptr && LinkBotN != nullptr) {
325       LinkTopN->setNextNode(LinkBotN);
326     }
327 #ifndef NDEBUG
328     // TODO: Remove this once we've done enough testing.
329     // Check that the chain is well formed.
330     auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
331     MemDGNode *ChainTopN =
332         MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
333     MemDGNode *ChainBotN =
334         MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
335     if (ChainTopN != nullptr && ChainBotN != nullptr) {
336       for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
337            LastN = N, N = N->getNextNode()) {
338         assert(N == LastN->getNextNode() && "Bad chain!");
339         assert(N->getPrevNode() == LastN && "Bad chain!");
340       }
341     }
342 #endif // NDEBUG
343   }
344 
345   setDefUseUnscheduledSuccs(NewInterval);
346 }
347 
getMemDGNodeBefore(DGNode * N,bool IncludingN,MemDGNode * SkipN) const348 MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
349                                                MemDGNode *SkipN) const {
350   auto *I = N->getInstruction();
351   for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
352        PrevI = PrevI->getPrevNode()) {
353     auto *PrevN = getNodeOrNull(PrevI);
354     if (PrevN == nullptr)
355       return nullptr;
356     auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
357     if (PrevMemN != nullptr && PrevMemN != SkipN)
358       return PrevMemN;
359   }
360   return nullptr;
361 }
362 
getMemDGNodeAfter(DGNode * N,bool IncludingN,MemDGNode * SkipN) const363 MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
364                                               MemDGNode *SkipN) const {
365   auto *I = N->getInstruction();
366   for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
367        NextI = NextI->getNextNode()) {
368     auto *NextN = getNodeOrNull(NextI);
369     if (NextN == nullptr)
370       return nullptr;
371     auto *NextMemN = dyn_cast<MemDGNode>(NextN);
372     if (NextMemN != nullptr && NextMemN != SkipN)
373       return NextMemN;
374   }
375   return nullptr;
376 }
377 
notifyCreateInstr(Instruction * I)378 void DependencyGraph::notifyCreateInstr(Instruction *I) {
379   if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting)
380     // We don't maintain the DAG while reverting.
381     return;
382   // Nothing to do if the node is not in the focus range of the DAG.
383   if (!(DAGInterval.contains(I) || DAGInterval.touches(I)))
384     return;
385   // Include `I` into the interval.
386   DAGInterval = DAGInterval.getUnionInterval({I, I});
387   auto *N = getOrCreateNode(I);
388   auto *MemN = dyn_cast<MemDGNode>(N);
389 
390   // Update the MemDGNode chain if this is a memory node.
391   if (MemN != nullptr) {
392     if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) {
393       PrevMemN->NextMemN = MemN;
394       MemN->PrevMemN = PrevMemN;
395     }
396     if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) {
397       NextMemN->PrevMemN = MemN;
398       MemN->NextMemN = NextMemN;
399     }
400 
401     // Add Mem dependencies.
402     // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN.
403     if (DAGInterval.top()->comesBefore(I)) {
404       Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode());
405       auto SrcInterval = MemDGNodeIntervalBuilder::make(AboveIntvl, *this);
406       scanAndAddDeps(*MemN, SrcInterval);
407     }
408     // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN.
409     if (I->comesBefore(DAGInterval.bottom())) {
410       Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom());
411       for (MemDGNode &BelowN :
412            MemDGNodeIntervalBuilder::make(BelowIntvl, *this))
413         scanAndAddDeps(BelowN, Interval<MemDGNode>(MemN, MemN));
414     }
415   }
416 }
417 
notifyMoveInstr(Instruction * I,const BBIterator & To)418 void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
419   if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting)
420     // We don't maintain the DAG while reverting.
421     return;
422   // NOTE: This function runs before `I` moves to its new destination.
423   BasicBlock *BB = To.getNodeParent();
424   assert(!(To != BB->end() && &*To == I->getNextNode()) &&
425          !(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
426          "Should not have been called if destination is same as origin.");
427 
428   // TODO: We can only handle fully internal movements within DAGInterval or at
429   // the borders, i.e., right before the top or right after the bottom.
430   assert(To.getNodeParent() == I->getParent() &&
431          "TODO: We don't support movement across BBs!");
432   assert(
433       (To == std::next(DAGInterval.bottom()->getIterator()) ||
434        (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
435        (To != BB->end() && DAGInterval.contains(&*To))) &&
436       "TODO: To should be either within the DAGInterval or right "
437       "before/after it.");
438 
439   // Make a copy of the DAGInterval before we update it.
440   auto OrigDAGInterval = DAGInterval;
441 
442   // Maintain the DAGInterval.
443   DAGInterval.notifyMoveInstr(I, To);
444 
445   // TODO: Perhaps check if this is legal by checking the dependencies?
446 
447   // Update the MemDGNode chain to reflect the instr movement if necessary.
448   DGNode *N = getNodeOrNull(I);
449   if (N == nullptr)
450     return;
451   MemDGNode *MemN = dyn_cast<MemDGNode>(N);
452   if (MemN == nullptr)
453     return;
454 
455   // First safely detach it from the existing chain.
456   MemN->detachFromChain();
457 
458   // Now insert it back into the chain at the new location.
459   //
460   // We won't always have a DGNode to insert before it. If `To` is BB->end() or
461   // if it points to an instr after DAGInterval.bottom() then we will have to
462   // find a node to insert *after*.
463   //
464   // BB:                              BB:
465   //  I1                               I1 ^
466   //  I2                               I2 | DAGInteval [I1 to I3]
467   //  I3                               I3 V
468   //  I4                               I4   <- `To` == right after DAGInterval
469   //    <- `To` == BB->end()
470   //
471   if (To == BB->end() ||
472       To == std::next(OrigDAGInterval.bottom()->getIterator())) {
473     // If we don't have a node to insert before, find a node to insert after and
474     // update the chain.
475     DGNode *InsertAfterN = getNode(&*std::prev(To));
476     MemN->setPrevNode(
477         getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
478   } else {
479     // We have a node to insert before, so update the chain.
480     DGNode *BeforeToN = getNode(&*To);
481     MemN->setPrevNode(
482         getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
483     MemN->setNextNode(
484         getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
485   }
486 }
487 
notifyEraseInstr(Instruction * I)488 void DependencyGraph::notifyEraseInstr(Instruction *I) {
489   if (Ctx->getTracker().getState() == Tracker::TrackerState::Reverting)
490     // We don't maintain the DAG while reverting.
491     return;
492   auto *N = getNode(I);
493   if (N == nullptr)
494     // Early return if there is no DAG node for `I`.
495     return;
496   if (auto *MemN = dyn_cast<MemDGNode>(getNode(I))) {
497     // Update the MemDGNode chain if this is a memory node.
498     auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false);
499     auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false);
500     if (PrevMemN != nullptr)
501       PrevMemN->NextMemN = NextMemN;
502     if (NextMemN != nullptr)
503       NextMemN->PrevMemN = PrevMemN;
504 
505     // Drop the memory dependencies from both predecessors and successors.
506     while (!MemN->memPreds().empty()) {
507       auto *PredN = *MemN->memPreds().begin();
508       MemN->removeMemPred(PredN);
509     }
510     while (!MemN->memSuccs().empty()) {
511       auto *SuccN = *MemN->memSuccs().begin();
512       SuccN->removeMemPred(MemN);
513     }
514     // NOTE: The unscheduled succs for MemNodes get updated be setMemPred().
515   } else {
516     // If this is a non-mem node we only need to update UnscheduledSuccs.
517     if (!N->scheduled())
518       for (auto *PredN : N->preds(*this))
519         PredN->decrUnscheduledSuccs();
520   }
521   // Finally erase the Node.
522   InstrToNodeMap.erase(I);
523 }
524 
notifySetUse(const Use & U,Value * NewSrc)525 void DependencyGraph::notifySetUse(const Use &U, Value *NewSrc) {
526   // Update the UnscheduledSuccs counter for both the current source and NewSrc
527   // if needed.
528   if (auto *CurrSrcI = dyn_cast<Instruction>(U.get())) {
529     if (auto *CurrSrcN = getNode(CurrSrcI)) {
530       CurrSrcN->decrUnscheduledSuccs();
531     }
532   }
533   if (auto *NewSrcI = dyn_cast<Instruction>(NewSrc)) {
534     if (auto *NewSrcN = getNode(NewSrcI)) {
535       ++NewSrcN->UnscheduledSuccs;
536     }
537   }
538 }
539 
extend(ArrayRef<Instruction * > Instrs)540 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
541   if (Instrs.empty())
542     return {};
543 
544   Interval<Instruction> InstrsInterval(Instrs);
545   Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
546   auto NewInterval = Union.getSingleDiff(DAGInterval);
547   if (NewInterval.empty())
548     return {};
549 
550   createNewNodes(NewInterval);
551 
552   // Create the dependencies.
553   //
554   // 1. This is a new DAG, DAGInterval is empty. Fully scan the whole interval.
555   // +---+       -             -
556   // |   | SrcN  |             |
557   // |   |  |    | SrcRange    |
558   // |New|  v    |             | DstRange
559   // |   | DstN  -             |
560   // |   |                     |
561   // +---+                     -
562   // We are scanning for deps with destination in NewInterval and sources in
563   // NewInterval until DstN, for each DstN.
564   auto FullScan = [this](const Interval<Instruction> Intvl) {
565     auto DstRange = MemDGNodeIntervalBuilder::make(Intvl, *this);
566     if (!DstRange.empty()) {
567       for (MemDGNode &DstN : drop_begin(DstRange)) {
568         auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
569         scanAndAddDeps(DstN, SrcRange);
570       }
571     }
572   };
573   auto MemDAGInterval = MemDGNodeIntervalBuilder::make(DAGInterval, *this);
574   if (MemDAGInterval.empty()) {
575     FullScan(NewInterval);
576   }
577   // 2. The new section is below the old section.
578   // +---+       -
579   // |   |       |
580   // |Old| SrcN  |
581   // |   |  |    |
582   // +---+  |    | SrcRange
583   // +---+  |    |             -
584   // |   |  |    |             |
585   // |New|  v    |             | DstRange
586   // |   | DstN  -             |
587   // |   |                     |
588   // +---+                     -
589   // We are scanning for deps with destination in NewInterval because the deps
590   // in DAGInterval have already been computed. We consider sources in the whole
591   // range including both NewInterval and DAGInterval until DstN, for each DstN.
592   else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
593     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
594     auto SrcRangeFull = MemDAGInterval.getUnionInterval(DstRange);
595     for (MemDGNode &DstN : DstRange) {
596       auto SrcRange =
597           Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
598       scanAndAddDeps(DstN, SrcRange);
599     }
600   }
601   // 3. The new section is above the old section.
602   else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
603     // +---+       -             -
604     // |   | SrcN  |             |
605     // |New|  |    | SrcRange    | DstRange
606     // |   |  v    |             |
607     // |   | DstN  -             |
608     // |   |                     |
609     // +---+                     -
610     // +---+
611     // |Old|
612     // |   |
613     // +---+
614     // When scanning for deps with destination in NewInterval we need to fully
615     // scan the interval. This is the same as the scanning for a new DAG.
616     FullScan(NewInterval);
617 
618     // +---+       -
619     // |   |       |
620     // |New| SrcN  | SrcRange
621     // |   |  |    |
622     // |   |  |    |
623     // |   |  |    |
624     // +---+  |    -
625     // +---+  |                  -
626     // |Old|  v                  | DstRange
627     // |   | DstN                |
628     // +---+                     -
629     // When scanning for deps with destination in DAGInterval we need to
630     // consider sources from the NewInterval only, because all intra-DAGInterval
631     // dependencies have already been created.
632     auto DstRangeOld = MemDAGInterval;
633     auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
634     for (MemDGNode &DstN : DstRangeOld)
635       scanAndAddDeps(DstN, SrcRange);
636   } else {
637     llvm_unreachable("We don't expect extending in both directions!");
638   }
639 
640   DAGInterval = Union;
641   return NewInterval;
642 }
643 
644 #ifndef NDEBUG
print(raw_ostream & OS) const645 void DependencyGraph::print(raw_ostream &OS) const {
646   // InstrToNodeMap is unordered so we need to create an ordered vector.
647   SmallVector<DGNode *> Nodes;
648   Nodes.reserve(InstrToNodeMap.size());
649   for (const auto &Pair : InstrToNodeMap)
650     Nodes.push_back(Pair.second.get());
651   // Sort them based on which one comes first in the BB.
652   sort(Nodes, [](DGNode *N1, DGNode *N2) {
653     return N1->getInstruction()->comesBefore(N2->getInstruction());
654   });
655   for (auto *N : Nodes)
656     N->print(OS, /*PrintDeps=*/true);
657 }
658 
dump() const659 void DependencyGraph::dump() const {
660   print(dbgs());
661   dbgs() << "\n";
662 }
663 #endif // NDEBUG
664 
665 } // namespace llvm::sandboxir
666