xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- LowerMatrixIntrinsics.cpp -  Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 //  * Improve fusion:
13 //   * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14 //     transposed.
15 //   * Improve cost-modeling, e.g. choose different number of rows/columns
16 //     columns for tiles, consider cost of copies on alias.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/Analysis/AliasAnalysis.h"
28 #include "llvm/Analysis/DomTreeUpdater.h"
29 #include "llvm/Analysis/LoopInfo.h"
30 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
31 #include "llvm/Analysis/TargetTransformInfo.h"
32 #include "llvm/Analysis/ValueTracking.h"
33 #include "llvm/Analysis/VectorUtils.h"
34 #include "llvm/IR/CFG.h"
35 #include "llvm/IR/DataLayout.h"
36 #include "llvm/IR/DebugInfoMetadata.h"
37 #include "llvm/IR/DerivedTypes.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/IRBuilder.h"
40 #include "llvm/IR/InstrTypes.h"
41 #include "llvm/IR/Instructions.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/IR/MatrixBuilder.h"
44 #include "llvm/IR/PatternMatch.h"
45 #include "llvm/Support/Alignment.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/Compiler.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
50 #include "llvm/Transforms/Utils/LoopUtils.h"
51 #include "llvm/Transforms/Utils/MatrixUtils.h"
52 
53 #include <cmath>
54 
55 using namespace llvm;
56 using namespace PatternMatch;
57 
58 #define DEBUG_TYPE "lower-matrix-intrinsics"
59 
60 STATISTIC(FlattenedMatrices, "Number of matrix flattenings");
61 STATISTIC(ReshapedMatrices, "Number of matrix reshapes");
62 STATISTIC(SplitMatrices, "Number of matrix splits");
63 
64 static cl::opt<bool>
65     FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
66                cl::desc("Enable/disable fusing matrix instructions."));
67 // TODO: Allow and use non-square tiles.
68 static cl::opt<unsigned> TileSize(
69     "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
70     cl::desc(
71         "Tile size for matrix instruction fusion using square-shaped tiles."));
72 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
73                                   cl::Hidden,
74                                   cl::desc("Generate loop nest for tiling."));
75 static cl::opt<bool> ForceFusion(
76     "force-fuse-matrix", cl::init(false), cl::Hidden,
77     cl::desc("Force matrix instruction fusion even if not profitable."));
78 static cl::opt<bool> AllowContractEnabled(
79     "matrix-allow-contract", cl::init(false), cl::Hidden,
80     cl::desc("Allow the use of FMAs if available and profitable. This may "
81              "result in different results, due to less rounding error."));
82 
83 static cl::opt<bool>
84     VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
85                     cl::desc("Enable/disable matrix shape verification."),
86                     cl::init(false));
87 
88 enum class MatrixLayoutTy { ColumnMajor, RowMajor };
89 
90 static cl::opt<MatrixLayoutTy> MatrixLayout(
91     "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
92     cl::desc("Sets the default matrix layout"),
93     cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
94                           "Use column-major layout"),
95                clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
96                           "Use row-major layout")));
97 
98 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
99                                             cl::init(false));
100 
101 /// Helper function to either return Scope, if it is a subprogram or the
102 /// attached subprogram for a local scope.
103 static DISubprogram *getSubprogram(DIScope *Scope) {
104   if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
105     return Subprogram;
106   return cast<DILocalScope>(Scope)->getSubprogram();
107 }
108 
109 /// Return true if V is a splat of a value (which is used when multiplying a
110 /// matrix with a scalar).
111 static bool isSplat(Value *V) {
112   if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
113     return SV->isZeroEltSplat();
114   return false;
115 }
116 
117 /// Match any mul operation (fp or integer).
118 template <typename LTy, typename RTy>
119 auto m_AnyMul(const LTy &L, const RTy &R) {
120   return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
121 }
122 
123 /// Match any add operation (fp or integer).
124 template <typename LTy, typename RTy>
125 auto m_AnyAdd(const LTy &L, const RTy &R) {
126   return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
127 }
128 
129 namespace {
130 
131 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
132 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
133 // assuming \p Stride elements between start two consecutive vectors.
134 // \p Stride must be >= \p NumElements.
135 // For column-major matrixes, the function computes the address of a column
136 // vectors and \p NumElements must be set to the number of elements in a column
137 // (= number of rows of the matrix). For row-major matrixes, the function
138 // computes the address of a row vector and \p NumElements must be set to the
139 // number of elements in a column (= number of columns of the matrix).
140 //
141 // Consider a 4x4 matrix in column-mjaor layout like below
142 //
143 //      0       1      2      3
144 // 0   v_0_0  v_0_1  v_0_2  v_0_3
145 // 1   v_1_0  v_1_1  v_1_2  v_1_3
146 // 2   v_2_0  v_2_1  v_2_2  v_2_3
147 // 3   v_3_0  v_3_1  v_3_2  v_3_3
148 
149 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
150 // we need a pointer to the first element of the submatrix as base pointer.
151 // Then we can use computeVectorAddr to compute the addresses for the columns
152 // of the sub-matrix.
153 //
154 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
155 //           -> just returns Base
156 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
157 //           -> returns Base + (1 * 4)
158 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
159 //           -> returns Base + (2 * 4)
160 //
161 // The graphic below illustrates the number of elements in a column (marked
162 // with |) and the number of skipped elements (marked with }).
163 //
164 //         v_0_0  v_0_1 {v_0_2 {v_0_3
165 //                Base   Col 1  Col 2
166 //                  |     |      |
167 //         v_1_0 |v_1_1 |v_1_2 |v_1_3
168 //         v_2_0 |v_2_1 |v_2_2 |v_2_3
169 //         v_3_0 {v_3_1 {v_3_2  v_3_3
170 //
171 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
172                          unsigned NumElements, Type *EltType,
173                          IRBuilder<> &Builder) {
174 
175   assert((!isa<ConstantInt>(Stride) ||
176           cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
177          "Stride must be >= the number of elements in the result vector.");
178 
179   // Compute the start of the vector with index VecIdx as VecIdx * Stride.
180   Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
181 
182   // Get pointer to the start of the selected vector. Skip GEP creation,
183   // if we select vector 0.
184   if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
185     VecStart = BasePtr;
186   else
187     VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
188 
189   return VecStart;
190 }
191 
192 namespace {
193 struct ShapeInfo {
194   unsigned NumRows;
195   unsigned NumColumns;
196 
197   bool IsColumnMajor;
198 
199   ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
200       : NumRows(NumRows), NumColumns(NumColumns),
201         IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
202 
203   ShapeInfo(Value *NumRows, Value *NumColumns)
204       : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
205                   cast<ConstantInt>(NumColumns)->getZExtValue()) {}
206 
207   bool operator==(const ShapeInfo &other) {
208     return NumRows == other.NumRows && NumColumns == other.NumColumns;
209   }
210   bool operator!=(const ShapeInfo &other) { return !(*this == other); }
211 
212   /// Returns true if shape-information is defined, meaning both dimensions
213   /// are != 0.
214   operator bool() const {
215     assert(NumRows == 0 || NumColumns != 0);
216     return NumRows != 0;
217   }
218 
219   unsigned getStride() const {
220     if (IsColumnMajor)
221       return NumRows;
222     return NumColumns;
223   }
224 
225   unsigned getNumVectors() const {
226     if (IsColumnMajor)
227       return NumColumns;
228     return NumRows;
229   }
230 
231   /// Returns the transposed shape.
232   ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
233 
234   friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI);
235 
236   LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; }
237 };
238 
239 raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
240   return OS << SI.NumRows << 'x' << SI.NumColumns;
241 }
242 
243 } // namespace
244 
245 static bool isUniformShape(Value *V) {
246   Instruction *I = dyn_cast<Instruction>(V);
247   if (!I)
248     return true;
249 
250   if (I->isBinaryOp())
251     return true;
252 
253   if (auto *Cast = dyn_cast<CastInst>(V)) {
254     switch (Cast->getOpcode()) {
255     case llvm::Instruction::Trunc:
256     case llvm::Instruction::ZExt:
257     case llvm::Instruction::SExt:
258     case llvm::Instruction::FPToUI:
259     case llvm::Instruction::FPToSI:
260     case llvm::Instruction::UIToFP:
261     case llvm::Instruction::SIToFP:
262     case llvm::Instruction::FPTrunc:
263     case llvm::Instruction::FPExt:
264       return true;
265     case llvm::Instruction::AddrSpaceCast:
266     case CastInst::PtrToInt:
267     case CastInst::IntToPtr:
268       return false;
269     case CastInst::BitCast: {
270       if (auto *SrcVTy = dyn_cast<FixedVectorType>(Cast->getSrcTy()))
271         if (auto *DestVTy = dyn_cast<FixedVectorType>(Cast->getDestTy()))
272           return SrcVTy->getNumElements() == DestVTy->getNumElements();
273       return false;
274     }
275     case llvm::Instruction::CastOpsEnd:
276       llvm_unreachable("not an actual cast op");
277     }
278     llvm_unreachable("unhandled cast opcode");
279   }
280 
281   if (auto *II = dyn_cast<IntrinsicInst>(V))
282     switch (II->getIntrinsicID()) {
283     case Intrinsic::abs:
284     case Intrinsic::fabs:
285       return true;
286     default:
287       return false;
288     }
289 
290   switch (I->getOpcode()) {
291   case Instruction::PHI:
292   case Instruction::FNeg:
293     return true;
294   default:
295     return false;
296   }
297 }
298 
299 /// Return the ShapeInfo for the result of \p I, it it can be determined.
300 static std::optional<ShapeInfo>
301 computeShapeInfoForInst(Instruction *I,
302                         const DenseMap<Value *, ShapeInfo> &ShapeMap) {
303   Value *M;
304   Value *N;
305   Value *K;
306   if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
307                    m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
308     return ShapeInfo(M, K);
309   if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
310                                                         m_Value(N)))) {
311     // Flip dimensions.
312     return ShapeInfo(N, M);
313   }
314   if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
315                    m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
316                    m_Value(N))))
317     return ShapeInfo(N, M);
318   if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
319                    m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
320     return ShapeInfo(M, N);
321   Value *MatrixA;
322   if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
323     auto OpShape = ShapeMap.find(MatrixA);
324     if (OpShape != ShapeMap.end())
325       return OpShape->second;
326   }
327 
328   if (isUniformShape(I) || isa<SelectInst>(I)) {
329     auto Ops = I->operands();
330     auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
331     // Find the first operand that has a known shape and use that.
332     for (auto &Op : ShapedOps) {
333       auto OpShape = ShapeMap.find(Op.get());
334       if (OpShape != ShapeMap.end())
335         return OpShape->second;
336     }
337   }
338   return std::nullopt;
339 }
340 
341 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
342 ///
343 /// Currently, the lowering for each matrix intrinsic is done as follows:
344 /// 1. Propagate the shape information from intrinsics to connected
345 /// instructions.
346 /// 2. Lower instructions with shape information (assuming column-major layout).
347 ///  The lowering works similarly using row-major layout.
348 ///  2.1. Get column vectors for each argument. If we already lowered the
349 ///       definition of an argument, use the produced column vectors directly.
350 ///       If not, split the operand vector containing an embedded matrix into
351 ///       a set of column vectors,
352 ///  2.2. Lower the instruction in terms of column major operations, which
353 ///       yields a set of column vectors containing result matrix. Note that we
354 ///       lower all instructions that have shape information. Besides the
355 ///       intrinsics, this includes stores for example.
356 ///  2.3. Update uses of the lowered instruction. If we have shape information
357 ///       for a user, there is nothing to do, as we will look up the result
358 ///       column matrix when lowering the user. For other uses, we embed the
359 ///       result matrix in a flat vector and update the use.
360 ///  2.4. Cache the result column matrix for the instruction we lowered
361 /// 3. After we lowered all instructions in a function, remove the now
362 ///    obsolete instructions.
363 ///
364 class LowerMatrixIntrinsics {
365   Function &Func;
366   const DataLayout &DL;
367   const TargetTransformInfo &TTI;
368   FunctionAnalysisManager *AM;
369   AliasAnalysis *AA = nullptr;
370   DominatorTree *DT = nullptr;
371   LoopInfo *LI = nullptr;
372   OptimizationRemarkEmitter *ORE = nullptr;
373 
374   /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
375   struct OpInfoTy {
376     /// Number of stores emitted to generate this matrix.
377     unsigned NumStores = 0;
378     /// Number of loads emitted to generate this matrix.
379     unsigned NumLoads = 0;
380     /// Number of compute operations emitted to generate this matrix.
381     unsigned NumComputeOps = 0;
382     /// Most of the time transposes can be fused with matrix multiplies or can
383     /// be folded away via algebraic simplifications.  This is the number of
384     /// transposes that we failed to make "free" via such optimizations.
385     unsigned NumExposedTransposes = 0;
386 
387     OpInfoTy &operator+=(const OpInfoTy &RHS) {
388       NumStores += RHS.NumStores;
389       NumLoads += RHS.NumLoads;
390       NumComputeOps += RHS.NumComputeOps;
391       NumExposedTransposes += RHS.NumExposedTransposes;
392       return *this;
393     }
394   };
395 
396   /// Wrapper class representing a matrix as a set of vectors, either in row or
397   /// column major layout. All vectors must have the same vector type.
398   class MatrixTy {
399     SmallVector<Value *, 16> Vectors;
400 
401     OpInfoTy OpInfo;
402 
403     bool IsColumnMajor = true;
404 
405   public:
406     MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
407     MatrixTy(ArrayRef<Value *> Vectors)
408         : Vectors(Vectors),
409           IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
410     MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
411         : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
412 
413       unsigned D = isColumnMajor() ? NumColumns : NumRows;
414       for (unsigned J = 0; J < D; ++J)
415         addVector(PoisonValue::get(FixedVectorType::get(
416             EltTy, isColumnMajor() ? NumRows : NumColumns)));
417     }
418 
419     Value *getVector(unsigned i) const { return Vectors[i]; }
420     Value *getColumn(unsigned i) const {
421       assert(isColumnMajor() && "only supported for column-major matrixes");
422       return Vectors[i];
423     }
424     Value *getRow(unsigned i) const {
425       assert(!isColumnMajor() && "only supported for row-major matrixes");
426       return Vectors[i];
427     }
428 
429     void setVector(unsigned i, Value *V) { Vectors[i] = V; }
430 
431     Type *getElementType() const { return getVectorTy()->getElementType(); }
432 
433     unsigned getNumVectors() const {
434       if (isColumnMajor())
435         return getNumColumns();
436       return getNumRows();
437     }
438 
439     unsigned getNumColumns() const {
440       if (isColumnMajor())
441         return Vectors.size();
442       else {
443         assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
444         return getVectorTy()->getNumElements();
445       }
446     }
447     unsigned getNumRows() const {
448       if (isColumnMajor()) {
449         assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
450         return getVectorTy()->getNumElements();
451       } else
452         return Vectors.size();
453     }
454 
455     void addVector(Value *V) { Vectors.push_back(V); }
456     FixedVectorType *getColumnTy() {
457       assert(isColumnMajor() && "only supported for column-major matrixes");
458       return getVectorTy();
459     }
460 
461     FixedVectorType *getVectorTy() const {
462       return cast<FixedVectorType>(Vectors[0]->getType());
463     }
464 
465     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
466       assert(isColumnMajor() &&
467              "columns() only supported for column-major matrixes");
468       return make_range(Vectors.begin(), Vectors.end());
469     }
470 
471     iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
472       return make_range(Vectors.begin(), Vectors.end());
473     }
474 
475     /// Embed the vectors of the matrix into a flat vector by concatenating
476     /// them.
477     Value *embedInVector(IRBuilder<> &Builder) const {
478       return Vectors.size() == 1 ? Vectors[0]
479                                  : concatenateVectors(Builder, Vectors);
480     }
481 
482     MatrixTy &addNumLoads(unsigned N) {
483       OpInfo.NumLoads += N;
484       return *this;
485     }
486 
487     void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
488 
489     MatrixTy &addNumStores(unsigned N) {
490       OpInfo.NumStores += N;
491       return *this;
492     }
493 
494     MatrixTy &addNumExposedTransposes(unsigned N) {
495       OpInfo.NumExposedTransposes += N;
496       return *this;
497     }
498 
499     MatrixTy &addNumComputeOps(unsigned N) {
500       OpInfo.NumComputeOps += N;
501       return *this;
502     }
503 
504     unsigned getNumStores() const { return OpInfo.NumStores; }
505     unsigned getNumLoads() const { return OpInfo.NumLoads; }
506     unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
507 
508     const OpInfoTy &getOpInfo() const { return OpInfo; }
509 
510     bool isColumnMajor() const { return IsColumnMajor; }
511 
512     unsigned getStride() const {
513       if (isColumnMajor())
514         return getNumRows();
515       return getNumColumns();
516     }
517 
518     ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; }
519 
520     /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
521     /// matrix is column-major, the result vector is extracted from a column
522     /// vector, otherwise from a row vector.
523     Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
524                          IRBuilder<> &Builder) const {
525       Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
526       assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
527                  NumElts &&
528              "Extracted vector will contain poison values");
529       return Builder.CreateShuffleVector(
530           Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
531           "block");
532     }
533   };
534 
535   /// Maps instructions to their shape information. The shape information
536   /// describes the shape to be used while lowering. This matches the shape of
537   /// the result value of the instruction, with the only exceptions being store
538   /// instructions and the matrix_column_major_store intrinsics. For those, the
539   /// shape information indicates that those instructions should be lowered
540   /// using shape information as well. Note that extra care is needed when
541   /// erasing or RAUW'ing a value that is present in ShapeMap. If the
542   /// replacement is also a matrix operation, use
543   /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
544   /// ShapeMap.  We don't use ValueMap, as there are also cases where we do not
545   /// want to add shape information for a replacement instruction. When directly
546   /// erasing a value with an entry in ShapeMap, use
547   /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
548   /// accordingly.
549   DenseMap<Value *, ShapeInfo> ShapeMap;
550 
551   /// List of instructions to remove. While lowering, we are not replacing all
552   /// users of a lowered instruction, if shape information is available and
553   /// those need to be removed after we finished lowering.
554   SmallVector<Instruction *, 16> ToRemove;
555 
556   /// Map from instructions to their produced column matrix.
557   MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
558 
559 private:
560   static FastMathFlags getFastMathFlags(Instruction *Inst) {
561     FastMathFlags FMF;
562 
563     if (isa<FPMathOperator>(*Inst))
564       FMF = Inst->getFastMathFlags();
565 
566     FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
567 
568     return FMF;
569   }
570 
571 public:
572   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
573                         FunctionAnalysisManager *AM)
574       : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
575 
576   unsigned getNumOps(Type *VT) {
577     assert(isa<FixedVectorType>(VT) && "Expected vector type");
578     return getNumOps(VT->getScalarType(),
579                      cast<FixedVectorType>(VT)->getNumElements());
580   }
581 
582   /// Is this the minimal version executed in the backend pipelines.
583   bool isMinimal() const {
584     return !DT;
585   }
586 
587   /// Return the estimated number of vector ops required for an operation on
588   /// \p VT * N.
589   unsigned getNumOps(Type *ST, unsigned N) {
590     return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
591                      double(TTI.getRegisterBitWidth(
592                                    TargetTransformInfo::RGK_FixedWidthVector)
593                                 .getFixedValue()));
594   }
595 
596   /// Return the set of vectors that a matrix value is lowered to.
597   ///
598   /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
599   /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
600   /// into vectors.
601   MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
602                      IRBuilder<> &Builder) {
603     FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
604     assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
605            "The vector size must match the number of matrix elements");
606 
607     // Check if we lowered MatrixVal using shape information. In that case,
608     // return the existing matrix, if it matches the requested shape
609     // information. If there is a mis-match, embed the result in a flat
610     // vector and split it later.
611     auto Found = Inst2ColumnMatrix.find(MatrixVal);
612     if (Found != Inst2ColumnMatrix.end()) {
613       MatrixTy &M = Found->second;
614       // Return the found matrix, if its shape matches the requested shape
615       // information
616       if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
617         return M;
618 
619       MatrixVal = M.embedInVector(Builder);
620     }
621 
622     // Otherwise split MatrixVal.
623     SmallVector<Value *, 16> SplitVecs;
624     for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
625          MaskStart += SI.getStride()) {
626       Value *V = Builder.CreateShuffleVector(
627           MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
628           "split");
629       SplitVecs.push_back(V);
630     }
631 
632     if (Instruction *Inst = dyn_cast<Instruction>(MatrixVal)) {
633       if (Found != Inst2ColumnMatrix.end()) {
634         // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
635         // that embedInVector created.
636         LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()
637                           << " to " << SI << " using at least "
638                           << SplitVecs.size() << " shuffles on behalf of:\n"
639                           << *Inst << '\n');
640         ReshapedMatrices++;
641       } else if (!ShapeMap.contains(MatrixVal)) {
642         LLVM_DEBUG(
643             dbgs()
644             << "splitting a " << SI << " matrix with " << SplitVecs.size()
645             << " shuffles beacuse we do not have a shape-aware lowering for "
646                "its def:\n"
647             << *Inst << '\n');
648         (void)Inst;
649         SplitMatrices++;
650       } else {
651         // The ShapeMap has it, so it's a case where we're being lowered
652         // before the def, and we expect that InstCombine will clean things up
653         // afterward.
654       }
655     }
656 
657     return {SplitVecs};
658   }
659 
660   /// If \p V already has a known shape return false.  Otherwise set the shape
661   /// for instructions that support it.
662   bool setShapeInfo(Value *V, ShapeInfo Shape) {
663     assert(Shape && "Shape not set");
664     if (isa<UndefValue>(V) || !supportsShapeInfo(V))
665       return false;
666 
667     auto SIter = ShapeMap.find(V);
668     if (SIter != ShapeMap.end()) {
669       if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
670                               SIter->second.NumColumns != Shape.NumColumns)) {
671         errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
672                << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
673                << Shape.NumColumns << ") for " << *V << "\n";
674         report_fatal_error(
675             "Matrix shape verification failed, compilation aborted!");
676       }
677 
678       LLVM_DEBUG(dbgs() << "  not overriding existing shape: "
679                         << SIter->second.NumRows << " "
680                         << SIter->second.NumColumns << " for " << *V << "\n");
681       return false;
682     }
683 
684     ShapeMap.insert({V, Shape});
685     LLVM_DEBUG(dbgs() << "  " << Shape.NumRows << " x " << Shape.NumColumns
686                       << " for " << *V << "\n");
687     return true;
688   }
689 
690   /// Returns true if shape information can be used for \p V. The supported
691   /// instructions must match the instructions that can be lowered by this pass.
692   bool supportsShapeInfo(Value *V) {
693     Instruction *Inst = dyn_cast<Instruction>(V);
694     if (!Inst)
695       return false;
696 
697     IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
698     if (II)
699       switch (II->getIntrinsicID()) {
700       case Intrinsic::matrix_multiply:
701       case Intrinsic::matrix_transpose:
702       case Intrinsic::matrix_column_major_load:
703       case Intrinsic::matrix_column_major_store:
704         return true;
705       default:
706         return isUniformShape(II);
707       }
708     return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
709            isa<SelectInst>(V);
710   }
711 
712   /// Propagate the shape information of instructions to their users.
713   /// The work list contains instructions for which we can compute the shape,
714   /// either based on the information provided by matrix intrinsics or known
715   /// shapes of operands.
716   SmallVector<Instruction *, 32>
717   propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
718     SmallVector<Instruction *, 32> NewWorkList;
719     // Pop an element for which we guaranteed to have at least one of the
720     // operand shapes.  Add the shape for this and then add users to the work
721     // list.
722     LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
723     while (!WorkList.empty()) {
724       Instruction *Inst = WorkList.pop_back_val();
725 
726       // New entry, set the value and insert operands
727       bool Propagate = false;
728       if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
729         Propagate = setShapeInfo(Inst, *SI);
730 
731       if (Propagate) {
732         NewWorkList.push_back(Inst);
733         for (auto *User : Inst->users())
734           if (ShapeMap.count(User) == 0)
735             WorkList.push_back(cast<Instruction>(User));
736       }
737     }
738 
739     return NewWorkList;
740   }
741 
742   /// Propagate the shape to operands of instructions with shape information.
743   /// \p Worklist contains the instruction for which we already know the shape.
744   SmallVector<Instruction *, 32>
745   propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
746     SmallVector<Instruction *, 32> NewWorkList;
747 
748     auto pushInstruction = [](Value *V,
749                               SmallVectorImpl<Instruction *> &WorkList) {
750       Instruction *I = dyn_cast<Instruction>(V);
751       if (I)
752         WorkList.push_back(I);
753     };
754     // Pop an element with known shape.  Traverse the operands, if their shape
755     // derives from the result shape and is unknown, add it and add them to the
756     // worklist.
757     LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
758     while (!WorkList.empty()) {
759       Value *V = WorkList.pop_back_val();
760 
761       size_t BeforeProcessingV = WorkList.size();
762       if (!isa<Instruction>(V))
763         continue;
764 
765       Value *MatrixA;
766       Value *MatrixB;
767       Value *M;
768       Value *N;
769       Value *K;
770       if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
771                        m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
772                        m_Value(N), m_Value(K)))) {
773         if (setShapeInfo(MatrixA, {M, N}))
774           pushInstruction(MatrixA, WorkList);
775 
776         if (setShapeInfo(MatrixB, {N, K}))
777           pushInstruction(MatrixB, WorkList);
778 
779       } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
780                               m_Value(MatrixA), m_Value(M), m_Value(N)))) {
781         // Flip dimensions.
782         if (setShapeInfo(MatrixA, {M, N}))
783           pushInstruction(MatrixA, WorkList);
784       } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
785                               m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
786                               m_Value(M), m_Value(N)))) {
787         if (setShapeInfo(MatrixA, {M, N})) {
788           pushInstruction(MatrixA, WorkList);
789         }
790       } else if (isa<LoadInst>(V) ||
791                  match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
792         // Nothing to do, no matrix input.
793       } else if (isa<StoreInst>(V)) {
794         // Nothing to do.  We forward-propagated to this so we would just
795         // backward propagate to an instruction with an already known shape.
796       } else if (isUniformShape(V) || isa<SelectInst>(V)) {
797         auto Ops = cast<Instruction>(V)->operands();
798         auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
799         // Propagate to all operands.
800         ShapeInfo Shape = ShapeMap[V];
801         for (Use &U : ShapedOps) {
802           if (setShapeInfo(U.get(), Shape))
803             pushInstruction(U.get(), WorkList);
804         }
805       }
806       // After we discovered new shape info for new instructions in the
807       // worklist, we use their users as seeds for the next round of forward
808       // propagation.
809       for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
810         for (User *U : WorkList[I]->users())
811           if (isa<Instruction>(U) && V != U)
812             NewWorkList.push_back(cast<Instruction>(U));
813     }
814     return NewWorkList;
815   }
816 
817   /// (Op0 op Op1)^T -> Op0^T op Op1^T
818   /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
819   /// them on both sides of \p Operation.
820   Instruction *distributeTransposes(
821       Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
822       MatrixBuilder &Builder,
823       function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
824           Operation) {
825     Value *T0 = Builder.CreateMatrixTranspose(
826         Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
827     // We are being run after shape prop, add shape for newly created
828     // instructions so that we lower them later.
829     setShapeInfo(T0, Shape0.t());
830     Value *T1 = Builder.CreateMatrixTranspose(
831         Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
832     setShapeInfo(T1, Shape1.t());
833     return Operation(T0, Shape0.t(), T1, Shape1.t());
834   }
835 
836   /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
837   /// itself.
838   void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
839     ShapeMap.erase(Inst);
840     Inst->eraseFromParent();
841   }
842 
843   /// Erase \p V from \p BB and move \II forward to avoid invalidating
844   /// iterators.
845   void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
846                               BasicBlock &BB) {
847     auto *Inst = cast<Instruction>(V);
848     // Still used, don't erase.
849     if (!Inst->use_empty())
850       return;
851     if (II != BB.rend() && Inst == &*II)
852       ++II;
853     eraseFromParentAndRemoveFromShapeMap(Inst);
854   }
855 
856   /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
857   /// entry for \p Old and replace all uses of \p Old with \p New.
858   void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
859     // We need to remove Old from the ShapeMap otherwise RAUW will replace it
860     // with New. We should only add New it it supportsShapeInfo so we insert
861     // it conditionally instead.
862     auto S = ShapeMap.find(&Old);
863     if (S != ShapeMap.end()) {
864       ShapeMap.erase(S);
865       if (supportsShapeInfo(New))
866         ShapeMap.insert({New, S->second});
867     }
868     Old.replaceAllUsesWith(New);
869   }
870 
871   /// Sink a top-level transpose inside matmuls and adds.
872   /// This creates and erases instructions as needed, and returns the newly
873   /// created instruction while updating the iterator to avoid invalidation. If
874   /// this returns nullptr, no new instruction was created.
875   Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II,
876                              bool &Changed) {
877     BasicBlock &BB = *I.getParent();
878     IRBuilder<> IB(&I);
879     MatrixBuilder Builder(IB);
880 
881     Value *TA, *TAMA, *TAMB;
882     ConstantInt *R, *K, *C;
883     if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
884                        m_Value(TA), m_ConstantInt(R), m_ConstantInt(C))))
885       return nullptr;
886 
887     // Transpose of a transpose is a nop when the shapes match.
888     Value *TATA;
889     if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(
890                       m_Value(TATA), m_Specific(C), m_Specific(R)))) {
891       updateShapeAndReplaceAllUsesWith(I, TATA);
892       eraseFromParentAndMove(&I, II, BB);
893       eraseFromParentAndMove(TA, II, BB);
894       Changed = true;
895       return nullptr;
896     }
897 
898     // k^T -> k
899     if (isSplat(TA)) {
900       updateShapeAndReplaceAllUsesWith(I, TA);
901       eraseFromParentAndMove(&I, II, BB);
902       Changed = true;
903       return nullptr;
904     }
905 
906     // (A * B)^t -> B^t * A^t
907     // RxK KxC      CxK   KxR
908     if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
909                       m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
910                       m_ConstantInt(K), m_ConstantInt(C)))) {
911       auto NewInst = distributeTransposes(
912           TAMB, {K, C}, TAMA, {R, K}, Builder,
913           [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
914             return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
915                                                 Shape0.NumColumns,
916                                                 Shape1.NumColumns, "mmul");
917           });
918       updateShapeAndReplaceAllUsesWith(I, NewInst);
919       eraseFromParentAndMove(&I, II, BB);
920       eraseFromParentAndMove(TA, II, BB);
921       Changed = true;
922       return NewInst;
923     }
924 
925     // Same as above, but with a mul, which occurs when multiplied
926     // with a scalar.
927     // (A * k)^t -> A^t * k
928     //  R  x  C     RxC
929     if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
930         (isSplat(TAMA) || isSplat(TAMB))) {
931       IRBuilder<> LocalBuilder(&I);
932       // We know that the transposed operand is of shape RxC.
933       // An when multiplied with a scalar, the shape is preserved.
934       auto NewInst = distributeTransposes(
935           TAMA, {R, C}, TAMB, {R, C}, Builder,
936           [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
937             bool IsFP = I.getType()->isFPOrFPVectorTy();
938             auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
939                              : LocalBuilder.CreateMul(T0, T1, "mmul");
940             auto *Result = cast<Instruction>(Mul);
941             setShapeInfo(Result, Shape0);
942             return Result;
943           });
944       updateShapeAndReplaceAllUsesWith(I, NewInst);
945       eraseFromParentAndMove(&I, II, BB);
946       eraseFromParentAndMove(TA, II, BB);
947       Changed = true;
948       return NewInst;
949     }
950 
951     // (A + B)^t -> A^t + B^t
952     // RxC RxC      CxR   CxR
953     if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
954       IRBuilder<> LocalBuilder(&I);
955       auto NewInst = distributeTransposes(
956           TAMA, {R, C}, TAMB, {R, C}, Builder,
957           [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
958             bool IsFP = I.getType()->isFPOrFPVectorTy();
959             auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
960                              : LocalBuilder.CreateAdd(T0, T1, "madd");
961 
962             auto *Result = cast<Instruction>(Add);
963             setShapeInfo(Result, Shape0);
964             return Result;
965           });
966       updateShapeAndReplaceAllUsesWith(I, NewInst);
967       eraseFromParentAndMove(&I, II, BB);
968       eraseFromParentAndMove(TA, II, BB);
969       Changed = true;
970       return NewInst;
971     }
972 
973     return nullptr;
974   }
975 
976   bool liftTranspose(Instruction &I) {
977     // Erase dead Instructions after lifting transposes from binops.
978     auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
979       if (T.use_empty())
980         eraseFromParentAndRemoveFromShapeMap(&T);
981       if (A->use_empty())
982         eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
983       if (A != B && B->use_empty())
984         eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
985     };
986 
987     Value *A, *B, *AT, *BT;
988     ConstantInt *R, *K, *C;
989     // A^t * B ^t -> (B * A)^t
990     if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
991                       m_Value(A), m_Value(B), m_ConstantInt(R),
992                       m_ConstantInt(K), m_ConstantInt(C))) &&
993         match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
994         match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
995       IRBuilder<> IB(&I);
996       MatrixBuilder Builder(IB);
997       Value *M = Builder.CreateMatrixMultiply(
998           BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
999       setShapeInfo(M, {C, R});
1000       Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
1001                                                            R->getZExtValue());
1002       updateShapeAndReplaceAllUsesWith(I, NewInst);
1003       CleanupBinOp(I, A, B);
1004       return true;
1005     }
1006     // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
1007     // the shape of the second transpose is different, there's a shape conflict
1008     // which gets resolved by picking the shape of the first operand.
1009     else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
1010              match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
1011                           m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
1012              match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
1013                           m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
1014       IRBuilder<> Builder(&I);
1015       auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");
1016       MatrixBuilder MBuilder(Builder);
1017       Instruction *NewInst = MBuilder.CreateMatrixTranspose(
1018           Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
1019       updateShapeAndReplaceAllUsesWith(I, NewInst);
1020       assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
1021                  computeShapeInfoForInst(&I, ShapeMap) &&
1022              "Shape of new instruction doesn't match original shape.");
1023       CleanupBinOp(I, A, B);
1024       if (auto *AddI = dyn_cast<Instruction>(Add)) {
1025         setShapeInfo(AddI, {R, C});
1026         assert(
1027             computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
1028                 ShapeMap[AddI] &&
1029             "Shape of updated addition doesn't match cached shape.");
1030       }
1031       return true;
1032     }
1033     return false;
1034   }
1035 
1036   /// Try moving transposes in order to fold them away or into multiplies.
1037   bool optimizeTransposes() {
1038     bool Changed = false;
1039     // First sink all transposes inside matmuls and adds, hoping that we end up
1040     // with NN, NT or TN variants.
1041     for (BasicBlock &BB : reverse(Func)) {
1042       for (auto II = BB.rbegin(); II != BB.rend();) {
1043         Instruction &I = *II;
1044         // We may remove II.  By default continue on the next/prev instruction.
1045         ++II;
1046         if (Instruction *NewInst = sinkTranspose(I, II, Changed))
1047           II = std::next(BasicBlock::reverse_iterator(NewInst));
1048       }
1049     }
1050 
1051     // If we have a TT matmul or a TT add, lift the transpose. We may be able
1052     // to fold into consuming multiply or add.
1053     for (BasicBlock &BB : Func) {
1054       for (Instruction &I : llvm::make_early_inc_range(BB)) {
1055         Changed |= liftTranspose(I);
1056       }
1057     }
1058     return Changed;
1059   }
1060 
1061   bool Visit() {
1062     SmallVector<Instruction *, 32> WorkList;
1063 
1064     // Initially only the shape of matrix intrinsics is known.
1065     // Initialize the work list with ops carrying shape information.
1066     for (BasicBlock &BB : Func)
1067       for (Instruction &Inst : BB) {
1068         IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
1069         if (!II)
1070           continue;
1071 
1072         switch (II->getIntrinsicID()) {
1073         case Intrinsic::matrix_multiply:
1074         case Intrinsic::matrix_transpose:
1075         case Intrinsic::matrix_column_major_load:
1076         case Intrinsic::matrix_column_major_store:
1077           WorkList.push_back(&Inst);
1078           break;
1079         default:
1080           break;
1081         }
1082       }
1083 
1084     // Avoid unnecessary work if there are no matrix intrinsics in the function.
1085     if (WorkList.empty())
1086       return false;
1087 
1088     if (AM) {
1089       ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(Func);
1090       AA = &AM->getResult<AAManager>(Func);
1091       DT = &AM->getResult<DominatorTreeAnalysis>(Func);
1092       LI = &AM->getResult<LoopAnalysis>(Func);
1093     }
1094 
1095     // Propagate shapes until nothing changes any longer.
1096     while (!WorkList.empty()) {
1097       WorkList = propagateShapeForward(WorkList);
1098       WorkList = propagateShapeBackward(WorkList);
1099     }
1100 
1101     bool Changed = false;
1102     if (!isMinimal()) {
1103       Changed |= optimizeTransposes();
1104       if (PrintAfterTransposeOpt) {
1105         dbgs() << "Dump after matrix transpose optimization:\n";
1106         Func.print(dbgs());
1107       }
1108     }
1109 
1110     SmallVector<CallInst *, 16> MaybeFusableInsts;
1111     SmallVector<Instruction *, 16> MatrixInsts;
1112     SmallVector<IntrinsicInst *, 16> LifetimeEnds;
1113 
1114     // First, collect all instructions with shape information and candidates for
1115     // fusion (currently only matrix multiplies).
1116     ReversePostOrderTraversal<Function *> RPOT(&Func);
1117     for (auto *BB : RPOT)
1118       for (Instruction &I : *BB) {
1119         if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
1120           LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
1121         if (!ShapeMap.contains(&I))
1122           continue;
1123         if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1124           MaybeFusableInsts.push_back(cast<CallInst>(&I));
1125         MatrixInsts.push_back(&I);
1126       }
1127 
1128     // Second, try to lower any dot products
1129     SmallPtrSet<Instruction *, 16> FusedInsts;
1130     for (CallInst *CI : MaybeFusableInsts)
1131       lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1132 
1133     // Third, try to fuse candidates.
1134     for (CallInst *CI : MaybeFusableInsts)
1135       if (!FusedInsts.contains(CI))
1136         LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1137 
1138     Changed |= !FusedInsts.empty();
1139 
1140     // Fourth, pre-process all the PHINode's. The incoming values will be
1141     // assigned later in VisitPHI.
1142     for (Instruction *Inst : MatrixInsts) {
1143       if (FusedInsts.count(Inst))
1144         continue;
1145 
1146       auto *PHI = dyn_cast<PHINode>(Inst);
1147       if (!PHI)
1148         continue;
1149 
1150       const ShapeInfo &SI = ShapeMap.at(Inst);
1151       auto *EltTy = cast<FixedVectorType>(PHI->getType())->getElementType();
1152       MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy);
1153 
1154       IRBuilder<> Builder(Inst);
1155       for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
1156         PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
1157                                              PHI->getNumIncomingValues(),
1158                                              PHI->getName()));
1159       assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?");
1160       Inst2ColumnMatrix[PHI] = PhiM;
1161     }
1162 
1163     // Fifth, lower remaining instructions with shape information.
1164     for (Instruction *Inst : MatrixInsts) {
1165       if (FusedInsts.count(Inst))
1166         continue;
1167 
1168       const ShapeInfo &SI = ShapeMap.at(Inst);
1169 
1170       Value *Op1;
1171       Value *Op2;
1172       MatrixTy Result;
1173       IRBuilder<> Builder(Inst);
1174       if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1175         Result = VisitBinaryOperator(BinOp, SI, Builder);
1176       else if (auto *Cast = dyn_cast<CastInst>(Inst))
1177         Result = VisitCastInstruction(Cast, SI, Builder);
1178       else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1179         Result = VisitUnaryOperator(UnOp, SI, Builder);
1180       else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
1181         Result = VisitIntrinsicInst(Intr, SI, Builder);
1182       else if (auto *Select = dyn_cast<SelectInst>(Inst))
1183         Result = VisitSelectInst(Select, SI, Builder);
1184       else if (match(Inst, m_Load(m_Value(Op1))))
1185         Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
1186       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1187         Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1188       else if (auto *PHI = dyn_cast<PHINode>(Inst))
1189         Result = VisitPHI(PHI, SI, Builder);
1190       else
1191         continue;
1192 
1193       finalizeLowering(Inst, Result, Builder);
1194       Changed = true;
1195     }
1196 
1197     if (ORE) {
1198       RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1199       RemarkGen.emitRemarks();
1200     }
1201 
1202     // Delete the instructions backwards, as it has a reduced likelihood of
1203     // having to update as many def-use and use-def chains.
1204     //
1205     // Because we add to ToRemove during fusion we can't guarantee that defs
1206     // are before uses.  Change uses to poison temporarily as these should get
1207     // removed as well.
1208     //
1209     // For verification, we keep track of where we changed uses to poison in
1210     // PoisonedInsts and then check that we in fact remove them.
1211     SmallSet<Instruction *, 16> PoisonedInsts;
1212     for (auto *Inst : reverse(ToRemove)) {
1213       for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1214         if (auto *Poisoned = dyn_cast<Instruction>(U.getUser()))
1215           PoisonedInsts.insert(Poisoned);
1216         U.set(PoisonValue::get(Inst->getType()));
1217       }
1218       Inst->eraseFromParent();
1219       PoisonedInsts.erase(Inst);
1220     }
1221     if (!PoisonedInsts.empty()) {
1222       // If we didn't remove all poisoned instructions, it's a hard error.
1223       dbgs() << "Poisoned but present instructions:\n";
1224       for (auto *I : PoisonedInsts)
1225         dbgs() << *I << "\n";
1226       llvm_unreachable("Poisoned but instruction not removed");
1227     }
1228 
1229     return Changed;
1230   }
1231 
1232   /// Replace intrinsic calls.
1233   MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,
1234                               IRBuilder<> &Builder) {
1235     assert(Inst->getCalledFunction() &&
1236            Inst->getCalledFunction()->isIntrinsic());
1237 
1238     switch (Inst->getCalledFunction()->getIntrinsicID()) {
1239     case Intrinsic::matrix_multiply:
1240       return LowerMultiply(Inst, Builder);
1241     case Intrinsic::matrix_transpose:
1242       return LowerTranspose(Inst, Builder);
1243     case Intrinsic::matrix_column_major_load:
1244       return LowerColumnMajorLoad(Inst, Builder);
1245     case Intrinsic::matrix_column_major_store:
1246       return LowerColumnMajorStore(Inst, Builder);
1247     case Intrinsic::abs:
1248     case Intrinsic::fabs: {
1249       MatrixTy Result;
1250       MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
1251       Builder.setFastMathFlags(getFastMathFlags(Inst));
1252 
1253       for (auto *Vector : M.vectors()) {
1254         switch (Inst->getIntrinsicID()) {
1255         case Intrinsic::abs:
1256           Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
1257                                                          Inst->getOperand(1)));
1258           continue;
1259         case Intrinsic::fabs:
1260           Result.addVector(
1261               Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
1262           continue;
1263         default:
1264           llvm_unreachable("unexpected intrinsic");
1265         }
1266       }
1267 
1268       return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1269                                      Result.getNumVectors());
1270     }
1271     default:
1272       break;
1273     }
1274     llvm_unreachable(
1275         "only intrinsics supporting shape info should be seen here");
1276   }
1277 
1278   /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1279   /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1280   /// ConstantInt, reduce the initial alignment based on the byte offset. For
1281   /// non-ConstantInt strides, return the common alignment of the initial
1282   /// alignment and the element size in bytes.
1283   Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1284                          MaybeAlign A) const {
1285     Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1286     if (Idx == 0)
1287       return InitialAlign;
1288 
1289     TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1290     if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
1291       uint64_t StrideInBytes =
1292           ConstStride->getZExtValue() * ElementSizeInBits / 8;
1293       return commonAlignment(InitialAlign, Idx * StrideInBytes);
1294     }
1295     return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1296   }
1297 
1298   /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1299   /// vectors.
1300   MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1301                       bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1302     auto *VType = cast<FixedVectorType>(Ty);
1303     Type *EltTy = VType->getElementType();
1304     Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1305     Value *EltPtr = Ptr;
1306     MatrixTy Result;
1307     for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1308       Value *GEP = computeVectorAddr(
1309           EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
1310           Stride, Shape.getStride(), EltTy, Builder);
1311       Value *Vector = Builder.CreateAlignedLoad(
1312           VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1313           IsVolatile, "col.load");
1314 
1315       Result.addVector(Vector);
1316     }
1317     return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1318                               Result.getNumVectors());
1319   }
1320 
1321   /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1322   /// starting at \p MatrixPtr[I][J].
1323   MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1324                       ShapeInfo MatrixShape, Value *I, Value *J,
1325                       ShapeInfo ResultShape, Type *EltTy,
1326                       IRBuilder<> &Builder) {
1327     Value *Offset = Builder.CreateAdd(
1328         Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1329 
1330     Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1331     auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
1332                                                    ResultShape.NumColumns);
1333 
1334     return loadMatrix(TileTy, TileStart, Align,
1335                       Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1336                       ResultShape, Builder);
1337   }
1338 
1339   /// Lower a load instruction with shape information.
1340   MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
1341                      Value *Stride, bool IsVolatile, ShapeInfo Shape,
1342                      IRBuilder<> &Builder) {
1343     return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
1344                       Builder);
1345   }
1346 
1347   /// Lowers llvm.matrix.column.major.load.
1348   ///
1349   /// The intrinsic loads a matrix from memory using a stride between columns.
1350   MatrixTy LowerColumnMajorLoad(CallInst *Inst, IRBuilder<> &Builder) {
1351     assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1352            "Intrinsic only supports column-major layout!");
1353     Value *Ptr = Inst->getArgOperand(0);
1354     Value *Stride = Inst->getArgOperand(1);
1355     return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1356                      cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1357                      {Inst->getArgOperand(3), Inst->getArgOperand(4)}, Builder);
1358   }
1359 
1360   /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1361   /// MatrixPtr[I][J].
1362   void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1363                    MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1364                    Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1365     Value *Offset = Builder.CreateAdd(
1366         Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1367 
1368     Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
1369     auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
1370                                                    StoreVal.getNumColumns());
1371 
1372     storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1373                 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1374   }
1375 
1376   /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1377   /// vectors.
1378   MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1379                        MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1380                        IRBuilder<> &Builder) {
1381     auto *VType = cast<FixedVectorType>(Ty);
1382     Value *EltPtr = Ptr;
1383     for (auto Vec : enumerate(StoreVal.vectors())) {
1384       Value *GEP = computeVectorAddr(
1385           EltPtr,
1386           Builder.getIntN(Stride->getType()->getScalarSizeInBits(),
1387                           Vec.index()),
1388           Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1389       Builder.CreateAlignedStore(Vec.value(), GEP,
1390                                  getAlignForIndex(Vec.index(), Stride,
1391                                                   VType->getElementType(),
1392                                                   MAlign),
1393                                  IsVolatile);
1394     }
1395     return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1396                                    StoreVal.getNumVectors());
1397   }
1398 
1399   /// Lower a store instruction with shape information.
1400   MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
1401                       MaybeAlign A, Value *Stride, bool IsVolatile,
1402                       ShapeInfo Shape, IRBuilder<> &Builder) {
1403     auto StoreVal = getMatrix(Matrix, Shape, Builder);
1404     return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
1405                        Builder);
1406   }
1407 
1408   /// Lowers llvm.matrix.column.major.store.
1409   ///
1410   /// The intrinsic store a matrix back memory using a stride between columns.
1411   MatrixTy LowerColumnMajorStore(CallInst *Inst, IRBuilder<> &Builder) {
1412     assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1413            "Intrinsic only supports column-major layout!");
1414     Value *Matrix = Inst->getArgOperand(0);
1415     Value *Ptr = Inst->getArgOperand(1);
1416     Value *Stride = Inst->getArgOperand(2);
1417     return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1418                       cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1419                       {Inst->getArgOperand(4), Inst->getArgOperand(5)},
1420                       Builder);
1421   }
1422 
1423   // Set elements I..I+NumElts-1 to Block
1424   Value *insertVector(Value *Col, unsigned I, Value *Block,
1425                       IRBuilder<> &Builder) {
1426 
1427     // First, bring Block to the same size as Col
1428     unsigned BlockNumElts =
1429         cast<FixedVectorType>(Block->getType())->getNumElements();
1430     unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
1431     assert(NumElts >= BlockNumElts && "Too few elements for current block");
1432 
1433     Block = Builder.CreateShuffleVector(
1434         Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
1435 
1436     // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1437     // 8, 4, 5, 6
1438     SmallVector<int, 16> Mask;
1439     unsigned i;
1440     for (i = 0; i < I; i++)
1441       Mask.push_back(i);
1442 
1443     unsigned VecNumElts =
1444         cast<FixedVectorType>(Col->getType())->getNumElements();
1445     for (; i < I + BlockNumElts; i++)
1446       Mask.push_back(i - I + VecNumElts);
1447 
1448     for (; i < VecNumElts; i++)
1449       Mask.push_back(i);
1450 
1451     return Builder.CreateShuffleVector(Col, Block, Mask);
1452   }
1453 
1454   Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1455                       IRBuilder<> &Builder, bool AllowContraction,
1456                       unsigned &NumComputeOps) {
1457     NumComputeOps += getNumOps(A->getType());
1458     if (!Sum)
1459       return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
1460 
1461     if (UseFPOp) {
1462       if (AllowContraction) {
1463         // Use fmuladd for floating point operations and let the backend decide
1464         // if that's profitable.
1465         return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),
1466                                        {A, B, Sum});
1467       }
1468       NumComputeOps += getNumOps(A->getType());
1469       Value *Mul = Builder.CreateFMul(A, B);
1470       return Builder.CreateFAdd(Sum, Mul);
1471     }
1472 
1473     NumComputeOps += getNumOps(A->getType());
1474     Value *Mul = Builder.CreateMul(A, B);
1475     return Builder.CreateAdd(Sum, Mul);
1476   }
1477 
1478   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1479   /// users with shape information, there's nothing to do: they will use the
1480   /// cached value when they are lowered. For other users, \p Matrix is
1481   /// flattened and the uses are updated to use it. Also marks \p Inst for
1482   /// deletion.
1483   void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1484                         IRBuilder<> &Builder) {
1485     auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1486     (void)inserted;
1487     assert((inserted.second || isa<PHINode>(Inst)) &&
1488            "multiple matrix lowering mapping");
1489 
1490     ToRemove.push_back(Inst);
1491     Value *Flattened = nullptr;
1492     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1493       if (ShapeMap.contains(U.getUser()))
1494         continue;
1495 
1496       if (!Flattened) {
1497         Flattened = Matrix.embedInVector(Builder);
1498         LLVM_DEBUG(
1499             if (Instruction *User = dyn_cast<Instruction>(U.getUser())) dbgs()
1500                 << "flattening a " << Matrix.shape() << " matrix:\n"
1501                 << *Inst
1502                 << "\nbecause we do not have a shape-aware lowering for its "
1503                    "user:\n"
1504                 << *User << '\n';);
1505         FlattenedMatrices++;
1506       }
1507       U.set(Flattened);
1508     }
1509   }
1510 
1511   /// Special case for MatMul lowering. Prevents scalar loads of row-major
1512   /// vectors Lowers to vector reduction add instead of sequential add if
1513   /// reassocation is enabled.
1514   void lowerDotProduct(CallInst *MatMul,
1515                        SmallPtrSet<Instruction *, 16> &FusedInsts,
1516                        FastMathFlags FMF) {
1517     if (FusedInsts.contains(MatMul) ||
1518         MatrixLayout != MatrixLayoutTy::ColumnMajor)
1519       return;
1520     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1521     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1522 
1523     if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1524       return;
1525 
1526     Value *LHS = MatMul->getArgOperand(0);
1527     Value *RHS = MatMul->getArgOperand(1);
1528 
1529     Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
1530     bool IsIntVec = ElementType->isIntegerTy();
1531 
1532     // Floating point reductions require reassocation.
1533     if (!IsIntVec && !FMF.allowReassoc())
1534       return;
1535 
1536     auto CanBeFlattened = [](Value *Op) {
1537       if (match(Op, m_BinOp()))
1538         return true;
1539       return match(
1540           Op, m_OneUse(m_CombineOr(
1541                   m_Load(m_Value()),
1542                   m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1543                               m_Intrinsic<Intrinsic::matrix_column_major_load>(
1544                                   m_Value(), m_SpecificInt(1))))));
1545     };
1546     // Returns the cost benefit of using \p Op with the dot product lowering. If
1547     // the returned cost is < 0, the argument is cheaper to use in the
1548     // dot-product lowering.
1549     auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1550       if (!ShapeMap.contains(Op))
1551         return InstructionCost::getInvalid();
1552 
1553       if (!isa<Instruction>(Op))
1554         return InstructionCost(0);
1555 
1556       FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
1557       Type *EltTy = VecTy->getElementType();
1558 
1559       if (!CanBeFlattened(Op)) {
1560         InstructionCost EmbedCost(0);
1561         // Roughly estimate the cost for embedding the columns into a vector.
1562         for (unsigned I = 1; I < N; ++I)
1563           EmbedCost += TTI.getShuffleCost(
1564               TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
1565               FixedVectorType::get(EltTy, 1), {}, TTI::TCK_RecipThroughput);
1566         return EmbedCost;
1567       }
1568 
1569       if (match(Op, m_BinOp()) && ShapeMap.contains(Op)) {
1570         InstructionCost OriginalCost =
1571             TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
1572                                        EltTy) *
1573             N;
1574         InstructionCost NewCost = TTI.getArithmeticInstrCost(
1575             cast<Instruction>(Op)->getOpcode(), VecTy);
1576         return NewCost - OriginalCost;
1577       }
1578 
1579       if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1580         // The transpose can be skipped for the dot product lowering, roughly
1581         // estimate the savings as the cost of embedding the columns in a
1582         // vector.
1583         InstructionCost EmbedCost(0);
1584         for (unsigned I = 1; I < N; ++I)
1585           EmbedCost -= TTI.getShuffleCost(
1586               TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
1587               FixedVectorType::get(EltTy, 1), {}, TTI::TCK_RecipThroughput);
1588         return EmbedCost;
1589       }
1590 
1591       // Costs for loads.
1592       if (N == 1)
1593         return InstructionCost(0);
1594 
1595       return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1596              N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1597     };
1598 
1599     // Iterate over LHS and operations feeding LHS and check if it is profitable
1600     // to flatten the visited ops.  For each op, we compute the difference
1601     // between the flattened and matrix versions.
1602     SmallPtrSet<Value *, 4> Seen;
1603     SmallVector<Value *> WorkList;
1604     SmallVector<Value *> ToFlatten;
1605     WorkList.push_back(LHS);
1606     InstructionCost LHSCost(0);
1607     while (!WorkList.empty()) {
1608       Value *Op = WorkList.pop_back_val();
1609       if (!Seen.insert(Op).second)
1610         continue;
1611 
1612       InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1613       if (OpCost + LHSCost >= LHSCost)
1614         continue;
1615 
1616       LHSCost += OpCost;
1617       ToFlatten.push_back(Op);
1618       if (auto *I = dyn_cast<Instruction>(Op))
1619         WorkList.append(I->op_begin(), I->op_end());
1620     }
1621 
1622     // We compare the costs of a vector.reduce.add to sequential add.
1623     int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1624     int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1625     InstructionCost ReductionCost =
1626         TTI.getArithmeticReductionCost(
1627             AddOpCode, cast<FixedVectorType>(LHS->getType()),
1628             IsIntVec ? std::nullopt : std::optional(FMF)) +
1629         TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
1630     InstructionCost SequentialAddCost =
1631         TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
1632             (LShape.NumColumns - 1) +
1633         TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
1634             (LShape.NumColumns);
1635     if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1636       return;
1637 
1638     FusedInsts.insert(MatMul);
1639     IRBuilder<> Builder(MatMul);
1640     auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1641                        this](Value *Op) {
1642       // Matmul must be the only user of loads because we don't use LowerLoad
1643       // for row vectors (LowerLoad results in scalar loads and shufflevectors
1644       // instead of single vector load).
1645       if (!CanBeFlattened(Op))
1646         return;
1647 
1648       if (match(Op, m_BinOp())) {
1649         auto It = ShapeMap.find(Op);
1650         if (It != ShapeMap.end()) {
1651           It->second = It->second.t();
1652           return;
1653         }
1654       }
1655 
1656       FusedInsts.insert(cast<Instruction>(Op));
1657       // If vector uses the builtin load, lower to a LoadInst
1658       Value *Arg;
1659       if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1660                         m_Value(Arg)))) {
1661         auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1662         Op->replaceAllUsesWith(NewLoad);
1663         eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
1664         return;
1665       } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1666                                m_Value(Arg)))) {
1667         ToRemove.push_back(cast<Instruction>(Op));
1668         Op->replaceAllUsesWith(Arg);
1669         return;
1670       }
1671     };
1672 
1673     for (auto *V : ToFlatten)
1674       FlattenArg(V);
1675 
1676     LHS = MatMul->getArgOperand(0);
1677 
1678     // Insert mul/fmul and llvm.vector.reduce.fadd
1679     Value *Mul =
1680         IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
1681 
1682     Value *Result;
1683     if (IsIntVec)
1684       Result = Builder.CreateAddReduce(Mul);
1685     else {
1686       Result = Builder.CreateFAddReduce(
1687           ConstantFP::get(
1688               cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
1689           Mul);
1690       cast<Instruction>(Result)->setFastMathFlags(FMF);
1691     }
1692 
1693     // pack scalar back into a matrix and then replace matmul inst
1694     Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
1695                                          Result, uint64_t(0));
1696     MatMul->replaceAllUsesWith(Result);
1697     FusedInsts.insert(MatMul);
1698     ToRemove.push_back(MatMul);
1699   }
1700 
1701   /// Compute \p Result += \p A * \p B for input matrices with left-associating
1702   /// addition.
1703   ///
1704   /// We can fold a transpose into the operand that is used to extract scalars.
1705   /// This is the first operands with row-major and the second with
1706   /// column-major.  If \p IsScalarMatrixTransposed we assume the appropriate
1707   /// operand is transposed.
1708   void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1709                           const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1710                           bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1711     const unsigned VF = std::max<unsigned>(
1712         TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1713                 .getFixedValue() /
1714             Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1715         1U);
1716     unsigned R = Result.getNumRows();
1717     unsigned C = Result.getNumColumns();
1718     unsigned M = A.getNumColumns();
1719 
1720     bool IsFP = Result.getElementType()->isFloatingPointTy();
1721     assert(A.isColumnMajor() == B.isColumnMajor() &&
1722            Result.isColumnMajor() == A.isColumnMajor() &&
1723            "operands must agree on matrix layout");
1724     unsigned NumComputeOps = 0;
1725 
1726     Builder.setFastMathFlags(FMF);
1727 
1728     if (A.isColumnMajor()) {
1729       // Multiply columns from the first operand with scalars from the second
1730       // operand. Then move along the K axes and accumulate the columns.  With
1731       // this the adds can be vectorized without reassociation.
1732       for (unsigned J = 0; J < C; ++J) {
1733         unsigned BlockSize = VF;
1734         // If Result is zero, we don't need to accumulate in the K==0 iteration.
1735         bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
1736 
1737         for (unsigned I = 0; I < R; I += BlockSize) {
1738           // Gradually lower the vectorization factor to cover the remainder.
1739           while (I + BlockSize > R)
1740             BlockSize /= 2;
1741 
1742           Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
1743                                : nullptr;
1744           for (unsigned K = 0; K < M; ++K) {
1745             Value *L = A.extractVector(I, K, BlockSize, Builder);
1746             Value *RH = Builder.CreateExtractElement(
1747                 B.getColumn(IsScalarMatrixTransposed ? K : J),
1748                 IsScalarMatrixTransposed ? J : K);
1749             Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
1750             Sum =
1751                 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1752                              IsFP, Builder, FMF.allowContract(), NumComputeOps);
1753           }
1754           Result.setVector(J,
1755                            insertVector(Result.getVector(J), I, Sum, Builder));
1756         }
1757       }
1758     } else {
1759       // Multiply rows from the second operand with scalars from the first
1760       // operand. Then move along the K axes and accumulate the rows.  With this
1761       // the adds can be vectorized without reassociation.
1762       for (unsigned I = 0; I < R; ++I) {
1763         unsigned BlockSize = VF;
1764         bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
1765         for (unsigned J = 0; J < C; J += BlockSize) {
1766           // Gradually lower the vectorization factor to cover the remainder.
1767           while (J + BlockSize > C)
1768             BlockSize /= 2;
1769 
1770           Value *Sum = nullptr;
1771           for (unsigned K = 0; K < M; ++K) {
1772             Value *R = B.extractVector(K, J, BlockSize, Builder);
1773             Value *LH = Builder.CreateExtractElement(
1774                 A.getVector(IsScalarMatrixTransposed ? K : I),
1775                 IsScalarMatrixTransposed ? I : K);
1776             Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
1777             Sum =
1778                 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1779                              IsFP, Builder, FMF.allowContract(), NumComputeOps);
1780           }
1781           Result.setVector(I,
1782                            insertVector(Result.getVector(I), J, Sum, Builder));
1783         }
1784       }
1785     }
1786     Result.addNumComputeOps(NumComputeOps);
1787   }
1788 
1789   /// Ensure that the memory in \p Load does not alias \p Store by potentially
1790   /// copying it to a new location.  This new or otherwise the original location
1791   /// is returned.
1792   Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1793                                CallInst *MatMul) {
1794     MemoryLocation StoreLoc = MemoryLocation::get(Store);
1795     MemoryLocation LoadLoc = MemoryLocation::get(Load);
1796 
1797     // If we can statically determine noalias we're good.
1798     if (AA->isNoAlias(LoadLoc, StoreLoc))
1799       return Load->getPointerOperand();
1800 
1801     // Create code to check if the memory locations of the Load and Store
1802     // overlap and if they do, copy Load's operand to a new buffer.
1803 
1804     // First, create  new blocks for 2n part of the check and the copy.
1805     BasicBlock *Check0 = MatMul->getParent();
1806     // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1807     // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1808     // as we adjust Check0 and Check1's branches.
1809     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
1810     for (BasicBlock *Succ : successors(Check0))
1811       DTUpdates.push_back({DT->Delete, Check0, Succ});
1812 
1813     BasicBlock *Check1 =
1814         SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1815                    nullptr, "alias_cont");
1816     BasicBlock *Copy =
1817         SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1818                    nullptr, "copy");
1819     BasicBlock *Fusion =
1820         SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
1821                    nullptr, "no_alias");
1822 
1823     // Check if the loaded memory location begins before the end of the store
1824     // location. If the condition holds, they might overlap, otherwise they are
1825     // guaranteed to not overlap.
1826     IRBuilder<> Builder(MatMul);
1827     Check0->getTerminator()->eraseFromParent();
1828     Builder.SetInsertPoint(Check0);
1829     Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
1830     Value *StoreBegin = Builder.CreatePtrToInt(
1831         const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1832     Value *StoreEnd = Builder.CreateAdd(
1833         StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1834         "store.end", true, true);
1835     Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
1836                                               IntPtrTy, "load.begin");
1837     Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
1838                          Fusion);
1839 
1840     // Check if the store begins before the end of the load location. If the
1841     // condition holds, they alias, otherwise they are guaranteed to not
1842     // overlap.
1843     Check1->getTerminator()->eraseFromParent();
1844     Builder.SetInsertPoint(Check1, Check1->begin());
1845     Value *LoadEnd = Builder.CreateAdd(
1846         LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1847         "load.end", true, true);
1848     Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
1849                          Fusion);
1850 
1851     // Copy load operand to new alloca.
1852     Builder.SetInsertPoint(Copy, Copy->begin());
1853     auto *VT = cast<FixedVectorType>(Load->getType());
1854     // Use an array type for the alloca, to avoid potentially huge alignment
1855     // requirements for large vector types.
1856     auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1857     AllocaInst *Alloca =
1858         Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1859 
1860     Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
1861                          Load->getAlign(), LoadLoc.Size.getValue());
1862     Builder.SetInsertPoint(Fusion, Fusion->begin());
1863     PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
1864     PHI->addIncoming(Load->getPointerOperand(), Check0);
1865     PHI->addIncoming(Load->getPointerOperand(), Check1);
1866     PHI->addIncoming(Alloca, Copy);
1867 
1868     // Adjust DT.
1869     DTUpdates.push_back({DT->Insert, Check0, Check1});
1870     DTUpdates.push_back({DT->Insert, Check0, Fusion});
1871     DTUpdates.push_back({DT->Insert, Check1, Copy});
1872     DTUpdates.push_back({DT->Insert, Check1, Fusion});
1873     DT->applyUpdates(DTUpdates);
1874     return PHI;
1875   }
1876 
1877   bool isFusionProfitable(CallInst *MatMul) {
1878     if (ForceFusion)
1879       return true;
1880 
1881     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1882     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1883 
1884     const unsigned R = LShape.NumRows;
1885     const unsigned C = RShape.NumColumns;
1886     const unsigned M = LShape.NumColumns;
1887     auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1888 
1889     const unsigned VF = std::max<unsigned>(
1890         TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
1891                 .getFixedValue() /
1892             EltType->getPrimitiveSizeInBits().getFixedValue(),
1893         1U);
1894 
1895     // Cost model for tiling
1896     //
1897     // For tiling to be beneficial, we need reuse either along the R or
1898     // the C axis.  We vectorize along the R axis so that means at least
1899     // 3 elements.
1900     // TODO: Also consider cost of copying if operands alias.
1901     if (R <= VF && C == 1)
1902       return false;
1903     // Then we need enough elements to exceed the number of vector
1904     // registers we have.  Note that this is an oversimplification since
1905     // fusing also takes some extra loads which may exceed the number of
1906     // reloads necessary.
1907     unsigned Op0Regs = (R + VF - 1) / VF * M;
1908     unsigned Op1Regs = (M + VF - 1) / VF * C;
1909     return Op0Regs + Op1Regs >
1910            TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));
1911   }
1912 
1913   MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1914     MatrixTy Res;
1915     auto *ColumType = FixedVectorType::get(EltType, R);
1916     for (unsigned I = 0; I < C; ++I)
1917       Res.addVector(ConstantAggregateZero::get(ColumType));
1918     return Res;
1919   }
1920 
1921   void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1922                         Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1923     auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1924 
1925     // Create the main tiling loop nest.
1926     TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1927     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1928     Instruction *InsertI = cast<Instruction>(MatMul);
1929     BasicBlock *Start = InsertI->getParent();
1930     BasicBlock *End =
1931         SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
1932     IRBuilder<> Builder(MatMul);
1933     BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1934 
1935     Type *TileVecTy =
1936         FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
1937     MatrixTy TileResult;
1938     // Insert in the inner loop header.
1939     Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1940     // Create PHI nodes for the result columns to accumulate across iterations.
1941     SmallVector<PHINode *, 4> ColumnPhis;
1942     for (unsigned I = 0; I < TileSize; I++) {
1943       auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1944       Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
1945                        TI.RowLoop.Header->getSingleSuccessor());
1946       TileResult.addVector(Phi);
1947       ColumnPhis.push_back(Phi);
1948     }
1949 
1950     // Insert in the inner loop body, which computes
1951     //   Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1952     Builder.SetInsertPoint(InnerBody->getTerminator());
1953     // Load tiles of the operands.
1954     MatrixTy A =
1955         loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1956                    {TileSize, TileSize}, EltType, Builder);
1957     MatrixTy B =
1958         loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1959                    {TileSize, TileSize}, EltType, Builder);
1960     emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1961                        getFastMathFlags(MatMul));
1962     // Store result after the inner loop is done.
1963     Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1964     storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1965                 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1966                 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1967 
1968     for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1969       ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1970 
1971     // Force unrolling of a few iterations of the inner loop, to make sure there
1972     // is enough work per iteration.
1973     // FIXME: The unroller should make this decision directly instead, but
1974     // currently the cost-model is not up to the task.
1975     unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1976     addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
1977                             "llvm.loop.unroll.count", InnerLoopUnrollCount);
1978   }
1979 
1980   void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1981                       StoreInst *Store,
1982                       SmallPtrSetImpl<Instruction *> &FusedInsts) {
1983     assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1984            "Tiling only supported for column-major matrixes at the moment!");
1985     if (!isFusionProfitable(MatMul))
1986       return;
1987 
1988     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
1989     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
1990 
1991     const unsigned R = LShape.NumRows;
1992     const unsigned C = RShape.NumColumns;
1993     const unsigned M = LShape.NumColumns;
1994     auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
1995 
1996     Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1997     Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1998     Value *CPtr = Store->getPointerOperand();
1999 
2000     if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
2001       createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
2002     else {
2003       IRBuilder<> Builder(Store);
2004       for (unsigned J = 0; J < C; J += TileSize)
2005         for (unsigned I = 0; I < R; I += TileSize) {
2006           const unsigned TileR = std::min(R - I, unsigned(TileSize));
2007           const unsigned TileC = std::min(C - J, unsigned(TileSize));
2008           MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
2009 
2010           for (unsigned K = 0; K < M; K += TileSize) {
2011             const unsigned TileM = std::min(M - K, unsigned(TileSize));
2012             MatrixTy A =
2013                 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
2014                            LShape, Builder.getInt64(I), Builder.getInt64(K),
2015                            {TileR, TileM}, EltType, Builder);
2016             MatrixTy B =
2017                 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
2018                            RShape, Builder.getInt64(K), Builder.getInt64(J),
2019                            {TileM, TileC}, EltType, Builder);
2020             emitMatrixMultiply(Res, A, B, Builder, true, false,
2021                                getFastMathFlags(MatMul));
2022           }
2023           storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
2024                       Builder.getInt64(I), Builder.getInt64(J), EltType,
2025                       Builder);
2026         }
2027     }
2028 
2029     // Mark eliminated instructions as fused and remove them.
2030     FusedInsts.insert(Store);
2031     FusedInsts.insert(MatMul);
2032     eraseFromParentAndRemoveFromShapeMap(Store);
2033     eraseFromParentAndRemoveFromShapeMap(MatMul);
2034     if (LoadOp0->use_empty()) {
2035       FusedInsts.insert(LoadOp0);
2036       eraseFromParentAndRemoveFromShapeMap(LoadOp0);
2037     }
2038     if (LoadOp1 != LoadOp0 && LoadOp1->use_empty()) {
2039       FusedInsts.insert(LoadOp1);
2040       eraseFromParentAndRemoveFromShapeMap(LoadOp1);
2041     }
2042   }
2043 
2044   /// Try to lower matrix multiply chains by fusing operations.
2045   ///
2046   /// Call finalizeLowering on lowered instructions.  Instructions that are
2047   /// completely eliminated by fusion are added to \p FusedInsts.
2048   void
2049   LowerMatrixMultiplyFused(CallInst *MatMul,
2050                            SmallPtrSetImpl<Instruction *> &FusedInsts,
2051                            SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
2052     if (!FuseMatrix || !DT)
2053       return;
2054 
2055     assert(AA && LI && "Analyses should be available");
2056 
2057     Value *A = MatMul->getArgOperand(0);
2058     Value *B = MatMul->getArgOperand(1);
2059 
2060     // We can fold the transpose into the operand that is used to fetch scalars.
2061     Value *T;
2062     if (MatrixLayout == MatrixLayoutTy::ColumnMajor
2063             ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
2064             : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
2065       IRBuilder<> Builder(MatMul);
2066       auto *EltType =
2067           cast<FixedVectorType>(MatMul->getType())->getElementType();
2068       ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2069       ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2070       const unsigned R = LShape.NumRows;
2071       const unsigned M = LShape.NumColumns;
2072       const unsigned C = RShape.NumColumns;
2073 
2074       MatrixTy MA;
2075       MatrixTy MB;
2076 
2077       Value *Transpose;
2078       if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
2079         MA = getMatrix(A, ShapeInfo(R, M), Builder);
2080         MB = getMatrix(T, ShapeInfo(C, M), Builder);
2081         Transpose = B;
2082       } else {
2083         MA = getMatrix(T, ShapeInfo(R, M), Builder);
2084         MB = getMatrix(B, ShapeInfo(C, M), Builder);
2085         Transpose = A;
2086       }
2087 
2088       // Initialize the output
2089       MatrixTy Result(R, C, EltType);
2090 
2091       emitMatrixMultiply(Result, MA, MB, Builder, false, true,
2092                          getFastMathFlags(MatMul));
2093 
2094       FusedInsts.insert(MatMul);
2095       if (Transpose->hasOneUse()) {
2096         FusedInsts.insert(cast<Instruction>(Transpose));
2097         ToRemove.push_back(cast<Instruction>(Transpose));
2098         // TODO: add a fake entry for the folded instruction so that this is
2099         // included in the expression in the remark.
2100         Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
2101       }
2102       finalizeLowering(MatMul, Result, Builder);
2103       return;
2104     }
2105 
2106     if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
2107       return;
2108 
2109     // Lower {ld, ld} -> matmul -> st chains.  No need to call finalizeLowering
2110     // since the single store user will be lowered as part of this.
2111     auto *LoadOp0 = dyn_cast<LoadInst>(A);
2112     auto *LoadOp1 = dyn_cast<LoadInst>(B);
2113     auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
2114     if (LoadOp0 && LoadOp1 && Store) {
2115       // The store address must dominate the MatMul instruction, otherwise
2116       // we create invalid IR.
2117       SetVector<Value *> WorkList;
2118       WorkList.insert(Store->getOperand(1));
2119       SmallVector<Instruction *> ToHoist;
2120       for (unsigned I = 0; I != WorkList.size(); ++I) {
2121         Value *Current = WorkList[I];
2122         auto *CurrI = dyn_cast<Instruction>(Current);
2123         if (!CurrI)
2124           continue;
2125         if (isa<PHINode>(CurrI))
2126           return;
2127         if (DT->dominates(CurrI, MatMul))
2128           continue;
2129         if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
2130           return;
2131         ToHoist.push_back(CurrI);
2132         WorkList.insert_range(CurrI->operands());
2133       }
2134 
2135       sort(ToHoist, [this](Instruction *A, Instruction *B) {
2136         return DT->dominates(A, B);
2137       });
2138       for (Instruction *I : ToHoist)
2139         I->moveBefore(MatMul->getIterator());
2140 
2141       // Deal with lifetime.end calls that might be between Load0/Load1 and the
2142       // store. To avoid introducing loads to dead objects (i.e. after the
2143       // lifetime has been termined by @llvm.lifetime.end), either sink them
2144       // after the store if in the same block, or remove the lifetime.end marker
2145       // otherwise. This might pessimize further optimizations, by extending the
2146       // lifetime of the object until the function returns, but should be
2147       // conservatively correct.
2148       MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
2149       MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
2150       BasicBlock *StoreParent = Store->getParent();
2151       bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
2152                                    LoadOp1->getParent() == StoreParent;
2153       for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
2154         IntrinsicInst *End = LifetimeEnds[Idx];
2155         auto Inc = make_scope_exit([&Idx]() { Idx++; });
2156         // If the lifetime.end is guaranteed to be before the loads or after the
2157         // store, it won't interfere with fusion.
2158         if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
2159           continue;
2160         if (DT->dominates(Store, End))
2161           continue;
2162         // If all fusable ops are in the same block and the lifetime.end is in a
2163         // different block, it won't interfere with fusion.
2164         if (FusableOpsInSameBlock && End->getParent() != StoreParent)
2165           continue;
2166 
2167         // If the loads don't alias the lifetime.end, it won't interfere with
2168         // fusion.
2169         MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
2170         if (!EndLoc.Ptr)
2171           continue;
2172         if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
2173           continue;
2174 
2175         // If both lifetime.end and the store are in the same block, extend the
2176         // lifetime until after the store, so the new lifetime covers the loads
2177         // we introduce later.
2178         if (End->getParent() == StoreParent) {
2179           End->moveAfter(Store);
2180           continue;
2181         }
2182 
2183         // Otherwise remove the conflicting lifetime.end marker.
2184         ToRemove.push_back(End);
2185         std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
2186         LifetimeEnds.pop_back();
2187         Inc.release();
2188       }
2189 
2190       emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2191       return;
2192     }
2193   }
2194 
2195   /// Lowers llvm.matrix.multiply.
2196   MatrixTy LowerMultiply(CallInst *MatMul, IRBuilder<> &Builder) {
2197     auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
2198     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
2199     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
2200 
2201     const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2202     const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2203     assert(Lhs.getElementType() == Rhs.getElementType() &&
2204            "Matrix multiply argument element types do not match.");
2205 
2206     const unsigned R = LShape.NumRows;
2207     const unsigned C = RShape.NumColumns;
2208     assert(LShape.NumColumns == RShape.NumRows);
2209 
2210     // Initialize the output
2211     MatrixTy Result(R, C, EltType);
2212     assert(Lhs.getElementType() == Result.getElementType() &&
2213            "Matrix multiply result element type does not match arguments.");
2214 
2215     emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
2216                        getFastMathFlags(MatMul));
2217     return Result;
2218   }
2219 
2220   /// Lowers llvm.matrix.transpose.
2221   MatrixTy LowerTranspose(CallInst *Inst, IRBuilder<> &Builder) {
2222     MatrixTy Result;
2223     Value *InputVal = Inst->getArgOperand(0);
2224     FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
2225     ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
2226     MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2227 
2228     const unsigned NewNumVecs =
2229         InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2230     const unsigned NewNumElts =
2231         InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2232 
2233     for (unsigned I = 0; I < NewNumVecs; ++I) {
2234       // Build a single result vector. First initialize it.
2235       Value *ResultVector = PoisonValue::get(
2236           FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
2237       // Go through the old elements and insert it into the resulting vector.
2238       for (auto J : enumerate(InputMatrix.vectors())) {
2239         Value *Elt = Builder.CreateExtractElement(J.value(), I);
2240         // Row and column indices are transposed.
2241         ResultVector =
2242             Builder.CreateInsertElement(ResultVector, Elt, J.index());
2243       }
2244       Result.addVector(ResultVector);
2245     }
2246 
2247     // TODO: Improve estimate of operations needed for transposes. Currently we
2248     // just count the insertelement/extractelement instructions, but do not
2249     // account for later simplifications/combines.
2250     return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2251         .addNumExposedTransposes(1);
2252   }
2253 
2254   /// Lower load instructions.
2255   MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2256                      IRBuilder<> &Builder) {
2257     return LowerLoad(Inst, Ptr, Inst->getAlign(),
2258                      Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
2259                      Builder);
2260   }
2261 
2262   MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2263                       Value *Ptr, IRBuilder<> &Builder) {
2264     return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2265                       Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
2266                       Builder);
2267   }
2268 
2269   MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
2270     auto BlockIP = Inst->getParent()->getFirstInsertionPt();
2271     Builder.SetInsertPoint(BlockIP);
2272     MatrixTy PhiM = getMatrix(Inst, SI, Builder);
2273 
2274     for (auto [IncomingV, IncomingB] :
2275          llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) {
2276       // getMatrix() may insert some instructions to help with reshaping. The
2277       // safest place for those is at the top of the block after the rest of the
2278       // PHI's. Even better, if we can put it in the incoming block.
2279       Builder.SetInsertPoint(BlockIP);
2280       if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
2281         if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef())
2282           Builder.SetInsertPoint(*MaybeIP);
2283 
2284       MatrixTy OpM = getMatrix(IncomingV, SI, Builder);
2285 
2286       for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
2287         PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
2288         NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
2289       }
2290     }
2291 
2292     // finalizeLowering() may also insert instructions in some cases. The safe
2293     // place for those is at the end of the initial block of PHIs.
2294     Builder.SetInsertPoint(BlockIP);
2295     return PhiM;
2296   }
2297 
2298   /// Lower binary operators.
2299   MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
2300                                IRBuilder<> &Builder) {
2301     Value *Lhs = Inst->getOperand(0);
2302     Value *Rhs = Inst->getOperand(1);
2303 
2304     MatrixTy Result;
2305     MatrixTy A = getMatrix(Lhs, SI, Builder);
2306     MatrixTy B = getMatrix(Rhs, SI, Builder);
2307     assert(A.isColumnMajor() == B.isColumnMajor() &&
2308            Result.isColumnMajor() == A.isColumnMajor() &&
2309            "operands must agree on matrix layout");
2310 
2311     Builder.setFastMathFlags(getFastMathFlags(Inst));
2312 
2313     for (auto [AV, BV] : llvm::zip_equal(A.vectors(), B.vectors()))
2314       Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), AV, BV));
2315 
2316     return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2317                                    Result.getNumVectors());
2318   }
2319 
2320   /// Lower unary operators.
2321   MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,
2322                               IRBuilder<> &Builder) {
2323     Value *Op = Inst->getOperand(0);
2324 
2325     MatrixTy Result;
2326     MatrixTy M = getMatrix(Op, SI, Builder);
2327 
2328     Builder.setFastMathFlags(getFastMathFlags(Inst));
2329 
2330     // Helper to perform unary op on vectors.
2331     auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2332       switch (Inst->getOpcode()) {
2333       case Instruction::FNeg:
2334         return Builder.CreateFNeg(Op);
2335       default:
2336         llvm_unreachable("Unsupported unary operator for matrix");
2337       }
2338     };
2339 
2340     for (auto *Vector : M.vectors())
2341       Result.addVector(BuildVectorOp(Vector));
2342 
2343     return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2344                                    Result.getNumVectors());
2345   }
2346 
2347   /// Lower cast instructions.
2348   MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,
2349                                 IRBuilder<> &Builder) {
2350     Value *Op = Inst->getOperand(0);
2351 
2352     MatrixTy Result;
2353     MatrixTy M = getMatrix(Op, Shape, Builder);
2354 
2355     Builder.setFastMathFlags(getFastMathFlags(Inst));
2356 
2357     auto *OrigVTy = cast<VectorType>(Inst->getType());
2358     auto *NewVTy = VectorType::get(OrigVTy->getElementType(),
2359                                    ElementCount::getFixed(M.getStride()));
2360 
2361     for (auto *Vector : M.vectors())
2362       Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy));
2363 
2364     return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2365                                    Result.getNumVectors());
2366   }
2367 
2368   /// Lower selects.
2369   MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape,
2370                            IRBuilder<> &Builder) {
2371     Value *Cond = Inst->getOperand(0);
2372     Value *OpA = Inst->getOperand(1);
2373     Value *OpB = Inst->getOperand(2);
2374 
2375     MatrixTy Result;
2376     MatrixTy A = getMatrix(OpA, Shape, Builder);
2377     MatrixTy B = getMatrix(OpB, Shape, Builder);
2378 
2379     SmallVector<Value*> CondV;
2380     if (isa<FixedVectorType>(Cond->getType())) {
2381       MatrixTy C = getMatrix(Cond, Shape, Builder);
2382       llvm::copy(C.vectors(), std::back_inserter(CondV));
2383     } else {
2384       CondV.resize(A.getNumVectors());
2385       llvm::fill(CondV, Cond);
2386     }
2387 
2388     for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors()))
2389       Result.addVector(Builder.CreateSelect(CV, AV, BV));
2390 
2391     return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2392                                    Result.getNumVectors());
2393   }
2394 
2395   /// Helper to linearize a matrix expression tree into a string. Currently
2396   /// matrix expressions are linarized by starting at an expression leaf and
2397   /// linearizing bottom up.
2398   struct ExprLinearizer {
2399     unsigned LengthToBreak = 100;
2400     std::string Str;
2401     raw_string_ostream Stream;
2402     unsigned LineLength = 0;
2403     const DataLayout &DL;
2404 
2405     /// Mapping from instructions to matrixes. It is used to identify
2406     /// matrix instructions.
2407     const MapVector<Value *, MatrixTy> &Inst2Matrix;
2408 
2409     /// Mapping from values to the leaves of all expressions that the value is
2410     /// part of.
2411     const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
2412 
2413     /// Set of matrix expressions in the scope of a given DISubprogram.
2414     const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2415 
2416     /// Leaf node of the expression to linearize.
2417     Value *Leaf;
2418 
2419     /// Used to keep track of sub-expressions that get reused while linearizing
2420     /// the expression. Re-used sub-expressions are marked as (reused).
2421     SmallPtrSet<Value *, 8> ReusedExprs;
2422 
2423     ExprLinearizer(const DataLayout &DL,
2424                    const MapVector<Value *, MatrixTy> &Inst2Matrix,
2425                    const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2426                    const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2427                    Value *Leaf)
2428         : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2429           ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2430 
2431     void indent(unsigned N) {
2432       LineLength += N;
2433       for (unsigned i = 0; i < N; i++)
2434         Stream << " ";
2435     }
2436 
2437     void lineBreak() {
2438       Stream << "\n";
2439       LineLength = 0;
2440     }
2441 
2442     void maybeIndent(unsigned Indent) {
2443       if (LineLength >= LengthToBreak)
2444         lineBreak();
2445 
2446       if (LineLength == 0)
2447         indent(Indent);
2448     }
2449 
2450     void write(StringRef S) {
2451       LineLength += S.size();
2452       Stream << S;
2453     }
2454 
2455     Value *getUnderlyingObjectThroughLoads(Value *V) {
2456       if (Value *Ptr = getPointerOperand(V))
2457         return getUnderlyingObjectThroughLoads(Ptr);
2458       else if (V->getType()->isPointerTy())
2459         return getUnderlyingObject(V);
2460       return V;
2461     }
2462 
2463     /// Returns true if \p V is a matrix value in the given subprogram.
2464     bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2465 
2466     /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2467     /// \p SS.
2468     void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2469       auto M = Inst2Matrix.find(V);
2470       if (M == Inst2Matrix.end())
2471         SS << "unknown";
2472       else {
2473         SS << M->second.getNumRows();
2474         SS << "x";
2475         SS << M->second.getNumColumns();
2476       }
2477     }
2478 
2479     /// Write the called function name. Handles calls to llvm.matrix.*
2480     /// specially: we write the name, followed by the dimensions of the input
2481     /// matrixes, followed by the scalar type name.
2482     void writeFnName(CallInst *CI) {
2483       if (!CI->getCalledFunction())
2484         write("<no called fn>");
2485       else {
2486         StringRef Name = CI->getCalledFunction()->getName();
2487         if (!Name.starts_with("llvm.matrix")) {
2488           write(Name);
2489           return;
2490         }
2491         auto *II = cast<IntrinsicInst>(CI);
2492         write(Intrinsic::getBaseName(II->getIntrinsicID())
2493                   .drop_front(StringRef("llvm.matrix.").size()));
2494         write(".");
2495         std::string Tmp;
2496         raw_string_ostream SS(Tmp);
2497 
2498         switch (II->getIntrinsicID()) {
2499         case Intrinsic::matrix_multiply:
2500           prettyPrintMatrixType(II->getOperand(0), SS);
2501           SS << ".";
2502           prettyPrintMatrixType(II->getOperand(1), SS);
2503           SS << "." << *II->getType()->getScalarType();
2504           break;
2505         case Intrinsic::matrix_transpose:
2506           prettyPrintMatrixType(II->getOperand(0), SS);
2507           SS << "." << *II->getType()->getScalarType();
2508           break;
2509         case Intrinsic::matrix_column_major_load:
2510           prettyPrintMatrixType(II, SS);
2511           SS << "." << *II->getType()->getScalarType();
2512           break;
2513         case Intrinsic::matrix_column_major_store:
2514           prettyPrintMatrixType(II->getOperand(0), SS);
2515           SS << "." << *II->getOperand(0)->getType()->getScalarType();
2516           break;
2517         default:
2518           llvm_unreachable("Unhandled case");
2519         }
2520         write(Tmp);
2521       }
2522     }
2523 
2524     unsigned getNumShapeArgs(CallInst *CI) const {
2525       if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
2526         switch (II->getIntrinsicID()) {
2527         case Intrinsic::matrix_multiply:
2528           return 3;
2529         case Intrinsic::matrix_transpose:
2530           return 2;
2531         case Intrinsic::matrix_column_major_load:
2532         case Intrinsic::matrix_column_major_store:
2533           return 3;
2534         default:
2535           return 0;
2536         }
2537       }
2538       return 0;
2539     }
2540 
2541     /// Special printing for values: for pointers, we print if they refer to an
2542     /// (function) external address or a stack address, for other values we
2543     /// either print the constant or "scalar"/"matrix" for other values.
2544     void write(Value *V) {
2545       V = getUnderlyingObjectThroughLoads(V);
2546       if (V->getType()->isPointerTy()) {
2547         if (isa<AllocaInst>(V)) {
2548           Stream << "stack addr";
2549           LineLength += StringRef("stack addr").size();
2550         } else {
2551           Stream << "addr";
2552           LineLength += StringRef("addr").size();
2553         }
2554         if (!V->getName().empty()) {
2555           Stream << " %" << V->getName() << "";
2556           LineLength += V->getName().size() + 2;
2557         }
2558         return;
2559       }
2560 
2561       std::string Tmp;
2562       raw_string_ostream TmpStream(Tmp);
2563 
2564       if (auto *CI = dyn_cast<ConstantInt>(V))
2565         TmpStream << CI->getValue();
2566       else if (isa<Constant>(V))
2567         TmpStream << "constant";
2568       else {
2569         if (isMatrix(V))
2570           TmpStream << "matrix";
2571         else
2572           TmpStream << "scalar";
2573       }
2574       Tmp = std::string(StringRef(Tmp).trim());
2575       LineLength += Tmp.size();
2576       Stream << Tmp;
2577     }
2578 
2579     /// Linearize expression \p Expr starting at an indentation of \p Indent.
2580     /// Expressions that are re-used multiple times are prefixed with (reused)
2581     /// at the re-used root instruction.
2582     void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2583                        bool ParentShared) {
2584       auto *I = cast<Instruction>(Expr);
2585       maybeIndent(Indent);
2586       SmallVector<Value *, 8> Ops;
2587 
2588       // Is Expr shared with other expression leaves?
2589       bool ExprShared = false;
2590 
2591       // Deal with shared subtrees. Mark them as shared, if required.
2592       if (!ParentShared) {
2593         auto SI = Shared.find(Expr);
2594         assert(SI != Shared.end() && SI->second.count(Leaf));
2595 
2596         for (Value *S : SI->second) {
2597           if (S == Leaf)
2598             continue;
2599           DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
2600           write("shared with remark at line " + std::to_string(DL.getLine()) +
2601                 " column " + std::to_string(DL.getCol()) + " (");
2602         }
2603         ExprShared = SI->second.size() > 1;
2604       }
2605 
2606       bool Reused = !ReusedExprs.insert(Expr).second;
2607       if (Reused && !ParentReused)
2608         write("(reused) ");
2609 
2610       if (auto *CI = dyn_cast<CallInst>(I)) {
2611         writeFnName(CI);
2612 
2613         Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
2614       } else if (isa<BitCastInst>(Expr)) {
2615         // Special case bitcasts, which are used to materialize matrixes from
2616         // non-matrix ops.
2617         write("matrix");
2618         return;
2619       } else {
2620         Ops.append(I->value_op_begin(), I->value_op_end());
2621         write(I->getOpcodeName());
2622       }
2623 
2624       write("(");
2625 
2626       unsigned NumOpsToBreak = 1;
2627       if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2628         NumOpsToBreak = 2;
2629 
2630       for (Value *Op : Ops) {
2631         if (Ops.size() > NumOpsToBreak)
2632           lineBreak();
2633 
2634         maybeIndent(Indent + 1);
2635         if (isMatrix(Op))
2636           linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2637         else
2638           write(Op);
2639         if (Op != Ops.back())
2640           write(", ");
2641       }
2642 
2643       write(")");
2644     }
2645 
2646     const std::string &getResult() {
2647       return Str;
2648     }
2649   };
2650 
2651   /// Generate remarks for matrix operations in a function. To generate remarks
2652   /// for matrix expressions, the following approach is used:
2653   /// 1. Use the inlined-at debug information to group matrix operations to the
2654   ///    DISubprograms they are contained in.
2655   /// 2. Collect leaves of matrix expressions (done in
2656   ///    RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2657   //     mapping.  Leaves are lowered matrix instructions without other matrix
2658   //     users (like stores) in the current subprogram.
2659   /// 3. For each leaf, create a remark containing a linearizied version of the
2660   ///    matrix expression. The expression is linearized by a recursive
2661   ///    bottom-up traversal of the matrix operands, starting at a leaf. Note
2662   ///    that multiple leaves can share sub-expressions. Shared subexpressions
2663   ///    are explicitly marked as shared().
2664   struct RemarkGenerator {
2665     const MapVector<Value *, MatrixTy> &Inst2Matrix;
2666     OptimizationRemarkEmitter &ORE;
2667     Function &Func;
2668     const DataLayout &DL;
2669 
2670     RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2671                     OptimizationRemarkEmitter &ORE, Function &Func)
2672         : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2673           DL(Func.getDataLayout()) {}
2674 
2675     /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2676     /// instructions in Inst2Matrix returning void or without any users in
2677     /// \p ExprsInSubprogram. Currently that should only include stores.
2678     SmallVector<Value *, 4>
2679     getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2680       SmallVector<Value *, 4> Leaves;
2681       for (auto *Expr : ExprsInSubprogram)
2682         if (Expr->getType()->isVoidTy() ||
2683             !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
2684               return ExprsInSubprogram.count(U);
2685             }))
2686           Leaves.push_back(Expr);
2687       return Leaves;
2688     }
2689 
2690     /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2691     /// to all visited expressions in \p Shared. Limit the matrix operations to
2692     /// the ones in \p ExprsInSubprogram.
2693     void collectSharedInfo(Value *Leaf, Value *V,
2694                            const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2695                            DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2696 
2697       if (!ExprsInSubprogram.count(V))
2698         return;
2699 
2700       Shared[V].insert(Leaf);
2701 
2702       for (Value *Op : cast<Instruction>(V)->operand_values())
2703         collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2704     }
2705 
2706     /// Calculate the number of exclusive and shared op counts for expression
2707     /// starting at \p V. Expressions used multiple times are counted once.
2708     /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2709     std::pair<OpInfoTy, OpInfoTy>
2710     sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2711                const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2712                DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2713       if (!ExprsInSubprogram.count(Root))
2714         return {};
2715 
2716       // Already counted this expression. Stop.
2717       if (!ReusedExprs.insert(Root).second)
2718         return {};
2719 
2720       OpInfoTy SharedCount;
2721       OpInfoTy Count;
2722 
2723       auto I = Shared.find(Root);
2724       auto CM = Inst2Matrix.find(Root);
2725       if (I->second.size() == 1)
2726         Count = CM->second.getOpInfo();
2727       else
2728         SharedCount = CM->second.getOpInfo();
2729 
2730       for (Value *Op : cast<Instruction>(Root)->operand_values()) {
2731         auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2732         Count += C.first;
2733         SharedCount += C.second;
2734       }
2735       return {Count, SharedCount};
2736     }
2737 
2738     void emitRemarks() {
2739       if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2740         return;
2741 
2742       // Map matrix operations to their containting subprograms, by traversing
2743       // the inlinedAt chain. If the function does not have a DISubprogram, we
2744       // only map them to the containing function.
2745       MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
2746       for (const auto &KV : Inst2Matrix) {
2747         if (Func.getSubprogram()) {
2748           auto *I = cast<Instruction>(KV.first);
2749           DILocation *Context = I->getDebugLoc();
2750           while (Context) {
2751             Subprog2Exprs[getSubprogram(Context->getScope())].push_back(
2752                 KV.first);
2753             Context = DebugLoc(Context).getInlinedAt();
2754           }
2755         } else {
2756           Subprog2Exprs[nullptr].push_back(KV.first);
2757         }
2758       }
2759       for (auto &KV : Subprog2Exprs) {
2760         SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2761                                                       KV.second.end());
2762         auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2763 
2764         DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
2765         for (Value *Leaf : Leaves)
2766           collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2767 
2768         // Generate remarks for each leaf.
2769         for (auto *L : Leaves) {
2770 
2771           DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
2772           DILocation *Context = cast<Instruction>(L)->getDebugLoc();
2773           while (Context) {
2774             if (getSubprogram(Context->getScope()) == KV.first) {
2775               Loc = Context;
2776               break;
2777             }
2778             Context = DebugLoc(Context).getInlinedAt();
2779           }
2780 
2781           SmallPtrSet<Value *, 8> ReusedExprs;
2782           OpInfoTy Counts, SharedCounts;
2783           std::tie(Counts, SharedCounts) =
2784               sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2785 
2786           OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2787                                  cast<Instruction>(L)->getParent());
2788 
2789           Rem << "Lowered with ";
2790           Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2791               << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2792               << ore::NV("NumComputeOps", Counts.NumComputeOps)
2793               << " compute ops, "
2794               << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2795               << " exposed transposes";
2796 
2797           if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2798               SharedCounts.NumComputeOps > 0) {
2799             Rem << ",\nadditionally "
2800                 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2801                 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2802                 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2803                 << " compute ops"
2804                 << " are shared with other expressions";
2805           }
2806 
2807           Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2808           ORE.emit(Rem);
2809         }
2810       }
2811     }
2812 
2813     std::string
2814     linearize(Value *L,
2815               const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2816               const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2817               const DataLayout &DL) {
2818       ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2819       Lin.linearizeExpr(L, 0, false, false);
2820       return Lin.getResult();
2821     }
2822   };
2823 };
2824 } // namespace
2825 
2826 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
2827                                                  FunctionAnalysisManager &AM) {
2828   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
2829 
2830   LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);
2831   if (LMT.Visit()) {
2832     PreservedAnalyses PA;
2833     if (!Minimal) {
2834       PA.preserve<LoopAnalysis>();
2835       PA.preserve<DominatorTreeAnalysis>();
2836     }
2837     return PA;
2838   }
2839   return PreservedAnalyses::all();
2840 }
2841 
2842 void LowerMatrixIntrinsicsPass::printPipeline(
2843     raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2844   static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
2845       OS, MapClassName2PassName);
2846   OS << '<';
2847   if (Minimal)
2848     OS << "minimal";
2849   OS << '>';
2850 }
2851