1 //===- RegAllocPBQP.h -------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the PBQPBuilder interface, for classes which build PBQP
10 // instances to represent register allocation problems, and the RegAllocPBQP
11 // interface.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef LLVM_CODEGEN_REGALLOCPBQP_H
16 #define LLVM_CODEGEN_REGALLOCPBQP_H
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/CodeGen/PBQP/CostAllocator.h"
21 #include "llvm/CodeGen/PBQP/Graph.h"
22 #include "llvm/CodeGen/PBQP/Math.h"
23 #include "llvm/CodeGen/PBQP/ReductionRules.h"
24 #include "llvm/CodeGen/PBQP/Solution.h"
25 #include "llvm/CodeGen/Register.h"
26 #include "llvm/MC/MCRegister.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include <algorithm>
29 #include <cassert>
30 #include <cstddef>
31 #include <limits>
32 #include <memory>
33 #include <set>
34 #include <vector>
35
36 namespace llvm {
37
38 class FunctionPass;
39 class LiveIntervals;
40 class MachineBlockFrequencyInfo;
41 class MachineFunction;
42 class raw_ostream;
43
44 namespace PBQP {
45 namespace RegAlloc {
46
47 /// Spill option index.
getSpillOptionIdx()48 inline unsigned getSpillOptionIdx() { return 0; }
49
50 /// Metadata to speed allocatability test.
51 ///
52 /// Keeps track of the number of infinities in each row and column.
53 class MatrixMetadata {
54 public:
MatrixMetadata(const Matrix & M)55 MatrixMetadata(const Matrix& M)
56 : UnsafeRows(new bool[M.getRows() - 1]()),
57 UnsafeCols(new bool[M.getCols() - 1]()) {
58 unsigned* ColCounts = new unsigned[M.getCols() - 1]();
59
60 for (unsigned i = 1; i < M.getRows(); ++i) {
61 unsigned RowCount = 0;
62 for (unsigned j = 1; j < M.getCols(); ++j) {
63 if (M[i][j] == std::numeric_limits<PBQPNum>::infinity()) {
64 ++RowCount;
65 ++ColCounts[j - 1];
66 UnsafeRows[i - 1] = true;
67 UnsafeCols[j - 1] = true;
68 }
69 }
70 WorstRow = std::max(WorstRow, RowCount);
71 }
72 unsigned WorstColCountForCurRow =
73 *std::max_element(ColCounts, ColCounts + M.getCols() - 1);
74 WorstCol = std::max(WorstCol, WorstColCountForCurRow);
75 delete[] ColCounts;
76 }
77
78 MatrixMetadata(const MatrixMetadata &) = delete;
79 MatrixMetadata &operator=(const MatrixMetadata &) = delete;
80
getWorstRow()81 unsigned getWorstRow() const { return WorstRow; }
getWorstCol()82 unsigned getWorstCol() const { return WorstCol; }
getUnsafeRows()83 const bool* getUnsafeRows() const { return UnsafeRows.get(); }
getUnsafeCols()84 const bool* getUnsafeCols() const { return UnsafeCols.get(); }
85
86 private:
87 unsigned WorstRow = 0;
88 unsigned WorstCol = 0;
89 std::unique_ptr<bool[]> UnsafeRows;
90 std::unique_ptr<bool[]> UnsafeCols;
91 };
92
93 /// Holds a vector of the allowed physical regs for a vreg.
94 class AllowedRegVector {
95 friend hash_code hash_value(const AllowedRegVector &);
96
97 public:
98 AllowedRegVector() = default;
99 AllowedRegVector(AllowedRegVector &&) = default;
100
AllowedRegVector(const std::vector<MCRegister> & OptVec)101 AllowedRegVector(const std::vector<MCRegister> &OptVec)
102 : NumOpts(OptVec.size()), Opts(new MCRegister[NumOpts]) {
103 std::copy(OptVec.begin(), OptVec.end(), Opts.get());
104 }
105
size()106 unsigned size() const { return NumOpts; }
107 MCRegister operator[](size_t I) const { return Opts[I]; }
108
109 bool operator==(const AllowedRegVector &Other) const {
110 if (NumOpts != Other.NumOpts)
111 return false;
112 return std::equal(Opts.get(), Opts.get() + NumOpts, Other.Opts.get());
113 }
114
115 bool operator!=(const AllowedRegVector &Other) const {
116 return !(*this == Other);
117 }
118
119 private:
120 unsigned NumOpts = 0;
121 std::unique_ptr<MCRegister[]> Opts;
122 };
123
hash_value(const AllowedRegVector & OptRegs)124 inline hash_code hash_value(const AllowedRegVector &OptRegs) {
125 MCRegister *OStart = OptRegs.Opts.get();
126 MCRegister *OEnd = OptRegs.Opts.get() + OptRegs.NumOpts;
127 return hash_combine(OptRegs.NumOpts,
128 hash_combine_range(OStart, OEnd));
129 }
130
131 /// Holds graph-level metadata relevant to PBQP RA problems.
132 class GraphMetadata {
133 private:
134 using AllowedRegVecPool = ValuePool<AllowedRegVector>;
135
136 public:
137 using AllowedRegVecRef = AllowedRegVecPool::PoolRef;
138
GraphMetadata(MachineFunction & MF,LiveIntervals & LIS,MachineBlockFrequencyInfo & MBFI)139 GraphMetadata(MachineFunction &MF,
140 LiveIntervals &LIS,
141 MachineBlockFrequencyInfo &MBFI)
142 : MF(MF), LIS(LIS), MBFI(MBFI) {}
143
144 MachineFunction &MF;
145 LiveIntervals &LIS;
146 MachineBlockFrequencyInfo &MBFI;
147
setNodeIdForVReg(Register VReg,GraphBase::NodeId NId)148 void setNodeIdForVReg(Register VReg, GraphBase::NodeId NId) {
149 VRegToNodeId[VReg.id()] = NId;
150 }
151
getNodeIdForVReg(Register VReg)152 GraphBase::NodeId getNodeIdForVReg(Register VReg) const {
153 auto VRegItr = VRegToNodeId.find(VReg);
154 if (VRegItr == VRegToNodeId.end())
155 return GraphBase::invalidNodeId();
156 return VRegItr->second;
157 }
158
getAllowedRegs(AllowedRegVector Allowed)159 AllowedRegVecRef getAllowedRegs(AllowedRegVector Allowed) {
160 return AllowedRegVecs.getValue(std::move(Allowed));
161 }
162
163 private:
164 DenseMap<Register, GraphBase::NodeId> VRegToNodeId;
165 AllowedRegVecPool AllowedRegVecs;
166 };
167
168 /// Holds solver state and other metadata relevant to each PBQP RA node.
169 class NodeMetadata {
170 public:
171 using AllowedRegVector = RegAlloc::AllowedRegVector;
172
173 // The node's reduction state. The order in this enum is important,
174 // as it is assumed nodes can only progress up (i.e. towards being
175 // optimally reducible) when reducing the graph.
176 using ReductionState = enum {
177 Unprocessed,
178 NotProvablyAllocatable,
179 ConservativelyAllocatable,
180 OptimallyReducible
181 };
182
183 NodeMetadata() = default;
184
NodeMetadata(const NodeMetadata & Other)185 NodeMetadata(const NodeMetadata &Other)
186 : RS(Other.RS), NumOpts(Other.NumOpts), DeniedOpts(Other.DeniedOpts),
187 OptUnsafeEdges(new unsigned[NumOpts]), VReg(Other.VReg),
188 AllowedRegs(Other.AllowedRegs)
189 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
190 ,
191 everConservativelyAllocatable(Other.everConservativelyAllocatable)
192 #endif
193 {
194 if (NumOpts > 0) {
195 std::copy(&Other.OptUnsafeEdges[0], &Other.OptUnsafeEdges[NumOpts],
196 &OptUnsafeEdges[0]);
197 }
198 }
199
200 NodeMetadata(NodeMetadata &&) = default;
201 NodeMetadata& operator=(NodeMetadata &&) = default;
202
setVReg(Register VReg)203 void setVReg(Register VReg) { this->VReg = VReg; }
getVReg()204 Register getVReg() const { return VReg; }
205
setAllowedRegs(GraphMetadata::AllowedRegVecRef AllowedRegs)206 void setAllowedRegs(GraphMetadata::AllowedRegVecRef AllowedRegs) {
207 this->AllowedRegs = std::move(AllowedRegs);
208 }
getAllowedRegs()209 const AllowedRegVector& getAllowedRegs() const { return *AllowedRegs; }
210
setup(const Vector & Costs)211 void setup(const Vector& Costs) {
212 NumOpts = Costs.getLength() - 1;
213 OptUnsafeEdges = std::unique_ptr<unsigned[]>(new unsigned[NumOpts]());
214 }
215
getReductionState()216 ReductionState getReductionState() const { return RS; }
setReductionState(ReductionState RS)217 void setReductionState(ReductionState RS) {
218 assert(RS >= this->RS && "A node's reduction state can not be downgraded");
219 this->RS = RS;
220
221 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
222 // Remember this state to assert later that a non-infinite register
223 // option was available.
224 if (RS == ConservativelyAllocatable)
225 everConservativelyAllocatable = true;
226 #endif
227 }
228
handleAddEdge(const MatrixMetadata & MD,bool Transpose)229 void handleAddEdge(const MatrixMetadata& MD, bool Transpose) {
230 DeniedOpts += Transpose ? MD.getWorstRow() : MD.getWorstCol();
231 const bool* UnsafeOpts =
232 Transpose ? MD.getUnsafeCols() : MD.getUnsafeRows();
233 for (unsigned i = 0; i < NumOpts; ++i)
234 OptUnsafeEdges[i] += UnsafeOpts[i];
235 }
236
handleRemoveEdge(const MatrixMetadata & MD,bool Transpose)237 void handleRemoveEdge(const MatrixMetadata& MD, bool Transpose) {
238 DeniedOpts -= Transpose ? MD.getWorstRow() : MD.getWorstCol();
239 const bool* UnsafeOpts =
240 Transpose ? MD.getUnsafeCols() : MD.getUnsafeRows();
241 for (unsigned i = 0; i < NumOpts; ++i)
242 OptUnsafeEdges[i] -= UnsafeOpts[i];
243 }
244
isConservativelyAllocatable()245 bool isConservativelyAllocatable() const {
246 return (DeniedOpts < NumOpts) ||
247 (std::find(&OptUnsafeEdges[0], &OptUnsafeEdges[NumOpts], 0) !=
248 &OptUnsafeEdges[NumOpts]);
249 }
250
251 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
wasConservativelyAllocatable()252 bool wasConservativelyAllocatable() const {
253 return everConservativelyAllocatable;
254 }
255 #endif
256
257 private:
258 ReductionState RS = Unprocessed;
259 unsigned NumOpts = 0;
260 unsigned DeniedOpts = 0;
261 std::unique_ptr<unsigned[]> OptUnsafeEdges;
262 Register VReg;
263 GraphMetadata::AllowedRegVecRef AllowedRegs;
264
265 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
266 bool everConservativelyAllocatable = false;
267 #endif
268 };
269
270 class RegAllocSolverImpl {
271 private:
272 using RAMatrix = MDMatrix<MatrixMetadata>;
273
274 public:
275 using RawVector = PBQP::Vector;
276 using RawMatrix = PBQP::Matrix;
277 using Vector = PBQP::Vector;
278 using Matrix = RAMatrix;
279 using CostAllocator = PBQP::PoolCostAllocator<Vector, Matrix>;
280
281 using NodeId = GraphBase::NodeId;
282 using EdgeId = GraphBase::EdgeId;
283
284 using NodeMetadata = RegAlloc::NodeMetadata;
285 struct EdgeMetadata {};
286 using GraphMetadata = RegAlloc::GraphMetadata;
287
288 using Graph = PBQP::Graph<RegAllocSolverImpl>;
289
RegAllocSolverImpl(Graph & G)290 RegAllocSolverImpl(Graph &G) : G(G) {}
291
solve()292 Solution solve() {
293 G.setSolver(*this);
294 Solution S;
295 setup();
296 S = backpropagate(G, reduce());
297 G.unsetSolver();
298 return S;
299 }
300
handleAddNode(NodeId NId)301 void handleAddNode(NodeId NId) {
302 assert(G.getNodeCosts(NId).getLength() > 1 &&
303 "PBQP Graph should not contain single or zero-option nodes");
304 G.getNodeMetadata(NId).setup(G.getNodeCosts(NId));
305 }
306
handleRemoveNode(NodeId NId)307 void handleRemoveNode(NodeId NId) {}
handleSetNodeCosts(NodeId NId,const Vector & newCosts)308 void handleSetNodeCosts(NodeId NId, const Vector& newCosts) {}
309
handleAddEdge(EdgeId EId)310 void handleAddEdge(EdgeId EId) {
311 handleReconnectEdge(EId, G.getEdgeNode1Id(EId));
312 handleReconnectEdge(EId, G.getEdgeNode2Id(EId));
313 }
314
handleDisconnectEdge(EdgeId EId,NodeId NId)315 void handleDisconnectEdge(EdgeId EId, NodeId NId) {
316 NodeMetadata& NMd = G.getNodeMetadata(NId);
317 const MatrixMetadata& MMd = G.getEdgeCosts(EId).getMetadata();
318 NMd.handleRemoveEdge(MMd, NId == G.getEdgeNode2Id(EId));
319 promote(NId, NMd);
320 }
321
handleReconnectEdge(EdgeId EId,NodeId NId)322 void handleReconnectEdge(EdgeId EId, NodeId NId) {
323 NodeMetadata& NMd = G.getNodeMetadata(NId);
324 const MatrixMetadata& MMd = G.getEdgeCosts(EId).getMetadata();
325 NMd.handleAddEdge(MMd, NId == G.getEdgeNode2Id(EId));
326 }
327
handleUpdateCosts(EdgeId EId,const Matrix & NewCosts)328 void handleUpdateCosts(EdgeId EId, const Matrix& NewCosts) {
329 NodeId N1Id = G.getEdgeNode1Id(EId);
330 NodeId N2Id = G.getEdgeNode2Id(EId);
331 NodeMetadata& N1Md = G.getNodeMetadata(N1Id);
332 NodeMetadata& N2Md = G.getNodeMetadata(N2Id);
333 bool Transpose = N1Id != G.getEdgeNode1Id(EId);
334
335 // Metadata are computed incrementally. First, update them
336 // by removing the old cost.
337 const MatrixMetadata& OldMMd = G.getEdgeCosts(EId).getMetadata();
338 N1Md.handleRemoveEdge(OldMMd, Transpose);
339 N2Md.handleRemoveEdge(OldMMd, !Transpose);
340
341 // And update now the metadata with the new cost.
342 const MatrixMetadata& MMd = NewCosts.getMetadata();
343 N1Md.handleAddEdge(MMd, Transpose);
344 N2Md.handleAddEdge(MMd, !Transpose);
345
346 // As the metadata may have changed with the update, the nodes may have
347 // become ConservativelyAllocatable or OptimallyReducible.
348 promote(N1Id, N1Md);
349 promote(N2Id, N2Md);
350 }
351
352 private:
promote(NodeId NId,NodeMetadata & NMd)353 void promote(NodeId NId, NodeMetadata& NMd) {
354 if (G.getNodeDegree(NId) == 3) {
355 // This node is becoming optimally reducible.
356 moveToOptimallyReducibleNodes(NId);
357 } else if (NMd.getReductionState() ==
358 NodeMetadata::NotProvablyAllocatable &&
359 NMd.isConservativelyAllocatable()) {
360 // This node just became conservatively allocatable.
361 moveToConservativelyAllocatableNodes(NId);
362 }
363 }
364
removeFromCurrentSet(NodeId NId)365 void removeFromCurrentSet(NodeId NId) {
366 switch (G.getNodeMetadata(NId).getReductionState()) {
367 case NodeMetadata::Unprocessed: break;
368 case NodeMetadata::OptimallyReducible:
369 assert(OptimallyReducibleNodes.find(NId) !=
370 OptimallyReducibleNodes.end() &&
371 "Node not in optimally reducible set.");
372 OptimallyReducibleNodes.erase(NId);
373 break;
374 case NodeMetadata::ConservativelyAllocatable:
375 assert(ConservativelyAllocatableNodes.find(NId) !=
376 ConservativelyAllocatableNodes.end() &&
377 "Node not in conservatively allocatable set.");
378 ConservativelyAllocatableNodes.erase(NId);
379 break;
380 case NodeMetadata::NotProvablyAllocatable:
381 assert(NotProvablyAllocatableNodes.find(NId) !=
382 NotProvablyAllocatableNodes.end() &&
383 "Node not in not-provably-allocatable set.");
384 NotProvablyAllocatableNodes.erase(NId);
385 break;
386 }
387 }
388
moveToOptimallyReducibleNodes(NodeId NId)389 void moveToOptimallyReducibleNodes(NodeId NId) {
390 removeFromCurrentSet(NId);
391 OptimallyReducibleNodes.insert(NId);
392 G.getNodeMetadata(NId).setReductionState(
393 NodeMetadata::OptimallyReducible);
394 }
395
moveToConservativelyAllocatableNodes(NodeId NId)396 void moveToConservativelyAllocatableNodes(NodeId NId) {
397 removeFromCurrentSet(NId);
398 ConservativelyAllocatableNodes.insert(NId);
399 G.getNodeMetadata(NId).setReductionState(
400 NodeMetadata::ConservativelyAllocatable);
401 }
402
moveToNotProvablyAllocatableNodes(NodeId NId)403 void moveToNotProvablyAllocatableNodes(NodeId NId) {
404 removeFromCurrentSet(NId);
405 NotProvablyAllocatableNodes.insert(NId);
406 G.getNodeMetadata(NId).setReductionState(
407 NodeMetadata::NotProvablyAllocatable);
408 }
409
setup()410 void setup() {
411 // Set up worklists.
412 for (auto NId : G.nodeIds()) {
413 if (G.getNodeDegree(NId) < 3)
414 moveToOptimallyReducibleNodes(NId);
415 else if (G.getNodeMetadata(NId).isConservativelyAllocatable())
416 moveToConservativelyAllocatableNodes(NId);
417 else
418 moveToNotProvablyAllocatableNodes(NId);
419 }
420 }
421
422 // Compute a reduction order for the graph by iteratively applying PBQP
423 // reduction rules. Locally optimal rules are applied whenever possible (R0,
424 // R1, R2). If no locally-optimal rules apply then any conservatively
425 // allocatable node is reduced. Finally, if no conservatively allocatable
426 // node exists then the node with the lowest spill-cost:degree ratio is
427 // selected.
reduce()428 std::vector<GraphBase::NodeId> reduce() {
429 assert(!G.empty() && "Cannot reduce empty graph.");
430
431 using NodeId = GraphBase::NodeId;
432 std::vector<NodeId> NodeStack;
433
434 // Consume worklists.
435 while (true) {
436 if (!OptimallyReducibleNodes.empty()) {
437 NodeSet::iterator NItr = OptimallyReducibleNodes.begin();
438 NodeId NId = *NItr;
439 OptimallyReducibleNodes.erase(NItr);
440 NodeStack.push_back(NId);
441 switch (G.getNodeDegree(NId)) {
442 case 0:
443 break;
444 case 1:
445 applyR1(G, NId);
446 break;
447 case 2:
448 applyR2(G, NId);
449 break;
450 default: llvm_unreachable("Not an optimally reducible node.");
451 }
452 } else if (!ConservativelyAllocatableNodes.empty()) {
453 // Conservatively allocatable nodes will never spill. For now just
454 // take the first node in the set and push it on the stack. When we
455 // start optimizing more heavily for register preferencing, it may
456 // would be better to push nodes with lower 'expected' or worst-case
457 // register costs first (since early nodes are the most
458 // constrained).
459 NodeSet::iterator NItr = ConservativelyAllocatableNodes.begin();
460 NodeId NId = *NItr;
461 ConservativelyAllocatableNodes.erase(NItr);
462 NodeStack.push_back(NId);
463 G.disconnectAllNeighborsFromNode(NId);
464 } else if (!NotProvablyAllocatableNodes.empty()) {
465 NodeSet::iterator NItr = llvm::min_element(NotProvablyAllocatableNodes,
466 SpillCostComparator(G));
467 NodeId NId = *NItr;
468 NotProvablyAllocatableNodes.erase(NItr);
469 NodeStack.push_back(NId);
470 G.disconnectAllNeighborsFromNode(NId);
471 } else
472 break;
473 }
474
475 return NodeStack;
476 }
477
478 class SpillCostComparator {
479 public:
SpillCostComparator(const Graph & G)480 SpillCostComparator(const Graph& G) : G(G) {}
481
operator()482 bool operator()(NodeId N1Id, NodeId N2Id) {
483 PBQPNum N1SC = G.getNodeCosts(N1Id)[0];
484 PBQPNum N2SC = G.getNodeCosts(N2Id)[0];
485 if (N1SC == N2SC)
486 return G.getNodeDegree(N1Id) < G.getNodeDegree(N2Id);
487 return N1SC < N2SC;
488 }
489
490 private:
491 const Graph& G;
492 };
493
494 Graph& G;
495 using NodeSet = std::set<NodeId>;
496 NodeSet OptimallyReducibleNodes;
497 NodeSet ConservativelyAllocatableNodes;
498 NodeSet NotProvablyAllocatableNodes;
499 };
500
501 class PBQPRAGraph : public PBQP::Graph<RegAllocSolverImpl> {
502 private:
503 using BaseT = PBQP::Graph<RegAllocSolverImpl>;
504
505 public:
PBQPRAGraph(GraphMetadata Metadata)506 PBQPRAGraph(GraphMetadata Metadata) : BaseT(std::move(Metadata)) {}
507
508 /// Dump this graph to dbgs().
509 void dump() const;
510
511 /// Dump this graph to an output stream.
512 /// @param OS Output stream to print on.
513 void dump(raw_ostream &OS) const;
514
515 /// Print a representation of this graph in DOT format.
516 /// @param OS Output stream to print on.
517 void printDot(raw_ostream &OS) const;
518 };
519
solve(PBQPRAGraph & G)520 inline Solution solve(PBQPRAGraph& G) {
521 if (G.empty())
522 return Solution();
523 RegAllocSolverImpl RegAllocSolver(G);
524 return RegAllocSolver.solve();
525 }
526
527 } // end namespace RegAlloc
528 } // end namespace PBQP
529
530 /// Create a PBQP register allocator instance.
531 FunctionPass *
532 createPBQPRegisterAllocator(char *customPassID = nullptr);
533
534 } // end namespace llvm
535
536 #endif // LLVM_CODEGEN_REGALLOCPBQP_H
537