1480093f4SDimitry Andric //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2480093f4SDimitry Andric //
3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6480093f4SDimitry Andric //
7480093f4SDimitry Andric //===----------------------------------------------------------------------===//
8480093f4SDimitry Andric //
9480093f4SDimitry Andric // Lower matrix intrinsics to vector operations.
10480093f4SDimitry Andric //
11480093f4SDimitry Andric // TODO:
125ffd83dbSDimitry Andric // * Improve fusion:
135ffd83dbSDimitry Andric // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
145ffd83dbSDimitry Andric // transposed.
155ffd83dbSDimitry Andric // * Improve cost-modeling, e.g. choose different number of rows/columns
165ffd83dbSDimitry Andric // columns for tiles, consider cost of copies on alias.
17480093f4SDimitry Andric //
18480093f4SDimitry Andric //===----------------------------------------------------------------------===//
19480093f4SDimitry Andric
20480093f4SDimitry Andric #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21480093f4SDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
22*0fca6ea1SDimitry Andric #include "llvm/ADT/ScopeExit.h"
235f757f3fSDimitry Andric #include "llvm/ADT/SmallSet.h"
24480093f4SDimitry Andric #include "llvm/ADT/SmallVector.h"
255ffd83dbSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h"
265ffd83dbSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
2781ad6265SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
285ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h"
29480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
305ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h"
31480093f4SDimitry Andric #include "llvm/Analysis/VectorUtils.h"
32480093f4SDimitry Andric #include "llvm/IR/CFG.h"
33480093f4SDimitry Andric #include "llvm/IR/DataLayout.h"
345ffd83dbSDimitry Andric #include "llvm/IR/DebugInfoMetadata.h"
35480093f4SDimitry Andric #include "llvm/IR/Function.h"
36480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h"
37480093f4SDimitry Andric #include "llvm/IR/Instructions.h"
38480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
39fe6060f1SDimitry Andric #include "llvm/IR/MatrixBuilder.h"
40480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h"
415ffd83dbSDimitry Andric #include "llvm/Support/Alignment.h"
425ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h"
43480093f4SDimitry Andric #include "llvm/Support/Debug.h"
445ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
45e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
46e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h"
47480093f4SDimitry Andric
48bdd1243dSDimitry Andric #include <cmath>
49bdd1243dSDimitry Andric
50480093f4SDimitry Andric using namespace llvm;
51480093f4SDimitry Andric using namespace PatternMatch;
52480093f4SDimitry Andric
53480093f4SDimitry Andric #define DEBUG_TYPE "lower-matrix-intrinsics"
54480093f4SDimitry Andric
555ffd83dbSDimitry Andric static cl::opt<bool>
565ffd83dbSDimitry Andric FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
575ffd83dbSDimitry Andric cl::desc("Enable/disable fusing matrix instructions."));
585ffd83dbSDimitry Andric // TODO: Allow and use non-square tiles.
595ffd83dbSDimitry Andric static cl::opt<unsigned> TileSize(
605ffd83dbSDimitry Andric "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
615ffd83dbSDimitry Andric cl::desc(
625ffd83dbSDimitry Andric "Tile size for matrix instruction fusion using square-shaped tiles."));
63e8d8bef9SDimitry Andric static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
64e8d8bef9SDimitry Andric cl::Hidden,
65e8d8bef9SDimitry Andric cl::desc("Generate loop nest for tiling."));
665ffd83dbSDimitry Andric static cl::opt<bool> ForceFusion(
675ffd83dbSDimitry Andric "force-fuse-matrix", cl::init(false), cl::Hidden,
685ffd83dbSDimitry Andric cl::desc("Force matrix instruction fusion even if not profitable."));
69480093f4SDimitry Andric static cl::opt<bool> AllowContractEnabled(
70480093f4SDimitry Andric "matrix-allow-contract", cl::init(false), cl::Hidden,
71480093f4SDimitry Andric cl::desc("Allow the use of FMAs if available and profitable. This may "
72480093f4SDimitry Andric "result in different results, due to less rounding error."));
73480093f4SDimitry Andric
7406c3fb27SDimitry Andric static cl::opt<bool>
7506c3fb27SDimitry Andric VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
7606c3fb27SDimitry Andric cl::desc("Enable/disable matrix shape verification."),
7706c3fb27SDimitry Andric cl::init(false));
7806c3fb27SDimitry Andric
795ffd83dbSDimitry Andric enum class MatrixLayoutTy { ColumnMajor, RowMajor };
805ffd83dbSDimitry Andric
815ffd83dbSDimitry Andric static cl::opt<MatrixLayoutTy> MatrixLayout(
825ffd83dbSDimitry Andric "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
835ffd83dbSDimitry Andric cl::desc("Sets the default matrix layout"),
845ffd83dbSDimitry Andric cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
855ffd83dbSDimitry Andric "Use column-major layout"),
865ffd83dbSDimitry Andric clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
875ffd83dbSDimitry Andric "Use row-major layout")));
885ffd83dbSDimitry Andric
89bdd1243dSDimitry Andric static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
90bdd1243dSDimitry Andric cl::init(false));
91bdd1243dSDimitry Andric
925ffd83dbSDimitry Andric /// Helper function to either return Scope, if it is a subprogram or the
935ffd83dbSDimitry Andric /// attached subprogram for a local scope.
getSubprogram(DIScope * Scope)945ffd83dbSDimitry Andric static DISubprogram *getSubprogram(DIScope *Scope) {
955ffd83dbSDimitry Andric if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
965ffd83dbSDimitry Andric return Subprogram;
975ffd83dbSDimitry Andric return cast<DILocalScope>(Scope)->getSubprogram();
985ffd83dbSDimitry Andric }
995ffd83dbSDimitry Andric
100bdd1243dSDimitry Andric /// Erase \p V from \p BB and move \II forward to avoid invalidating
101bdd1243dSDimitry Andric /// iterators.
eraseFromParentAndMove(Value * V,BasicBlock::reverse_iterator & II,BasicBlock & BB)102bdd1243dSDimitry Andric static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
103bdd1243dSDimitry Andric BasicBlock &BB) {
104bdd1243dSDimitry Andric auto *Inst = cast<Instruction>(V);
105bdd1243dSDimitry Andric // Still used, don't erase.
106bdd1243dSDimitry Andric if (!Inst->use_empty())
107bdd1243dSDimitry Andric return;
108bdd1243dSDimitry Andric if (II != BB.rend() && Inst == &*II)
109bdd1243dSDimitry Andric ++II;
110bdd1243dSDimitry Andric Inst->eraseFromParent();
111bdd1243dSDimitry Andric }
112bdd1243dSDimitry Andric
113bdd1243dSDimitry Andric /// Return true if V is a splat of a value (which is used when multiplying a
114bdd1243dSDimitry Andric /// matrix with a scalar).
isSplat(Value * V)115bdd1243dSDimitry Andric static bool isSplat(Value *V) {
116bdd1243dSDimitry Andric if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
117bdd1243dSDimitry Andric return SV->isZeroEltSplat();
118bdd1243dSDimitry Andric return false;
119bdd1243dSDimitry Andric }
120bdd1243dSDimitry Andric
121bdd1243dSDimitry Andric /// Match any mul operation (fp or integer).
122bdd1243dSDimitry Andric template <typename LTy, typename RTy>
m_AnyMul(const LTy & L,const RTy & R)123bdd1243dSDimitry Andric auto m_AnyMul(const LTy &L, const RTy &R) {
124bdd1243dSDimitry Andric return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
125bdd1243dSDimitry Andric }
126bdd1243dSDimitry Andric
127bdd1243dSDimitry Andric /// Match any add operation (fp or integer).
128bdd1243dSDimitry Andric template <typename LTy, typename RTy>
m_AnyAdd(const LTy & L,const RTy & R)129bdd1243dSDimitry Andric auto m_AnyAdd(const LTy &L, const RTy &R) {
130bdd1243dSDimitry Andric return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
131bdd1243dSDimitry Andric }
132bdd1243dSDimitry Andric
133480093f4SDimitry Andric namespace {
134480093f4SDimitry Andric
1355ffd83dbSDimitry Andric // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
1365ffd83dbSDimitry Andric // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
1375ffd83dbSDimitry Andric // assuming \p Stride elements between start two consecutive vectors.
1385ffd83dbSDimitry Andric // \p Stride must be >= \p NumElements.
1395ffd83dbSDimitry Andric // For column-major matrixes, the function computes the address of a column
1405ffd83dbSDimitry Andric // vectors and \p NumElements must be set to the number of elements in a column
1415ffd83dbSDimitry Andric // (= number of rows of the matrix). For row-major matrixes, the function
1425ffd83dbSDimitry Andric // computes the address of a row vector and \p NumElements must be set to the
1435ffd83dbSDimitry Andric // number of elements in a column (= number of columns of the matrix).
144480093f4SDimitry Andric //
1455ffd83dbSDimitry Andric // Consider a 4x4 matrix in column-mjaor layout like below
146480093f4SDimitry Andric //
147480093f4SDimitry Andric // 0 1 2 3
148480093f4SDimitry Andric // 0 v_0_0 v_0_1 v_0_2 v_0_3
149480093f4SDimitry Andric // 1 v_1_0 v_1_1 v_1_2 v_1_3
150480093f4SDimitry Andric // 2 v_2_0 v_2_1 v_2_2 v_2_3
151480093f4SDimitry Andric // 3 v_3_0 v_3_1 v_3_2 v_3_3
152480093f4SDimitry Andric
153480093f4SDimitry Andric // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
154480093f4SDimitry Andric // we need a pointer to the first element of the submatrix as base pointer.
1555ffd83dbSDimitry Andric // Then we can use computeVectorAddr to compute the addresses for the columns
156480093f4SDimitry Andric // of the sub-matrix.
157480093f4SDimitry Andric //
1585ffd83dbSDimitry Andric // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
159480093f4SDimitry Andric // -> just returns Base
1605ffd83dbSDimitry Andric // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
161480093f4SDimitry Andric // -> returns Base + (1 * 4)
1625ffd83dbSDimitry Andric // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
163480093f4SDimitry Andric // -> returns Base + (2 * 4)
164480093f4SDimitry Andric //
165480093f4SDimitry Andric // The graphic below illustrates the number of elements in a column (marked
166480093f4SDimitry Andric // with |) and the number of skipped elements (marked with }).
167480093f4SDimitry Andric //
168480093f4SDimitry Andric // v_0_0 v_0_1 {v_0_2 {v_0_3
169480093f4SDimitry Andric // Base Col 1 Col 2
170480093f4SDimitry Andric // | | |
171480093f4SDimitry Andric // v_1_0 |v_1_1 |v_1_2 |v_1_3
172480093f4SDimitry Andric // v_2_0 |v_2_1 |v_2_2 |v_2_3
173480093f4SDimitry Andric // v_3_0 {v_3_1 {v_3_2 v_3_3
174480093f4SDimitry Andric //
computeVectorAddr(Value * BasePtr,Value * VecIdx,Value * Stride,unsigned NumElements,Type * EltType,IRBuilder<> & Builder)1755ffd83dbSDimitry Andric Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
1765ffd83dbSDimitry Andric unsigned NumElements, Type *EltType,
177480093f4SDimitry Andric IRBuilder<> &Builder) {
178480093f4SDimitry Andric
179480093f4SDimitry Andric assert((!isa<ConstantInt>(Stride) ||
1805ffd83dbSDimitry Andric cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
1815ffd83dbSDimitry Andric "Stride must be >= the number of elements in the result vector.");
182480093f4SDimitry Andric
1835ffd83dbSDimitry Andric // Compute the start of the vector with index VecIdx as VecIdx * Stride.
1845ffd83dbSDimitry Andric Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
185480093f4SDimitry Andric
1865ffd83dbSDimitry Andric // Get pointer to the start of the selected vector. Skip GEP creation,
1875ffd83dbSDimitry Andric // if we select vector 0.
1885ffd83dbSDimitry Andric if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
1895ffd83dbSDimitry Andric VecStart = BasePtr;
190480093f4SDimitry Andric else
1915ffd83dbSDimitry Andric VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
192480093f4SDimitry Andric
1935f757f3fSDimitry Andric return VecStart;
194480093f4SDimitry Andric }
195480093f4SDimitry Andric
196*0fca6ea1SDimitry Andric namespace {
197*0fca6ea1SDimitry Andric struct ShapeInfo {
198*0fca6ea1SDimitry Andric unsigned NumRows;
199*0fca6ea1SDimitry Andric unsigned NumColumns;
200*0fca6ea1SDimitry Andric
201*0fca6ea1SDimitry Andric bool IsColumnMajor;
202*0fca6ea1SDimitry Andric
ShapeInfo__anon821fcdb70111::__anon821fcdb70211::ShapeInfo203*0fca6ea1SDimitry Andric ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
204*0fca6ea1SDimitry Andric : NumRows(NumRows), NumColumns(NumColumns),
205*0fca6ea1SDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
206*0fca6ea1SDimitry Andric
ShapeInfo__anon821fcdb70111::__anon821fcdb70211::ShapeInfo207*0fca6ea1SDimitry Andric ShapeInfo(Value *NumRows, Value *NumColumns)
208*0fca6ea1SDimitry Andric : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
209*0fca6ea1SDimitry Andric cast<ConstantInt>(NumColumns)->getZExtValue()) {}
210*0fca6ea1SDimitry Andric
operator ==__anon821fcdb70111::__anon821fcdb70211::ShapeInfo211*0fca6ea1SDimitry Andric bool operator==(const ShapeInfo &other) {
212*0fca6ea1SDimitry Andric return NumRows == other.NumRows && NumColumns == other.NumColumns;
213*0fca6ea1SDimitry Andric }
operator !=__anon821fcdb70111::__anon821fcdb70211::ShapeInfo214*0fca6ea1SDimitry Andric bool operator!=(const ShapeInfo &other) { return !(*this == other); }
215*0fca6ea1SDimitry Andric
216*0fca6ea1SDimitry Andric /// Returns true if shape-information is defined, meaning both dimensions
217*0fca6ea1SDimitry Andric /// are != 0.
operator bool__anon821fcdb70111::__anon821fcdb70211::ShapeInfo218*0fca6ea1SDimitry Andric operator bool() const {
219*0fca6ea1SDimitry Andric assert(NumRows == 0 || NumColumns != 0);
220*0fca6ea1SDimitry Andric return NumRows != 0;
221*0fca6ea1SDimitry Andric }
222*0fca6ea1SDimitry Andric
getStride__anon821fcdb70111::__anon821fcdb70211::ShapeInfo223*0fca6ea1SDimitry Andric unsigned getStride() const {
224*0fca6ea1SDimitry Andric if (IsColumnMajor)
225*0fca6ea1SDimitry Andric return NumRows;
226*0fca6ea1SDimitry Andric return NumColumns;
227*0fca6ea1SDimitry Andric }
228*0fca6ea1SDimitry Andric
getNumVectors__anon821fcdb70111::__anon821fcdb70211::ShapeInfo229*0fca6ea1SDimitry Andric unsigned getNumVectors() const {
230*0fca6ea1SDimitry Andric if (IsColumnMajor)
231*0fca6ea1SDimitry Andric return NumColumns;
232*0fca6ea1SDimitry Andric return NumRows;
233*0fca6ea1SDimitry Andric }
234*0fca6ea1SDimitry Andric
235*0fca6ea1SDimitry Andric /// Returns the transposed shape.
t__anon821fcdb70111::__anon821fcdb70211::ShapeInfo236*0fca6ea1SDimitry Andric ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
237*0fca6ea1SDimitry Andric };
238*0fca6ea1SDimitry Andric } // namespace
239*0fca6ea1SDimitry Andric
isUniformShape(Value * V)240*0fca6ea1SDimitry Andric static bool isUniformShape(Value *V) {
241*0fca6ea1SDimitry Andric Instruction *I = dyn_cast<Instruction>(V);
242*0fca6ea1SDimitry Andric if (!I)
243*0fca6ea1SDimitry Andric return true;
244*0fca6ea1SDimitry Andric
245*0fca6ea1SDimitry Andric switch (I->getOpcode()) {
246*0fca6ea1SDimitry Andric case Instruction::FAdd:
247*0fca6ea1SDimitry Andric case Instruction::FSub:
248*0fca6ea1SDimitry Andric case Instruction::FMul: // Scalar multiply.
249*0fca6ea1SDimitry Andric case Instruction::FNeg:
250*0fca6ea1SDimitry Andric case Instruction::Add:
251*0fca6ea1SDimitry Andric case Instruction::Mul:
252*0fca6ea1SDimitry Andric case Instruction::Sub:
253*0fca6ea1SDimitry Andric return true;
254*0fca6ea1SDimitry Andric default:
255*0fca6ea1SDimitry Andric return false;
256*0fca6ea1SDimitry Andric }
257*0fca6ea1SDimitry Andric }
258*0fca6ea1SDimitry Andric
259*0fca6ea1SDimitry Andric /// Return the ShapeInfo for the result of \p I, it it can be determined.
260*0fca6ea1SDimitry Andric static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction * I,const ValueMap<Value *,ShapeInfo> & ShapeMap)261*0fca6ea1SDimitry Andric computeShapeInfoForInst(Instruction *I,
262*0fca6ea1SDimitry Andric const ValueMap<Value *, ShapeInfo> &ShapeMap) {
263*0fca6ea1SDimitry Andric Value *M;
264*0fca6ea1SDimitry Andric Value *N;
265*0fca6ea1SDimitry Andric Value *K;
266*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
267*0fca6ea1SDimitry Andric m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
268*0fca6ea1SDimitry Andric return ShapeInfo(M, K);
269*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
270*0fca6ea1SDimitry Andric m_Value(N)))) {
271*0fca6ea1SDimitry Andric // Flip dimensions.
272*0fca6ea1SDimitry Andric return ShapeInfo(N, M);
273*0fca6ea1SDimitry Andric }
274*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
275*0fca6ea1SDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
276*0fca6ea1SDimitry Andric m_Value(N))))
277*0fca6ea1SDimitry Andric return ShapeInfo(N, M);
278*0fca6ea1SDimitry Andric if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
279*0fca6ea1SDimitry Andric m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
280*0fca6ea1SDimitry Andric return ShapeInfo(M, N);
281*0fca6ea1SDimitry Andric Value *MatrixA;
282*0fca6ea1SDimitry Andric if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
283*0fca6ea1SDimitry Andric auto OpShape = ShapeMap.find(MatrixA);
284*0fca6ea1SDimitry Andric if (OpShape != ShapeMap.end())
285*0fca6ea1SDimitry Andric return OpShape->second;
286*0fca6ea1SDimitry Andric }
287*0fca6ea1SDimitry Andric
288*0fca6ea1SDimitry Andric if (isUniformShape(I)) {
289*0fca6ea1SDimitry Andric // Find the first operand that has a known shape and use that.
290*0fca6ea1SDimitry Andric for (auto &Op : I->operands()) {
291*0fca6ea1SDimitry Andric auto OpShape = ShapeMap.find(Op.get());
292*0fca6ea1SDimitry Andric if (OpShape != ShapeMap.end())
293*0fca6ea1SDimitry Andric return OpShape->second;
294*0fca6ea1SDimitry Andric }
295*0fca6ea1SDimitry Andric }
296*0fca6ea1SDimitry Andric return std::nullopt;
297*0fca6ea1SDimitry Andric }
298*0fca6ea1SDimitry Andric
299480093f4SDimitry Andric /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
300480093f4SDimitry Andric ///
301480093f4SDimitry Andric /// Currently, the lowering for each matrix intrinsic is done as follows:
302480093f4SDimitry Andric /// 1. Propagate the shape information from intrinsics to connected
303480093f4SDimitry Andric /// instructions.
3045ffd83dbSDimitry Andric /// 2. Lower instructions with shape information (assuming column-major layout).
3055ffd83dbSDimitry Andric /// The lowering works similarly using row-major layout.
306480093f4SDimitry Andric /// 2.1. Get column vectors for each argument. If we already lowered the
307480093f4SDimitry Andric /// definition of an argument, use the produced column vectors directly.
308480093f4SDimitry Andric /// If not, split the operand vector containing an embedded matrix into
309480093f4SDimitry Andric /// a set of column vectors,
3105ffd83dbSDimitry Andric /// 2.2. Lower the instruction in terms of column major operations, which
3115ffd83dbSDimitry Andric /// yields a set of column vectors containing result matrix. Note that we
3125ffd83dbSDimitry Andric /// lower all instructions that have shape information. Besides the
3135ffd83dbSDimitry Andric /// intrinsics, this includes stores for example.
314480093f4SDimitry Andric /// 2.3. Update uses of the lowered instruction. If we have shape information
315480093f4SDimitry Andric /// for a user, there is nothing to do, as we will look up the result
316480093f4SDimitry Andric /// column matrix when lowering the user. For other uses, we embed the
317480093f4SDimitry Andric /// result matrix in a flat vector and update the use.
318480093f4SDimitry Andric /// 2.4. Cache the result column matrix for the instruction we lowered
319480093f4SDimitry Andric /// 3. After we lowered all instructions in a function, remove the now
320480093f4SDimitry Andric /// obsolete instructions.
321480093f4SDimitry Andric ///
322480093f4SDimitry Andric class LowerMatrixIntrinsics {
323480093f4SDimitry Andric Function &Func;
324480093f4SDimitry Andric const DataLayout &DL;
325480093f4SDimitry Andric const TargetTransformInfo &TTI;
326e8d8bef9SDimitry Andric AliasAnalysis *AA;
327e8d8bef9SDimitry Andric DominatorTree *DT;
328e8d8bef9SDimitry Andric LoopInfo *LI;
329e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE;
330480093f4SDimitry Andric
3315ffd83dbSDimitry Andric /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
3325ffd83dbSDimitry Andric struct OpInfoTy {
3335ffd83dbSDimitry Andric /// Number of stores emitted to generate this matrix.
3345ffd83dbSDimitry Andric unsigned NumStores = 0;
3355ffd83dbSDimitry Andric /// Number of loads emitted to generate this matrix.
3365ffd83dbSDimitry Andric unsigned NumLoads = 0;
3375ffd83dbSDimitry Andric /// Number of compute operations emitted to generate this matrix.
3385ffd83dbSDimitry Andric unsigned NumComputeOps = 0;
339fe6060f1SDimitry Andric /// Most of the time transposes can be fused with matrix multiplies or can
340fe6060f1SDimitry Andric /// be folded away via algebraic simplifications. This is the number of
341fe6060f1SDimitry Andric /// transposes that we failed to make "free" via such optimizations.
342fe6060f1SDimitry Andric unsigned NumExposedTransposes = 0;
3435ffd83dbSDimitry Andric
operator +=__anon821fcdb70111::LowerMatrixIntrinsics::OpInfoTy3445ffd83dbSDimitry Andric OpInfoTy &operator+=(const OpInfoTy &RHS) {
3455ffd83dbSDimitry Andric NumStores += RHS.NumStores;
3465ffd83dbSDimitry Andric NumLoads += RHS.NumLoads;
3475ffd83dbSDimitry Andric NumComputeOps += RHS.NumComputeOps;
348fe6060f1SDimitry Andric NumExposedTransposes += RHS.NumExposedTransposes;
3495ffd83dbSDimitry Andric return *this;
3505ffd83dbSDimitry Andric }
3515ffd83dbSDimitry Andric };
3525ffd83dbSDimitry Andric
3535ffd83dbSDimitry Andric /// Wrapper class representing a matrix as a set of vectors, either in row or
3545ffd83dbSDimitry Andric /// column major layout. All vectors must have the same vector type.
3555ffd83dbSDimitry Andric class MatrixTy {
3565ffd83dbSDimitry Andric SmallVector<Value *, 16> Vectors;
3575ffd83dbSDimitry Andric
3585ffd83dbSDimitry Andric OpInfoTy OpInfo;
3595ffd83dbSDimitry Andric
3605ffd83dbSDimitry Andric bool IsColumnMajor = true;
361480093f4SDimitry Andric
362480093f4SDimitry Andric public:
MatrixTy()36304eeddc0SDimitry Andric MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
MatrixTy(ArrayRef<Value * > Vectors)3645ffd83dbSDimitry Andric MatrixTy(ArrayRef<Value *> Vectors)
3655ffd83dbSDimitry Andric : Vectors(Vectors.begin(), Vectors.end()),
3665ffd83dbSDimitry Andric IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
MatrixTy(unsigned NumRows,unsigned NumColumns,Type * EltTy)3675ffd83dbSDimitry Andric MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
3685ffd83dbSDimitry Andric : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
369480093f4SDimitry Andric
3705ffd83dbSDimitry Andric unsigned D = isColumnMajor() ? NumColumns : NumRows;
3715ffd83dbSDimitry Andric for (unsigned J = 0; J < D; ++J)
37206c3fb27SDimitry Andric addVector(PoisonValue::get(FixedVectorType::get(
3735ffd83dbSDimitry Andric EltTy, isColumnMajor() ? NumRows : NumColumns)));
374480093f4SDimitry Andric }
375480093f4SDimitry Andric
getVector(unsigned i) const3765ffd83dbSDimitry Andric Value *getVector(unsigned i) const { return Vectors[i]; }
getColumn(unsigned i) const3775ffd83dbSDimitry Andric Value *getColumn(unsigned i) const {
3785ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes");
3795ffd83dbSDimitry Andric return Vectors[i];
3805ffd83dbSDimitry Andric }
getRow(unsigned i) const3815ffd83dbSDimitry Andric Value *getRow(unsigned i) const {
3825ffd83dbSDimitry Andric assert(!isColumnMajor() && "only supported for row-major matrixes");
3835ffd83dbSDimitry Andric return Vectors[i];
3845ffd83dbSDimitry Andric }
385480093f4SDimitry Andric
setVector(unsigned i,Value * V)3865ffd83dbSDimitry Andric void setVector(unsigned i, Value *V) { Vectors[i] = V; }
387480093f4SDimitry Andric
getElementType() const388e8d8bef9SDimitry Andric Type *getElementType() const { return getVectorTy()->getElementType(); }
3895ffd83dbSDimitry Andric
getNumVectors() const3905ffd83dbSDimitry Andric unsigned getNumVectors() const {
3915ffd83dbSDimitry Andric if (isColumnMajor())
3925ffd83dbSDimitry Andric return getNumColumns();
3935ffd83dbSDimitry Andric return getNumRows();
3945ffd83dbSDimitry Andric }
3955ffd83dbSDimitry Andric
getNumColumns() const3965ffd83dbSDimitry Andric unsigned getNumColumns() const {
3975ffd83dbSDimitry Andric if (isColumnMajor())
3985ffd83dbSDimitry Andric return Vectors.size();
3995ffd83dbSDimitry Andric else {
4005ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
4015ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
4025ffd83dbSDimitry Andric }
4035ffd83dbSDimitry Andric }
getNumRows() const4045ffd83dbSDimitry Andric unsigned getNumRows() const {
4055ffd83dbSDimitry Andric if (isColumnMajor()) {
4065ffd83dbSDimitry Andric assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
4075ffd83dbSDimitry Andric return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
4085ffd83dbSDimitry Andric } else
4095ffd83dbSDimitry Andric return Vectors.size();
4105ffd83dbSDimitry Andric }
4115ffd83dbSDimitry Andric
addVector(Value * V)4125ffd83dbSDimitry Andric void addVector(Value *V) { Vectors.push_back(V); }
getColumnTy()4135ffd83dbSDimitry Andric VectorType *getColumnTy() {
4145ffd83dbSDimitry Andric assert(isColumnMajor() && "only supported for column-major matrixes");
4155ffd83dbSDimitry Andric return getVectorTy();
4165ffd83dbSDimitry Andric }
4175ffd83dbSDimitry Andric
getVectorTy() const418e8d8bef9SDimitry Andric VectorType *getVectorTy() const {
4195ffd83dbSDimitry Andric return cast<VectorType>(Vectors[0]->getType());
4205ffd83dbSDimitry Andric }
421480093f4SDimitry Andric
columns()422480093f4SDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> columns() {
4235ffd83dbSDimitry Andric assert(isColumnMajor() &&
4245ffd83dbSDimitry Andric "columns() only supported for column-major matrixes");
4255ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end());
426480093f4SDimitry Andric }
427480093f4SDimitry Andric
vectors()4285ffd83dbSDimitry Andric iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
4295ffd83dbSDimitry Andric return make_range(Vectors.begin(), Vectors.end());
4305ffd83dbSDimitry Andric }
4315ffd83dbSDimitry Andric
4325ffd83dbSDimitry Andric /// Embed the vectors of the matrix into a flat vector by concatenating
433480093f4SDimitry Andric /// them.
embedInVector(IRBuilder<> & Builder) const434480093f4SDimitry Andric Value *embedInVector(IRBuilder<> &Builder) const {
4355ffd83dbSDimitry Andric return Vectors.size() == 1 ? Vectors[0]
4365ffd83dbSDimitry Andric : concatenateVectors(Builder, Vectors);
4375ffd83dbSDimitry Andric }
4385ffd83dbSDimitry Andric
addNumLoads(unsigned N)4395ffd83dbSDimitry Andric MatrixTy &addNumLoads(unsigned N) {
4405ffd83dbSDimitry Andric OpInfo.NumLoads += N;
4415ffd83dbSDimitry Andric return *this;
4425ffd83dbSDimitry Andric }
4435ffd83dbSDimitry Andric
setNumLoads(unsigned N)4445ffd83dbSDimitry Andric void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
4455ffd83dbSDimitry Andric
addNumStores(unsigned N)4465ffd83dbSDimitry Andric MatrixTy &addNumStores(unsigned N) {
4475ffd83dbSDimitry Andric OpInfo.NumStores += N;
4485ffd83dbSDimitry Andric return *this;
4495ffd83dbSDimitry Andric }
4505ffd83dbSDimitry Andric
addNumExposedTransposes(unsigned N)451fe6060f1SDimitry Andric MatrixTy &addNumExposedTransposes(unsigned N) {
452fe6060f1SDimitry Andric OpInfo.NumExposedTransposes += N;
453fe6060f1SDimitry Andric return *this;
454fe6060f1SDimitry Andric }
455fe6060f1SDimitry Andric
addNumComputeOps(unsigned N)4565ffd83dbSDimitry Andric MatrixTy &addNumComputeOps(unsigned N) {
4575ffd83dbSDimitry Andric OpInfo.NumComputeOps += N;
4585ffd83dbSDimitry Andric return *this;
4595ffd83dbSDimitry Andric }
4605ffd83dbSDimitry Andric
getNumStores() const4615ffd83dbSDimitry Andric unsigned getNumStores() const { return OpInfo.NumStores; }
getNumLoads() const4625ffd83dbSDimitry Andric unsigned getNumLoads() const { return OpInfo.NumLoads; }
getNumComputeOps() const4635ffd83dbSDimitry Andric unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
4645ffd83dbSDimitry Andric
getOpInfo() const4655ffd83dbSDimitry Andric const OpInfoTy &getOpInfo() const { return OpInfo; }
4665ffd83dbSDimitry Andric
isColumnMajor() const4675ffd83dbSDimitry Andric bool isColumnMajor() const { return IsColumnMajor; }
4685ffd83dbSDimitry Andric
getStride() const4695ffd83dbSDimitry Andric unsigned getStride() const {
4705ffd83dbSDimitry Andric if (isColumnMajor())
4715ffd83dbSDimitry Andric return getNumRows();
4725ffd83dbSDimitry Andric return getNumColumns();
4735ffd83dbSDimitry Andric }
4745ffd83dbSDimitry Andric
4755ffd83dbSDimitry Andric /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
4765ffd83dbSDimitry Andric /// matrix is column-major, the result vector is extracted from a column
4775ffd83dbSDimitry Andric /// vector, otherwise from a row vector.
extractVector(unsigned I,unsigned J,unsigned NumElts,IRBuilder<> & Builder) const4785ffd83dbSDimitry Andric Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
4795ffd83dbSDimitry Andric IRBuilder<> &Builder) const {
4805ffd83dbSDimitry Andric Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
481972a253aSDimitry Andric assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
482972a253aSDimitry Andric NumElts &&
483972a253aSDimitry Andric "Extracted vector will contain poison values");
4845ffd83dbSDimitry Andric return Builder.CreateShuffleVector(
485e8d8bef9SDimitry Andric Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
4865ffd83dbSDimitry Andric "block");
487480093f4SDimitry Andric }
488480093f4SDimitry Andric };
489480093f4SDimitry Andric
490480093f4SDimitry Andric /// Maps instructions to their shape information. The shape information
491480093f4SDimitry Andric /// describes the shape to be used while lowering. This matches the shape of
492480093f4SDimitry Andric /// the result value of the instruction, with the only exceptions being store
4935ffd83dbSDimitry Andric /// instructions and the matrix_column_major_store intrinsics. For those, the
494480093f4SDimitry Andric /// shape information indicates that those instructions should be lowered
495fe6060f1SDimitry Andric /// using shape information as well. A ValueMap is used so that when
496fe6060f1SDimitry Andric /// sub-passes like optimizeTransposes performs RAUW the map stays
497fe6060f1SDimitry Andric /// up-to-date.
498fe6060f1SDimitry Andric ValueMap<Value *, ShapeInfo> ShapeMap;
499480093f4SDimitry Andric
500480093f4SDimitry Andric /// List of instructions to remove. While lowering, we are not replacing all
501480093f4SDimitry Andric /// users of a lowered instruction, if shape information is available and
502480093f4SDimitry Andric /// those need to be removed after we finished lowering.
503480093f4SDimitry Andric SmallVector<Instruction *, 16> ToRemove;
504480093f4SDimitry Andric
505480093f4SDimitry Andric /// Map from instructions to their produced column matrix.
5065ffd83dbSDimitry Andric MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
507480093f4SDimitry Andric
508fe6060f1SDimitry Andric private:
getFastMathFlags(Instruction * Inst)509fe6060f1SDimitry Andric static FastMathFlags getFastMathFlags(Instruction *Inst) {
510fe6060f1SDimitry Andric FastMathFlags FMF;
511fe6060f1SDimitry Andric
512fe6060f1SDimitry Andric if (isa<FPMathOperator>(*Inst))
513fe6060f1SDimitry Andric FMF = Inst->getFastMathFlags();
514fe6060f1SDimitry Andric
515fe6060f1SDimitry Andric FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
516fe6060f1SDimitry Andric
517fe6060f1SDimitry Andric return FMF;
518fe6060f1SDimitry Andric }
519fe6060f1SDimitry Andric
520480093f4SDimitry Andric public:
LowerMatrixIntrinsics(Function & F,TargetTransformInfo & TTI,AliasAnalysis * AA,DominatorTree * DT,LoopInfo * LI,OptimizationRemarkEmitter * ORE)5215ffd83dbSDimitry Andric LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
522e8d8bef9SDimitry Andric AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
523e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE)
524*0fca6ea1SDimitry Andric : Func(F), DL(F.getDataLayout()), TTI(TTI), AA(AA), DT(DT),
5255ffd83dbSDimitry Andric LI(LI), ORE(ORE) {}
526480093f4SDimitry Andric
getNumOps(Type * VT)5275ffd83dbSDimitry Andric unsigned getNumOps(Type *VT) {
5285ffd83dbSDimitry Andric assert(isa<VectorType>(VT) && "Expected vector type");
5295ffd83dbSDimitry Andric return getNumOps(VT->getScalarType(),
5305ffd83dbSDimitry Andric cast<FixedVectorType>(VT)->getNumElements());
5315ffd83dbSDimitry Andric }
5325ffd83dbSDimitry Andric
533fe6060f1SDimitry Andric /// Is this the minimal version executed in the backend pipelines.
isMinimal() const534fe6060f1SDimitry Andric bool isMinimal() const {
535fe6060f1SDimitry Andric return !DT;
536fe6060f1SDimitry Andric }
537fe6060f1SDimitry Andric
5385ffd83dbSDimitry Andric /// Return the estimated number of vector ops required for an operation on
5395ffd83dbSDimitry Andric /// \p VT * N.
getNumOps(Type * ST,unsigned N)5405ffd83dbSDimitry Andric unsigned getNumOps(Type *ST, unsigned N) {
541bdd1243dSDimitry Andric return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
542fe6060f1SDimitry Andric double(TTI.getRegisterBitWidth(
543fe6060f1SDimitry Andric TargetTransformInfo::RGK_FixedWidthVector)
544bdd1243dSDimitry Andric .getFixedValue()));
5455ffd83dbSDimitry Andric }
5465ffd83dbSDimitry Andric
5475ffd83dbSDimitry Andric /// Return the set of vectors that a matrix value is lowered to.
548480093f4SDimitry Andric ///
5495ffd83dbSDimitry Andric /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
5505ffd83dbSDimitry Andric /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
5515ffd83dbSDimitry Andric /// into vectors.
getMatrix(Value * MatrixVal,const ShapeInfo & SI,IRBuilder<> & Builder)5525ffd83dbSDimitry Andric MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
5535ffd83dbSDimitry Andric IRBuilder<> &Builder) {
554480093f4SDimitry Andric VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
555480093f4SDimitry Andric assert(VType && "MatrixVal must be a vector type");
5565ffd83dbSDimitry Andric assert(cast<FixedVectorType>(VType)->getNumElements() ==
5575ffd83dbSDimitry Andric SI.NumRows * SI.NumColumns &&
558480093f4SDimitry Andric "The vector size must match the number of matrix elements");
559480093f4SDimitry Andric
560480093f4SDimitry Andric // Check if we lowered MatrixVal using shape information. In that case,
5615ffd83dbSDimitry Andric // return the existing matrix, if it matches the requested shape
562480093f4SDimitry Andric // information. If there is a mis-match, embed the result in a flat
563480093f4SDimitry Andric // vector and split it later.
564480093f4SDimitry Andric auto Found = Inst2ColumnMatrix.find(MatrixVal);
565480093f4SDimitry Andric if (Found != Inst2ColumnMatrix.end()) {
5665ffd83dbSDimitry Andric MatrixTy &M = Found->second;
567480093f4SDimitry Andric // Return the found matrix, if its shape matches the requested shape
568480093f4SDimitry Andric // information
569480093f4SDimitry Andric if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
570480093f4SDimitry Andric return M;
571480093f4SDimitry Andric
572480093f4SDimitry Andric MatrixVal = M.embedInVector(Builder);
573480093f4SDimitry Andric }
574480093f4SDimitry Andric
575480093f4SDimitry Andric // Otherwise split MatrixVal.
576480093f4SDimitry Andric SmallVector<Value *, 16> SplitVecs;
5775ffd83dbSDimitry Andric for (unsigned MaskStart = 0;
5785ffd83dbSDimitry Andric MaskStart < cast<FixedVectorType>(VType)->getNumElements();
5795ffd83dbSDimitry Andric MaskStart += SI.getStride()) {
5805ffd83dbSDimitry Andric Value *V = Builder.CreateShuffleVector(
581e8d8bef9SDimitry Andric MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
5825ffd83dbSDimitry Andric "split");
583480093f4SDimitry Andric SplitVecs.push_back(V);
584480093f4SDimitry Andric }
585480093f4SDimitry Andric
586480093f4SDimitry Andric return {SplitVecs};
587480093f4SDimitry Andric }
588480093f4SDimitry Andric
589480093f4SDimitry Andric /// If \p V already has a known shape return false. Otherwise set the shape
590480093f4SDimitry Andric /// for instructions that support it.
setShapeInfo(Value * V,ShapeInfo Shape)591480093f4SDimitry Andric bool setShapeInfo(Value *V, ShapeInfo Shape) {
592480093f4SDimitry Andric assert(Shape && "Shape not set");
593480093f4SDimitry Andric if (isa<UndefValue>(V) || !supportsShapeInfo(V))
594480093f4SDimitry Andric return false;
595480093f4SDimitry Andric
596480093f4SDimitry Andric auto SIter = ShapeMap.find(V);
597480093f4SDimitry Andric if (SIter != ShapeMap.end()) {
59806c3fb27SDimitry Andric if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
59906c3fb27SDimitry Andric SIter->second.NumColumns != Shape.NumColumns)) {
60006c3fb27SDimitry Andric errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
60106c3fb27SDimitry Andric << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
60206c3fb27SDimitry Andric << Shape.NumColumns << ") for " << *V << "\n";
60306c3fb27SDimitry Andric report_fatal_error(
60406c3fb27SDimitry Andric "Matrix shape verification failed, compilation aborted!");
60506c3fb27SDimitry Andric }
60606c3fb27SDimitry Andric
607480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " not overriding existing shape: "
608480093f4SDimitry Andric << SIter->second.NumRows << " "
609480093f4SDimitry Andric << SIter->second.NumColumns << " for " << *V << "\n");
610480093f4SDimitry Andric return false;
611480093f4SDimitry Andric }
612480093f4SDimitry Andric
613480093f4SDimitry Andric ShapeMap.insert({V, Shape});
614480093f4SDimitry Andric LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
615480093f4SDimitry Andric << " for " << *V << "\n");
616480093f4SDimitry Andric return true;
617480093f4SDimitry Andric }
618480093f4SDimitry Andric
619480093f4SDimitry Andric /// Returns true if shape information can be used for \p V. The supported
620480093f4SDimitry Andric /// instructions must match the instructions that can be lowered by this pass.
supportsShapeInfo(Value * V)621480093f4SDimitry Andric bool supportsShapeInfo(Value *V) {
622480093f4SDimitry Andric Instruction *Inst = dyn_cast<Instruction>(V);
623480093f4SDimitry Andric if (!Inst)
624480093f4SDimitry Andric return false;
625480093f4SDimitry Andric
626480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
627480093f4SDimitry Andric if (II)
628480093f4SDimitry Andric switch (II->getIntrinsicID()) {
629480093f4SDimitry Andric case Intrinsic::matrix_multiply:
630480093f4SDimitry Andric case Intrinsic::matrix_transpose:
6315ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load:
6325ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store:
633480093f4SDimitry Andric return true;
634480093f4SDimitry Andric default:
635480093f4SDimitry Andric return false;
636480093f4SDimitry Andric }
637480093f4SDimitry Andric return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
638480093f4SDimitry Andric }
639480093f4SDimitry Andric
640480093f4SDimitry Andric /// Propagate the shape information of instructions to their users.
641480093f4SDimitry Andric /// The work list contains instructions for which we can compute the shape,
642480093f4SDimitry Andric /// either based on the information provided by matrix intrinsics or known
643480093f4SDimitry Andric /// shapes of operands.
644480093f4SDimitry Andric SmallVector<Instruction *, 32>
propagateShapeForward(SmallVectorImpl<Instruction * > & WorkList)645480093f4SDimitry Andric propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
646480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList;
647480093f4SDimitry Andric // Pop an element for which we guaranteed to have at least one of the
648480093f4SDimitry Andric // operand shapes. Add the shape for this and then add users to the work
649480093f4SDimitry Andric // list.
650480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
651480093f4SDimitry Andric while (!WorkList.empty()) {
652e8d8bef9SDimitry Andric Instruction *Inst = WorkList.pop_back_val();
653480093f4SDimitry Andric
654480093f4SDimitry Andric // New entry, set the value and insert operands
655480093f4SDimitry Andric bool Propagate = false;
656*0fca6ea1SDimitry Andric if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
657*0fca6ea1SDimitry Andric Propagate = setShapeInfo(Inst, *SI);
658480093f4SDimitry Andric
659480093f4SDimitry Andric if (Propagate) {
660480093f4SDimitry Andric NewWorkList.push_back(Inst);
661480093f4SDimitry Andric for (auto *User : Inst->users())
662480093f4SDimitry Andric if (ShapeMap.count(User) == 0)
663480093f4SDimitry Andric WorkList.push_back(cast<Instruction>(User));
664480093f4SDimitry Andric }
665480093f4SDimitry Andric }
666480093f4SDimitry Andric
667480093f4SDimitry Andric return NewWorkList;
668480093f4SDimitry Andric }
669480093f4SDimitry Andric
670480093f4SDimitry Andric /// Propagate the shape to operands of instructions with shape information.
671480093f4SDimitry Andric /// \p Worklist contains the instruction for which we already know the shape.
672480093f4SDimitry Andric SmallVector<Instruction *, 32>
propagateShapeBackward(SmallVectorImpl<Instruction * > & WorkList)673480093f4SDimitry Andric propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
674480093f4SDimitry Andric SmallVector<Instruction *, 32> NewWorkList;
675480093f4SDimitry Andric
676480093f4SDimitry Andric auto pushInstruction = [](Value *V,
677480093f4SDimitry Andric SmallVectorImpl<Instruction *> &WorkList) {
678480093f4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V);
679480093f4SDimitry Andric if (I)
680480093f4SDimitry Andric WorkList.push_back(I);
681480093f4SDimitry Andric };
682480093f4SDimitry Andric // Pop an element with known shape. Traverse the operands, if their shape
683480093f4SDimitry Andric // derives from the result shape and is unknown, add it and add them to the
684480093f4SDimitry Andric // worklist.
685480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
686480093f4SDimitry Andric while (!WorkList.empty()) {
687e8d8bef9SDimitry Andric Value *V = WorkList.pop_back_val();
688480093f4SDimitry Andric
689480093f4SDimitry Andric size_t BeforeProcessingV = WorkList.size();
690480093f4SDimitry Andric if (!isa<Instruction>(V))
691480093f4SDimitry Andric continue;
692480093f4SDimitry Andric
693480093f4SDimitry Andric Value *MatrixA;
694480093f4SDimitry Andric Value *MatrixB;
695480093f4SDimitry Andric Value *M;
696480093f4SDimitry Andric Value *N;
697480093f4SDimitry Andric Value *K;
698480093f4SDimitry Andric if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
699480093f4SDimitry Andric m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
700480093f4SDimitry Andric m_Value(N), m_Value(K)))) {
701480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N}))
702480093f4SDimitry Andric pushInstruction(MatrixA, WorkList);
703480093f4SDimitry Andric
704480093f4SDimitry Andric if (setShapeInfo(MatrixB, {N, K}))
705480093f4SDimitry Andric pushInstruction(MatrixB, WorkList);
706480093f4SDimitry Andric
707480093f4SDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
708480093f4SDimitry Andric m_Value(MatrixA), m_Value(M), m_Value(N)))) {
709480093f4SDimitry Andric // Flip dimensions.
710480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N}))
711480093f4SDimitry Andric pushInstruction(MatrixA, WorkList);
7125ffd83dbSDimitry Andric } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
7135ffd83dbSDimitry Andric m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
714480093f4SDimitry Andric m_Value(M), m_Value(N)))) {
715480093f4SDimitry Andric if (setShapeInfo(MatrixA, {M, N})) {
716480093f4SDimitry Andric pushInstruction(MatrixA, WorkList);
717480093f4SDimitry Andric }
718480093f4SDimitry Andric } else if (isa<LoadInst>(V) ||
7195ffd83dbSDimitry Andric match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
720480093f4SDimitry Andric // Nothing to do, no matrix input.
721480093f4SDimitry Andric } else if (isa<StoreInst>(V)) {
722480093f4SDimitry Andric // Nothing to do. We forward-propagated to this so we would just
723480093f4SDimitry Andric // backward propagate to an instruction with an already known shape.
724480093f4SDimitry Andric } else if (isUniformShape(V)) {
725480093f4SDimitry Andric // Propagate to all operands.
726480093f4SDimitry Andric ShapeInfo Shape = ShapeMap[V];
727480093f4SDimitry Andric for (Use &U : cast<Instruction>(V)->operands()) {
728480093f4SDimitry Andric if (setShapeInfo(U.get(), Shape))
729480093f4SDimitry Andric pushInstruction(U.get(), WorkList);
730480093f4SDimitry Andric }
731480093f4SDimitry Andric }
732480093f4SDimitry Andric // After we discovered new shape info for new instructions in the
733480093f4SDimitry Andric // worklist, we use their users as seeds for the next round of forward
734480093f4SDimitry Andric // propagation.
735480093f4SDimitry Andric for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
736480093f4SDimitry Andric for (User *U : WorkList[I]->users())
737480093f4SDimitry Andric if (isa<Instruction>(U) && V != U)
738480093f4SDimitry Andric NewWorkList.push_back(cast<Instruction>(U));
739480093f4SDimitry Andric }
740480093f4SDimitry Andric return NewWorkList;
741480093f4SDimitry Andric }
742480093f4SDimitry Andric
743bdd1243dSDimitry Andric /// (Op0 op Op1)^T -> Op0^T op Op1^T
744bdd1243dSDimitry Andric /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
745bdd1243dSDimitry Andric /// them on both sides of \p Operation.
distributeTransposes(Value * Op0,ShapeInfo Shape0,Value * Op1,ShapeInfo Shape1,MatrixBuilder & Builder,function_ref<Instruction * (Value *,ShapeInfo,Value *,ShapeInfo)> Operation)746bdd1243dSDimitry Andric Instruction *distributeTransposes(
747bdd1243dSDimitry Andric Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
748bdd1243dSDimitry Andric MatrixBuilder &Builder,
749bdd1243dSDimitry Andric function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
750bdd1243dSDimitry Andric Operation) {
751bdd1243dSDimitry Andric Value *T0 = Builder.CreateMatrixTranspose(
752bdd1243dSDimitry Andric Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
753bdd1243dSDimitry Andric // We are being run after shape prop, add shape for newly created
754bdd1243dSDimitry Andric // instructions so that we lower them later.
755bdd1243dSDimitry Andric setShapeInfo(T0, Shape0.t());
756bdd1243dSDimitry Andric Value *T1 = Builder.CreateMatrixTranspose(
757bdd1243dSDimitry Andric Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
758bdd1243dSDimitry Andric setShapeInfo(T1, Shape1.t());
759bdd1243dSDimitry Andric return Operation(T0, Shape0.t(), T1, Shape1.t());
760bdd1243dSDimitry Andric }
761bdd1243dSDimitry Andric
updateShapeAndReplaceAllUsesWith(Instruction & Old,Value * New)762bdd1243dSDimitry Andric void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
763fe6060f1SDimitry Andric // We need to remove Old from the ShapeMap otherwise RAUW will replace it
764fe6060f1SDimitry Andric // with New. We should only add New it it supportsShapeInfo so we insert
765fe6060f1SDimitry Andric // it conditionally instead.
766fe6060f1SDimitry Andric auto S = ShapeMap.find(&Old);
767fe6060f1SDimitry Andric if (S != ShapeMap.end()) {
768fe6060f1SDimitry Andric ShapeMap.erase(S);
769fe6060f1SDimitry Andric if (supportsShapeInfo(New))
770fe6060f1SDimitry Andric ShapeMap.insert({New, S->second});
771fe6060f1SDimitry Andric }
772fe6060f1SDimitry Andric Old.replaceAllUsesWith(New);
773fe6060f1SDimitry Andric }
774fe6060f1SDimitry Andric
775bdd1243dSDimitry Andric /// Sink a top-level transpose inside matmuls and adds.
776bdd1243dSDimitry Andric /// This creates and erases instructions as needed, and returns the newly
777bdd1243dSDimitry Andric /// created instruction while updating the iterator to avoid invalidation. If
778bdd1243dSDimitry Andric /// this returns nullptr, no new instruction was created.
sinkTranspose(Instruction & I,BasicBlock::reverse_iterator & II)779bdd1243dSDimitry Andric Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
780bdd1243dSDimitry Andric BasicBlock &BB = *I.getParent();
781fe6060f1SDimitry Andric IRBuilder<> IB(&I);
78281ad6265SDimitry Andric MatrixBuilder Builder(IB);
783fe6060f1SDimitry Andric
784fe6060f1SDimitry Andric Value *TA, *TAMA, *TAMB;
785fe6060f1SDimitry Andric ConstantInt *R, *K, *C;
786bdd1243dSDimitry Andric if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
787bdd1243dSDimitry Andric m_Value(TA), m_ConstantInt(R), m_ConstantInt(C))))
788bdd1243dSDimitry Andric return nullptr;
789fe6060f1SDimitry Andric
790fe6060f1SDimitry Andric // Transpose of a transpose is a nop
791fe6060f1SDimitry Andric Value *TATA;
792bdd1243dSDimitry Andric if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
793bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, TATA);
794bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB);
795bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB);
796bdd1243dSDimitry Andric return nullptr;
797bdd1243dSDimitry Andric }
798bdd1243dSDimitry Andric
799bdd1243dSDimitry Andric // k^T -> k
800bdd1243dSDimitry Andric if (isSplat(TA)) {
801bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, TA);
802bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB);
803bdd1243dSDimitry Andric return nullptr;
804fe6060f1SDimitry Andric }
805fe6060f1SDimitry Andric
806fe6060f1SDimitry Andric // (A * B)^t -> B^t * A^t
807fe6060f1SDimitry Andric // RxK KxC CxK KxR
808bdd1243dSDimitry Andric if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
809fe6060f1SDimitry Andric m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
810fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C)))) {
811bdd1243dSDimitry Andric auto NewInst = distributeTransposes(
812bdd1243dSDimitry Andric TAMB, {K, C}, TAMA, {R, K}, Builder,
813bdd1243dSDimitry Andric [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
814bdd1243dSDimitry Andric return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
815bdd1243dSDimitry Andric Shape0.NumColumns,
816bdd1243dSDimitry Andric Shape1.NumColumns, "mmul");
817bdd1243dSDimitry Andric });
818bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst);
819bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB);
820bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB);
821bdd1243dSDimitry Andric return NewInst;
822fe6060f1SDimitry Andric }
823fe6060f1SDimitry Andric
824bdd1243dSDimitry Andric // Same as above, but with a mul, which occurs when multiplied
825bdd1243dSDimitry Andric // with a scalar.
826bdd1243dSDimitry Andric // (A * k)^t -> A^t * k
827bdd1243dSDimitry Andric // R x C RxC
828bdd1243dSDimitry Andric if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
829bdd1243dSDimitry Andric (isSplat(TAMA) || isSplat(TAMB))) {
830bdd1243dSDimitry Andric IRBuilder<> LocalBuilder(&I);
831bdd1243dSDimitry Andric // We know that the transposed operand is of shape RxC.
832bdd1243dSDimitry Andric // An when multiplied with a scalar, the shape is preserved.
833bdd1243dSDimitry Andric auto NewInst = distributeTransposes(
834bdd1243dSDimitry Andric TAMA, {R, C}, TAMB, {R, C}, Builder,
835bdd1243dSDimitry Andric [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
836bdd1243dSDimitry Andric bool IsFP = I.getType()->isFPOrFPVectorTy();
837bdd1243dSDimitry Andric auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
838bdd1243dSDimitry Andric : LocalBuilder.CreateMul(T0, T1, "mmul");
839bdd1243dSDimitry Andric auto *Result = cast<Instruction>(Mul);
840bdd1243dSDimitry Andric setShapeInfo(Result, Shape0);
841bdd1243dSDimitry Andric return Result;
842bdd1243dSDimitry Andric });
843bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst);
844bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB);
845bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB);
846bdd1243dSDimitry Andric return NewInst;
847fe6060f1SDimitry Andric }
848fe6060f1SDimitry Andric
849bdd1243dSDimitry Andric // (A + B)^t -> A^t + B^t
850bdd1243dSDimitry Andric // RxC RxC CxR CxR
851bdd1243dSDimitry Andric if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
852bdd1243dSDimitry Andric IRBuilder<> LocalBuilder(&I);
853bdd1243dSDimitry Andric auto NewInst = distributeTransposes(
854bdd1243dSDimitry Andric TAMA, {R, C}, TAMB, {R, C}, Builder,
855bdd1243dSDimitry Andric [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
85606c3fb27SDimitry Andric bool IsFP = I.getType()->isFPOrFPVectorTy();
85706c3fb27SDimitry Andric auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
85806c3fb27SDimitry Andric : LocalBuilder.CreateAdd(T0, T1, "madd");
85906c3fb27SDimitry Andric
86006c3fb27SDimitry Andric auto *Result = cast<Instruction>(Add);
86106c3fb27SDimitry Andric setShapeInfo(Result, Shape0);
86206c3fb27SDimitry Andric return Result;
863bdd1243dSDimitry Andric });
864bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst);
865bdd1243dSDimitry Andric eraseFromParentAndMove(&I, II, BB);
866bdd1243dSDimitry Andric eraseFromParentAndMove(TA, II, BB);
867bdd1243dSDimitry Andric return NewInst;
868bdd1243dSDimitry Andric }
869bdd1243dSDimitry Andric
870bdd1243dSDimitry Andric return nullptr;
871bdd1243dSDimitry Andric }
872bdd1243dSDimitry Andric
liftTranspose(Instruction & I)873bdd1243dSDimitry Andric void liftTranspose(Instruction &I) {
874bdd1243dSDimitry Andric // Erase dead Instructions after lifting transposes from binops.
875bdd1243dSDimitry Andric auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
876bdd1243dSDimitry Andric if (T.use_empty())
877bdd1243dSDimitry Andric T.eraseFromParent();
878bdd1243dSDimitry Andric if (A->use_empty())
879bdd1243dSDimitry Andric cast<Instruction>(A)->eraseFromParent();
880bdd1243dSDimitry Andric if (A != B && B->use_empty())
881bdd1243dSDimitry Andric cast<Instruction>(B)->eraseFromParent();
882bdd1243dSDimitry Andric };
883bdd1243dSDimitry Andric
884fe6060f1SDimitry Andric Value *A, *B, *AT, *BT;
885fe6060f1SDimitry Andric ConstantInt *R, *K, *C;
886fe6060f1SDimitry Andric // A^t * B ^t -> (B * A)^t
88781ad6265SDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
888fe6060f1SDimitry Andric m_Value(A), m_Value(B), m_ConstantInt(R),
889fe6060f1SDimitry Andric m_ConstantInt(K), m_ConstantInt(C))) &&
890fe6060f1SDimitry Andric match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
891fe6060f1SDimitry Andric match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
89281ad6265SDimitry Andric IRBuilder<> IB(&I);
89381ad6265SDimitry Andric MatrixBuilder Builder(IB);
894fe6060f1SDimitry Andric Value *M = Builder.CreateMatrixMultiply(
895fe6060f1SDimitry Andric BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
896fe6060f1SDimitry Andric setShapeInfo(M, {C, R});
897bdd1243dSDimitry Andric Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
898bdd1243dSDimitry Andric R->getZExtValue());
899bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst);
900bdd1243dSDimitry Andric CleanupBinOp(I, A, B);
901fe6060f1SDimitry Andric }
902*0fca6ea1SDimitry Andric // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
903*0fca6ea1SDimitry Andric // the shape of the second transpose is different, there's a shape conflict
904*0fca6ea1SDimitry Andric // which gets resolved by picking the shape of the first operand.
905bdd1243dSDimitry Andric else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
906bdd1243dSDimitry Andric match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
907bdd1243dSDimitry Andric m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
908bdd1243dSDimitry Andric match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
909*0fca6ea1SDimitry Andric m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
910bdd1243dSDimitry Andric IRBuilder<> Builder(&I);
911*0fca6ea1SDimitry Andric auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
912*0fca6ea1SDimitry Andric setShapeInfo(Add, {R, C});
913bdd1243dSDimitry Andric MatrixBuilder MBuilder(Builder);
914bdd1243dSDimitry Andric Instruction *NewInst = MBuilder.CreateMatrixTranspose(
915*0fca6ea1SDimitry Andric Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
916bdd1243dSDimitry Andric updateShapeAndReplaceAllUsesWith(I, NewInst);
917*0fca6ea1SDimitry Andric assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
918*0fca6ea1SDimitry Andric computeShapeInfoForInst(&I, ShapeMap) &&
919*0fca6ea1SDimitry Andric "Shape of new instruction doesn't match original shape.");
920bdd1243dSDimitry Andric CleanupBinOp(I, A, B);
921*0fca6ea1SDimitry Andric assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
922*0fca6ea1SDimitry Andric ShapeMap[Add] &&
923*0fca6ea1SDimitry Andric "Shape of updated addition doesn't match cached shape.");
924bdd1243dSDimitry Andric }
925bdd1243dSDimitry Andric }
926bdd1243dSDimitry Andric
927bdd1243dSDimitry Andric /// Try moving transposes in order to fold them away or into multiplies.
optimizeTransposes()928bdd1243dSDimitry Andric void optimizeTransposes() {
929bdd1243dSDimitry Andric // First sink all transposes inside matmuls and adds, hoping that we end up
930bdd1243dSDimitry Andric // with NN, NT or TN variants.
931bdd1243dSDimitry Andric for (BasicBlock &BB : reverse(Func)) {
932bdd1243dSDimitry Andric for (auto II = BB.rbegin(); II != BB.rend();) {
933bdd1243dSDimitry Andric Instruction &I = *II;
934bdd1243dSDimitry Andric // We may remove II. By default continue on the next/prev instruction.
935bdd1243dSDimitry Andric ++II;
936bdd1243dSDimitry Andric if (Instruction *NewInst = sinkTranspose(I, II))
937bdd1243dSDimitry Andric II = std::next(BasicBlock::reverse_iterator(NewInst));
938bdd1243dSDimitry Andric }
939bdd1243dSDimitry Andric }
940bdd1243dSDimitry Andric
941bdd1243dSDimitry Andric // If we have a TT matmul or a TT add, lift the transpose. We may be able
942bdd1243dSDimitry Andric // to fold into consuming multiply or add.
943bdd1243dSDimitry Andric for (BasicBlock &BB : Func) {
944bdd1243dSDimitry Andric for (Instruction &I : llvm::make_early_inc_range(BB)) {
945bdd1243dSDimitry Andric liftTranspose(I);
946fe6060f1SDimitry Andric }
947fe6060f1SDimitry Andric }
948fe6060f1SDimitry Andric }
949fe6060f1SDimitry Andric
Visit()950480093f4SDimitry Andric bool Visit() {
951480093f4SDimitry Andric SmallVector<Instruction *, 32> WorkList;
952480093f4SDimitry Andric
953480093f4SDimitry Andric // Initially only the shape of matrix intrinsics is known.
954480093f4SDimitry Andric // Initialize the work list with ops carrying shape information.
955480093f4SDimitry Andric for (BasicBlock &BB : Func)
956480093f4SDimitry Andric for (Instruction &Inst : BB) {
957480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
958480093f4SDimitry Andric if (!II)
959480093f4SDimitry Andric continue;
960480093f4SDimitry Andric
961480093f4SDimitry Andric switch (II->getIntrinsicID()) {
962480093f4SDimitry Andric case Intrinsic::matrix_multiply:
963480093f4SDimitry Andric case Intrinsic::matrix_transpose:
9645ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load:
9655ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store:
966480093f4SDimitry Andric WorkList.push_back(&Inst);
967480093f4SDimitry Andric break;
968480093f4SDimitry Andric default:
969480093f4SDimitry Andric break;
970480093f4SDimitry Andric }
971480093f4SDimitry Andric }
972fe6060f1SDimitry Andric
973fe6060f1SDimitry Andric // Avoid unnecessary work if there are no matrix intrinsics in the function.
974fe6060f1SDimitry Andric if (WorkList.empty())
975fe6060f1SDimitry Andric return false;
976fe6060f1SDimitry Andric
977480093f4SDimitry Andric // Propagate shapes until nothing changes any longer.
978480093f4SDimitry Andric while (!WorkList.empty()) {
979480093f4SDimitry Andric WorkList = propagateShapeForward(WorkList);
980480093f4SDimitry Andric WorkList = propagateShapeBackward(WorkList);
981480093f4SDimitry Andric }
982fe6060f1SDimitry Andric
983fe6060f1SDimitry Andric if (!isMinimal()) {
984fe6060f1SDimitry Andric optimizeTransposes();
985bdd1243dSDimitry Andric if (PrintAfterTransposeOpt) {
986fe6060f1SDimitry Andric dbgs() << "Dump after matrix transpose optimization:\n";
987bdd1243dSDimitry Andric Func.print(dbgs());
988bdd1243dSDimitry Andric }
989480093f4SDimitry Andric }
990480093f4SDimitry Andric
991480093f4SDimitry Andric bool Changed = false;
9925ffd83dbSDimitry Andric SmallVector<CallInst *, 16> MaybeFusableInsts;
9935ffd83dbSDimitry Andric SmallVector<Instruction *, 16> MatrixInsts;
994*0fca6ea1SDimitry Andric SmallVector<IntrinsicInst *, 16> LifetimeEnds;
995480093f4SDimitry Andric
9965ffd83dbSDimitry Andric // First, collect all instructions with shape information and candidates for
9975ffd83dbSDimitry Andric // fusion (currently only matrix multiplies).
9985ffd83dbSDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&Func);
9995ffd83dbSDimitry Andric for (auto *BB : RPOT)
10005ffd83dbSDimitry Andric for (Instruction &I : *BB) {
1001*0fca6ea1SDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
1002*0fca6ea1SDimitry Andric LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
10035ffd83dbSDimitry Andric if (ShapeMap.find(&I) == ShapeMap.end())
10045ffd83dbSDimitry Andric continue;
10055ffd83dbSDimitry Andric if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
10065ffd83dbSDimitry Andric MaybeFusableInsts.push_back(cast<CallInst>(&I));
10075ffd83dbSDimitry Andric MatrixInsts.push_back(&I);
10085ffd83dbSDimitry Andric }
10095ffd83dbSDimitry Andric
101006c3fb27SDimitry Andric // Second, try to lower any dot products
10115ffd83dbSDimitry Andric SmallPtrSet<Instruction *, 16> FusedInsts;
10125ffd83dbSDimitry Andric for (CallInst *CI : MaybeFusableInsts)
101306c3fb27SDimitry Andric lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
101406c3fb27SDimitry Andric
101506c3fb27SDimitry Andric // Third, try to fuse candidates.
101606c3fb27SDimitry Andric for (CallInst *CI : MaybeFusableInsts)
1017*0fca6ea1SDimitry Andric LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
101806c3fb27SDimitry Andric
10195ffd83dbSDimitry Andric Changed = !FusedInsts.empty();
10205ffd83dbSDimitry Andric
102106c3fb27SDimitry Andric // Fourth, lower remaining instructions with shape information.
10225ffd83dbSDimitry Andric for (Instruction *Inst : MatrixInsts) {
10235ffd83dbSDimitry Andric if (FusedInsts.count(Inst))
10245ffd83dbSDimitry Andric continue;
10255ffd83dbSDimitry Andric
10265ffd83dbSDimitry Andric IRBuilder<> Builder(Inst);
10275ffd83dbSDimitry Andric
10285ffd83dbSDimitry Andric if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1029480093f4SDimitry Andric Changed |= VisitCallInst(CInst);
1030480093f4SDimitry Andric
1031480093f4SDimitry Andric Value *Op1;
1032480093f4SDimitry Andric Value *Op2;
10335ffd83dbSDimitry Andric if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1034480093f4SDimitry Andric Changed |= VisitBinaryOperator(BinOp);
1035e8d8bef9SDimitry Andric if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1036e8d8bef9SDimitry Andric Changed |= VisitUnaryOperator(UnOp);
10375ffd83dbSDimitry Andric if (match(Inst, m_Load(m_Value(Op1))))
10385ffd83dbSDimitry Andric Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
10395ffd83dbSDimitry Andric else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
10405ffd83dbSDimitry Andric Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
1041480093f4SDimitry Andric }
10425ffd83dbSDimitry Andric
1043e8d8bef9SDimitry Andric if (ORE) {
1044e8d8bef9SDimitry Andric RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
10455ffd83dbSDimitry Andric RemarkGen.emitRemarks();
1046e8d8bef9SDimitry Andric }
1047480093f4SDimitry Andric
1048fe6060f1SDimitry Andric // Delete the instructions backwards, as it has a reduced likelihood of
1049fe6060f1SDimitry Andric // having to update as many def-use and use-def chains.
1050fe6060f1SDimitry Andric //
1051fe6060f1SDimitry Andric // Because we add to ToRemove during fusion we can't guarantee that defs
105281ad6265SDimitry Andric // are before uses. Change uses to poison temporarily as these should get
1053fe6060f1SDimitry Andric // removed as well.
1054fe6060f1SDimitry Andric //
105581ad6265SDimitry Andric // For verification, we keep track of where we changed uses to poison in
105681ad6265SDimitry Andric // PoisonedInsts and then check that we in fact remove them.
105781ad6265SDimitry Andric SmallSet<Instruction *, 16> PoisonedInsts;
1058fe6060f1SDimitry Andric for (auto *Inst : reverse(ToRemove)) {
1059349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
106081ad6265SDimitry Andric if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
106181ad6265SDimitry Andric PoisonedInsts.insert(Poisoned);
106281ad6265SDimitry Andric U.set(PoisonValue::get(Inst->getType()));
1063fe6060f1SDimitry Andric }
1064480093f4SDimitry Andric Inst->eraseFromParent();
106581ad6265SDimitry Andric PoisonedInsts.erase(Inst);
1066fe6060f1SDimitry Andric }
106781ad6265SDimitry Andric if (!PoisonedInsts.empty()) {
106881ad6265SDimitry Andric // If we didn't remove all poisoned instructions, it's a hard error.
106981ad6265SDimitry Andric dbgs() << "Poisoned but present instructions:\n";
107081ad6265SDimitry Andric for (auto *I : PoisonedInsts)
1071fe6060f1SDimitry Andric dbgs() << *I << "\n";
107281ad6265SDimitry Andric llvm_unreachable("Poisoned but instruction not removed");
1073fe6060f1SDimitry Andric }
1074480093f4SDimitry Andric
1075480093f4SDimitry Andric return Changed;
1076480093f4SDimitry Andric }
1077480093f4SDimitry Andric
1078480093f4SDimitry Andric /// Replace intrinsic calls
VisitCallInst(CallInst * Inst)1079480093f4SDimitry Andric bool VisitCallInst(CallInst *Inst) {
1080480093f4SDimitry Andric if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1081480093f4SDimitry Andric return false;
1082480093f4SDimitry Andric
1083480093f4SDimitry Andric switch (Inst->getCalledFunction()->getIntrinsicID()) {
1084480093f4SDimitry Andric case Intrinsic::matrix_multiply:
1085480093f4SDimitry Andric LowerMultiply(Inst);
1086480093f4SDimitry Andric break;
1087480093f4SDimitry Andric case Intrinsic::matrix_transpose:
1088480093f4SDimitry Andric LowerTranspose(Inst);
1089480093f4SDimitry Andric break;
10905ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load:
10915ffd83dbSDimitry Andric LowerColumnMajorLoad(Inst);
1092480093f4SDimitry Andric break;
10935ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store:
10945ffd83dbSDimitry Andric LowerColumnMajorStore(Inst);
1095480093f4SDimitry Andric break;
1096480093f4SDimitry Andric default:
1097480093f4SDimitry Andric return false;
1098480093f4SDimitry Andric }
1099480093f4SDimitry Andric return true;
1100480093f4SDimitry Andric }
1101480093f4SDimitry Andric
11025ffd83dbSDimitry Andric /// Compute the alignment for a column/row \p Idx with \p Stride between them.
11035ffd83dbSDimitry Andric /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
11045ffd83dbSDimitry Andric /// ConstantInt, reduce the initial alignment based on the byte offset. For
11055ffd83dbSDimitry Andric /// non-ConstantInt strides, return the common alignment of the initial
11065ffd83dbSDimitry Andric /// alignment and the element size in bytes.
getAlignForIndex(unsigned Idx,Value * Stride,Type * ElementTy,MaybeAlign A) const11075ffd83dbSDimitry Andric Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
11085ffd83dbSDimitry Andric MaybeAlign A) const {
11095ffd83dbSDimitry Andric Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
11105ffd83dbSDimitry Andric if (Idx == 0)
11115ffd83dbSDimitry Andric return InitialAlign;
11125ffd83dbSDimitry Andric
11135ffd83dbSDimitry Andric TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
11145ffd83dbSDimitry Andric if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
11155ffd83dbSDimitry Andric uint64_t StrideInBytes =
11165ffd83dbSDimitry Andric ConstStride->getZExtValue() * ElementSizeInBits / 8;
11175ffd83dbSDimitry Andric return commonAlignment(InitialAlign, Idx * StrideInBytes);
11185ffd83dbSDimitry Andric }
11195ffd83dbSDimitry Andric return commonAlignment(InitialAlign, ElementSizeInBits / 8);
11205ffd83dbSDimitry Andric }
11215ffd83dbSDimitry Andric
11225ffd83dbSDimitry Andric /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
11235ffd83dbSDimitry Andric /// vectors.
loadMatrix(Type * Ty,Value * Ptr,MaybeAlign MAlign,Value * Stride,bool IsVolatile,ShapeInfo Shape,IRBuilder<> & Builder)11245ffd83dbSDimitry Andric MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
11255ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1126fe6060f1SDimitry Andric auto *VType = cast<VectorType>(Ty);
1127fe6060f1SDimitry Andric Type *EltTy = VType->getElementType();
1128fe6060f1SDimitry Andric Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
11295f757f3fSDimitry Andric Value *EltPtr = Ptr;
11305ffd83dbSDimitry Andric MatrixTy Result;
11315ffd83dbSDimitry Andric for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1132349cc55cSDimitry Andric Value *GEP = computeVectorAddr(
1133349cc55cSDimitry Andric EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1134349cc55cSDimitry Andric Stride, Shape.getStride(), EltTy, Builder);
11355ffd83dbSDimitry Andric Value *Vector = Builder.CreateAlignedLoad(
1136fe6060f1SDimitry Andric VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
11375ffd83dbSDimitry Andric IsVolatile, "col.load");
11385ffd83dbSDimitry Andric
11395ffd83dbSDimitry Andric Result.addVector(Vector);
11405ffd83dbSDimitry Andric }
11415ffd83dbSDimitry Andric return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
11425ffd83dbSDimitry Andric Result.getNumVectors());
1143480093f4SDimitry Andric }
1144480093f4SDimitry Andric
11455ffd83dbSDimitry Andric /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
11465ffd83dbSDimitry Andric /// starting at \p MatrixPtr[I][J].
loadMatrix(Value * MatrixPtr,MaybeAlign Align,bool IsVolatile,ShapeInfo MatrixShape,Value * I,Value * J,ShapeInfo ResultShape,Type * EltTy,IRBuilder<> & Builder)11475ffd83dbSDimitry Andric MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
11485ffd83dbSDimitry Andric ShapeInfo MatrixShape, Value *I, Value *J,
11495ffd83dbSDimitry Andric ShapeInfo ResultShape, Type *EltTy,
11505ffd83dbSDimitry Andric IRBuilder<> &Builder) {
11515ffd83dbSDimitry Andric
11525ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd(
11535ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
11545ffd83dbSDimitry Andric
11555f757f3fSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
11565ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
11575ffd83dbSDimitry Andric ResultShape.NumColumns);
11585ffd83dbSDimitry Andric
11595f757f3fSDimitry Andric return loadMatrix(TileTy, TileStart, Align,
11605ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile,
11615ffd83dbSDimitry Andric ResultShape, Builder);
1162480093f4SDimitry Andric }
1163480093f4SDimitry Andric
11645ffd83dbSDimitry Andric /// Lower a load instruction with shape information.
LowerLoad(Instruction * Inst,Value * Ptr,MaybeAlign Align,Value * Stride,bool IsVolatile,ShapeInfo Shape)11655ffd83dbSDimitry Andric void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
11665ffd83dbSDimitry Andric bool IsVolatile, ShapeInfo Shape) {
11675ffd83dbSDimitry Andric IRBuilder<> Builder(Inst);
11685ffd83dbSDimitry Andric finalizeLowering(Inst,
11695ffd83dbSDimitry Andric loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
11705ffd83dbSDimitry Andric Shape, Builder),
11715ffd83dbSDimitry Andric Builder);
11725ffd83dbSDimitry Andric }
11735ffd83dbSDimitry Andric
11745ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.load.
1175480093f4SDimitry Andric ///
1176480093f4SDimitry Andric /// The intrinsic loads a matrix from memory using a stride between columns.
LowerColumnMajorLoad(CallInst * Inst)11775ffd83dbSDimitry Andric void LowerColumnMajorLoad(CallInst *Inst) {
11785ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
11795ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!");
1180480093f4SDimitry Andric Value *Ptr = Inst->getArgOperand(0);
1181480093f4SDimitry Andric Value *Stride = Inst->getArgOperand(1);
11825ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
11835ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1184480093f4SDimitry Andric {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1185480093f4SDimitry Andric }
1186480093f4SDimitry Andric
11875ffd83dbSDimitry Andric /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
11885ffd83dbSDimitry Andric /// MatrixPtr[I][J].
storeMatrix(const MatrixTy & StoreVal,Value * MatrixPtr,MaybeAlign MAlign,bool IsVolatile,ShapeInfo MatrixShape,Value * I,Value * J,Type * EltTy,IRBuilder<> & Builder)11895ffd83dbSDimitry Andric void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
11905ffd83dbSDimitry Andric MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
11915ffd83dbSDimitry Andric Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
11925ffd83dbSDimitry Andric Value *Offset = Builder.CreateAdd(
11935ffd83dbSDimitry Andric Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
11945ffd83dbSDimitry Andric
11955f757f3fSDimitry Andric Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
11965ffd83dbSDimitry Andric auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
11975ffd83dbSDimitry Andric StoreVal.getNumColumns());
11985ffd83dbSDimitry Andric
11995f757f3fSDimitry Andric storeMatrix(TileTy, StoreVal, TileStart, MAlign,
12005ffd83dbSDimitry Andric Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
12015ffd83dbSDimitry Andric }
12025ffd83dbSDimitry Andric
12035ffd83dbSDimitry Andric /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
12045ffd83dbSDimitry Andric /// vectors.
storeMatrix(Type * Ty,MatrixTy StoreVal,Value * Ptr,MaybeAlign MAlign,Value * Stride,bool IsVolatile,IRBuilder<> & Builder)12055ffd83dbSDimitry Andric MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
12065ffd83dbSDimitry Andric MaybeAlign MAlign, Value *Stride, bool IsVolatile,
12075ffd83dbSDimitry Andric IRBuilder<> &Builder) {
12085ffd83dbSDimitry Andric auto VType = cast<VectorType>(Ty);
12095f757f3fSDimitry Andric Value *EltPtr = Ptr;
12105ffd83dbSDimitry Andric for (auto Vec : enumerate(StoreVal.vectors())) {
1211349cc55cSDimitry Andric Value *GEP = computeVectorAddr(
1212349cc55cSDimitry Andric EltPtr,
1213349cc55cSDimitry Andric Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1214349cc55cSDimitry Andric Vec.index()),
1215349cc55cSDimitry Andric Stride, StoreVal.getStride(), VType->getElementType(), Builder);
12165ffd83dbSDimitry Andric Builder.CreateAlignedStore(Vec.value(), GEP,
12175ffd83dbSDimitry Andric getAlignForIndex(Vec.index(), Stride,
12185ffd83dbSDimitry Andric VType->getElementType(),
12195ffd83dbSDimitry Andric MAlign),
12205ffd83dbSDimitry Andric IsVolatile);
12215ffd83dbSDimitry Andric }
12225ffd83dbSDimitry Andric return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
12235ffd83dbSDimitry Andric StoreVal.getNumVectors());
12245ffd83dbSDimitry Andric }
12255ffd83dbSDimitry Andric
12265ffd83dbSDimitry Andric /// Lower a store instruction with shape information.
LowerStore(Instruction * Inst,Value * Matrix,Value * Ptr,MaybeAlign A,Value * Stride,bool IsVolatile,ShapeInfo Shape)12275ffd83dbSDimitry Andric void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
12285ffd83dbSDimitry Andric Value *Stride, bool IsVolatile, ShapeInfo Shape) {
12295ffd83dbSDimitry Andric IRBuilder<> Builder(Inst);
12305ffd83dbSDimitry Andric auto StoreVal = getMatrix(Matrix, Shape, Builder);
12315ffd83dbSDimitry Andric finalizeLowering(Inst,
12325ffd83dbSDimitry Andric storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
12335ffd83dbSDimitry Andric IsVolatile, Builder),
12345ffd83dbSDimitry Andric Builder);
12355ffd83dbSDimitry Andric }
12365ffd83dbSDimitry Andric
12375ffd83dbSDimitry Andric /// Lowers llvm.matrix.column.major.store.
12385ffd83dbSDimitry Andric ///
12395ffd83dbSDimitry Andric /// The intrinsic store a matrix back memory using a stride between columns.
LowerColumnMajorStore(CallInst * Inst)12405ffd83dbSDimitry Andric void LowerColumnMajorStore(CallInst *Inst) {
12415ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
12425ffd83dbSDimitry Andric "Intrinsic only supports column-major layout!");
12435ffd83dbSDimitry Andric Value *Matrix = Inst->getArgOperand(0);
12445ffd83dbSDimitry Andric Value *Ptr = Inst->getArgOperand(1);
12455ffd83dbSDimitry Andric Value *Stride = Inst->getArgOperand(2);
12465ffd83dbSDimitry Andric LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
12475ffd83dbSDimitry Andric cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
12485ffd83dbSDimitry Andric {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1249480093f4SDimitry Andric }
1250480093f4SDimitry Andric
1251480093f4SDimitry Andric // Set elements I..I+NumElts-1 to Block
insertVector(Value * Col,unsigned I,Value * Block,IRBuilder<> & Builder)1252480093f4SDimitry Andric Value *insertVector(Value *Col, unsigned I, Value *Block,
12535ffd83dbSDimitry Andric IRBuilder<> &Builder) {
1254480093f4SDimitry Andric
1255480093f4SDimitry Andric // First, bring Block to the same size as Col
1256480093f4SDimitry Andric unsigned BlockNumElts =
12575ffd83dbSDimitry Andric cast<FixedVectorType>(Block->getType())->getNumElements();
12585ffd83dbSDimitry Andric unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1259480093f4SDimitry Andric assert(NumElts >= BlockNumElts && "Too few elements for current block");
1260480093f4SDimitry Andric
12615ffd83dbSDimitry Andric Block = Builder.CreateShuffleVector(
1262e8d8bef9SDimitry Andric Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1263480093f4SDimitry Andric
1264480093f4SDimitry Andric // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1265480093f4SDimitry Andric // 8, 4, 5, 6
12665ffd83dbSDimitry Andric SmallVector<int, 16> Mask;
1267480093f4SDimitry Andric unsigned i;
1268480093f4SDimitry Andric for (i = 0; i < I; i++)
12695ffd83dbSDimitry Andric Mask.push_back(i);
1270480093f4SDimitry Andric
12715ffd83dbSDimitry Andric unsigned VecNumElts =
12725ffd83dbSDimitry Andric cast<FixedVectorType>(Col->getType())->getNumElements();
1273480093f4SDimitry Andric for (; i < I + BlockNumElts; i++)
12745ffd83dbSDimitry Andric Mask.push_back(i - I + VecNumElts);
1275480093f4SDimitry Andric
1276480093f4SDimitry Andric for (; i < VecNumElts; i++)
12775ffd83dbSDimitry Andric Mask.push_back(i);
1278480093f4SDimitry Andric
12795ffd83dbSDimitry Andric return Builder.CreateShuffleVector(Col, Block, Mask);
1280480093f4SDimitry Andric }
1281480093f4SDimitry Andric
createMulAdd(Value * Sum,Value * A,Value * B,bool UseFPOp,IRBuilder<> & Builder,bool AllowContraction,unsigned & NumComputeOps)1282480093f4SDimitry Andric Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
12835ffd83dbSDimitry Andric IRBuilder<> &Builder, bool AllowContraction,
12845ffd83dbSDimitry Andric unsigned &NumComputeOps) {
12855ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType());
1286480093f4SDimitry Andric if (!Sum)
1287480093f4SDimitry Andric return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1288480093f4SDimitry Andric
1289480093f4SDimitry Andric if (UseFPOp) {
1290480093f4SDimitry Andric if (AllowContraction) {
1291480093f4SDimitry Andric // Use fmuladd for floating point operations and let the backend decide
1292480093f4SDimitry Andric // if that's profitable.
12935ffd83dbSDimitry Andric Function *FMulAdd = Intrinsic::getDeclaration(
1294480093f4SDimitry Andric Func.getParent(), Intrinsic::fmuladd, A->getType());
1295480093f4SDimitry Andric return Builder.CreateCall(FMulAdd, {A, B, Sum});
1296480093f4SDimitry Andric }
12975ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType());
1298480093f4SDimitry Andric Value *Mul = Builder.CreateFMul(A, B);
1299480093f4SDimitry Andric return Builder.CreateFAdd(Sum, Mul);
1300480093f4SDimitry Andric }
1301480093f4SDimitry Andric
13025ffd83dbSDimitry Andric NumComputeOps += getNumOps(A->getType());
1303480093f4SDimitry Andric Value *Mul = Builder.CreateMul(A, B);
1304480093f4SDimitry Andric return Builder.CreateAdd(Sum, Mul);
1305480093f4SDimitry Andric }
1306480093f4SDimitry Andric
1307480093f4SDimitry Andric /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1308fe6060f1SDimitry Andric /// users with shape information, there's nothing to do: they will use the
1309480093f4SDimitry Andric /// cached value when they are lowered. For other users, \p Matrix is
1310480093f4SDimitry Andric /// flattened and the uses are updated to use it. Also marks \p Inst for
1311480093f4SDimitry Andric /// deletion.
finalizeLowering(Instruction * Inst,MatrixTy Matrix,IRBuilder<> & Builder)13125ffd83dbSDimitry Andric void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1313480093f4SDimitry Andric IRBuilder<> &Builder) {
1314fe6060f1SDimitry Andric auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1315fe6060f1SDimitry Andric (void)inserted;
1316fe6060f1SDimitry Andric assert(inserted.second && "multiple matrix lowering mapping");
1317480093f4SDimitry Andric
1318480093f4SDimitry Andric ToRemove.push_back(Inst);
1319480093f4SDimitry Andric Value *Flattened = nullptr;
1320fe6060f1SDimitry Andric for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1321480093f4SDimitry Andric if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1322480093f4SDimitry Andric if (!Flattened)
1323480093f4SDimitry Andric Flattened = Matrix.embedInVector(Builder);
1324480093f4SDimitry Andric U.set(Flattened);
1325480093f4SDimitry Andric }
1326480093f4SDimitry Andric }
1327480093f4SDimitry Andric }
1328480093f4SDimitry Andric
132906c3fb27SDimitry Andric /// Special case for MatMul lowering. Prevents scalar loads of row-major
133006c3fb27SDimitry Andric /// vectors Lowers to vector reduction add instead of sequential add if
133106c3fb27SDimitry Andric /// reassocation is enabled.
lowerDotProduct(CallInst * MatMul,SmallPtrSet<Instruction *,16> & FusedInsts,FastMathFlags FMF)133206c3fb27SDimitry Andric void lowerDotProduct(CallInst *MatMul,
133306c3fb27SDimitry Andric SmallPtrSet<Instruction *, 16> &FusedInsts,
133406c3fb27SDimitry Andric FastMathFlags FMF) {
133506c3fb27SDimitry Andric if (FusedInsts.contains(MatMul) ||
133606c3fb27SDimitry Andric MatrixLayout != MatrixLayoutTy::ColumnMajor)
133706c3fb27SDimitry Andric return;
133806c3fb27SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
133906c3fb27SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
134006c3fb27SDimitry Andric
134106c3fb27SDimitry Andric if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
134206c3fb27SDimitry Andric return;
134306c3fb27SDimitry Andric
134406c3fb27SDimitry Andric Value *LHS = MatMul->getArgOperand(0);
134506c3fb27SDimitry Andric Value *RHS = MatMul->getArgOperand(1);
134606c3fb27SDimitry Andric
134706c3fb27SDimitry Andric Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
134806c3fb27SDimitry Andric bool IsIntVec = ElementType->isIntegerTy();
134906c3fb27SDimitry Andric
135006c3fb27SDimitry Andric // Floating point reductions require reassocation.
135106c3fb27SDimitry Andric if (!IsIntVec && !FMF.allowReassoc())
135206c3fb27SDimitry Andric return;
135306c3fb27SDimitry Andric
1354*0fca6ea1SDimitry Andric auto CanBeFlattened = [](Value *Op) {
1355*0fca6ea1SDimitry Andric if (match(Op, m_BinOp()))
135606c3fb27SDimitry Andric return true;
135706c3fb27SDimitry Andric return match(
135806c3fb27SDimitry Andric Op, m_OneUse(m_CombineOr(
135906c3fb27SDimitry Andric m_Load(m_Value()),
136006c3fb27SDimitry Andric m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
136106c3fb27SDimitry Andric m_Intrinsic<Intrinsic::matrix_column_major_load>(
136206c3fb27SDimitry Andric m_Value(), m_SpecificInt(1))))));
136306c3fb27SDimitry Andric };
136406c3fb27SDimitry Andric // Returns the cost benefit of using \p Op with the dot product lowering. If
136506c3fb27SDimitry Andric // the returned cost is < 0, the argument is cheaper to use in the
136606c3fb27SDimitry Andric // dot-product lowering.
136706c3fb27SDimitry Andric auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1368*0fca6ea1SDimitry Andric if (ShapeMap.find(Op) == ShapeMap.end())
1369*0fca6ea1SDimitry Andric return InstructionCost::getInvalid();
1370*0fca6ea1SDimitry Andric
137106c3fb27SDimitry Andric if (!isa<Instruction>(Op))
137206c3fb27SDimitry Andric return InstructionCost(0);
137306c3fb27SDimitry Andric
137406c3fb27SDimitry Andric FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
137506c3fb27SDimitry Andric Type *EltTy = VecTy->getElementType();
137606c3fb27SDimitry Andric
137706c3fb27SDimitry Andric if (!CanBeFlattened(Op)) {
137806c3fb27SDimitry Andric InstructionCost EmbedCost(0);
137906c3fb27SDimitry Andric // Roughly estimate the cost for embedding the columns into a vector.
138006c3fb27SDimitry Andric for (unsigned I = 1; I < N; ++I)
1381*0fca6ea1SDimitry Andric EmbedCost +=
138206c3fb27SDimitry Andric TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
138306c3fb27SDimitry Andric std::nullopt, TTI::TCK_RecipThroughput);
138406c3fb27SDimitry Andric return EmbedCost;
138506c3fb27SDimitry Andric }
138606c3fb27SDimitry Andric
138706c3fb27SDimitry Andric if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
138806c3fb27SDimitry Andric InstructionCost OriginalCost =
138906c3fb27SDimitry Andric TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
139006c3fb27SDimitry Andric EltTy) *
139106c3fb27SDimitry Andric N;
139206c3fb27SDimitry Andric InstructionCost NewCost = TTI.getArithmeticInstrCost(
139306c3fb27SDimitry Andric cast<Instruction>(Op)->getOpcode(), VecTy);
139406c3fb27SDimitry Andric return NewCost - OriginalCost;
139506c3fb27SDimitry Andric }
139606c3fb27SDimitry Andric
139706c3fb27SDimitry Andric if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
139806c3fb27SDimitry Andric // The transpose can be skipped for the dot product lowering, roughly
139906c3fb27SDimitry Andric // estimate the savings as the cost of embedding the columns in a
140006c3fb27SDimitry Andric // vector.
140106c3fb27SDimitry Andric InstructionCost EmbedCost(0);
140206c3fb27SDimitry Andric for (unsigned I = 1; I < N; ++I)
1403*0fca6ea1SDimitry Andric EmbedCost -=
140406c3fb27SDimitry Andric TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
140506c3fb27SDimitry Andric std::nullopt, TTI::TCK_RecipThroughput);
140606c3fb27SDimitry Andric return EmbedCost;
140706c3fb27SDimitry Andric }
140806c3fb27SDimitry Andric
140906c3fb27SDimitry Andric // Costs for loads.
141006c3fb27SDimitry Andric if (N == 1)
141106c3fb27SDimitry Andric return InstructionCost(0);
141206c3fb27SDimitry Andric
141306c3fb27SDimitry Andric return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
141406c3fb27SDimitry Andric N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
141506c3fb27SDimitry Andric };
1416*0fca6ea1SDimitry Andric
1417*0fca6ea1SDimitry Andric // Iterate over LHS and operations feeding LHS and check if it is profitable
1418*0fca6ea1SDimitry Andric // to flatten the visited ops. For each op, we compute the difference
1419*0fca6ea1SDimitry Andric // between the flattened and matrix versions.
1420*0fca6ea1SDimitry Andric SmallPtrSet<Value *, 4> Seen;
1421*0fca6ea1SDimitry Andric SmallVector<Value *> WorkList;
1422*0fca6ea1SDimitry Andric SmallVector<Value *> ToFlatten;
1423*0fca6ea1SDimitry Andric WorkList.push_back(LHS);
1424*0fca6ea1SDimitry Andric InstructionCost LHSCost(0);
1425*0fca6ea1SDimitry Andric while (!WorkList.empty()) {
1426*0fca6ea1SDimitry Andric Value *Op = WorkList.pop_back_val();
1427*0fca6ea1SDimitry Andric if (!Seen.insert(Op).second)
1428*0fca6ea1SDimitry Andric continue;
1429*0fca6ea1SDimitry Andric
1430*0fca6ea1SDimitry Andric InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1431*0fca6ea1SDimitry Andric if (OpCost + LHSCost >= LHSCost)
1432*0fca6ea1SDimitry Andric continue;
1433*0fca6ea1SDimitry Andric
1434*0fca6ea1SDimitry Andric LHSCost += OpCost;
1435*0fca6ea1SDimitry Andric ToFlatten.push_back(Op);
1436*0fca6ea1SDimitry Andric if (auto *I = dyn_cast<Instruction>(Op))
1437*0fca6ea1SDimitry Andric WorkList.append(I->op_begin(), I->op_end());
1438*0fca6ea1SDimitry Andric }
143906c3fb27SDimitry Andric
144006c3fb27SDimitry Andric // We compare the costs of a vector.reduce.add to sequential add.
144106c3fb27SDimitry Andric int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
144206c3fb27SDimitry Andric int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
144306c3fb27SDimitry Andric InstructionCost ReductionCost =
144406c3fb27SDimitry Andric TTI.getArithmeticReductionCost(
144506c3fb27SDimitry Andric AddOpCode, cast<VectorType>(LHS->getType()),
144606c3fb27SDimitry Andric IsIntVec ? std::nullopt : std::optional(FMF)) +
144706c3fb27SDimitry Andric TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
144806c3fb27SDimitry Andric InstructionCost SequentialAddCost =
144906c3fb27SDimitry Andric TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
145006c3fb27SDimitry Andric (LShape.NumColumns - 1) +
145106c3fb27SDimitry Andric TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
145206c3fb27SDimitry Andric (LShape.NumColumns);
145306c3fb27SDimitry Andric if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
145406c3fb27SDimitry Andric return;
145506c3fb27SDimitry Andric
145606c3fb27SDimitry Andric FusedInsts.insert(MatMul);
145706c3fb27SDimitry Andric IRBuilder<> Builder(MatMul);
145806c3fb27SDimitry Andric auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1459*0fca6ea1SDimitry Andric this](Value *Op) {
146006c3fb27SDimitry Andric // Matmul must be the only user of loads because we don't use LowerLoad
146106c3fb27SDimitry Andric // for row vectors (LowerLoad results in scalar loads and shufflevectors
146206c3fb27SDimitry Andric // instead of single vector load).
146306c3fb27SDimitry Andric if (!CanBeFlattened(Op))
1464*0fca6ea1SDimitry Andric return;
146506c3fb27SDimitry Andric
146606c3fb27SDimitry Andric if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
146706c3fb27SDimitry Andric ShapeMap[Op] = ShapeMap[Op].t();
1468*0fca6ea1SDimitry Andric return;
146906c3fb27SDimitry Andric }
147006c3fb27SDimitry Andric
147106c3fb27SDimitry Andric FusedInsts.insert(cast<Instruction>(Op));
147206c3fb27SDimitry Andric // If vector uses the builtin load, lower to a LoadInst
147306c3fb27SDimitry Andric Value *Arg;
147406c3fb27SDimitry Andric if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
147506c3fb27SDimitry Andric m_Value(Arg)))) {
147606c3fb27SDimitry Andric auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
147706c3fb27SDimitry Andric Op->replaceAllUsesWith(NewLoad);
147806c3fb27SDimitry Andric cast<Instruction>(Op)->eraseFromParent();
1479*0fca6ea1SDimitry Andric return;
148006c3fb27SDimitry Andric } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
148106c3fb27SDimitry Andric m_Value(Arg)))) {
148206c3fb27SDimitry Andric ToRemove.push_back(cast<Instruction>(Op));
1483*0fca6ea1SDimitry Andric Op->replaceAllUsesWith(Arg);
1484*0fca6ea1SDimitry Andric return;
148506c3fb27SDimitry Andric }
148606c3fb27SDimitry Andric };
1487*0fca6ea1SDimitry Andric
1488*0fca6ea1SDimitry Andric for (auto *V : ToFlatten)
1489*0fca6ea1SDimitry Andric FlattenArg(V);
1490*0fca6ea1SDimitry Andric
1491*0fca6ea1SDimitry Andric LHS = MatMul->getArgOperand(0);
149206c3fb27SDimitry Andric
149306c3fb27SDimitry Andric // Insert mul/fmul and llvm.vector.reduce.fadd
149406c3fb27SDimitry Andric Value *Mul =
149506c3fb27SDimitry Andric IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
149606c3fb27SDimitry Andric
149706c3fb27SDimitry Andric Value *Result;
149806c3fb27SDimitry Andric if (IsIntVec)
149906c3fb27SDimitry Andric Result = Builder.CreateAddReduce(Mul);
150006c3fb27SDimitry Andric else {
150106c3fb27SDimitry Andric Result = Builder.CreateFAddReduce(
150206c3fb27SDimitry Andric ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
150306c3fb27SDimitry Andric 0.0),
150406c3fb27SDimitry Andric Mul);
150506c3fb27SDimitry Andric cast<Instruction>(Result)->setFastMathFlags(FMF);
150606c3fb27SDimitry Andric }
150706c3fb27SDimitry Andric
150806c3fb27SDimitry Andric // pack scalar back into a matrix and then replace matmul inst
150906c3fb27SDimitry Andric Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
151006c3fb27SDimitry Andric Result, uint64_t(0));
151106c3fb27SDimitry Andric MatMul->replaceAllUsesWith(Result);
151206c3fb27SDimitry Andric FusedInsts.insert(MatMul);
151306c3fb27SDimitry Andric ToRemove.push_back(MatMul);
151406c3fb27SDimitry Andric }
151506c3fb27SDimitry Andric
15165ffd83dbSDimitry Andric /// Compute \p Result += \p A * \p B for input matrices with left-associating
15175ffd83dbSDimitry Andric /// addition.
1518fe6060f1SDimitry Andric ///
1519fe6060f1SDimitry Andric /// We can fold a transpose into the operand that is used to extract scalars.
1520fe6060f1SDimitry Andric /// This is the first operands with row-major and the second with
1521fe6060f1SDimitry Andric /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1522fe6060f1SDimitry Andric /// operand is transposed.
emitMatrixMultiply(MatrixTy & Result,const MatrixTy & A,const MatrixTy & B,IRBuilder<> & Builder,bool IsTiled,bool IsScalarMatrixTransposed,FastMathFlags FMF)15235ffd83dbSDimitry Andric void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1524fe6060f1SDimitry Andric const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1525fe6060f1SDimitry Andric bool IsScalarMatrixTransposed, FastMathFlags FMF) {
15265ffd83dbSDimitry Andric const unsigned VF = std::max<unsigned>(
1527fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1528bdd1243dSDimitry Andric .getFixedValue() /
1529bdd1243dSDimitry Andric Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
15305ffd83dbSDimitry Andric 1U);
15315ffd83dbSDimitry Andric unsigned R = Result.getNumRows();
15325ffd83dbSDimitry Andric unsigned C = Result.getNumColumns();
15335ffd83dbSDimitry Andric unsigned M = A.getNumColumns();
15345ffd83dbSDimitry Andric
15355ffd83dbSDimitry Andric bool IsFP = Result.getElementType()->isFloatingPointTy();
15365ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() &&
15375ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() &&
15385ffd83dbSDimitry Andric "operands must agree on matrix layout");
15395ffd83dbSDimitry Andric unsigned NumComputeOps = 0;
1540fe6060f1SDimitry Andric
1541fe6060f1SDimitry Andric Builder.setFastMathFlags(FMF);
1542fe6060f1SDimitry Andric
15435ffd83dbSDimitry Andric if (A.isColumnMajor()) {
15445ffd83dbSDimitry Andric // Multiply columns from the first operand with scalars from the second
15455ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the columns. With
15465ffd83dbSDimitry Andric // this the adds can be vectorized without reassociation.
15475ffd83dbSDimitry Andric for (unsigned J = 0; J < C; ++J) {
15485ffd83dbSDimitry Andric unsigned BlockSize = VF;
15495ffd83dbSDimitry Andric // If Result is zero, we don't need to accumulate in the K==0 iteration.
15505ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
15515ffd83dbSDimitry Andric
15525ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += BlockSize) {
15535ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder.
15545ffd83dbSDimitry Andric while (I + BlockSize > R)
15555ffd83dbSDimitry Andric BlockSize /= 2;
15565ffd83dbSDimitry Andric
1557fe6060f1SDimitry Andric Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
15585ffd83dbSDimitry Andric : nullptr;
15595ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) {
15605ffd83dbSDimitry Andric Value *L = A.extractVector(I, K, BlockSize, Builder);
1561fe6060f1SDimitry Andric Value *RH = Builder.CreateExtractElement(
1562fe6060f1SDimitry Andric B.getColumn(IsScalarMatrixTransposed ? K : J),
1563fe6060f1SDimitry Andric IsScalarMatrixTransposed ? J : K);
15645ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1565fe6060f1SDimitry Andric Sum =
1566fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1567fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps);
15685ffd83dbSDimitry Andric }
15695ffd83dbSDimitry Andric Result.setVector(J,
15705ffd83dbSDimitry Andric insertVector(Result.getVector(J), I, Sum, Builder));
15715ffd83dbSDimitry Andric }
15725ffd83dbSDimitry Andric }
15735ffd83dbSDimitry Andric } else {
15745ffd83dbSDimitry Andric // Multiply rows from the second operand with scalars from the first
15755ffd83dbSDimitry Andric // operand. Then move along the K axes and accumulate the rows. With this
15765ffd83dbSDimitry Andric // the adds can be vectorized without reassociation.
15775ffd83dbSDimitry Andric for (unsigned I = 0; I < R; ++I) {
15785ffd83dbSDimitry Andric unsigned BlockSize = VF;
15795ffd83dbSDimitry Andric bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
15805ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += BlockSize) {
15815ffd83dbSDimitry Andric // Gradually lower the vectorization factor to cover the remainder.
15825ffd83dbSDimitry Andric while (J + BlockSize > C)
15835ffd83dbSDimitry Andric BlockSize /= 2;
15845ffd83dbSDimitry Andric
15855ffd83dbSDimitry Andric Value *Sum = nullptr;
15865ffd83dbSDimitry Andric for (unsigned K = 0; K < M; ++K) {
15875ffd83dbSDimitry Andric Value *R = B.extractVector(K, J, BlockSize, Builder);
1588fe6060f1SDimitry Andric Value *LH = Builder.CreateExtractElement(
1589fe6060f1SDimitry Andric A.getVector(IsScalarMatrixTransposed ? K : I),
1590fe6060f1SDimitry Andric IsScalarMatrixTransposed ? I : K);
15915ffd83dbSDimitry Andric Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1592fe6060f1SDimitry Andric Sum =
1593fe6060f1SDimitry Andric createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1594fe6060f1SDimitry Andric IsFP, Builder, FMF.allowContract(), NumComputeOps);
15955ffd83dbSDimitry Andric }
15965ffd83dbSDimitry Andric Result.setVector(I,
15975ffd83dbSDimitry Andric insertVector(Result.getVector(I), J, Sum, Builder));
15985ffd83dbSDimitry Andric }
15995ffd83dbSDimitry Andric }
16005ffd83dbSDimitry Andric }
16015ffd83dbSDimitry Andric Result.addNumComputeOps(NumComputeOps);
16025ffd83dbSDimitry Andric }
16035ffd83dbSDimitry Andric
16045ffd83dbSDimitry Andric /// Ensure that the memory in \p Load does not alias \p Store by potentially
16055ffd83dbSDimitry Andric /// copying it to a new location. This new or otherwise the original location
16065ffd83dbSDimitry Andric /// is returned.
getNonAliasingPointer(LoadInst * Load,StoreInst * Store,CallInst * MatMul)16075ffd83dbSDimitry Andric Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
16085ffd83dbSDimitry Andric CallInst *MatMul) {
16095ffd83dbSDimitry Andric MemoryLocation StoreLoc = MemoryLocation::get(Store);
16105ffd83dbSDimitry Andric MemoryLocation LoadLoc = MemoryLocation::get(Load);
16115ffd83dbSDimitry Andric
16125ffd83dbSDimitry Andric // If we can statically determine noalias we're good.
1613fe6060f1SDimitry Andric if (AA->isNoAlias(LoadLoc, StoreLoc))
16145ffd83dbSDimitry Andric return Load->getPointerOperand();
16155ffd83dbSDimitry Andric
16165ffd83dbSDimitry Andric // Create code to check if the memory locations of the Load and Store
16175ffd83dbSDimitry Andric // overlap and if they do, copy Load's operand to a new buffer.
16185ffd83dbSDimitry Andric
16195ffd83dbSDimitry Andric // First, create new blocks for 2n part of the check and the copy.
16205ffd83dbSDimitry Andric BasicBlock *Check0 = MatMul->getParent();
16215ffd83dbSDimitry Andric // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
16225ffd83dbSDimitry Andric // DT. Manually collect dominator tree updates, to avoid unnecessary work,
16235ffd83dbSDimitry Andric // as we adjust Check0 and Check1's branches.
16245ffd83dbSDimitry Andric SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
16255ffd83dbSDimitry Andric for (BasicBlock *Succ : successors(Check0))
1626e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Delete, Check0, Succ});
16275ffd83dbSDimitry Andric
1628e8d8bef9SDimitry Andric BasicBlock *Check1 =
1629e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
16305ffd83dbSDimitry Andric nullptr, "alias_cont");
16315ffd83dbSDimitry Andric BasicBlock *Copy =
1632e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1633e8d8bef9SDimitry Andric nullptr, "copy");
1634e8d8bef9SDimitry Andric BasicBlock *Fusion =
1635e8d8bef9SDimitry Andric SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
16365ffd83dbSDimitry Andric nullptr, "no_alias");
16375ffd83dbSDimitry Andric
16385ffd83dbSDimitry Andric // Check if the loaded memory location begins before the end of the store
16395ffd83dbSDimitry Andric // location. If the condition holds, they might overlap, otherwise they are
16405ffd83dbSDimitry Andric // guaranteed to not overlap.
16415ffd83dbSDimitry Andric IRBuilder<> Builder(MatMul);
16425ffd83dbSDimitry Andric Check0->getTerminator()->eraseFromParent();
16435ffd83dbSDimitry Andric Builder.SetInsertPoint(Check0);
1644*0fca6ea1SDimitry Andric Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
16455ffd83dbSDimitry Andric Value *StoreBegin = Builder.CreatePtrToInt(
16465ffd83dbSDimitry Andric const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
16475ffd83dbSDimitry Andric Value *StoreEnd = Builder.CreateAdd(
16485ffd83dbSDimitry Andric StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
16495ffd83dbSDimitry Andric "store.end", true, true);
16505ffd83dbSDimitry Andric Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
16515ffd83dbSDimitry Andric IntPtrTy, "load.begin");
16525ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
16535ffd83dbSDimitry Andric Fusion);
16545ffd83dbSDimitry Andric
16555ffd83dbSDimitry Andric // Check if the store begins before the end of the load location. If the
16565ffd83dbSDimitry Andric // condition holds, they alias, otherwise they are guaranteed to not
16575ffd83dbSDimitry Andric // overlap.
16585ffd83dbSDimitry Andric Check1->getTerminator()->eraseFromParent();
16595ffd83dbSDimitry Andric Builder.SetInsertPoint(Check1, Check1->begin());
16605ffd83dbSDimitry Andric Value *LoadEnd = Builder.CreateAdd(
16615ffd83dbSDimitry Andric LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
16625ffd83dbSDimitry Andric "load.end", true, true);
16635ffd83dbSDimitry Andric Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
16645ffd83dbSDimitry Andric Fusion);
16655ffd83dbSDimitry Andric
16665ffd83dbSDimitry Andric // Copy load operand to new alloca.
16675ffd83dbSDimitry Andric Builder.SetInsertPoint(Copy, Copy->begin());
16681fd87a68SDimitry Andric auto *VT = cast<FixedVectorType>(Load->getType());
16691fd87a68SDimitry Andric // Use an array type for the alloca, to avoid potentially huge alignment
16701fd87a68SDimitry Andric // requirements for large vector types.
16711fd87a68SDimitry Andric auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
16721fd87a68SDimitry Andric AllocaInst *Alloca =
16731fd87a68SDimitry Andric Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
16741fd87a68SDimitry Andric
167506c3fb27SDimitry Andric Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
16761fd87a68SDimitry Andric Load->getAlign(), LoadLoc.Size.getValue());
16775ffd83dbSDimitry Andric Builder.SetInsertPoint(Fusion, Fusion->begin());
16785ffd83dbSDimitry Andric PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
16795ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check0);
16805ffd83dbSDimitry Andric PHI->addIncoming(Load->getPointerOperand(), Check1);
168106c3fb27SDimitry Andric PHI->addIncoming(Alloca, Copy);
16825ffd83dbSDimitry Andric
16835ffd83dbSDimitry Andric // Adjust DT.
1684e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Check1});
1685e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check0, Fusion});
1686e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Copy});
1687e8d8bef9SDimitry Andric DTUpdates.push_back({DT->Insert, Check1, Fusion});
1688e8d8bef9SDimitry Andric DT->applyUpdates(DTUpdates);
16895ffd83dbSDimitry Andric return PHI;
16905ffd83dbSDimitry Andric }
16915ffd83dbSDimitry Andric
isFusionProfitable(CallInst * MatMul)16925ffd83dbSDimitry Andric bool isFusionProfitable(CallInst *MatMul) {
16935ffd83dbSDimitry Andric if (ForceFusion)
16945ffd83dbSDimitry Andric return true;
16955ffd83dbSDimitry Andric
16965ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
16975ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
16985ffd83dbSDimitry Andric
16995ffd83dbSDimitry Andric const unsigned R = LShape.NumRows;
17005ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns;
17015ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns;
17025ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
17035ffd83dbSDimitry Andric
1704fe6060f1SDimitry Andric const unsigned VF = std::max<unsigned>(
1705fe6060f1SDimitry Andric TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1706bdd1243dSDimitry Andric .getFixedValue() /
1707bdd1243dSDimitry Andric EltType->getPrimitiveSizeInBits().getFixedValue(),
17085ffd83dbSDimitry Andric 1U);
17095ffd83dbSDimitry Andric
17105ffd83dbSDimitry Andric // Cost model for tiling
17115ffd83dbSDimitry Andric //
17125ffd83dbSDimitry Andric // For tiling to be beneficial, we need reuse either along the R or
17135ffd83dbSDimitry Andric // the C axis. We vectorize along the R axis so that means at least
17145ffd83dbSDimitry Andric // 3 elements.
17155ffd83dbSDimitry Andric // TODO: Also consider cost of copying if operands alias.
17165ffd83dbSDimitry Andric if (R <= VF && C == 1)
17175ffd83dbSDimitry Andric return false;
17185ffd83dbSDimitry Andric // Then we need enough elements to exceed the number of vector
17195ffd83dbSDimitry Andric // registers we have. Note that this is an oversimplification since
17205ffd83dbSDimitry Andric // fusing also takes some extra loads which may exceed the number of
17215ffd83dbSDimitry Andric // reloads necessary.
17225ffd83dbSDimitry Andric unsigned Op0Regs = (R + VF - 1) / VF * M;
17235ffd83dbSDimitry Andric unsigned Op1Regs = (M + VF - 1) / VF * C;
172404eeddc0SDimitry Andric return Op0Regs + Op1Regs >
172504eeddc0SDimitry Andric TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
17265ffd83dbSDimitry Andric }
17275ffd83dbSDimitry Andric
getZeroMatrix(Type * EltType,unsigned R,unsigned C)17285ffd83dbSDimitry Andric MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
17295ffd83dbSDimitry Andric MatrixTy Res;
17305ffd83dbSDimitry Andric auto *ColumType = FixedVectorType::get(EltType, R);
17315ffd83dbSDimitry Andric for (unsigned I = 0; I < C; ++I)
17325ffd83dbSDimitry Andric Res.addVector(ConstantAggregateZero::get(ColumType));
17335ffd83dbSDimitry Andric return Res;
17345ffd83dbSDimitry Andric }
17355ffd83dbSDimitry Andric
createTiledLoops(CallInst * MatMul,Value * LPtr,ShapeInfo LShape,Value * RPtr,ShapeInfo RShape,StoreInst * Store)1736e8d8bef9SDimitry Andric void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1737fe6060f1SDimitry Andric Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1738e8d8bef9SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1739e8d8bef9SDimitry Andric
1740e8d8bef9SDimitry Andric // Create the main tiling loop nest.
1741e8d8bef9SDimitry Andric TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1742e8d8bef9SDimitry Andric DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1743e8d8bef9SDimitry Andric Instruction *InsertI = cast<Instruction>(MatMul);
1744e8d8bef9SDimitry Andric BasicBlock *Start = InsertI->getParent();
1745e8d8bef9SDimitry Andric BasicBlock *End =
1746e8d8bef9SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1747e8d8bef9SDimitry Andric IRBuilder<> Builder(MatMul);
1748e8d8bef9SDimitry Andric BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1749e8d8bef9SDimitry Andric
1750e8d8bef9SDimitry Andric Type *TileVecTy =
1751e8d8bef9SDimitry Andric FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
1752e8d8bef9SDimitry Andric MatrixTy TileResult;
1753e8d8bef9SDimitry Andric // Insert in the inner loop header.
1754972a253aSDimitry Andric Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1755e8d8bef9SDimitry Andric // Create PHI nodes for the result columns to accumulate across iterations.
1756e8d8bef9SDimitry Andric SmallVector<PHINode *, 4> ColumnPhis;
1757e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileSize; I++) {
1758e8d8bef9SDimitry Andric auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1759e8d8bef9SDimitry Andric Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1760972a253aSDimitry Andric TI.RowLoop.Header->getSingleSuccessor());
1761e8d8bef9SDimitry Andric TileResult.addVector(Phi);
1762e8d8bef9SDimitry Andric ColumnPhis.push_back(Phi);
1763e8d8bef9SDimitry Andric }
1764e8d8bef9SDimitry Andric
1765e8d8bef9SDimitry Andric // Insert in the inner loop body, which computes
1766e8d8bef9SDimitry Andric // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1767e8d8bef9SDimitry Andric Builder.SetInsertPoint(InnerBody->getTerminator());
1768e8d8bef9SDimitry Andric // Load tiles of the operands.
1769972a253aSDimitry Andric MatrixTy A =
1770972a253aSDimitry Andric loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1771e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder);
1772972a253aSDimitry Andric MatrixTy B =
1773972a253aSDimitry Andric loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1774e8d8bef9SDimitry Andric {TileSize, TileSize}, EltType, Builder);
1775fe6060f1SDimitry Andric emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1776fe6060f1SDimitry Andric getFastMathFlags(MatMul));
1777e8d8bef9SDimitry Andric // Store result after the inner loop is done.
1778972a253aSDimitry Andric Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1779e8d8bef9SDimitry Andric storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1780e8d8bef9SDimitry Andric Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1781972a253aSDimitry Andric TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1782e8d8bef9SDimitry Andric
1783e8d8bef9SDimitry Andric for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1784972a253aSDimitry Andric ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1785e8d8bef9SDimitry Andric
1786e8d8bef9SDimitry Andric // Force unrolling of a few iterations of the inner loop, to make sure there
1787e8d8bef9SDimitry Andric // is enough work per iteration.
1788e8d8bef9SDimitry Andric // FIXME: The unroller should make this decision directly instead, but
1789e8d8bef9SDimitry Andric // currently the cost-model is not up to the task.
1790e8d8bef9SDimitry Andric unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1791972a253aSDimitry Andric addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1792e8d8bef9SDimitry Andric "llvm.loop.unroll.count", InnerLoopUnrollCount);
1793e8d8bef9SDimitry Andric }
1794e8d8bef9SDimitry Andric
emitSIMDTiling(CallInst * MatMul,LoadInst * LoadOp0,LoadInst * LoadOp1,StoreInst * Store,SmallPtrSetImpl<Instruction * > & FusedInsts)17955ffd83dbSDimitry Andric void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
17965ffd83dbSDimitry Andric StoreInst *Store,
17975ffd83dbSDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts) {
17985ffd83dbSDimitry Andric assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
17995ffd83dbSDimitry Andric "Tiling only supported for column-major matrixes at the moment!");
18005ffd83dbSDimitry Andric if (!isFusionProfitable(MatMul))
18015ffd83dbSDimitry Andric return;
18025ffd83dbSDimitry Andric
18035ffd83dbSDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
18045ffd83dbSDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
18055ffd83dbSDimitry Andric
18065ffd83dbSDimitry Andric const unsigned R = LShape.NumRows;
18075ffd83dbSDimitry Andric const unsigned C = RShape.NumColumns;
18085ffd83dbSDimitry Andric const unsigned M = LShape.NumColumns;
18095ffd83dbSDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
18105ffd83dbSDimitry Andric
18115ffd83dbSDimitry Andric Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
18125ffd83dbSDimitry Andric Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
18135ffd83dbSDimitry Andric Value *CPtr = Store->getPointerOperand();
18145ffd83dbSDimitry Andric
1815e8d8bef9SDimitry Andric if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1816fe6060f1SDimitry Andric createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1817e8d8bef9SDimitry Andric else {
18185ffd83dbSDimitry Andric IRBuilder<> Builder(Store);
18195ffd83dbSDimitry Andric for (unsigned J = 0; J < C; J += TileSize)
18205ffd83dbSDimitry Andric for (unsigned I = 0; I < R; I += TileSize) {
18215ffd83dbSDimitry Andric const unsigned TileR = std::min(R - I, unsigned(TileSize));
18225ffd83dbSDimitry Andric const unsigned TileC = std::min(C - J, unsigned(TileSize));
18235ffd83dbSDimitry Andric MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
18245ffd83dbSDimitry Andric
18255ffd83dbSDimitry Andric for (unsigned K = 0; K < M; K += TileSize) {
18265ffd83dbSDimitry Andric const unsigned TileM = std::min(M - K, unsigned(TileSize));
18275ffd83dbSDimitry Andric MatrixTy A =
18285ffd83dbSDimitry Andric loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
18295ffd83dbSDimitry Andric LShape, Builder.getInt64(I), Builder.getInt64(K),
18305ffd83dbSDimitry Andric {TileR, TileM}, EltType, Builder);
18315ffd83dbSDimitry Andric MatrixTy B =
18325ffd83dbSDimitry Andric loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
18335ffd83dbSDimitry Andric RShape, Builder.getInt64(K), Builder.getInt64(J),
18345ffd83dbSDimitry Andric {TileM, TileC}, EltType, Builder);
1835fe6060f1SDimitry Andric emitMatrixMultiply(Res, A, B, Builder, true, false,
1836fe6060f1SDimitry Andric getFastMathFlags(MatMul));
18375ffd83dbSDimitry Andric }
18385ffd83dbSDimitry Andric storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1839e8d8bef9SDimitry Andric Builder.getInt64(I), Builder.getInt64(J), EltType,
1840e8d8bef9SDimitry Andric Builder);
1841e8d8bef9SDimitry Andric }
18425ffd83dbSDimitry Andric }
18435ffd83dbSDimitry Andric
18445ffd83dbSDimitry Andric // Mark eliminated instructions as fused and remove them.
18455ffd83dbSDimitry Andric FusedInsts.insert(Store);
18465ffd83dbSDimitry Andric FusedInsts.insert(MatMul);
18475ffd83dbSDimitry Andric Store->eraseFromParent();
18485ffd83dbSDimitry Andric MatMul->eraseFromParent();
18495ffd83dbSDimitry Andric if (LoadOp0->hasNUses(0)) {
18505ffd83dbSDimitry Andric FusedInsts.insert(LoadOp0);
18515ffd83dbSDimitry Andric LoadOp0->eraseFromParent();
18525ffd83dbSDimitry Andric }
1853fe6060f1SDimitry Andric if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
18545ffd83dbSDimitry Andric FusedInsts.insert(LoadOp1);
18555ffd83dbSDimitry Andric LoadOp1->eraseFromParent();
18565ffd83dbSDimitry Andric }
18575ffd83dbSDimitry Andric }
18585ffd83dbSDimitry Andric
18595ffd83dbSDimitry Andric /// Try to lower matrix multiply chains by fusing operations.
18605ffd83dbSDimitry Andric ///
1861fe6060f1SDimitry Andric /// Call finalizeLowering on lowered instructions. Instructions that are
1862fe6060f1SDimitry Andric /// completely eliminated by fusion are added to \p FusedInsts.
1863*0fca6ea1SDimitry Andric void
LowerMatrixMultiplyFused(CallInst * MatMul,SmallPtrSetImpl<Instruction * > & FusedInsts,SmallVector<IntrinsicInst *,16> & LifetimeEnds)1864*0fca6ea1SDimitry Andric LowerMatrixMultiplyFused(CallInst *MatMul,
1865*0fca6ea1SDimitry Andric SmallPtrSetImpl<Instruction *> &FusedInsts,
1866*0fca6ea1SDimitry Andric SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
1867fe6060f1SDimitry Andric if (!FuseMatrix || !DT)
18685ffd83dbSDimitry Andric return;
18695ffd83dbSDimitry Andric
1870e8d8bef9SDimitry Andric assert(AA && LI && "Analyses should be available");
1871e8d8bef9SDimitry Andric
1872fe6060f1SDimitry Andric Value *A = MatMul->getArgOperand(0);
1873fe6060f1SDimitry Andric Value *B = MatMul->getArgOperand(1);
1874fe6060f1SDimitry Andric
1875fe6060f1SDimitry Andric // We can fold the transpose into the operand that is used to fetch scalars.
1876fe6060f1SDimitry Andric Value *T;
1877fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1878fe6060f1SDimitry Andric ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1879fe6060f1SDimitry Andric : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1880fe6060f1SDimitry Andric IRBuilder<> Builder(MatMul);
1881fe6060f1SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1882fe6060f1SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1883fe6060f1SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1884fe6060f1SDimitry Andric const unsigned R = LShape.NumRows;
1885fe6060f1SDimitry Andric const unsigned M = LShape.NumColumns;
1886fe6060f1SDimitry Andric const unsigned C = RShape.NumColumns;
1887fe6060f1SDimitry Andric
1888fe6060f1SDimitry Andric MatrixTy MA;
1889fe6060f1SDimitry Andric MatrixTy MB;
1890fe6060f1SDimitry Andric
1891fe6060f1SDimitry Andric Value *Transpose;
1892fe6060f1SDimitry Andric if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1893fe6060f1SDimitry Andric MA = getMatrix(A, ShapeInfo(R, M), Builder);
1894fe6060f1SDimitry Andric MB = getMatrix(T, ShapeInfo(C, M), Builder);
1895fe6060f1SDimitry Andric Transpose = B;
1896fe6060f1SDimitry Andric } else {
1897fe6060f1SDimitry Andric MA = getMatrix(T, ShapeInfo(R, M), Builder);
1898fe6060f1SDimitry Andric MB = getMatrix(B, ShapeInfo(C, M), Builder);
1899fe6060f1SDimitry Andric Transpose = A;
1900fe6060f1SDimitry Andric }
1901fe6060f1SDimitry Andric
1902fe6060f1SDimitry Andric // Initialize the output
1903fe6060f1SDimitry Andric MatrixTy Result(R, C, EltType);
1904fe6060f1SDimitry Andric
1905fe6060f1SDimitry Andric emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1906fe6060f1SDimitry Andric getFastMathFlags(MatMul));
1907fe6060f1SDimitry Andric
1908fe6060f1SDimitry Andric FusedInsts.insert(MatMul);
1909fe6060f1SDimitry Andric if (Transpose->hasOneUse()) {
1910fe6060f1SDimitry Andric FusedInsts.insert(cast<Instruction>(Transpose));
1911fe6060f1SDimitry Andric ToRemove.push_back(cast<Instruction>(Transpose));
1912fe6060f1SDimitry Andric // TODO: add a fake entry for the folded instruction so that this is
1913fe6060f1SDimitry Andric // included in the expression in the remark.
1914fe6060f1SDimitry Andric Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1915fe6060f1SDimitry Andric }
1916fe6060f1SDimitry Andric finalizeLowering(MatMul, Result, Builder);
1917fe6060f1SDimitry Andric return;
1918fe6060f1SDimitry Andric }
1919fe6060f1SDimitry Andric
1920fe6060f1SDimitry Andric if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1921fe6060f1SDimitry Andric return;
1922fe6060f1SDimitry Andric
1923fe6060f1SDimitry Andric // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1924fe6060f1SDimitry Andric // since the single store user will be lowered as part of this.
1925fe6060f1SDimitry Andric auto *LoadOp0 = dyn_cast<LoadInst>(A);
1926fe6060f1SDimitry Andric auto *LoadOp1 = dyn_cast<LoadInst>(B);
19275ffd83dbSDimitry Andric auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
19285ffd83dbSDimitry Andric if (LoadOp0 && LoadOp1 && Store) {
19295ffd83dbSDimitry Andric // The store address must dominate the MatMul instruction, otherwise
19305ffd83dbSDimitry Andric // we create invalid IR.
1931fe6060f1SDimitry Andric SetVector<Value *> WorkList;
1932fe6060f1SDimitry Andric WorkList.insert(Store->getOperand(1));
1933fe6060f1SDimitry Andric SmallVector<Instruction *> ToHoist;
1934fe6060f1SDimitry Andric for (unsigned I = 0; I != WorkList.size(); ++I) {
1935fe6060f1SDimitry Andric Value *Current = WorkList[I];
1936fe6060f1SDimitry Andric auto *CurrI = dyn_cast<Instruction>(Current);
1937fe6060f1SDimitry Andric if (!CurrI)
1938fe6060f1SDimitry Andric continue;
1939fe6060f1SDimitry Andric if (isa<PHINode>(CurrI))
19405ffd83dbSDimitry Andric return;
1941fe6060f1SDimitry Andric if (DT->dominates(CurrI, MatMul))
1942fe6060f1SDimitry Andric continue;
1943fe6060f1SDimitry Andric if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1944fe6060f1SDimitry Andric return;
1945fe6060f1SDimitry Andric ToHoist.push_back(CurrI);
1946fe6060f1SDimitry Andric WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1947fe6060f1SDimitry Andric }
1948fe6060f1SDimitry Andric
1949fe6060f1SDimitry Andric sort(ToHoist, [this](Instruction *A, Instruction *B) {
1950fe6060f1SDimitry Andric return DT->dominates(A, B);
1951fe6060f1SDimitry Andric });
1952fe6060f1SDimitry Andric for (Instruction *I : ToHoist)
1953fe6060f1SDimitry Andric I->moveBefore(MatMul);
19545ffd83dbSDimitry Andric
1955*0fca6ea1SDimitry Andric // Deal with lifetime.end calls that might be between Load0/Load1 and the
1956*0fca6ea1SDimitry Andric // store. To avoid introducing loads to dead objects (i.e. after the
1957*0fca6ea1SDimitry Andric // lifetime has been termined by @llvm.lifetime.end), either sink them
1958*0fca6ea1SDimitry Andric // after the store if in the same block, or remove the lifetime.end marker
1959*0fca6ea1SDimitry Andric // otherwise. This might pessimize further optimizations, by extending the
1960*0fca6ea1SDimitry Andric // lifetime of the object until the function returns, but should be
1961*0fca6ea1SDimitry Andric // conservatively correct.
1962*0fca6ea1SDimitry Andric MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
1963*0fca6ea1SDimitry Andric MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
1964*0fca6ea1SDimitry Andric BasicBlock *StoreParent = Store->getParent();
1965*0fca6ea1SDimitry Andric bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
1966*0fca6ea1SDimitry Andric LoadOp1->getParent() == StoreParent;
1967*0fca6ea1SDimitry Andric for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
1968*0fca6ea1SDimitry Andric IntrinsicInst *End = LifetimeEnds[Idx];
1969*0fca6ea1SDimitry Andric auto Inc = make_scope_exit([&Idx]() { Idx++; });
1970*0fca6ea1SDimitry Andric // If the lifetime.end is guaranteed to be before the loads or after the
1971*0fca6ea1SDimitry Andric // store, it won't interfere with fusion.
1972*0fca6ea1SDimitry Andric if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
1973*0fca6ea1SDimitry Andric continue;
1974*0fca6ea1SDimitry Andric if (DT->dominates(Store, End))
1975*0fca6ea1SDimitry Andric continue;
1976*0fca6ea1SDimitry Andric // If all fusable ops are in the same block and the lifetime.end is in a
1977*0fca6ea1SDimitry Andric // different block, it won't interfere with fusion.
1978*0fca6ea1SDimitry Andric if (FusableOpsInSameBlock && End->getParent() != StoreParent)
1979*0fca6ea1SDimitry Andric continue;
1980*0fca6ea1SDimitry Andric
1981*0fca6ea1SDimitry Andric // If the loads don't alias the lifetime.end, it won't interfere with
1982*0fca6ea1SDimitry Andric // fusion.
1983*0fca6ea1SDimitry Andric MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
1984*0fca6ea1SDimitry Andric if (!EndLoc.Ptr)
1985*0fca6ea1SDimitry Andric continue;
1986*0fca6ea1SDimitry Andric if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
1987*0fca6ea1SDimitry Andric continue;
1988*0fca6ea1SDimitry Andric
1989*0fca6ea1SDimitry Andric // If both lifetime.end and the store are in the same block, extend the
1990*0fca6ea1SDimitry Andric // lifetime until after the store, so the new lifetime covers the loads
1991*0fca6ea1SDimitry Andric // we introduce later.
1992*0fca6ea1SDimitry Andric if (End->getParent() == StoreParent) {
1993*0fca6ea1SDimitry Andric End->moveAfter(Store);
1994*0fca6ea1SDimitry Andric continue;
1995*0fca6ea1SDimitry Andric }
1996*0fca6ea1SDimitry Andric
1997*0fca6ea1SDimitry Andric // Otherwise remove the conflicting lifetime.end marker.
1998*0fca6ea1SDimitry Andric ToRemove.push_back(End);
1999*0fca6ea1SDimitry Andric std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
2000*0fca6ea1SDimitry Andric LifetimeEnds.pop_back();
2001*0fca6ea1SDimitry Andric Inc.release();
2002*0fca6ea1SDimitry Andric }
2003*0fca6ea1SDimitry Andric
20045ffd83dbSDimitry Andric emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
20055ffd83dbSDimitry Andric return;
20065ffd83dbSDimitry Andric }
20075ffd83dbSDimitry Andric }
20085ffd83dbSDimitry Andric
2009480093f4SDimitry Andric /// Lowers llvm.matrix.multiply.
LowerMultiply(CallInst * MatMul)2010480093f4SDimitry Andric void LowerMultiply(CallInst *MatMul) {
2011480093f4SDimitry Andric IRBuilder<> Builder(MatMul);
2012480093f4SDimitry Andric auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
2013480093f4SDimitry Andric ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2014480093f4SDimitry Andric ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2015480093f4SDimitry Andric
20165ffd83dbSDimitry Andric const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
20175ffd83dbSDimitry Andric const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2018e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Rhs.getElementType() &&
2019e8d8bef9SDimitry Andric "Matrix multiply argument element types do not match.");
2020480093f4SDimitry Andric
2021480093f4SDimitry Andric const unsigned R = LShape.NumRows;
2022480093f4SDimitry Andric const unsigned C = RShape.NumColumns;
20235ffd83dbSDimitry Andric assert(LShape.NumColumns == RShape.NumRows);
2024480093f4SDimitry Andric
2025480093f4SDimitry Andric // Initialize the output
20265ffd83dbSDimitry Andric MatrixTy Result(R, C, EltType);
2027e8d8bef9SDimitry Andric assert(Lhs.getElementType() == Result.getElementType() &&
2028e8d8bef9SDimitry Andric "Matrix multiply result element type does not match arguments.");
2029480093f4SDimitry Andric
2030fe6060f1SDimitry Andric emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
2031fe6060f1SDimitry Andric getFastMathFlags(MatMul));
2032480093f4SDimitry Andric finalizeLowering(MatMul, Result, Builder);
2033480093f4SDimitry Andric }
2034480093f4SDimitry Andric
2035480093f4SDimitry Andric /// Lowers llvm.matrix.transpose.
LowerTranspose(CallInst * Inst)2036480093f4SDimitry Andric void LowerTranspose(CallInst *Inst) {
20375ffd83dbSDimitry Andric MatrixTy Result;
2038480093f4SDimitry Andric IRBuilder<> Builder(Inst);
2039480093f4SDimitry Andric Value *InputVal = Inst->getArgOperand(0);
2040480093f4SDimitry Andric VectorType *VectorTy = cast<VectorType>(InputVal->getType());
2041480093f4SDimitry Andric ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
20425ffd83dbSDimitry Andric MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2043480093f4SDimitry Andric
20445ffd83dbSDimitry Andric const unsigned NewNumVecs =
20455ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
20465ffd83dbSDimitry Andric const unsigned NewNumElts =
20475ffd83dbSDimitry Andric InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2048480093f4SDimitry Andric
20495ffd83dbSDimitry Andric for (unsigned I = 0; I < NewNumVecs; ++I) {
20505ffd83dbSDimitry Andric // Build a single result vector. First initialize it.
205181ad6265SDimitry Andric Value *ResultVector = PoisonValue::get(
20525ffd83dbSDimitry Andric FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
20535ffd83dbSDimitry Andric // Go through the old elements and insert it into the resulting vector.
20545ffd83dbSDimitry Andric for (auto J : enumerate(InputMatrix.vectors())) {
20555ffd83dbSDimitry Andric Value *Elt = Builder.CreateExtractElement(J.value(), I);
20565ffd83dbSDimitry Andric // Row and column indices are transposed.
20575ffd83dbSDimitry Andric ResultVector =
20585ffd83dbSDimitry Andric Builder.CreateInsertElement(ResultVector, Elt, J.index());
2059480093f4SDimitry Andric }
20605ffd83dbSDimitry Andric Result.addVector(ResultVector);
2061480093f4SDimitry Andric }
2062480093f4SDimitry Andric
20635ffd83dbSDimitry Andric // TODO: Improve estimate of operations needed for transposes. Currently we
20645ffd83dbSDimitry Andric // just count the insertelement/extractelement instructions, but do not
20655ffd83dbSDimitry Andric // account for later simplifications/combines.
20665ffd83dbSDimitry Andric finalizeLowering(
20675ffd83dbSDimitry Andric Inst,
2068fe6060f1SDimitry Andric Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2069fe6060f1SDimitry Andric .addNumExposedTransposes(1),
20705ffd83dbSDimitry Andric Builder);
2071480093f4SDimitry Andric }
2072480093f4SDimitry Andric
2073480093f4SDimitry Andric /// Lower load instructions, if shape information is available.
VisitLoad(LoadInst * Inst,Value * Ptr,IRBuilder<> & Builder)20745ffd83dbSDimitry Andric bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2075480093f4SDimitry Andric auto I = ShapeMap.find(Inst);
2076480093f4SDimitry Andric if (I == ShapeMap.end())
2077480093f4SDimitry Andric return false;
2078480093f4SDimitry Andric
20795ffd83dbSDimitry Andric LowerLoad(Inst, Ptr, Inst->getAlign(),
20805ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
20815ffd83dbSDimitry Andric I->second);
2082480093f4SDimitry Andric return true;
2083480093f4SDimitry Andric }
2084480093f4SDimitry Andric
VisitStore(StoreInst * Inst,Value * StoredVal,Value * Ptr,IRBuilder<> & Builder)20855ffd83dbSDimitry Andric bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
2086480093f4SDimitry Andric IRBuilder<> &Builder) {
2087480093f4SDimitry Andric auto I = ShapeMap.find(StoredVal);
2088480093f4SDimitry Andric if (I == ShapeMap.end())
2089480093f4SDimitry Andric return false;
2090480093f4SDimitry Andric
20915ffd83dbSDimitry Andric LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
20925ffd83dbSDimitry Andric Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
20935ffd83dbSDimitry Andric I->second);
2094480093f4SDimitry Andric return true;
2095480093f4SDimitry Andric }
2096480093f4SDimitry Andric
2097480093f4SDimitry Andric /// Lower binary operators, if shape information is available.
VisitBinaryOperator(BinaryOperator * Inst)2098480093f4SDimitry Andric bool VisitBinaryOperator(BinaryOperator *Inst) {
2099480093f4SDimitry Andric auto I = ShapeMap.find(Inst);
2100480093f4SDimitry Andric if (I == ShapeMap.end())
2101480093f4SDimitry Andric return false;
2102480093f4SDimitry Andric
2103480093f4SDimitry Andric Value *Lhs = Inst->getOperand(0);
2104480093f4SDimitry Andric Value *Rhs = Inst->getOperand(1);
2105480093f4SDimitry Andric
2106480093f4SDimitry Andric IRBuilder<> Builder(Inst);
2107480093f4SDimitry Andric ShapeInfo &Shape = I->second;
2108480093f4SDimitry Andric
21095ffd83dbSDimitry Andric MatrixTy Result;
21105ffd83dbSDimitry Andric MatrixTy A = getMatrix(Lhs, Shape, Builder);
21115ffd83dbSDimitry Andric MatrixTy B = getMatrix(Rhs, Shape, Builder);
21125ffd83dbSDimitry Andric assert(A.isColumnMajor() == B.isColumnMajor() &&
21135ffd83dbSDimitry Andric Result.isColumnMajor() == A.isColumnMajor() &&
21145ffd83dbSDimitry Andric "operands must agree on matrix layout");
2115480093f4SDimitry Andric
2116fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst));
2117fe6060f1SDimitry Andric
21185ffd83dbSDimitry Andric // Helper to perform binary op on vectors.
21195ffd83dbSDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2120480093f4SDimitry Andric switch (Inst->getOpcode()) {
2121480093f4SDimitry Andric case Instruction::Add:
2122480093f4SDimitry Andric return Builder.CreateAdd(LHS, RHS);
2123480093f4SDimitry Andric case Instruction::Mul:
2124480093f4SDimitry Andric return Builder.CreateMul(LHS, RHS);
2125480093f4SDimitry Andric case Instruction::Sub:
2126480093f4SDimitry Andric return Builder.CreateSub(LHS, RHS);
2127480093f4SDimitry Andric case Instruction::FAdd:
2128480093f4SDimitry Andric return Builder.CreateFAdd(LHS, RHS);
2129480093f4SDimitry Andric case Instruction::FMul:
2130480093f4SDimitry Andric return Builder.CreateFMul(LHS, RHS);
2131480093f4SDimitry Andric case Instruction::FSub:
2132480093f4SDimitry Andric return Builder.CreateFSub(LHS, RHS);
2133480093f4SDimitry Andric default:
2134480093f4SDimitry Andric llvm_unreachable("Unsupported binary operator for matrix");
2135480093f4SDimitry Andric }
2136480093f4SDimitry Andric };
2137480093f4SDimitry Andric
21385ffd83dbSDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
21395ffd83dbSDimitry Andric Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
21405ffd83dbSDimitry Andric
21415ffd83dbSDimitry Andric finalizeLowering(Inst,
21425ffd83dbSDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
21435ffd83dbSDimitry Andric Result.getNumVectors()),
21445ffd83dbSDimitry Andric Builder);
2145480093f4SDimitry Andric return true;
2146480093f4SDimitry Andric }
21475ffd83dbSDimitry Andric
2148e8d8bef9SDimitry Andric /// Lower unary operators, if shape information is available.
VisitUnaryOperator(UnaryOperator * Inst)2149e8d8bef9SDimitry Andric bool VisitUnaryOperator(UnaryOperator *Inst) {
2150e8d8bef9SDimitry Andric auto I = ShapeMap.find(Inst);
2151e8d8bef9SDimitry Andric if (I == ShapeMap.end())
2152e8d8bef9SDimitry Andric return false;
2153e8d8bef9SDimitry Andric
2154e8d8bef9SDimitry Andric Value *Op = Inst->getOperand(0);
2155e8d8bef9SDimitry Andric
2156e8d8bef9SDimitry Andric IRBuilder<> Builder(Inst);
2157e8d8bef9SDimitry Andric ShapeInfo &Shape = I->second;
2158e8d8bef9SDimitry Andric
2159e8d8bef9SDimitry Andric MatrixTy Result;
2160e8d8bef9SDimitry Andric MatrixTy M = getMatrix(Op, Shape, Builder);
2161e8d8bef9SDimitry Andric
2162fe6060f1SDimitry Andric Builder.setFastMathFlags(getFastMathFlags(Inst));
2163fe6060f1SDimitry Andric
2164e8d8bef9SDimitry Andric // Helper to perform unary op on vectors.
2165e8d8bef9SDimitry Andric auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2166e8d8bef9SDimitry Andric switch (Inst->getOpcode()) {
2167e8d8bef9SDimitry Andric case Instruction::FNeg:
2168e8d8bef9SDimitry Andric return Builder.CreateFNeg(Op);
2169e8d8bef9SDimitry Andric default:
2170e8d8bef9SDimitry Andric llvm_unreachable("Unsupported unary operator for matrix");
2171e8d8bef9SDimitry Andric }
2172e8d8bef9SDimitry Andric };
2173e8d8bef9SDimitry Andric
2174e8d8bef9SDimitry Andric for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2175e8d8bef9SDimitry Andric Result.addVector(BuildVectorOp(M.getVector(I)));
2176e8d8bef9SDimitry Andric
2177e8d8bef9SDimitry Andric finalizeLowering(Inst,
2178e8d8bef9SDimitry Andric Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2179e8d8bef9SDimitry Andric Result.getNumVectors()),
2180e8d8bef9SDimitry Andric Builder);
2181e8d8bef9SDimitry Andric return true;
2182e8d8bef9SDimitry Andric }
2183e8d8bef9SDimitry Andric
21845ffd83dbSDimitry Andric /// Helper to linearize a matrix expression tree into a string. Currently
21855ffd83dbSDimitry Andric /// matrix expressions are linarized by starting at an expression leaf and
21865ffd83dbSDimitry Andric /// linearizing bottom up.
21875ffd83dbSDimitry Andric struct ExprLinearizer {
21885ffd83dbSDimitry Andric unsigned LengthToBreak = 100;
21895ffd83dbSDimitry Andric std::string Str;
21905ffd83dbSDimitry Andric raw_string_ostream Stream;
21915ffd83dbSDimitry Andric unsigned LineLength = 0;
21925ffd83dbSDimitry Andric const DataLayout &DL;
21935ffd83dbSDimitry Andric
21945ffd83dbSDimitry Andric /// Mapping from instructions to matrixes. It is used to identify
21955ffd83dbSDimitry Andric /// matrix instructions.
21965ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix;
21975ffd83dbSDimitry Andric
21985ffd83dbSDimitry Andric /// Mapping from values to the leaves of all expressions that the value is
21995ffd83dbSDimitry Andric /// part of.
22005ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
22015ffd83dbSDimitry Andric
22025ffd83dbSDimitry Andric /// Set of matrix expressions in the scope of a given DISubprogram.
22035ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram;
22045ffd83dbSDimitry Andric
22055ffd83dbSDimitry Andric /// Leaf node of the expression to linearize.
22065ffd83dbSDimitry Andric Value *Leaf;
22075ffd83dbSDimitry Andric
22085ffd83dbSDimitry Andric /// Used to keep track of sub-expressions that get reused while linearizing
22095ffd83dbSDimitry Andric /// the expression. Re-used sub-expressions are marked as (reused).
22105ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs;
22115ffd83dbSDimitry Andric
ExprLinearizer__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22125ffd83dbSDimitry Andric ExprLinearizer(const DataLayout &DL,
22135ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix,
22145ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
22155ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram,
22165ffd83dbSDimitry Andric Value *Leaf)
221704eeddc0SDimitry Andric : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
22185ffd83dbSDimitry Andric ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
22195ffd83dbSDimitry Andric
indent__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22205ffd83dbSDimitry Andric void indent(unsigned N) {
22215ffd83dbSDimitry Andric LineLength += N;
22225ffd83dbSDimitry Andric for (unsigned i = 0; i < N; i++)
22235ffd83dbSDimitry Andric Stream << " ";
22245ffd83dbSDimitry Andric }
22255ffd83dbSDimitry Andric
lineBreak__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22265ffd83dbSDimitry Andric void lineBreak() {
22275ffd83dbSDimitry Andric Stream << "\n";
22285ffd83dbSDimitry Andric LineLength = 0;
22295ffd83dbSDimitry Andric }
22305ffd83dbSDimitry Andric
maybeIndent__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22315ffd83dbSDimitry Andric void maybeIndent(unsigned Indent) {
22325ffd83dbSDimitry Andric if (LineLength >= LengthToBreak)
22335ffd83dbSDimitry Andric lineBreak();
22345ffd83dbSDimitry Andric
22355ffd83dbSDimitry Andric if (LineLength == 0)
22365ffd83dbSDimitry Andric indent(Indent);
22375ffd83dbSDimitry Andric }
22385ffd83dbSDimitry Andric
write__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22395ffd83dbSDimitry Andric void write(StringRef S) {
22405ffd83dbSDimitry Andric LineLength += S.size();
22415ffd83dbSDimitry Andric Stream << S;
22425ffd83dbSDimitry Andric }
22435ffd83dbSDimitry Andric
getUnderlyingObjectThroughLoads__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22445ffd83dbSDimitry Andric Value *getUnderlyingObjectThroughLoads(Value *V) {
22455ffd83dbSDimitry Andric if (Value *Ptr = getPointerOperand(V))
22465ffd83dbSDimitry Andric return getUnderlyingObjectThroughLoads(Ptr);
22475ffd83dbSDimitry Andric else if (V->getType()->isPointerTy())
2248e8d8bef9SDimitry Andric return getUnderlyingObject(V);
22495ffd83dbSDimitry Andric return V;
22505ffd83dbSDimitry Andric }
22515ffd83dbSDimitry Andric
22525ffd83dbSDimitry Andric /// Returns true if \p V is a matrix value in the given subprogram.
isMatrix__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22535ffd83dbSDimitry Andric bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
22545ffd83dbSDimitry Andric
22555f757f3fSDimitry Andric /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
22565ffd83dbSDimitry Andric /// \p SS.
prettyPrintMatrixType__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22575ffd83dbSDimitry Andric void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
22585ffd83dbSDimitry Andric auto M = Inst2Matrix.find(V);
22595ffd83dbSDimitry Andric if (M == Inst2Matrix.end())
22605ffd83dbSDimitry Andric SS << "unknown";
22615ffd83dbSDimitry Andric else {
22625ffd83dbSDimitry Andric SS << M->second.getNumRows();
22635ffd83dbSDimitry Andric SS << "x";
22645ffd83dbSDimitry Andric SS << M->second.getNumColumns();
22655ffd83dbSDimitry Andric }
22665ffd83dbSDimitry Andric }
22675ffd83dbSDimitry Andric
22685ffd83dbSDimitry Andric /// Write the called function name. Handles calls to llvm.matrix.*
22695ffd83dbSDimitry Andric /// specially: we write the name, followed by the dimensions of the input
22705ffd83dbSDimitry Andric /// matrixes, followed by the scalar type name.
writeFnName__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer22715ffd83dbSDimitry Andric void writeFnName(CallInst *CI) {
22725ffd83dbSDimitry Andric if (!CI->getCalledFunction())
22735ffd83dbSDimitry Andric write("<no called fn>");
22745ffd83dbSDimitry Andric else {
22755ffd83dbSDimitry Andric StringRef Name = CI->getCalledFunction()->getName();
22765f757f3fSDimitry Andric if (!Name.starts_with("llvm.matrix")) {
22775ffd83dbSDimitry Andric write(Name);
22785ffd83dbSDimitry Andric return;
22795ffd83dbSDimitry Andric }
228004eeddc0SDimitry Andric auto *II = cast<IntrinsicInst>(CI);
2281fe6060f1SDimitry Andric write(Intrinsic::getBaseName(II->getIntrinsicID())
22825ffd83dbSDimitry Andric .drop_front(StringRef("llvm.matrix.").size()));
22835ffd83dbSDimitry Andric write(".");
2284e8d8bef9SDimitry Andric std::string Tmp;
22855ffd83dbSDimitry Andric raw_string_ostream SS(Tmp);
22865ffd83dbSDimitry Andric
22875ffd83dbSDimitry Andric switch (II->getIntrinsicID()) {
22885ffd83dbSDimitry Andric case Intrinsic::matrix_multiply:
22895ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS);
22905ffd83dbSDimitry Andric SS << ".";
22915ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(1), SS);
22925ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType();
22935ffd83dbSDimitry Andric break;
22945ffd83dbSDimitry Andric case Intrinsic::matrix_transpose:
22955ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS);
22965ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType();
22975ffd83dbSDimitry Andric break;
22985ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load:
22995ffd83dbSDimitry Andric prettyPrintMatrixType(II, SS);
23005ffd83dbSDimitry Andric SS << "." << *II->getType()->getScalarType();
23015ffd83dbSDimitry Andric break;
23025ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store:
23035ffd83dbSDimitry Andric prettyPrintMatrixType(II->getOperand(0), SS);
23045ffd83dbSDimitry Andric SS << "." << *II->getOperand(0)->getType()->getScalarType();
23055ffd83dbSDimitry Andric break;
23065ffd83dbSDimitry Andric default:
23075ffd83dbSDimitry Andric llvm_unreachable("Unhandled case");
23085ffd83dbSDimitry Andric }
23095ffd83dbSDimitry Andric SS.flush();
23105ffd83dbSDimitry Andric write(Tmp);
23115ffd83dbSDimitry Andric }
23125ffd83dbSDimitry Andric }
23135ffd83dbSDimitry Andric
getNumShapeArgs__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer23145ffd83dbSDimitry Andric unsigned getNumShapeArgs(CallInst *CI) const {
23155ffd83dbSDimitry Andric if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
23165ffd83dbSDimitry Andric switch (II->getIntrinsicID()) {
23175ffd83dbSDimitry Andric case Intrinsic::matrix_multiply:
23185ffd83dbSDimitry Andric return 3;
23195ffd83dbSDimitry Andric case Intrinsic::matrix_transpose:
23205ffd83dbSDimitry Andric return 2;
23215ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_load:
23225ffd83dbSDimitry Andric case Intrinsic::matrix_column_major_store:
23235ffd83dbSDimitry Andric return 3;
23245ffd83dbSDimitry Andric default:
23255ffd83dbSDimitry Andric return 0;
23265ffd83dbSDimitry Andric }
23275ffd83dbSDimitry Andric }
23285ffd83dbSDimitry Andric return 0;
23295ffd83dbSDimitry Andric }
23305ffd83dbSDimitry Andric
23315ffd83dbSDimitry Andric /// Special printing for values: for pointers, we print if they refer to an
23325ffd83dbSDimitry Andric /// (function) external address or a stack address, for other values we
23335ffd83dbSDimitry Andric /// either print the constant or "scalar"/"matrix" for other values.
write__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer23345ffd83dbSDimitry Andric void write(Value *V) {
23355ffd83dbSDimitry Andric V = getUnderlyingObjectThroughLoads(V);
23365ffd83dbSDimitry Andric if (V->getType()->isPointerTy()) {
23375ffd83dbSDimitry Andric if (isa<AllocaInst>(V)) {
23385ffd83dbSDimitry Andric Stream << "stack addr";
23395ffd83dbSDimitry Andric LineLength += StringRef("stack addr").size();
23405ffd83dbSDimitry Andric } else {
23415ffd83dbSDimitry Andric Stream << "addr";
23425ffd83dbSDimitry Andric LineLength += StringRef("addr").size();
23435ffd83dbSDimitry Andric }
23445ffd83dbSDimitry Andric if (!V->getName().empty()) {
23455ffd83dbSDimitry Andric Stream << " %" << V->getName() << "";
23465ffd83dbSDimitry Andric LineLength += V->getName().size() + 2;
23475ffd83dbSDimitry Andric }
23485ffd83dbSDimitry Andric return;
23495ffd83dbSDimitry Andric }
23505ffd83dbSDimitry Andric
23515ffd83dbSDimitry Andric std::string Tmp;
23525ffd83dbSDimitry Andric raw_string_ostream TmpStream(Tmp);
23535ffd83dbSDimitry Andric
23545ffd83dbSDimitry Andric if (auto *CI = dyn_cast<ConstantInt>(V))
23555ffd83dbSDimitry Andric TmpStream << CI->getValue();
23565ffd83dbSDimitry Andric else if (isa<Constant>(V))
23575ffd83dbSDimitry Andric TmpStream << "constant";
23585ffd83dbSDimitry Andric else {
23595ffd83dbSDimitry Andric if (isMatrix(V))
23605ffd83dbSDimitry Andric TmpStream << "matrix";
23615ffd83dbSDimitry Andric else
23625ffd83dbSDimitry Andric TmpStream << "scalar";
23635ffd83dbSDimitry Andric }
23645ffd83dbSDimitry Andric TmpStream.flush();
23655ffd83dbSDimitry Andric Tmp = std::string(StringRef(Tmp).trim());
23665ffd83dbSDimitry Andric LineLength += Tmp.size();
23675ffd83dbSDimitry Andric Stream << Tmp;
23685ffd83dbSDimitry Andric }
23695ffd83dbSDimitry Andric
23705ffd83dbSDimitry Andric /// Linearize expression \p Expr starting at an indentation of \p Indent.
23715ffd83dbSDimitry Andric /// Expressions that are re-used multiple times are prefixed with (reused)
23725ffd83dbSDimitry Andric /// at the re-used root instruction.
linearizeExpr__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer23735ffd83dbSDimitry Andric void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
23745ffd83dbSDimitry Andric bool ParentShared) {
23755ffd83dbSDimitry Andric auto *I = cast<Instruction>(Expr);
23765ffd83dbSDimitry Andric maybeIndent(Indent);
23775ffd83dbSDimitry Andric SmallVector<Value *, 8> Ops;
23785ffd83dbSDimitry Andric
23795ffd83dbSDimitry Andric // Is Expr shared with other expression leaves?
23805ffd83dbSDimitry Andric bool ExprShared = false;
23815ffd83dbSDimitry Andric
23825ffd83dbSDimitry Andric // Deal with shared subtrees. Mark them as shared, if required.
23835ffd83dbSDimitry Andric if (!ParentShared) {
23845ffd83dbSDimitry Andric auto SI = Shared.find(Expr);
23855ffd83dbSDimitry Andric assert(SI != Shared.end() && SI->second.count(Leaf));
23865ffd83dbSDimitry Andric
23875ffd83dbSDimitry Andric for (Value *S : SI->second) {
23885ffd83dbSDimitry Andric if (S == Leaf)
23895ffd83dbSDimitry Andric continue;
23905ffd83dbSDimitry Andric DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
23915ffd83dbSDimitry Andric write("shared with remark at line " + std::to_string(DL.getLine()) +
23925ffd83dbSDimitry Andric " column " + std::to_string(DL.getCol()) + " (");
23935ffd83dbSDimitry Andric }
23945ffd83dbSDimitry Andric ExprShared = SI->second.size() > 1;
23955ffd83dbSDimitry Andric }
23965ffd83dbSDimitry Andric
23975ffd83dbSDimitry Andric bool Reused = !ReusedExprs.insert(Expr).second;
23985ffd83dbSDimitry Andric if (Reused && !ParentReused)
23995ffd83dbSDimitry Andric write("(reused) ");
24005ffd83dbSDimitry Andric
24015ffd83dbSDimitry Andric if (auto *CI = dyn_cast<CallInst>(I)) {
24025ffd83dbSDimitry Andric writeFnName(CI);
24035ffd83dbSDimitry Andric
24045ffd83dbSDimitry Andric Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
24055ffd83dbSDimitry Andric } else if (isa<BitCastInst>(Expr)) {
24065ffd83dbSDimitry Andric // Special case bitcasts, which are used to materialize matrixes from
24075ffd83dbSDimitry Andric // non-matrix ops.
24085ffd83dbSDimitry Andric write("matrix");
24095ffd83dbSDimitry Andric return;
24105ffd83dbSDimitry Andric } else {
24115ffd83dbSDimitry Andric Ops.append(I->value_op_begin(), I->value_op_end());
24125ffd83dbSDimitry Andric write(std::string(I->getOpcodeName()));
24135ffd83dbSDimitry Andric }
24145ffd83dbSDimitry Andric
24155ffd83dbSDimitry Andric write(std::string("("));
24165ffd83dbSDimitry Andric
24175ffd83dbSDimitry Andric unsigned NumOpsToBreak = 1;
24185ffd83dbSDimitry Andric if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
24195ffd83dbSDimitry Andric NumOpsToBreak = 2;
24205ffd83dbSDimitry Andric
24215ffd83dbSDimitry Andric for (Value *Op : Ops) {
24225ffd83dbSDimitry Andric if (Ops.size() > NumOpsToBreak)
24235ffd83dbSDimitry Andric lineBreak();
24245ffd83dbSDimitry Andric
24255ffd83dbSDimitry Andric maybeIndent(Indent + 1);
24265ffd83dbSDimitry Andric if (isMatrix(Op))
24275ffd83dbSDimitry Andric linearizeExpr(Op, Indent + 1, Reused, ExprShared);
24285ffd83dbSDimitry Andric else
24295ffd83dbSDimitry Andric write(Op);
24305ffd83dbSDimitry Andric if (Op != Ops.back())
24315ffd83dbSDimitry Andric write(", ");
24325ffd83dbSDimitry Andric }
24335ffd83dbSDimitry Andric
24345ffd83dbSDimitry Andric write(")");
24355ffd83dbSDimitry Andric }
24365ffd83dbSDimitry Andric
getResult__anon821fcdb70111::LowerMatrixIntrinsics::ExprLinearizer24375ffd83dbSDimitry Andric const std::string &getResult() {
24385ffd83dbSDimitry Andric Stream.flush();
24395ffd83dbSDimitry Andric return Str;
24405ffd83dbSDimitry Andric }
24415ffd83dbSDimitry Andric };
24425ffd83dbSDimitry Andric
24435ffd83dbSDimitry Andric /// Generate remarks for matrix operations in a function. To generate remarks
24445ffd83dbSDimitry Andric /// for matrix expressions, the following approach is used:
24455ffd83dbSDimitry Andric /// 1. Use the inlined-at debug information to group matrix operations to the
24465ffd83dbSDimitry Andric /// DISubprograms they are contained in.
24475ffd83dbSDimitry Andric /// 2. Collect leaves of matrix expressions (done in
24485ffd83dbSDimitry Andric /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
24495ffd83dbSDimitry Andric // mapping. Leaves are lowered matrix instructions without other matrix
24505ffd83dbSDimitry Andric // users (like stores) in the current subprogram.
24515ffd83dbSDimitry Andric /// 3. For each leaf, create a remark containing a linearizied version of the
24525ffd83dbSDimitry Andric /// matrix expression. The expression is linearized by a recursive
24535ffd83dbSDimitry Andric /// bottom-up traversal of the matrix operands, starting at a leaf. Note
24545ffd83dbSDimitry Andric /// that multiple leaves can share sub-expressions. Shared subexpressions
24555ffd83dbSDimitry Andric /// are explicitly marked as shared().
24565ffd83dbSDimitry Andric struct RemarkGenerator {
24575ffd83dbSDimitry Andric const MapVector<Value *, MatrixTy> &Inst2Matrix;
24585ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE;
24595ffd83dbSDimitry Andric Function &Func;
24605ffd83dbSDimitry Andric const DataLayout &DL;
24615ffd83dbSDimitry Andric
RemarkGenerator__anon821fcdb70111::LowerMatrixIntrinsics::RemarkGenerator24625ffd83dbSDimitry Andric RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
24635ffd83dbSDimitry Andric OptimizationRemarkEmitter &ORE, Function &Func)
24645ffd83dbSDimitry Andric : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2465*0fca6ea1SDimitry Andric DL(Func.getDataLayout()) {}
24665ffd83dbSDimitry Andric
24675ffd83dbSDimitry Andric /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
24685ffd83dbSDimitry Andric /// instructions in Inst2Matrix returning void or without any users in
24695ffd83dbSDimitry Andric /// \p ExprsInSubprogram. Currently that should only include stores.
24705ffd83dbSDimitry Andric SmallVector<Value *, 4>
getExpressionLeaves__anon821fcdb70111::LowerMatrixIntrinsics::RemarkGenerator24715ffd83dbSDimitry Andric getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
24725ffd83dbSDimitry Andric SmallVector<Value *, 4> Leaves;
24735ffd83dbSDimitry Andric for (auto *Expr : ExprsInSubprogram)
24745ffd83dbSDimitry Andric if (Expr->getType()->isVoidTy() ||
24755ffd83dbSDimitry Andric !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
24765ffd83dbSDimitry Andric return ExprsInSubprogram.count(U);
24775ffd83dbSDimitry Andric }))
24785ffd83dbSDimitry Andric Leaves.push_back(Expr);
24795ffd83dbSDimitry Andric return Leaves;
24805ffd83dbSDimitry Andric }
24815ffd83dbSDimitry Andric
24825ffd83dbSDimitry Andric /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
24835ffd83dbSDimitry Andric /// to all visited expressions in \p Shared. Limit the matrix operations to
24845ffd83dbSDimitry Andric /// the ones in \p ExprsInSubprogram.
collectSharedInfo__anon821fcdb70111::LowerMatrixIntrinsics::RemarkGenerator24855ffd83dbSDimitry Andric void collectSharedInfo(Value *Leaf, Value *V,
24865ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram,
24875ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
24885ffd83dbSDimitry Andric
24895ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(V))
24905ffd83dbSDimitry Andric return;
24915ffd83dbSDimitry Andric
24925ffd83dbSDimitry Andric auto I = Shared.insert({V, {}});
24935ffd83dbSDimitry Andric I.first->second.insert(Leaf);
24945ffd83dbSDimitry Andric
24955ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(V)->operand_values())
24965ffd83dbSDimitry Andric collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
24975ffd83dbSDimitry Andric }
24985ffd83dbSDimitry Andric
24995ffd83dbSDimitry Andric /// Calculate the number of exclusive and shared op counts for expression
25005ffd83dbSDimitry Andric /// starting at \p V. Expressions used multiple times are counted once.
25015ffd83dbSDimitry Andric /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
25025ffd83dbSDimitry Andric std::pair<OpInfoTy, OpInfoTy>
sumOpInfos__anon821fcdb70111::LowerMatrixIntrinsics::RemarkGenerator25035ffd83dbSDimitry Andric sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
25045ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram,
25055ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
25065ffd83dbSDimitry Andric if (!ExprsInSubprogram.count(Root))
25075ffd83dbSDimitry Andric return {};
25085ffd83dbSDimitry Andric
25095ffd83dbSDimitry Andric // Already counted this expression. Stop.
25105ffd83dbSDimitry Andric if (!ReusedExprs.insert(Root).second)
25115ffd83dbSDimitry Andric return {};
25125ffd83dbSDimitry Andric
25135ffd83dbSDimitry Andric OpInfoTy SharedCount;
25145ffd83dbSDimitry Andric OpInfoTy Count;
25155ffd83dbSDimitry Andric
25165ffd83dbSDimitry Andric auto I = Shared.find(Root);
25175ffd83dbSDimitry Andric auto CM = Inst2Matrix.find(Root);
25185ffd83dbSDimitry Andric if (I->second.size() == 1)
25195ffd83dbSDimitry Andric Count = CM->second.getOpInfo();
25205ffd83dbSDimitry Andric else
25215ffd83dbSDimitry Andric SharedCount = CM->second.getOpInfo();
25225ffd83dbSDimitry Andric
25235ffd83dbSDimitry Andric for (Value *Op : cast<Instruction>(Root)->operand_values()) {
25245ffd83dbSDimitry Andric auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
25255ffd83dbSDimitry Andric Count += C.first;
25265ffd83dbSDimitry Andric SharedCount += C.second;
25275ffd83dbSDimitry Andric }
25285ffd83dbSDimitry Andric return {Count, SharedCount};
25295ffd83dbSDimitry Andric }
25305ffd83dbSDimitry Andric
emitRemarks__anon821fcdb70111::LowerMatrixIntrinsics::RemarkGenerator25315ffd83dbSDimitry Andric void emitRemarks() {
25325ffd83dbSDimitry Andric if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
25335ffd83dbSDimitry Andric return;
25345ffd83dbSDimitry Andric
25355ffd83dbSDimitry Andric // Map matrix operations to their containting subprograms, by traversing
25365ffd83dbSDimitry Andric // the inlinedAt chain. If the function does not have a DISubprogram, we
25375ffd83dbSDimitry Andric // only map them to the containing function.
25385ffd83dbSDimitry Andric MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
2539bdd1243dSDimitry Andric for (const auto &KV : Inst2Matrix) {
25405ffd83dbSDimitry Andric if (Func.getSubprogram()) {
25415ffd83dbSDimitry Andric auto *I = cast<Instruction>(KV.first);
25425ffd83dbSDimitry Andric DILocation *Context = I->getDebugLoc();
25435ffd83dbSDimitry Andric while (Context) {
25445ffd83dbSDimitry Andric auto I =
25455ffd83dbSDimitry Andric Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
25465ffd83dbSDimitry Andric I.first->second.push_back(KV.first);
25475ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt();
25485ffd83dbSDimitry Andric }
25495ffd83dbSDimitry Andric } else {
25505ffd83dbSDimitry Andric auto I = Subprog2Exprs.insert({nullptr, {}});
25515ffd83dbSDimitry Andric I.first->second.push_back(KV.first);
25525ffd83dbSDimitry Andric }
25535ffd83dbSDimitry Andric }
25545ffd83dbSDimitry Andric for (auto &KV : Subprog2Exprs) {
25555ffd83dbSDimitry Andric SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
25565ffd83dbSDimitry Andric KV.second.end());
25575ffd83dbSDimitry Andric auto Leaves = getExpressionLeaves(ExprsInSubprogram);
25585ffd83dbSDimitry Andric
25595ffd83dbSDimitry Andric DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
25605ffd83dbSDimitry Andric for (Value *Leaf : Leaves)
25615ffd83dbSDimitry Andric collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
25625ffd83dbSDimitry Andric
25635ffd83dbSDimitry Andric // Generate remarks for each leaf.
25645ffd83dbSDimitry Andric for (auto *L : Leaves) {
25655ffd83dbSDimitry Andric
25665ffd83dbSDimitry Andric DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
25675ffd83dbSDimitry Andric DILocation *Context = cast<Instruction>(L)->getDebugLoc();
25685ffd83dbSDimitry Andric while (Context) {
25695ffd83dbSDimitry Andric if (getSubprogram(Context->getScope()) == KV.first) {
25705ffd83dbSDimitry Andric Loc = Context;
25715ffd83dbSDimitry Andric break;
25725ffd83dbSDimitry Andric }
25735ffd83dbSDimitry Andric Context = DebugLoc(Context).getInlinedAt();
25745ffd83dbSDimitry Andric }
25755ffd83dbSDimitry Andric
25765ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> ReusedExprs;
25775ffd83dbSDimitry Andric OpInfoTy Counts, SharedCounts;
25785ffd83dbSDimitry Andric std::tie(Counts, SharedCounts) =
25795ffd83dbSDimitry Andric sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
25805ffd83dbSDimitry Andric
25815ffd83dbSDimitry Andric OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
25825ffd83dbSDimitry Andric cast<Instruction>(L)->getParent());
25835ffd83dbSDimitry Andric
25845ffd83dbSDimitry Andric Rem << "Lowered with ";
25855ffd83dbSDimitry Andric Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
25865ffd83dbSDimitry Andric << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
25875ffd83dbSDimitry Andric << ore::NV("NumComputeOps", Counts.NumComputeOps)
2588fe6060f1SDimitry Andric << " compute ops, "
2589fe6060f1SDimitry Andric << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2590fe6060f1SDimitry Andric << " exposed transposes";
25915ffd83dbSDimitry Andric
25925ffd83dbSDimitry Andric if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
25935ffd83dbSDimitry Andric SharedCounts.NumComputeOps > 0) {
25945ffd83dbSDimitry Andric Rem << ",\nadditionally "
25955ffd83dbSDimitry Andric << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
25965ffd83dbSDimitry Andric << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
25975ffd83dbSDimitry Andric << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
25985ffd83dbSDimitry Andric << " compute ops"
25995ffd83dbSDimitry Andric << " are shared with other expressions";
26005ffd83dbSDimitry Andric }
26015ffd83dbSDimitry Andric
26025ffd83dbSDimitry Andric Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
26035ffd83dbSDimitry Andric ORE.emit(Rem);
26045ffd83dbSDimitry Andric }
26055ffd83dbSDimitry Andric }
26065ffd83dbSDimitry Andric }
26075ffd83dbSDimitry Andric
26085ffd83dbSDimitry Andric std::string
linearize__anon821fcdb70111::LowerMatrixIntrinsics::RemarkGenerator26095ffd83dbSDimitry Andric linearize(Value *L,
26105ffd83dbSDimitry Andric const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
26115ffd83dbSDimitry Andric const SmallSetVector<Value *, 32> &ExprsInSubprogram,
26125ffd83dbSDimitry Andric const DataLayout &DL) {
26135ffd83dbSDimitry Andric ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
26145ffd83dbSDimitry Andric Lin.linearizeExpr(L, 0, false, false);
26155ffd83dbSDimitry Andric return Lin.getResult();
26165ffd83dbSDimitry Andric }
26175ffd83dbSDimitry Andric };
2618480093f4SDimitry Andric };
2619480093f4SDimitry Andric } // namespace
2620480093f4SDimitry Andric
run(Function & F,FunctionAnalysisManager & AM)2621480093f4SDimitry Andric PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
2622480093f4SDimitry Andric FunctionAnalysisManager &AM) {
2623480093f4SDimitry Andric auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2624e8d8bef9SDimitry Andric OptimizationRemarkEmitter *ORE = nullptr;
2625e8d8bef9SDimitry Andric AAResults *AA = nullptr;
2626e8d8bef9SDimitry Andric DominatorTree *DT = nullptr;
2627e8d8bef9SDimitry Andric LoopInfo *LI = nullptr;
2628e8d8bef9SDimitry Andric
2629e8d8bef9SDimitry Andric if (!Minimal) {
2630e8d8bef9SDimitry Andric ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
2631e8d8bef9SDimitry Andric AA = &AM.getResult<AAManager>(F);
2632e8d8bef9SDimitry Andric DT = &AM.getResult<DominatorTreeAnalysis>(F);
2633e8d8bef9SDimitry Andric LI = &AM.getResult<LoopAnalysis>(F);
2634e8d8bef9SDimitry Andric }
26355ffd83dbSDimitry Andric
26365ffd83dbSDimitry Andric LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2637480093f4SDimitry Andric if (LMT.Visit()) {
2638480093f4SDimitry Andric PreservedAnalyses PA;
2639e8d8bef9SDimitry Andric if (!Minimal) {
2640e8d8bef9SDimitry Andric PA.preserve<LoopAnalysis>();
2641e8d8bef9SDimitry Andric PA.preserve<DominatorTreeAnalysis>();
2642e8d8bef9SDimitry Andric }
2643480093f4SDimitry Andric return PA;
2644480093f4SDimitry Andric }
2645480093f4SDimitry Andric return PreservedAnalyses::all();
2646480093f4SDimitry Andric }
2647480093f4SDimitry Andric
printPipeline(raw_ostream & OS,function_ref<StringRef (StringRef)> MapClassName2PassName)2648349cc55cSDimitry Andric void LowerMatrixIntrinsicsPass::printPipeline(
2649349cc55cSDimitry Andric raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2650349cc55cSDimitry Andric static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
2651349cc55cSDimitry Andric OS, MapClassName2PassName);
265206c3fb27SDimitry Andric OS << '<';
2653349cc55cSDimitry Andric if (Minimal)
2654349cc55cSDimitry Andric OS << "minimal";
265506c3fb27SDimitry Andric OS << '>';
2656e8d8bef9SDimitry Andric }
2657