1 //===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef LLVM_CODEGEN_PBQP_MATH_H
10 #define LLVM_CODEGEN_PBQP_MATH_H
11
12 #include "llvm/ADT/ArrayRef.h"
13 #include "llvm/ADT/Hashing.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Support/InterleavedRange.h"
16 #include <algorithm>
17 #include <cassert>
18 #include <functional>
19 #include <memory>
20
21 namespace llvm {
22 namespace PBQP {
23
24 using PBQPNum = float;
25
26 /// PBQP Vector class.
27 class Vector {
28 public:
29 /// Construct a PBQP vector of the given size.
Vector(unsigned Length)30 explicit Vector(unsigned Length) : Data(Length) {}
31
32 /// Construct a PBQP vector with initializer.
Vector(unsigned Length,PBQPNum InitVal)33 Vector(unsigned Length, PBQPNum InitVal) : Data(Length) {
34 std::fill(begin(), end(), InitVal);
35 }
36
37 /// Copy construct a PBQP vector.
Vector(const Vector & V)38 Vector(const Vector &V) : Data(ArrayRef<PBQPNum>(V.Data)) {}
39
40 /// Move construct a PBQP vector.
Vector(Vector && V)41 Vector(Vector &&V) : Data(std::move(V.Data)) {}
42
43 // Iterator-based access.
begin()44 const PBQPNum *begin() const { return Data.begin(); }
end()45 const PBQPNum *end() const { return Data.end(); }
begin()46 PBQPNum *begin() { return Data.begin(); }
end()47 PBQPNum *end() { return Data.end(); }
48
49 /// Comparison operator.
50 bool operator==(const Vector &V) const {
51 assert(!Data.empty() && "Invalid vector");
52 return llvm::equal(*this, V);
53 }
54
55 /// Return the length of the vector
getLength()56 unsigned getLength() const {
57 assert(!Data.empty() && "Invalid vector");
58 return Data.size();
59 }
60
61 /// Element access.
62 PBQPNum& operator[](unsigned Index) {
63 assert(!Data.empty() && "Invalid vector");
64 assert(Index < Data.size() && "Vector element access out of bounds.");
65 return Data[Index];
66 }
67
68 /// Const element access.
69 const PBQPNum& operator[](unsigned Index) const {
70 assert(!Data.empty() && "Invalid vector");
71 assert(Index < Data.size() && "Vector element access out of bounds.");
72 return Data[Index];
73 }
74
75 /// Add another vector to this one.
76 Vector& operator+=(const Vector &V) {
77 assert(!Data.empty() && "Invalid vector");
78 assert(Data.size() == V.Data.size() && "Vector length mismatch.");
79 std::transform(begin(), end(), V.begin(), begin(), std::plus<PBQPNum>());
80 return *this;
81 }
82
83 /// Returns the index of the minimum value in this vector
minIndex()84 unsigned minIndex() const {
85 assert(!Data.empty() && "Invalid vector");
86 return llvm::min_element(*this) - begin();
87 }
88
89 private:
90 OwningArrayRef<PBQPNum> Data;
91 };
92
93 /// Return a hash_value for the given vector.
hash_value(const Vector & V)94 inline hash_code hash_value(const Vector &V) {
95 const unsigned *VBegin = reinterpret_cast<const unsigned *>(V.begin());
96 const unsigned *VEnd = reinterpret_cast<const unsigned *>(V.end());
97 return hash_combine(V.getLength(), hash_combine_range(VBegin, VEnd));
98 }
99
100 /// Output a textual representation of the given vector on the given
101 /// output stream.
102 template <typename OStream>
103 OStream& operator<<(OStream &OS, const Vector &V) {
104 assert((V.getLength() != 0) && "Zero-length vector badness.");
105 OS << "[ " << llvm::interleaved(V) << " ]";
106 return OS;
107 }
108
109 /// PBQP Matrix class
110 class Matrix {
111 private:
112 friend hash_code hash_value(const Matrix &);
113
114 public:
115 /// Construct a PBQP Matrix with the given dimensions.
Matrix(unsigned Rows,unsigned Cols)116 Matrix(unsigned Rows, unsigned Cols) :
117 Rows(Rows), Cols(Cols), Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
118 }
119
120 /// Construct a PBQP Matrix with the given dimensions and initial
121 /// value.
Matrix(unsigned Rows,unsigned Cols,PBQPNum InitVal)122 Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal)
123 : Rows(Rows), Cols(Cols),
124 Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
125 std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal);
126 }
127
128 /// Copy construct a PBQP matrix.
Matrix(const Matrix & M)129 Matrix(const Matrix &M)
130 : Rows(M.Rows), Cols(M.Cols),
131 Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
132 std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get());
133 }
134
135 /// Move construct a PBQP matrix.
Matrix(Matrix && M)136 Matrix(Matrix &&M)
137 : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) {
138 M.Rows = M.Cols = 0;
139 }
140
141 /// Comparison operator.
142 bool operator==(const Matrix &M) const {
143 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
144 if (Rows != M.Rows || Cols != M.Cols)
145 return false;
146 return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get());
147 }
148
149 /// Return the number of rows in this matrix.
getRows()150 unsigned getRows() const {
151 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
152 return Rows;
153 }
154
155 /// Return the number of cols in this matrix.
getCols()156 unsigned getCols() const {
157 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
158 return Cols;
159 }
160
161 /// Matrix element access.
162 PBQPNum* operator[](unsigned R) {
163 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
164 assert(R < Rows && "Row out of bounds.");
165 return Data.get() + (R * Cols);
166 }
167
168 /// Matrix element access.
169 const PBQPNum* operator[](unsigned R) const {
170 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
171 assert(R < Rows && "Row out of bounds.");
172 return Data.get() + (R * Cols);
173 }
174
175 /// Returns the given row as a vector.
getRowAsVector(unsigned R)176 Vector getRowAsVector(unsigned R) const {
177 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
178 Vector V(Cols);
179 for (unsigned C = 0; C < Cols; ++C)
180 V[C] = (*this)[R][C];
181 return V;
182 }
183
184 /// Returns the given column as a vector.
getColAsVector(unsigned C)185 Vector getColAsVector(unsigned C) const {
186 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
187 Vector V(Rows);
188 for (unsigned R = 0; R < Rows; ++R)
189 V[R] = (*this)[R][C];
190 return V;
191 }
192
193 /// Matrix transpose.
transpose()194 Matrix transpose() const {
195 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
196 Matrix M(Cols, Rows);
197 for (unsigned r = 0; r < Rows; ++r)
198 for (unsigned c = 0; c < Cols; ++c)
199 M[c][r] = (*this)[r][c];
200 return M;
201 }
202
203 /// Add the given matrix to this one.
204 Matrix& operator+=(const Matrix &M) {
205 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
206 assert(Rows == M.Rows && Cols == M.Cols &&
207 "Matrix dimensions mismatch.");
208 std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(),
209 Data.get(), std::plus<PBQPNum>());
210 return *this;
211 }
212
213 Matrix operator+(const Matrix &M) {
214 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
215 Matrix Tmp(*this);
216 Tmp += M;
217 return Tmp;
218 }
219
220 private:
221 unsigned Rows, Cols;
222 std::unique_ptr<PBQPNum []> Data;
223 };
224
225 /// Return a hash_code for the given matrix.
hash_value(const Matrix & M)226 inline hash_code hash_value(const Matrix &M) {
227 unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get());
228 unsigned *MEnd =
229 reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols));
230 return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd));
231 }
232
233 /// Output a textual representation of the given matrix on the given
234 /// output stream.
235 template <typename OStream>
236 OStream& operator<<(OStream &OS, const Matrix &M) {
237 assert((M.getRows() != 0) && "Zero-row matrix badness.");
238 for (unsigned i = 0; i < M.getRows(); ++i)
239 OS << M.getRowAsVector(i) << "\n";
240 return OS;
241 }
242
243 template <typename Metadata>
244 class MDVector : public Vector {
245 public:
MDVector(const Vector & v)246 MDVector(const Vector &v) : Vector(v), md(*this) {}
MDVector(Vector && v)247 MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { }
248
getMetadata()249 const Metadata& getMetadata() const { return md; }
250
251 private:
252 Metadata md;
253 };
254
255 template <typename Metadata>
hash_value(const MDVector<Metadata> & V)256 inline hash_code hash_value(const MDVector<Metadata> &V) {
257 return hash_value(static_cast<const Vector&>(V));
258 }
259
260 template <typename Metadata>
261 class MDMatrix : public Matrix {
262 public:
MDMatrix(const Matrix & m)263 MDMatrix(const Matrix &m) : Matrix(m), md(*this) {}
MDMatrix(Matrix && m)264 MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { }
265
getMetadata()266 const Metadata& getMetadata() const { return md; }
267
268 private:
269 Metadata md;
270 };
271
272 template <typename Metadata>
hash_value(const MDMatrix<Metadata> & M)273 inline hash_code hash_value(const MDMatrix<Metadata> &M) {
274 return hash_value(static_cast<const Matrix&>(M));
275 }
276
277 } // end namespace PBQP
278 } // end namespace llvm
279
280 #endif // LLVM_CODEGEN_PBQP_MATH_H
281