xref: /freebsd/contrib/llvm-project/llvm/include/llvm/CodeGen/RegAllocPBQP.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- RegAllocPBQP.h -------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the PBQPBuilder interface, for classes which build PBQP
10 // instances to represent register allocation problems, and the RegAllocPBQP
11 // interface.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_CODEGEN_REGALLOCPBQP_H
16 #define LLVM_CODEGEN_REGALLOCPBQP_H
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/CodeGen/PBQP/CostAllocator.h"
21 #include "llvm/CodeGen/PBQP/Graph.h"
22 #include "llvm/CodeGen/PBQP/Math.h"
23 #include "llvm/CodeGen/PBQP/ReductionRules.h"
24 #include "llvm/CodeGen/PBQP/Solution.h"
25 #include "llvm/CodeGen/Register.h"
26 #include "llvm/MC/MCRegister.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include <algorithm>
29 #include <cassert>
30 #include <cstddef>
31 #include <limits>
32 #include <memory>
33 #include <set>
34 #include <vector>
35 
36 namespace llvm {
37 
38 class FunctionPass;
39 class LiveIntervals;
40 class MachineBlockFrequencyInfo;
41 class MachineFunction;
42 class raw_ostream;
43 
44 namespace PBQP {
45 namespace RegAlloc {
46 
47 /// Spill option index.
getSpillOptionIdx()48 inline unsigned getSpillOptionIdx() { return 0; }
49 
50 /// Metadata to speed allocatability test.
51 ///
52 /// Keeps track of the number of infinities in each row and column.
53 class MatrixMetadata {
54 public:
MatrixMetadata(const Matrix & M)55   MatrixMetadata(const Matrix& M)
56     : UnsafeRows(new bool[M.getRows() - 1]()),
57       UnsafeCols(new bool[M.getCols() - 1]()) {
58     unsigned* ColCounts = new unsigned[M.getCols() - 1]();
59 
60     for (unsigned i = 1; i < M.getRows(); ++i) {
61       unsigned RowCount = 0;
62       for (unsigned j = 1; j < M.getCols(); ++j) {
63         if (M[i][j] == std::numeric_limits<PBQPNum>::infinity()) {
64           ++RowCount;
65           ++ColCounts[j - 1];
66           UnsafeRows[i - 1] = true;
67           UnsafeCols[j - 1] = true;
68         }
69       }
70       WorstRow = std::max(WorstRow, RowCount);
71     }
72     unsigned WorstColCountForCurRow =
73       *std::max_element(ColCounts, ColCounts + M.getCols() - 1);
74     WorstCol = std::max(WorstCol, WorstColCountForCurRow);
75     delete[] ColCounts;
76   }
77 
78   MatrixMetadata(const MatrixMetadata &) = delete;
79   MatrixMetadata &operator=(const MatrixMetadata &) = delete;
80 
getWorstRow()81   unsigned getWorstRow() const { return WorstRow; }
getWorstCol()82   unsigned getWorstCol() const { return WorstCol; }
getUnsafeRows()83   const bool* getUnsafeRows() const { return UnsafeRows.get(); }
getUnsafeCols()84   const bool* getUnsafeCols() const { return UnsafeCols.get(); }
85 
86 private:
87   unsigned WorstRow = 0;
88   unsigned WorstCol = 0;
89   std::unique_ptr<bool[]> UnsafeRows;
90   std::unique_ptr<bool[]> UnsafeCols;
91 };
92 
93 /// Holds a vector of the allowed physical regs for a vreg.
94 class AllowedRegVector {
95   friend hash_code hash_value(const AllowedRegVector &);
96 
97 public:
98   AllowedRegVector() = default;
99   AllowedRegVector(AllowedRegVector &&) = default;
100 
AllowedRegVector(const std::vector<MCRegister> & OptVec)101   AllowedRegVector(const std::vector<MCRegister> &OptVec)
102       : NumOpts(OptVec.size()), Opts(new MCRegister[NumOpts]) {
103     std::copy(OptVec.begin(), OptVec.end(), Opts.get());
104   }
105 
size()106   unsigned size() const { return NumOpts; }
107   MCRegister operator[](size_t I) const { return Opts[I]; }
108 
109   bool operator==(const AllowedRegVector &Other) const {
110     if (NumOpts != Other.NumOpts)
111       return false;
112     return std::equal(Opts.get(), Opts.get() + NumOpts, Other.Opts.get());
113   }
114 
115   bool operator!=(const AllowedRegVector &Other) const {
116     return !(*this == Other);
117   }
118 
119 private:
120   unsigned NumOpts = 0;
121   std::unique_ptr<MCRegister[]> Opts;
122 };
123 
hash_value(const AllowedRegVector & OptRegs)124 inline hash_code hash_value(const AllowedRegVector &OptRegs) {
125   MCRegister *OStart = OptRegs.Opts.get();
126   MCRegister *OEnd = OptRegs.Opts.get() + OptRegs.NumOpts;
127   return hash_combine(OptRegs.NumOpts,
128                       hash_combine_range(OStart, OEnd));
129 }
130 
131 /// Holds graph-level metadata relevant to PBQP RA problems.
132 class GraphMetadata {
133 private:
134   using AllowedRegVecPool = ValuePool<AllowedRegVector>;
135 
136 public:
137   using AllowedRegVecRef = AllowedRegVecPool::PoolRef;
138 
GraphMetadata(MachineFunction & MF,LiveIntervals & LIS,MachineBlockFrequencyInfo & MBFI)139   GraphMetadata(MachineFunction &MF,
140                 LiveIntervals &LIS,
141                 MachineBlockFrequencyInfo &MBFI)
142     : MF(MF), LIS(LIS), MBFI(MBFI) {}
143 
144   MachineFunction &MF;
145   LiveIntervals &LIS;
146   MachineBlockFrequencyInfo &MBFI;
147 
setNodeIdForVReg(Register VReg,GraphBase::NodeId NId)148   void setNodeIdForVReg(Register VReg, GraphBase::NodeId NId) {
149     VRegToNodeId[VReg.id()] = NId;
150   }
151 
getNodeIdForVReg(Register VReg)152   GraphBase::NodeId getNodeIdForVReg(Register VReg) const {
153     auto VRegItr = VRegToNodeId.find(VReg);
154     if (VRegItr == VRegToNodeId.end())
155       return GraphBase::invalidNodeId();
156     return VRegItr->second;
157   }
158 
getAllowedRegs(AllowedRegVector Allowed)159   AllowedRegVecRef getAllowedRegs(AllowedRegVector Allowed) {
160     return AllowedRegVecs.getValue(std::move(Allowed));
161   }
162 
163 private:
164   DenseMap<Register, GraphBase::NodeId> VRegToNodeId;
165   AllowedRegVecPool AllowedRegVecs;
166 };
167 
168 /// Holds solver state and other metadata relevant to each PBQP RA node.
169 class NodeMetadata {
170 public:
171   using AllowedRegVector = RegAlloc::AllowedRegVector;
172 
173   // The node's reduction state. The order in this enum is important,
174   // as it is assumed nodes can only progress up (i.e. towards being
175   // optimally reducible) when reducing the graph.
176   using ReductionState = enum {
177     Unprocessed,
178     NotProvablyAllocatable,
179     ConservativelyAllocatable,
180     OptimallyReducible
181   };
182 
183   NodeMetadata() = default;
184 
NodeMetadata(const NodeMetadata & Other)185   NodeMetadata(const NodeMetadata &Other)
186       : RS(Other.RS), NumOpts(Other.NumOpts), DeniedOpts(Other.DeniedOpts),
187         OptUnsafeEdges(new unsigned[NumOpts]), VReg(Other.VReg),
188         AllowedRegs(Other.AllowedRegs)
189 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
190         ,
191         everConservativelyAllocatable(Other.everConservativelyAllocatable)
192 #endif
193   {
194     if (NumOpts > 0) {
195       std::copy(&Other.OptUnsafeEdges[0], &Other.OptUnsafeEdges[NumOpts],
196                 &OptUnsafeEdges[0]);
197     }
198   }
199 
200   NodeMetadata(NodeMetadata &&) = default;
201   NodeMetadata& operator=(NodeMetadata &&) = default;
202 
setVReg(Register VReg)203   void setVReg(Register VReg) { this->VReg = VReg; }
getVReg()204   Register getVReg() const { return VReg; }
205 
setAllowedRegs(GraphMetadata::AllowedRegVecRef AllowedRegs)206   void setAllowedRegs(GraphMetadata::AllowedRegVecRef AllowedRegs) {
207     this->AllowedRegs = std::move(AllowedRegs);
208   }
getAllowedRegs()209   const AllowedRegVector& getAllowedRegs() const { return *AllowedRegs; }
210 
setup(const Vector & Costs)211   void setup(const Vector& Costs) {
212     NumOpts = Costs.getLength() - 1;
213     OptUnsafeEdges = std::unique_ptr<unsigned[]>(new unsigned[NumOpts]());
214   }
215 
getReductionState()216   ReductionState getReductionState() const { return RS; }
setReductionState(ReductionState RS)217   void setReductionState(ReductionState RS) {
218     assert(RS >= this->RS && "A node's reduction state can not be downgraded");
219     this->RS = RS;
220 
221 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
222     // Remember this state to assert later that a non-infinite register
223     // option was available.
224     if (RS == ConservativelyAllocatable)
225       everConservativelyAllocatable = true;
226 #endif
227   }
228 
handleAddEdge(const MatrixMetadata & MD,bool Transpose)229   void handleAddEdge(const MatrixMetadata& MD, bool Transpose) {
230     DeniedOpts += Transpose ? MD.getWorstRow() : MD.getWorstCol();
231     const bool* UnsafeOpts =
232       Transpose ? MD.getUnsafeCols() : MD.getUnsafeRows();
233     for (unsigned i = 0; i < NumOpts; ++i)
234       OptUnsafeEdges[i] += UnsafeOpts[i];
235   }
236 
handleRemoveEdge(const MatrixMetadata & MD,bool Transpose)237   void handleRemoveEdge(const MatrixMetadata& MD, bool Transpose) {
238     DeniedOpts -= Transpose ? MD.getWorstRow() : MD.getWorstCol();
239     const bool* UnsafeOpts =
240       Transpose ? MD.getUnsafeCols() : MD.getUnsafeRows();
241     for (unsigned i = 0; i < NumOpts; ++i)
242       OptUnsafeEdges[i] -= UnsafeOpts[i];
243   }
244 
isConservativelyAllocatable()245   bool isConservativelyAllocatable() const {
246     return (DeniedOpts < NumOpts) ||
247       (std::find(&OptUnsafeEdges[0], &OptUnsafeEdges[NumOpts], 0) !=
248        &OptUnsafeEdges[NumOpts]);
249   }
250 
251 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
wasConservativelyAllocatable()252   bool wasConservativelyAllocatable() const {
253     return everConservativelyAllocatable;
254   }
255 #endif
256 
257 private:
258   ReductionState RS = Unprocessed;
259   unsigned NumOpts = 0;
260   unsigned DeniedOpts = 0;
261   std::unique_ptr<unsigned[]> OptUnsafeEdges;
262   Register VReg;
263   GraphMetadata::AllowedRegVecRef AllowedRegs;
264 
265 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
266   bool everConservativelyAllocatable = false;
267 #endif
268 };
269 
270 class RegAllocSolverImpl {
271 private:
272   using RAMatrix = MDMatrix<MatrixMetadata>;
273 
274 public:
275   using RawVector = PBQP::Vector;
276   using RawMatrix = PBQP::Matrix;
277   using Vector = PBQP::Vector;
278   using Matrix = RAMatrix;
279   using CostAllocator = PBQP::PoolCostAllocator<Vector, Matrix>;
280 
281   using NodeId = GraphBase::NodeId;
282   using EdgeId = GraphBase::EdgeId;
283 
284   using NodeMetadata = RegAlloc::NodeMetadata;
285   struct EdgeMetadata {};
286   using GraphMetadata = RegAlloc::GraphMetadata;
287 
288   using Graph = PBQP::Graph<RegAllocSolverImpl>;
289 
RegAllocSolverImpl(Graph & G)290   RegAllocSolverImpl(Graph &G) : G(G) {}
291 
solve()292   Solution solve() {
293     G.setSolver(*this);
294     Solution S;
295     setup();
296     S = backpropagate(G, reduce());
297     G.unsetSolver();
298     return S;
299   }
300 
handleAddNode(NodeId NId)301   void handleAddNode(NodeId NId) {
302     assert(G.getNodeCosts(NId).getLength() > 1 &&
303            "PBQP Graph should not contain single or zero-option nodes");
304     G.getNodeMetadata(NId).setup(G.getNodeCosts(NId));
305   }
306 
handleRemoveNode(NodeId NId)307   void handleRemoveNode(NodeId NId) {}
handleSetNodeCosts(NodeId NId,const Vector & newCosts)308   void handleSetNodeCosts(NodeId NId, const Vector& newCosts) {}
309 
handleAddEdge(EdgeId EId)310   void handleAddEdge(EdgeId EId) {
311     handleReconnectEdge(EId, G.getEdgeNode1Id(EId));
312     handleReconnectEdge(EId, G.getEdgeNode2Id(EId));
313   }
314 
handleDisconnectEdge(EdgeId EId,NodeId NId)315   void handleDisconnectEdge(EdgeId EId, NodeId NId) {
316     NodeMetadata& NMd = G.getNodeMetadata(NId);
317     const MatrixMetadata& MMd = G.getEdgeCosts(EId).getMetadata();
318     NMd.handleRemoveEdge(MMd, NId == G.getEdgeNode2Id(EId));
319     promote(NId, NMd);
320   }
321 
handleReconnectEdge(EdgeId EId,NodeId NId)322   void handleReconnectEdge(EdgeId EId, NodeId NId) {
323     NodeMetadata& NMd = G.getNodeMetadata(NId);
324     const MatrixMetadata& MMd = G.getEdgeCosts(EId).getMetadata();
325     NMd.handleAddEdge(MMd, NId == G.getEdgeNode2Id(EId));
326   }
327 
handleUpdateCosts(EdgeId EId,const Matrix & NewCosts)328   void handleUpdateCosts(EdgeId EId, const Matrix& NewCosts) {
329     NodeId N1Id = G.getEdgeNode1Id(EId);
330     NodeId N2Id = G.getEdgeNode2Id(EId);
331     NodeMetadata& N1Md = G.getNodeMetadata(N1Id);
332     NodeMetadata& N2Md = G.getNodeMetadata(N2Id);
333     bool Transpose = N1Id != G.getEdgeNode1Id(EId);
334 
335     // Metadata are computed incrementally. First, update them
336     // by removing the old cost.
337     const MatrixMetadata& OldMMd = G.getEdgeCosts(EId).getMetadata();
338     N1Md.handleRemoveEdge(OldMMd, Transpose);
339     N2Md.handleRemoveEdge(OldMMd, !Transpose);
340 
341     // And update now the metadata with the new cost.
342     const MatrixMetadata& MMd = NewCosts.getMetadata();
343     N1Md.handleAddEdge(MMd, Transpose);
344     N2Md.handleAddEdge(MMd, !Transpose);
345 
346     // As the metadata may have changed with the update, the nodes may have
347     // become ConservativelyAllocatable or OptimallyReducible.
348     promote(N1Id, N1Md);
349     promote(N2Id, N2Md);
350   }
351 
352 private:
promote(NodeId NId,NodeMetadata & NMd)353   void promote(NodeId NId, NodeMetadata& NMd) {
354     if (G.getNodeDegree(NId) == 3) {
355       // This node is becoming optimally reducible.
356       moveToOptimallyReducibleNodes(NId);
357     } else if (NMd.getReductionState() ==
358                NodeMetadata::NotProvablyAllocatable &&
359                NMd.isConservativelyAllocatable()) {
360       // This node just became conservatively allocatable.
361       moveToConservativelyAllocatableNodes(NId);
362     }
363   }
364 
removeFromCurrentSet(NodeId NId)365   void removeFromCurrentSet(NodeId NId) {
366     switch (G.getNodeMetadata(NId).getReductionState()) {
367     case NodeMetadata::Unprocessed: break;
368     case NodeMetadata::OptimallyReducible:
369       assert(OptimallyReducibleNodes.find(NId) !=
370              OptimallyReducibleNodes.end() &&
371              "Node not in optimally reducible set.");
372       OptimallyReducibleNodes.erase(NId);
373       break;
374     case NodeMetadata::ConservativelyAllocatable:
375       assert(ConservativelyAllocatableNodes.find(NId) !=
376              ConservativelyAllocatableNodes.end() &&
377              "Node not in conservatively allocatable set.");
378       ConservativelyAllocatableNodes.erase(NId);
379       break;
380     case NodeMetadata::NotProvablyAllocatable:
381       assert(NotProvablyAllocatableNodes.find(NId) !=
382              NotProvablyAllocatableNodes.end() &&
383              "Node not in not-provably-allocatable set.");
384       NotProvablyAllocatableNodes.erase(NId);
385       break;
386     }
387   }
388 
moveToOptimallyReducibleNodes(NodeId NId)389   void moveToOptimallyReducibleNodes(NodeId NId) {
390     removeFromCurrentSet(NId);
391     OptimallyReducibleNodes.insert(NId);
392     G.getNodeMetadata(NId).setReductionState(
393       NodeMetadata::OptimallyReducible);
394   }
395 
moveToConservativelyAllocatableNodes(NodeId NId)396   void moveToConservativelyAllocatableNodes(NodeId NId) {
397     removeFromCurrentSet(NId);
398     ConservativelyAllocatableNodes.insert(NId);
399     G.getNodeMetadata(NId).setReductionState(
400       NodeMetadata::ConservativelyAllocatable);
401   }
402 
moveToNotProvablyAllocatableNodes(NodeId NId)403   void moveToNotProvablyAllocatableNodes(NodeId NId) {
404     removeFromCurrentSet(NId);
405     NotProvablyAllocatableNodes.insert(NId);
406     G.getNodeMetadata(NId).setReductionState(
407       NodeMetadata::NotProvablyAllocatable);
408   }
409 
setup()410   void setup() {
411     // Set up worklists.
412     for (auto NId : G.nodeIds()) {
413       if (G.getNodeDegree(NId) < 3)
414         moveToOptimallyReducibleNodes(NId);
415       else if (G.getNodeMetadata(NId).isConservativelyAllocatable())
416         moveToConservativelyAllocatableNodes(NId);
417       else
418         moveToNotProvablyAllocatableNodes(NId);
419     }
420   }
421 
422   // Compute a reduction order for the graph by iteratively applying PBQP
423   // reduction rules. Locally optimal rules are applied whenever possible (R0,
424   // R1, R2). If no locally-optimal rules apply then any conservatively
425   // allocatable node is reduced. Finally, if no conservatively allocatable
426   // node exists then the node with the lowest spill-cost:degree ratio is
427   // selected.
reduce()428   std::vector<GraphBase::NodeId> reduce() {
429     assert(!G.empty() && "Cannot reduce empty graph.");
430 
431     using NodeId = GraphBase::NodeId;
432     std::vector<NodeId> NodeStack;
433 
434     // Consume worklists.
435     while (true) {
436       if (!OptimallyReducibleNodes.empty()) {
437         NodeSet::iterator NItr = OptimallyReducibleNodes.begin();
438         NodeId NId = *NItr;
439         OptimallyReducibleNodes.erase(NItr);
440         NodeStack.push_back(NId);
441         switch (G.getNodeDegree(NId)) {
442         case 0:
443           break;
444         case 1:
445           applyR1(G, NId);
446           break;
447         case 2:
448           applyR2(G, NId);
449           break;
450         default: llvm_unreachable("Not an optimally reducible node.");
451         }
452       } else if (!ConservativelyAllocatableNodes.empty()) {
453         // Conservatively allocatable nodes will never spill. For now just
454         // take the first node in the set and push it on the stack. When we
455         // start optimizing more heavily for register preferencing, it may
456         // would be better to push nodes with lower 'expected' or worst-case
457         // register costs first (since early nodes are the most
458         // constrained).
459         NodeSet::iterator NItr = ConservativelyAllocatableNodes.begin();
460         NodeId NId = *NItr;
461         ConservativelyAllocatableNodes.erase(NItr);
462         NodeStack.push_back(NId);
463         G.disconnectAllNeighborsFromNode(NId);
464       } else if (!NotProvablyAllocatableNodes.empty()) {
465         NodeSet::iterator NItr = llvm::min_element(NotProvablyAllocatableNodes,
466                                                    SpillCostComparator(G));
467         NodeId NId = *NItr;
468         NotProvablyAllocatableNodes.erase(NItr);
469         NodeStack.push_back(NId);
470         G.disconnectAllNeighborsFromNode(NId);
471       } else
472         break;
473     }
474 
475     return NodeStack;
476   }
477 
478   class SpillCostComparator {
479   public:
SpillCostComparator(const Graph & G)480     SpillCostComparator(const Graph& G) : G(G) {}
481 
operator()482     bool operator()(NodeId N1Id, NodeId N2Id) {
483       PBQPNum N1SC = G.getNodeCosts(N1Id)[0];
484       PBQPNum N2SC = G.getNodeCosts(N2Id)[0];
485       if (N1SC == N2SC)
486         return G.getNodeDegree(N1Id) < G.getNodeDegree(N2Id);
487       return N1SC < N2SC;
488     }
489 
490   private:
491     const Graph& G;
492   };
493 
494   Graph& G;
495   using NodeSet = std::set<NodeId>;
496   NodeSet OptimallyReducibleNodes;
497   NodeSet ConservativelyAllocatableNodes;
498   NodeSet NotProvablyAllocatableNodes;
499 };
500 
501 class PBQPRAGraph : public PBQP::Graph<RegAllocSolverImpl> {
502 private:
503   using BaseT = PBQP::Graph<RegAllocSolverImpl>;
504 
505 public:
PBQPRAGraph(GraphMetadata Metadata)506   PBQPRAGraph(GraphMetadata Metadata) : BaseT(std::move(Metadata)) {}
507 
508   /// Dump this graph to dbgs().
509   void dump() const;
510 
511   /// Dump this graph to an output stream.
512   /// @param OS Output stream to print on.
513   void dump(raw_ostream &OS) const;
514 
515   /// Print a representation of this graph in DOT format.
516   /// @param OS Output stream to print on.
517   void printDot(raw_ostream &OS) const;
518 };
519 
solve(PBQPRAGraph & G)520 inline Solution solve(PBQPRAGraph& G) {
521   if (G.empty())
522     return Solution();
523   RegAllocSolverImpl RegAllocSolver(G);
524   return RegAllocSolver.solve();
525 }
526 
527 } // end namespace RegAlloc
528 } // end namespace PBQP
529 
530 /// Create a PBQP register allocator instance.
531 FunctionPass *
532 createPBQPRegisterAllocator(char *customPassID = nullptr);
533 
534 } // end namespace llvm
535 
536 #endif // LLVM_CODEGEN_REGALLOCPBQP_H
537