xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (revision 99282790b7d01ec3c4072621d46a0d7302517ad4)
1 //===- LowerMatrixIntrinsics.cpp -  Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 //  * Implement multiply & add fusion
13 //  * Add remark, summarizing the available matrix optimization opportunities.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
18 #include "llvm/ADT/GraphTraits.h"
19 #include "llvm/ADT/PostOrderIterator.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/PatternMatch.h"
30 #include "llvm/InitializePasses.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Transforms/Scalar.h"
34 
35 using namespace llvm;
36 using namespace PatternMatch;
37 
38 #define DEBUG_TYPE "lower-matrix-intrinsics"
39 
40 static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape",
41                                             cl::init(true));
42 
43 static cl::opt<bool> AllowContractEnabled(
44     "matrix-allow-contract", cl::init(false), cl::Hidden,
45     cl::desc("Allow the use of FMAs if available and profitable. This may "
46              "result in different results, due to less rounding error."));
47 
48 namespace {
49 
50 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
51 // the start address of column \p Col with type (\p EltType x \p NumRows)
52 // assuming \p Stride elements between start two consecutive columns.
53 // \p Stride must be >= \p NumRows.
54 //
55 // Consider a 4x4 matrix like below
56 //
57 //      0       1      2      3
58 // 0   v_0_0  v_0_1  v_0_2  v_0_3
59 // 1   v_1_0  v_1_1  v_1_2  v_1_3
60 // 2   v_2_0  v_2_1  v_2_2  v_2_3
61 // 3   v_3_0  v_3_1  v_3_2  v_3_3
62 
63 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
64 // we need a pointer to the first element of the submatrix as base pointer.
65 // Then we can use computeColumnAddr to compute the addresses for the columns
66 // of the sub-matrix.
67 //
68 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
69 //           -> just returns Base
70 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
71 //           -> returns Base + (1 * 4)
72 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
73 //           -> returns Base + (2 * 4)
74 //
75 // The graphic below illustrates the number of elements in a column (marked
76 // with |) and the number of skipped elements (marked with }).
77 //
78 //         v_0_0  v_0_1 {v_0_2 {v_0_3
79 //                Base   Col 1  Col 2
80 //                  |     |      |
81 //         v_1_0 |v_1_1 |v_1_2 |v_1_3
82 //         v_2_0 |v_2_1 |v_2_2 |v_2_3
83 //         v_3_0 {v_3_1 {v_3_2  v_3_3
84 //
85 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
86                          unsigned NumRows, Type *EltType,
87                          IRBuilder<> &Builder) {
88 
89   assert((!isa<ConstantInt>(Stride) ||
90           cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
91          "Stride must be >= the number of rows.");
92   unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
93 
94   // Compute the start of the column with index Col as Col * Stride.
95   Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");
96 
97   // Get pointer to the start of the selected column. Skip GEP creation,
98   // if we select column 0.
99   if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
100     ColumnStart = BasePtr;
101   else
102     ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");
103 
104   // Cast elementwise column start pointer to a pointer to a column
105   // (EltType x NumRows)*.
106   Type *ColumnType = VectorType::get(EltType, NumRows);
107   Type *ColumnPtrType = PointerType::get(ColumnType, AS);
108   return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
109 }
110 
111 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
112 ///
113 /// Currently, the lowering for each matrix intrinsic is done as follows:
114 /// 1. Propagate the shape information from intrinsics to connected
115 /// instructions.
116 /// 2. Lower instructions with shape information.
117 ///  2.1. Get column vectors for each argument. If we already lowered the
118 ///       definition of an argument, use the produced column vectors directly.
119 ///       If not, split the operand vector containing an embedded matrix into
120 ///       a set of column vectors,
121 ///  2.2. Lower the instruction in terms of columnwise operations, which yields
122 ///       a set of column vectors containing result matrix. Note that we lower
123 ///       all instructions that have shape information. Besides the intrinsics,
124 ///       this includes stores for example.
125 ///  2.3. Update uses of the lowered instruction. If we have shape information
126 ///       for a user, there is nothing to do, as we will look up the result
127 ///       column matrix when lowering the user. For other uses, we embed the
128 ///       result matrix in a flat vector and update the use.
129 ///  2.4. Cache the result column matrix for the instruction we lowered
130 /// 3. After we lowered all instructions in a function, remove the now
131 ///    obsolete instructions.
132 ///
133 class LowerMatrixIntrinsics {
134   Function &Func;
135   const DataLayout &DL;
136   const TargetTransformInfo &TTI;
137 
138   /// Wrapper class representing a matrix as a set of column vectors.
139   /// All column vectors must have the same vector type.
140   class ColumnMatrixTy {
141     SmallVector<Value *, 16> Columns;
142 
143   public:
144     ColumnMatrixTy() : Columns() {}
145     ColumnMatrixTy(ArrayRef<Value *> Cols)
146         : Columns(Cols.begin(), Cols.end()) {}
147 
148     Value *getColumn(unsigned i) const { return Columns[i]; }
149 
150     void setColumn(unsigned i, Value *V) { Columns[i] = V; }
151 
152     size_t getNumColumns() const { return Columns.size(); }
153     size_t getNumRows() const {
154       assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
155       return cast<VectorType>(Columns[0]->getType())->getNumElements();
156     }
157 
158     const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
159 
160     SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
161 
162     void addColumn(Value *V) { Columns.push_back(V); }
163 
164     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
165       return make_range(Columns.begin(), Columns.end());
166     }
167 
168     /// Embed the columns of the matrix into a flat vector by concatenating
169     /// them.
170     Value *embedInVector(IRBuilder<> &Builder) const {
171       return Columns.size() == 1 ? Columns[0]
172                                  : concatenateVectors(Builder, Columns);
173     }
174   };
175 
176   struct ShapeInfo {
177     unsigned NumRows;
178     unsigned NumColumns;
179 
180     ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
181         : NumRows(NumRows), NumColumns(NumColumns) {}
182 
183     ShapeInfo(Value *NumRows, Value *NumColumns)
184         : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
185           NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
186 
187     bool operator==(const ShapeInfo &other) {
188       return NumRows == other.NumRows && NumColumns == other.NumColumns;
189     }
190     bool operator!=(const ShapeInfo &other) { return !(*this == other); }
191 
192     /// Returns true if shape-information is defined, meaning both dimensions
193     /// are != 0.
194     operator bool() const {
195       assert(NumRows == 0 || NumColumns != 0);
196       return NumRows != 0;
197     }
198   };
199 
200   /// Maps instructions to their shape information. The shape information
201   /// describes the shape to be used while lowering. This matches the shape of
202   /// the result value of the instruction, with the only exceptions being store
203   /// instructions and the matrix_columnwise_store intrinsics. For those, the
204   /// shape information indicates that those instructions should be lowered
205   /// using shape information as well.
206   DenseMap<Value *, ShapeInfo> ShapeMap;
207 
208   /// List of instructions to remove. While lowering, we are not replacing all
209   /// users of a lowered instruction, if shape information is available and
210   /// those need to be removed after we finished lowering.
211   SmallVector<Instruction *, 16> ToRemove;
212 
213   /// Map from instructions to their produced column matrix.
214   DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
215 
216 public:
217   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI)
218       : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {}
219 
220   /// Return the set of column vectors that a matrix value is lowered to.
221   ///
222   /// If we lowered \p MatrixVal, just return the cache result column matrix.
223   /// Otherwie split the flat vector \p MatrixVal containing a matrix with
224   /// shape \p SI into column vectors.
225   ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
226                            IRBuilder<> Builder) {
227     VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
228     assert(VType && "MatrixVal must be a vector type");
229     assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
230            "The vector size must match the number of matrix elements");
231 
232     // Check if we lowered MatrixVal using shape information. In that case,
233     // return the existing column matrix, if it matches the requested shape
234     // information. If there is a mis-match, embed the result in a flat
235     // vector and split it later.
236     auto Found = Inst2ColumnMatrix.find(MatrixVal);
237     if (Found != Inst2ColumnMatrix.end()) {
238       ColumnMatrixTy &M = Found->second;
239       // Return the found matrix, if its shape matches the requested shape
240       // information
241       if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
242         return M;
243 
244       MatrixVal = M.embedInVector(Builder);
245     }
246 
247     // Otherwise split MatrixVal.
248     SmallVector<Value *, 16> SplitVecs;
249     Value *Undef = UndefValue::get(VType);
250     for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
251          MaskStart += SI.NumRows) {
252       Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
253       Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
254       SplitVecs.push_back(V);
255     }
256 
257     return {SplitVecs};
258   }
259 
260   /// If \p V already has a known shape return false.  Otherwise set the shape
261   /// for instructions that support it.
262   bool setShapeInfo(Value *V, ShapeInfo Shape) {
263     assert(Shape && "Shape not set");
264     if (isa<UndefValue>(V) || !supportsShapeInfo(V))
265       return false;
266 
267     auto SIter = ShapeMap.find(V);
268     if (SIter != ShapeMap.end()) {
269       LLVM_DEBUG(dbgs() << "  not overriding existing shape: "
270                         << SIter->second.NumRows << " "
271                         << SIter->second.NumColumns << " for " << *V << "\n");
272       return false;
273     }
274 
275     ShapeMap.insert({V, Shape});
276     LLVM_DEBUG(dbgs() << "  " << Shape.NumRows << " x " << Shape.NumColumns
277                       << " for " << *V << "\n");
278     return true;
279   }
280 
281   bool isUniformShape(Value *V) {
282     Instruction *I = dyn_cast<Instruction>(V);
283     if (!I)
284       return true;
285 
286     switch (I->getOpcode()) {
287     case Instruction::FAdd:
288     case Instruction::FSub:
289     case Instruction::FMul: // Scalar multiply.
290     case Instruction::Add:
291     case Instruction::Mul:
292     case Instruction::Sub:
293       return true;
294     default:
295       return false;
296     }
297   }
298 
299   /// Returns true if shape information can be used for \p V. The supported
300   /// instructions must match the instructions that can be lowered by this pass.
301   bool supportsShapeInfo(Value *V) {
302     Instruction *Inst = dyn_cast<Instruction>(V);
303     if (!Inst)
304       return false;
305 
306     IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
307     if (II)
308       switch (II->getIntrinsicID()) {
309       case Intrinsic::matrix_multiply:
310       case Intrinsic::matrix_transpose:
311       case Intrinsic::matrix_columnwise_load:
312       case Intrinsic::matrix_columnwise_store:
313         return true;
314       default:
315         return false;
316       }
317     return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
318   }
319 
320   /// Propagate the shape information of instructions to their users.
321   /// The work list contains instructions for which we can compute the shape,
322   /// either based on the information provided by matrix intrinsics or known
323   /// shapes of operands.
324   SmallVector<Instruction *, 32>
325   propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
326     SmallVector<Instruction *, 32> NewWorkList;
327     // Pop an element for which we guaranteed to have at least one of the
328     // operand shapes.  Add the shape for this and then add users to the work
329     // list.
330     LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
331     while (!WorkList.empty()) {
332       Instruction *Inst = WorkList.back();
333       WorkList.pop_back();
334 
335       // New entry, set the value and insert operands
336       bool Propagate = false;
337 
338       Value *MatrixA;
339       Value *MatrixB;
340       Value *M;
341       Value *N;
342       Value *K;
343       if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
344                           m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
345                           m_Value(N), m_Value(K)))) {
346         Propagate = setShapeInfo(Inst, {M, K});
347       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
348                                  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
349         // Flip dimensions.
350         Propagate = setShapeInfo(Inst, {N, M});
351       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
352                                  m_Value(MatrixA), m_Value(), m_Value(),
353                                  m_Value(M), m_Value(N)))) {
354         Propagate = setShapeInfo(Inst, {N, M});
355       } else if (match(Inst,
356                        m_Intrinsic<Intrinsic::matrix_columnwise_load>(
357                            m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
358         Propagate = setShapeInfo(Inst, {M, N});
359       } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
360         auto OpShape = ShapeMap.find(MatrixA);
361         if (OpShape != ShapeMap.end())
362           setShapeInfo(Inst, OpShape->second);
363         continue;
364       } else if (isUniformShape(Inst)) {
365         // Find the first operand that has a known shape and use that.
366         for (auto &Op : Inst->operands()) {
367           auto OpShape = ShapeMap.find(Op.get());
368           if (OpShape != ShapeMap.end()) {
369             Propagate |= setShapeInfo(Inst, OpShape->second);
370             break;
371           }
372         }
373       }
374 
375       if (Propagate) {
376         NewWorkList.push_back(Inst);
377         for (auto *User : Inst->users())
378           if (ShapeMap.count(User) == 0)
379             WorkList.push_back(cast<Instruction>(User));
380       }
381     }
382 
383     return NewWorkList;
384   }
385 
386   /// Propagate the shape to operands of instructions with shape information.
387   /// \p Worklist contains the instruction for which we already know the shape.
388   SmallVector<Instruction *, 32>
389   propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
390     SmallVector<Instruction *, 32> NewWorkList;
391 
392     auto pushInstruction = [](Value *V,
393                               SmallVectorImpl<Instruction *> &WorkList) {
394       Instruction *I = dyn_cast<Instruction>(V);
395       if (I)
396         WorkList.push_back(I);
397     };
398     // Pop an element with known shape.  Traverse the operands, if their shape
399     // derives from the result shape and is unknown, add it and add them to the
400     // worklist.
401     LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
402     while (!WorkList.empty()) {
403       Value *V = WorkList.back();
404       WorkList.pop_back();
405 
406       size_t BeforeProcessingV = WorkList.size();
407       if (!isa<Instruction>(V))
408         continue;
409 
410       Value *MatrixA;
411       Value *MatrixB;
412       Value *M;
413       Value *N;
414       Value *K;
415       if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
416                        m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
417                        m_Value(N), m_Value(K)))) {
418         if (setShapeInfo(MatrixA, {M, N}))
419           pushInstruction(MatrixA, WorkList);
420 
421         if (setShapeInfo(MatrixB, {N, K}))
422           pushInstruction(MatrixB, WorkList);
423 
424       } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
425                               m_Value(MatrixA), m_Value(M), m_Value(N)))) {
426         // Flip dimensions.
427         if (setShapeInfo(MatrixA, {M, N}))
428           pushInstruction(MatrixA, WorkList);
429       } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
430                               m_Value(MatrixA), m_Value(), m_Value(),
431                               m_Value(M), m_Value(N)))) {
432         if (setShapeInfo(MatrixA, {M, N})) {
433           pushInstruction(MatrixA, WorkList);
434         }
435       } else if (isa<LoadInst>(V) ||
436                  match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
437         // Nothing to do, no matrix input.
438       } else if (isa<StoreInst>(V)) {
439         // Nothing to do.  We forward-propagated to this so we would just
440         // backward propagate to an instruction with an already known shape.
441       } else if (isUniformShape(V)) {
442         // Propagate to all operands.
443         ShapeInfo Shape = ShapeMap[V];
444         for (Use &U : cast<Instruction>(V)->operands()) {
445           if (setShapeInfo(U.get(), Shape))
446             pushInstruction(U.get(), WorkList);
447         }
448       }
449       // After we discovered new shape info for new instructions in the
450       // worklist, we use their users as seeds for the next round of forward
451       // propagation.
452       for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
453         for (User *U : WorkList[I]->users())
454           if (isa<Instruction>(U) && V != U)
455             NewWorkList.push_back(cast<Instruction>(U));
456     }
457     return NewWorkList;
458   }
459 
460   bool Visit() {
461     if (EnableShapePropagation) {
462       SmallVector<Instruction *, 32> WorkList;
463 
464       // Initially only the shape of matrix intrinsics is known.
465       // Initialize the work list with ops carrying shape information.
466       for (BasicBlock &BB : Func)
467         for (Instruction &Inst : BB) {
468           IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
469           if (!II)
470             continue;
471 
472           switch (II->getIntrinsicID()) {
473           case Intrinsic::matrix_multiply:
474           case Intrinsic::matrix_transpose:
475           case Intrinsic::matrix_columnwise_load:
476           case Intrinsic::matrix_columnwise_store:
477             WorkList.push_back(&Inst);
478             break;
479           default:
480             break;
481           }
482         }
483       // Propagate shapes until nothing changes any longer.
484       while (!WorkList.empty()) {
485         WorkList = propagateShapeForward(WorkList);
486         WorkList = propagateShapeBackward(WorkList);
487       }
488     }
489 
490     ReversePostOrderTraversal<Function *> RPOT(&Func);
491     bool Changed = false;
492     for (auto *BB : RPOT) {
493       for (Instruction &Inst : make_early_inc_range(*BB)) {
494         IRBuilder<> Builder(&Inst);
495 
496         if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
497           Changed |= VisitCallInst(CInst);
498 
499         Value *Op1;
500         Value *Op2;
501         if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
502           Changed |= VisitBinaryOperator(BinOp);
503         if (match(&Inst, m_Load(m_Value(Op1))))
504           Changed |= VisitLoad(&Inst, Op1, Builder);
505         else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
506           Changed |= VisitStore(&Inst, Op1, Op2, Builder);
507       }
508     }
509 
510     for (Instruction *Inst : reverse(ToRemove))
511       Inst->eraseFromParent();
512 
513     return Changed;
514   }
515 
516   LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
517                              IRBuilder<> Builder) {
518     unsigned Align = DL.getABITypeAlignment(EltType);
519     return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load");
520   }
521 
522   StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
523                                Type *EltType, IRBuilder<> Builder) {
524     unsigned Align = DL.getABITypeAlignment(EltType);
525     return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align);
526   }
527 
528 
529   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
530   Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
531     unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
532     Type *EltPtrType = PointerType::get(EltType, AS);
533     return Builder.CreatePointerCast(BasePtr, EltPtrType);
534   }
535 
536   /// Replace intrinsic calls
537   bool VisitCallInst(CallInst *Inst) {
538     if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
539       return false;
540 
541     switch (Inst->getCalledFunction()->getIntrinsicID()) {
542     case Intrinsic::matrix_multiply:
543       LowerMultiply(Inst);
544       break;
545     case Intrinsic::matrix_transpose:
546       LowerTranspose(Inst);
547       break;
548     case Intrinsic::matrix_columnwise_load:
549       LowerColumnwiseLoad(Inst);
550       break;
551     case Intrinsic::matrix_columnwise_store:
552       LowerColumnwiseStore(Inst);
553       break;
554     default:
555       return false;
556     }
557     return true;
558   }
559 
560   void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
561                  ShapeInfo Shape) {
562     IRBuilder<> Builder(Inst);
563     auto VType = cast<VectorType>(Inst->getType());
564     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
565     ColumnMatrixTy Result;
566     // Distance between start of one column and the start of the next
567     for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
568       Value *GEP =
569           computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
570                             VType->getElementType(), Builder);
571       Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
572       Result.addColumn(Column);
573     }
574 
575     finalizeLowering(Inst, Result, Builder);
576   }
577 
578   /// Lowers llvm.matrix.columnwise.load.
579   ///
580   /// The intrinsic loads a matrix from memory using a stride between columns.
581   void LowerColumnwiseLoad(CallInst *Inst) {
582     Value *Ptr = Inst->getArgOperand(0);
583     Value *Stride = Inst->getArgOperand(1);
584     LowerLoad(Inst, Ptr, Stride,
585               {Inst->getArgOperand(2), Inst->getArgOperand(3)});
586   }
587 
588   void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
589                   ShapeInfo Shape) {
590     IRBuilder<> Builder(Inst);
591     auto VType = cast<VectorType>(Matrix->getType());
592     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
593     auto LM = getMatrix(Matrix, Shape, Builder);
594     for (auto C : enumerate(LM.columns())) {
595       Value *GEP =
596           computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
597                             Shape.NumRows, VType->getElementType(), Builder);
598       createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
599     }
600 
601     ToRemove.push_back(Inst);
602   }
603 
604   /// Lowers llvm.matrix.columnwise.store.
605   ///
606   /// The intrinsic store a matrix back memory using a stride between columns.
607   void LowerColumnwiseStore(CallInst *Inst) {
608     Value *Matrix = Inst->getArgOperand(0);
609     Value *Ptr = Inst->getArgOperand(1);
610     Value *Stride = Inst->getArgOperand(2);
611     LowerStore(Inst, Matrix, Ptr, Stride,
612                {Inst->getArgOperand(3), Inst->getArgOperand(4)});
613   }
614 
615   /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
616   /// the matrix \p LM represented as a vector of column vectors.
617   Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
618                        unsigned NumElts, IRBuilder<> Builder) {
619     Value *Col = LM.getColumn(J);
620     Value *Undef = UndefValue::get(Col->getType());
621     Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
622     return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
623   }
624 
625   // Set elements I..I+NumElts-1 to Block
626   Value *insertVector(Value *Col, unsigned I, Value *Block,
627                       IRBuilder<> Builder) {
628 
629     // First, bring Block to the same size as Col
630     unsigned BlockNumElts =
631         cast<VectorType>(Block->getType())->getNumElements();
632     unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
633     assert(NumElts >= BlockNumElts && "Too few elements for current block");
634 
635     Value *ExtendMask =
636         createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
637     Value *Undef = UndefValue::get(Block->getType());
638     Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
639 
640     // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
641     // 8, 4, 5, 6
642     SmallVector<Constant *, 16> Mask;
643     unsigned i;
644     for (i = 0; i < I; i++)
645       Mask.push_back(Builder.getInt32(i));
646 
647     unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
648     for (; i < I + BlockNumElts; i++)
649       Mask.push_back(Builder.getInt32(i - I + VecNumElts));
650 
651     for (; i < VecNumElts; i++)
652       Mask.push_back(Builder.getInt32(i));
653 
654     Value *MaskVal = ConstantVector::get(Mask);
655 
656     return Builder.CreateShuffleVector(Col, Block, MaskVal);
657   }
658 
659   Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
660                       IRBuilder<> &Builder, bool AllowContraction) {
661 
662     if (!Sum)
663       return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
664 
665     if (UseFPOp) {
666       if (AllowContraction) {
667         // Use fmuladd for floating point operations and let the backend decide
668         // if that's profitable.
669         Value *FMulAdd = Intrinsic::getDeclaration(
670             Func.getParent(), Intrinsic::fmuladd, A->getType());
671         return Builder.CreateCall(FMulAdd, {A, B, Sum});
672       }
673       Value *Mul = Builder.CreateFMul(A, B);
674       return Builder.CreateFAdd(Sum, Mul);
675     }
676 
677     Value *Mul = Builder.CreateMul(A, B);
678     return Builder.CreateAdd(Sum, Mul);
679   }
680 
681   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
682   /// users with shape information, there's nothing to do: the will use the
683   /// cached value when they are lowered. For other users, \p Matrix is
684   /// flattened and the uses are updated to use it. Also marks \p Inst for
685   /// deletion.
686   void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
687                         IRBuilder<> &Builder) {
688     Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
689 
690     ToRemove.push_back(Inst);
691     Value *Flattened = nullptr;
692     for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
693       Use &U = *I++;
694       if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
695         if (!Flattened)
696           Flattened = Matrix.embedInVector(Builder);
697         U.set(Flattened);
698       }
699     }
700   }
701 
702   /// Lowers llvm.matrix.multiply.
703   void LowerMultiply(CallInst *MatMul) {
704     IRBuilder<> Builder(MatMul);
705     auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
706     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
707     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
708 
709     const ColumnMatrixTy &Lhs =
710         getMatrix(MatMul->getArgOperand(0), LShape, Builder);
711     const ColumnMatrixTy &Rhs =
712         getMatrix(MatMul->getArgOperand(1), RShape, Builder);
713 
714     const unsigned R = LShape.NumRows;
715     const unsigned M = LShape.NumColumns;
716     const unsigned C = RShape.NumColumns;
717     assert(M == RShape.NumRows);
718 
719     // Initialize the output
720     ColumnMatrixTy Result;
721     for (unsigned J = 0; J < C; ++J)
722       Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
723 
724     const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
725                                      EltType->getPrimitiveSizeInBits(),
726                                  uint64_t(1));
727 
728     bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
729                                                   MatMul->hasAllowContract());
730     // Multiply columns from the first operand with scalars from the second
731     // operand.  Then move along the K axes and accumulate the columns.  With
732     // this the adds can be vectorized without reassociation.
733     for (unsigned J = 0; J < C; ++J) {
734       unsigned BlockSize = VF;
735       for (unsigned I = 0; I < R; I += BlockSize) {
736         // Gradually lower the vectorization factor to cover the remainder.
737         while (I + BlockSize > R)
738           BlockSize /= 2;
739 
740         Value *Sum = nullptr;
741         for (unsigned K = 0; K < M; ++K) {
742           Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
743           Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
744           Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
745           Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
746                              Builder, AllowContract);
747         }
748         Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
749       }
750     }
751     finalizeLowering(MatMul, Result, Builder);
752   }
753 
754   /// Lowers llvm.matrix.transpose.
755   void LowerTranspose(CallInst *Inst) {
756     ColumnMatrixTy Result;
757     IRBuilder<> Builder(Inst);
758     Value *InputVal = Inst->getArgOperand(0);
759     VectorType *VectorTy = cast<VectorType>(InputVal->getType());
760     ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
761     ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
762 
763     for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
764       // Build a single column vector for this row. First initialize it.
765       Value *ResultColumn = UndefValue::get(
766           VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
767 
768       // Go through the elements of this row and insert it into the resulting
769       // column vector.
770       for (auto C : enumerate(InputMatrix.columns())) {
771         Value *Elt = Builder.CreateExtractElement(C.value(), Row);
772         // We insert at index Column since that is the row index after the
773         // transpose.
774         ResultColumn =
775             Builder.CreateInsertElement(ResultColumn, Elt, C.index());
776       }
777       Result.addColumn(ResultColumn);
778     }
779 
780     finalizeLowering(Inst, Result, Builder);
781   }
782 
783   /// Lower load instructions, if shape information is available.
784   bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
785     auto I = ShapeMap.find(Inst);
786     if (I == ShapeMap.end())
787       return false;
788 
789     LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
790     return true;
791   }
792 
793   bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
794                   IRBuilder<> &Builder) {
795     auto I = ShapeMap.find(StoredVal);
796     if (I == ShapeMap.end())
797       return false;
798 
799     LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
800     return true;
801   }
802 
803   /// Lower binary operators, if shape information is available.
804   bool VisitBinaryOperator(BinaryOperator *Inst) {
805     auto I = ShapeMap.find(Inst);
806     if (I == ShapeMap.end())
807       return false;
808 
809     Value *Lhs = Inst->getOperand(0);
810     Value *Rhs = Inst->getOperand(1);
811 
812     IRBuilder<> Builder(Inst);
813     ShapeInfo &Shape = I->second;
814 
815     ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
816     ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
817 
818     // Add each column and store the result back into the opmapping
819     ColumnMatrixTy Result;
820     auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
821       switch (Inst->getOpcode()) {
822       case Instruction::Add:
823         return Builder.CreateAdd(LHS, RHS);
824       case Instruction::Mul:
825         return Builder.CreateMul(LHS, RHS);
826       case Instruction::Sub:
827         return Builder.CreateSub(LHS, RHS);
828       case Instruction::FAdd:
829         return Builder.CreateFAdd(LHS, RHS);
830       case Instruction::FMul:
831         return Builder.CreateFMul(LHS, RHS);
832       case Instruction::FSub:
833         return Builder.CreateFSub(LHS, RHS);
834       default:
835         llvm_unreachable("Unsupported binary operator for matrix");
836       }
837     };
838     for (unsigned C = 0; C < Shape.NumColumns; ++C)
839       Result.addColumn(
840           BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
841 
842     finalizeLowering(Inst, Result, Builder);
843     return true;
844   }
845 };
846 } // namespace
847 
848 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
849                                                  FunctionAnalysisManager &AM) {
850   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
851   LowerMatrixIntrinsics LMT(F, TTI);
852   if (LMT.Visit()) {
853     PreservedAnalyses PA;
854     PA.preserveSet<CFGAnalyses>();
855     return PA;
856   }
857   return PreservedAnalyses::all();
858 }
859 
860 namespace {
861 
862 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
863 public:
864   static char ID;
865 
866   LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
867     initializeLowerMatrixIntrinsicsLegacyPassPass(
868         *PassRegistry::getPassRegistry());
869   }
870 
871   bool runOnFunction(Function &F) override {
872     auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
873     LowerMatrixIntrinsics LMT(F, *TTI);
874     bool C = LMT.Visit();
875     return C;
876   }
877 
878   void getAnalysisUsage(AnalysisUsage &AU) const override {
879     AU.addRequired<TargetTransformInfoWrapperPass>();
880     AU.setPreservesCFG();
881   }
882 };
883 } // namespace
884 
885 static const char pass_name[] = "Lower the matrix intrinsics";
886 char LowerMatrixIntrinsicsLegacyPass::ID = 0;
887 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
888                       false, false)
889 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
890                     false, false)
891 
892 Pass *llvm::createLowerMatrixIntrinsicsPass() {
893   return new LowerMatrixIntrinsicsLegacyPass();
894 }
895