Lines Matching full:matrix
1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
9 // Lower matrix intrinsics to vector operations.
53 #define DEBUG_TYPE "lower-matrix-intrinsics"
56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
57 cl::desc("Enable/disable fusing matrix instructions."));
60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
62 "Tile size for matrix instruction fusion using square-shaped tiles."));
63 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
67 "force-fuse-matrix", cl::init(false), cl::Hidden,
68 cl::desc("Force matrix instruction fusion even if not profitable."));
70 "matrix-allow-contract", cl::init(false), cl::Hidden,
75 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
76 cl::desc("Enable/disable matrix shape verification."),
82 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
83 cl::desc("Sets the default matrix layout"),
89 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
114 /// matrix with a scalar).
135 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
141 // (= number of rows of the matrix). For row-major matrixes, the function
143 // number of elements in a column (= number of columns of the matrix).
145 // Consider a 4x4 matrix in column-mjaor layout like below
153 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
156 // of the sub-matrix.
299 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
301 /// Currently, the lowering for each matrix intrinsic is done as follows:
308 /// If not, split the operand vector containing an embedded matrix into
311 /// yields a set of column vectors containing result matrix. Note that we
316 /// column matrix when lowering the user. For other uses, we embed the
317 /// result matrix in a flat vector and update the use.
318 /// 2.4. Cache the result column matrix for the instruction we lowered
331 …stimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
333 /// Number of stores emitted to generate this matrix.
335 /// Number of loads emitted to generate this matrix.
337 /// Number of compute operations emitted to generate this matrix.
339 /// Most of the time transposes can be fused with matrix multiplies or can
353 /// Wrapper class representing a matrix as a set of vectors, either in row or
432 /// Embed the vectors of the matrix into a flat vector by concatenating
476 /// matrix is column-major, the result vector is extracted from a column
505 /// Map from instructions to their produced column matrix.
547 /// Return the set of vectors that a matrix value is lowered to.
549 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
550 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
558 "The vector size must match the number of matrix elements"); in getMatrix()
561 // return the existing matrix, if it matches the requested shape in getMatrix()
567 // Return the found matrix, if its shape matches the requested shape in getMatrix()
604 "Matrix shape verification failed, compilation aborted!"); in setShapeInfo()
642 /// either based on the information provided by matrix intrinsics or known
720 // Nothing to do, no matrix input. in propagateShapeBackward()
953 // Initially only the shape of matrix intrinsics is known. in Visit()
973 // Avoid unnecessary work if there are no matrix intrinsics in the function. in Visit()
986 dbgs() << "Dump after matrix transpose optimization:\n"; in Visit()
997 // fusion (currently only matrix multiplies). in Visit()
1122 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1145 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1174 /// Lowers llvm.matrix.column.major.load.
1176 /// The intrinsic loads a matrix from memory using a stride between columns.
1187 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1203 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1227 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, in LowerStore() argument
1230 auto StoreVal = getMatrix(Matrix, Shape, Builder); in LowerStore()
1232 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, in LowerStore()
1237 /// Lowers llvm.matrix.column.major.store.
1239 /// The intrinsic store a matrix back memory using a stride between columns.
1243 Value *Matrix = Inst->getArgOperand(0); in LowerColumnMajorStore() local
1246 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, in LowerColumnMajorStore()
1307 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1309 /// cached value when they are lowered. For other users, \p Matrix is
1312 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, in finalizeLowering() argument
1314 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); in finalizeLowering()
1316 assert(inserted.second && "multiple matrix lowering mapping"); in finalizeLowering()
1323 Flattened = Matrix.embedInVector(Builder); in finalizeLowering()
1419 // between the flattened and matrix versions. in lowerDotProduct()
1508 // pack scalar back into a matrix and then replace matmul inst in lowerDotProduct()
1538 "operands must agree on matrix layout"); in emitMatrixMultiply()
1859 /// Try to lower matrix multiply chains by fusing operations.
2009 /// Lowers llvm.matrix.multiply.
2019 "Matrix multiply argument element types do not match."); in LowerMultiply()
2028 "Matrix multiply result element type does not match arguments."); in LowerMultiply()
2035 /// Lowers llvm.matrix.transpose.
2114 "operands must agree on matrix layout"); in VisitBinaryOperator()
2134 llvm_unreachable("Unsupported binary operator for matrix"); in VisitBinaryOperator()
2170 llvm_unreachable("Unsupported unary operator for matrix"); in VisitUnaryOperator()
2184 /// Helper to linearize a matrix expression tree into a string. Currently
2185 /// matrix expressions are linarized by starting at an expression leaf and
2195 /// matrix instructions.
2202 /// Set of matrix expressions in the scope of a given DISubprogram.
2252 /// Returns true if \p V is a matrix value in the given subprogram.
2255 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2268 /// Write the called function name. Handles calls to llvm.matrix.*
2276 if (!Name.starts_with("llvm.matrix")) { in writeFnName()
2282 .drop_front(StringRef("llvm.matrix.").size())); in writeFnName()
2333 /// either print the constant or "scalar"/"matrix" for other values.
2360 TmpStream << "matrix"; in write()
2407 // non-matrix ops. in linearizeExpr()
2408 write("matrix"); in linearizeExpr()
2443 /// Generate remarks for matrix operations in a function. To generate remarks
2444 /// for matrix expressions, the following approach is used:
2445 /// 1. Use the inlined-at debug information to group matrix operations to the
2447 /// 2. Collect leaves of matrix expressions (done in
2449 // mapping. Leaves are lowered matrix instructions without other matrix
2452 /// matrix expression. The expression is linearized by a recursive
2453 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2483 /// to all visited expressions in \p Shared. Limit the matrix operations to
2501 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2535 // Map matrix operations to their containting subprograms, by traversing in emitRemarks()
2581 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, in emitRemarks()