1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Implement multiply & add fusion 13 // * Add remark, summarizing the available matrix optimization opportunities. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 18 #include "llvm/ADT/GraphTraits.h" 19 #include "llvm/ADT/PostOrderIterator.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/Analysis/TargetTransformInfo.h" 22 #include "llvm/Analysis/VectorUtils.h" 23 #include "llvm/IR/CFG.h" 24 #include "llvm/IR/DataLayout.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/IntrinsicInst.h" 29 #include "llvm/IR/PatternMatch.h" 30 #include "llvm/InitializePasses.h" 31 #include "llvm/Pass.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Transforms/Scalar.h" 34 35 using namespace llvm; 36 using namespace PatternMatch; 37 38 #define DEBUG_TYPE "lower-matrix-intrinsics" 39 40 static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape", 41 cl::init(true)); 42 43 static cl::opt<bool> AllowContractEnabled( 44 "matrix-allow-contract", cl::init(false), cl::Hidden, 45 cl::desc("Allow the use of FMAs if available and profitable. This may " 46 "result in different results, due to less rounding error.")); 47 48 namespace { 49 50 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute 51 // the start address of column \p Col with type (\p EltType x \p NumRows) 52 // assuming \p Stride elements between start two consecutive columns. 53 // \p Stride must be >= \p NumRows. 54 // 55 // Consider a 4x4 matrix like below 56 // 57 // 0 1 2 3 58 // 0 v_0_0 v_0_1 v_0_2 v_0_3 59 // 1 v_1_0 v_1_1 v_1_2 v_1_3 60 // 2 v_2_0 v_2_1 v_2_2 v_2_3 61 // 3 v_3_0 v_3_1 v_3_2 v_3_3 62 63 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 64 // we need a pointer to the first element of the submatrix as base pointer. 65 // Then we can use computeColumnAddr to compute the addresses for the columns 66 // of the sub-matrix. 67 // 68 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 69 // -> just returns Base 70 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 71 // -> returns Base + (1 * 4) 72 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 73 // -> returns Base + (2 * 4) 74 // 75 // The graphic below illustrates the number of elements in a column (marked 76 // with |) and the number of skipped elements (marked with }). 77 // 78 // v_0_0 v_0_1 {v_0_2 {v_0_3 79 // Base Col 1 Col 2 80 // | | | 81 // v_1_0 |v_1_1 |v_1_2 |v_1_3 82 // v_2_0 |v_2_1 |v_2_2 |v_2_3 83 // v_3_0 {v_3_1 {v_3_2 v_3_3 84 // 85 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, 86 unsigned NumRows, Type *EltType, 87 IRBuilder<> &Builder) { 88 89 assert((!isa<ConstantInt>(Stride) || 90 cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && 91 "Stride must be >= the number of rows."); 92 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 93 94 // Compute the start of the column with index Col as Col * Stride. 95 Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); 96 97 // Get pointer to the start of the selected column. Skip GEP creation, 98 // if we select column 0. 99 if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) 100 ColumnStart = BasePtr; 101 else 102 ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); 103 104 // Cast elementwise column start pointer to a pointer to a column 105 // (EltType x NumRows)*. 106 Type *ColumnType = VectorType::get(EltType, NumRows); 107 Type *ColumnPtrType = PointerType::get(ColumnType, AS); 108 return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); 109 } 110 111 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 112 /// 113 /// Currently, the lowering for each matrix intrinsic is done as follows: 114 /// 1. Propagate the shape information from intrinsics to connected 115 /// instructions. 116 /// 2. Lower instructions with shape information. 117 /// 2.1. Get column vectors for each argument. If we already lowered the 118 /// definition of an argument, use the produced column vectors directly. 119 /// If not, split the operand vector containing an embedded matrix into 120 /// a set of column vectors, 121 /// 2.2. Lower the instruction in terms of columnwise operations, which yields 122 /// a set of column vectors containing result matrix. Note that we lower 123 /// all instructions that have shape information. Besides the intrinsics, 124 /// this includes stores for example. 125 /// 2.3. Update uses of the lowered instruction. If we have shape information 126 /// for a user, there is nothing to do, as we will look up the result 127 /// column matrix when lowering the user. For other uses, we embed the 128 /// result matrix in a flat vector and update the use. 129 /// 2.4. Cache the result column matrix for the instruction we lowered 130 /// 3. After we lowered all instructions in a function, remove the now 131 /// obsolete instructions. 132 /// 133 class LowerMatrixIntrinsics { 134 Function &Func; 135 const DataLayout &DL; 136 const TargetTransformInfo &TTI; 137 138 /// Wrapper class representing a matrix as a set of column vectors. 139 /// All column vectors must have the same vector type. 140 class ColumnMatrixTy { 141 SmallVector<Value *, 16> Columns; 142 143 public: 144 ColumnMatrixTy() : Columns() {} 145 ColumnMatrixTy(ArrayRef<Value *> Cols) 146 : Columns(Cols.begin(), Cols.end()) {} 147 148 Value *getColumn(unsigned i) const { return Columns[i]; } 149 150 void setColumn(unsigned i, Value *V) { Columns[i] = V; } 151 152 size_t getNumColumns() const { return Columns.size(); } 153 size_t getNumRows() const { 154 assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); 155 return cast<VectorType>(Columns[0]->getType())->getNumElements(); 156 } 157 158 const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } 159 160 SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } 161 162 void addColumn(Value *V) { Columns.push_back(V); } 163 164 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 165 return make_range(Columns.begin(), Columns.end()); 166 } 167 168 /// Embed the columns of the matrix into a flat vector by concatenating 169 /// them. 170 Value *embedInVector(IRBuilder<> &Builder) const { 171 return Columns.size() == 1 ? Columns[0] 172 : concatenateVectors(Builder, Columns); 173 } 174 }; 175 176 struct ShapeInfo { 177 unsigned NumRows; 178 unsigned NumColumns; 179 180 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 181 : NumRows(NumRows), NumColumns(NumColumns) {} 182 183 ShapeInfo(Value *NumRows, Value *NumColumns) 184 : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), 185 NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} 186 187 bool operator==(const ShapeInfo &other) { 188 return NumRows == other.NumRows && NumColumns == other.NumColumns; 189 } 190 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 191 192 /// Returns true if shape-information is defined, meaning both dimensions 193 /// are != 0. 194 operator bool() const { 195 assert(NumRows == 0 || NumColumns != 0); 196 return NumRows != 0; 197 } 198 }; 199 200 /// Maps instructions to their shape information. The shape information 201 /// describes the shape to be used while lowering. This matches the shape of 202 /// the result value of the instruction, with the only exceptions being store 203 /// instructions and the matrix_columnwise_store intrinsics. For those, the 204 /// shape information indicates that those instructions should be lowered 205 /// using shape information as well. 206 DenseMap<Value *, ShapeInfo> ShapeMap; 207 208 /// List of instructions to remove. While lowering, we are not replacing all 209 /// users of a lowered instruction, if shape information is available and 210 /// those need to be removed after we finished lowering. 211 SmallVector<Instruction *, 16> ToRemove; 212 213 /// Map from instructions to their produced column matrix. 214 DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix; 215 216 public: 217 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) 218 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} 219 220 /// Return the set of column vectors that a matrix value is lowered to. 221 /// 222 /// If we lowered \p MatrixVal, just return the cache result column matrix. 223 /// Otherwie split the flat vector \p MatrixVal containing a matrix with 224 /// shape \p SI into column vectors. 225 ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 226 IRBuilder<> Builder) { 227 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 228 assert(VType && "MatrixVal must be a vector type"); 229 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 230 "The vector size must match the number of matrix elements"); 231 232 // Check if we lowered MatrixVal using shape information. In that case, 233 // return the existing column matrix, if it matches the requested shape 234 // information. If there is a mis-match, embed the result in a flat 235 // vector and split it later. 236 auto Found = Inst2ColumnMatrix.find(MatrixVal); 237 if (Found != Inst2ColumnMatrix.end()) { 238 ColumnMatrixTy &M = Found->second; 239 // Return the found matrix, if its shape matches the requested shape 240 // information 241 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 242 return M; 243 244 MatrixVal = M.embedInVector(Builder); 245 } 246 247 // Otherwise split MatrixVal. 248 SmallVector<Value *, 16> SplitVecs; 249 Value *Undef = UndefValue::get(VType); 250 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 251 MaskStart += SI.NumRows) { 252 Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); 253 Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); 254 SplitVecs.push_back(V); 255 } 256 257 return {SplitVecs}; 258 } 259 260 /// If \p V already has a known shape return false. Otherwise set the shape 261 /// for instructions that support it. 262 bool setShapeInfo(Value *V, ShapeInfo Shape) { 263 assert(Shape && "Shape not set"); 264 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 265 return false; 266 267 auto SIter = ShapeMap.find(V); 268 if (SIter != ShapeMap.end()) { 269 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 270 << SIter->second.NumRows << " " 271 << SIter->second.NumColumns << " for " << *V << "\n"); 272 return false; 273 } 274 275 ShapeMap.insert({V, Shape}); 276 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 277 << " for " << *V << "\n"); 278 return true; 279 } 280 281 bool isUniformShape(Value *V) { 282 Instruction *I = dyn_cast<Instruction>(V); 283 if (!I) 284 return true; 285 286 switch (I->getOpcode()) { 287 case Instruction::FAdd: 288 case Instruction::FSub: 289 case Instruction::FMul: // Scalar multiply. 290 case Instruction::Add: 291 case Instruction::Mul: 292 case Instruction::Sub: 293 return true; 294 default: 295 return false; 296 } 297 } 298 299 /// Returns true if shape information can be used for \p V. The supported 300 /// instructions must match the instructions that can be lowered by this pass. 301 bool supportsShapeInfo(Value *V) { 302 Instruction *Inst = dyn_cast<Instruction>(V); 303 if (!Inst) 304 return false; 305 306 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 307 if (II) 308 switch (II->getIntrinsicID()) { 309 case Intrinsic::matrix_multiply: 310 case Intrinsic::matrix_transpose: 311 case Intrinsic::matrix_columnwise_load: 312 case Intrinsic::matrix_columnwise_store: 313 return true; 314 default: 315 return false; 316 } 317 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 318 } 319 320 /// Propagate the shape information of instructions to their users. 321 /// The work list contains instructions for which we can compute the shape, 322 /// either based on the information provided by matrix intrinsics or known 323 /// shapes of operands. 324 SmallVector<Instruction *, 32> 325 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 326 SmallVector<Instruction *, 32> NewWorkList; 327 // Pop an element for which we guaranteed to have at least one of the 328 // operand shapes. Add the shape for this and then add users to the work 329 // list. 330 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 331 while (!WorkList.empty()) { 332 Instruction *Inst = WorkList.back(); 333 WorkList.pop_back(); 334 335 // New entry, set the value and insert operands 336 bool Propagate = false; 337 338 Value *MatrixA; 339 Value *MatrixB; 340 Value *M; 341 Value *N; 342 Value *K; 343 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 344 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 345 m_Value(N), m_Value(K)))) { 346 Propagate = setShapeInfo(Inst, {M, K}); 347 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 348 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 349 // Flip dimensions. 350 Propagate = setShapeInfo(Inst, {N, M}); 351 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 352 m_Value(MatrixA), m_Value(), m_Value(), 353 m_Value(M), m_Value(N)))) { 354 Propagate = setShapeInfo(Inst, {N, M}); 355 } else if (match(Inst, 356 m_Intrinsic<Intrinsic::matrix_columnwise_load>( 357 m_Value(), m_Value(), m_Value(M), m_Value(N)))) { 358 Propagate = setShapeInfo(Inst, {M, N}); 359 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 360 auto OpShape = ShapeMap.find(MatrixA); 361 if (OpShape != ShapeMap.end()) 362 setShapeInfo(Inst, OpShape->second); 363 continue; 364 } else if (isUniformShape(Inst)) { 365 // Find the first operand that has a known shape and use that. 366 for (auto &Op : Inst->operands()) { 367 auto OpShape = ShapeMap.find(Op.get()); 368 if (OpShape != ShapeMap.end()) { 369 Propagate |= setShapeInfo(Inst, OpShape->second); 370 break; 371 } 372 } 373 } 374 375 if (Propagate) { 376 NewWorkList.push_back(Inst); 377 for (auto *User : Inst->users()) 378 if (ShapeMap.count(User) == 0) 379 WorkList.push_back(cast<Instruction>(User)); 380 } 381 } 382 383 return NewWorkList; 384 } 385 386 /// Propagate the shape to operands of instructions with shape information. 387 /// \p Worklist contains the instruction for which we already know the shape. 388 SmallVector<Instruction *, 32> 389 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 390 SmallVector<Instruction *, 32> NewWorkList; 391 392 auto pushInstruction = [](Value *V, 393 SmallVectorImpl<Instruction *> &WorkList) { 394 Instruction *I = dyn_cast<Instruction>(V); 395 if (I) 396 WorkList.push_back(I); 397 }; 398 // Pop an element with known shape. Traverse the operands, if their shape 399 // derives from the result shape and is unknown, add it and add them to the 400 // worklist. 401 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 402 while (!WorkList.empty()) { 403 Value *V = WorkList.back(); 404 WorkList.pop_back(); 405 406 size_t BeforeProcessingV = WorkList.size(); 407 if (!isa<Instruction>(V)) 408 continue; 409 410 Value *MatrixA; 411 Value *MatrixB; 412 Value *M; 413 Value *N; 414 Value *K; 415 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 416 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 417 m_Value(N), m_Value(K)))) { 418 if (setShapeInfo(MatrixA, {M, N})) 419 pushInstruction(MatrixA, WorkList); 420 421 if (setShapeInfo(MatrixB, {N, K})) 422 pushInstruction(MatrixB, WorkList); 423 424 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 425 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 426 // Flip dimensions. 427 if (setShapeInfo(MatrixA, {M, N})) 428 pushInstruction(MatrixA, WorkList); 429 } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( 430 m_Value(MatrixA), m_Value(), m_Value(), 431 m_Value(M), m_Value(N)))) { 432 if (setShapeInfo(MatrixA, {M, N})) { 433 pushInstruction(MatrixA, WorkList); 434 } 435 } else if (isa<LoadInst>(V) || 436 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { 437 // Nothing to do, no matrix input. 438 } else if (isa<StoreInst>(V)) { 439 // Nothing to do. We forward-propagated to this so we would just 440 // backward propagate to an instruction with an already known shape. 441 } else if (isUniformShape(V)) { 442 // Propagate to all operands. 443 ShapeInfo Shape = ShapeMap[V]; 444 for (Use &U : cast<Instruction>(V)->operands()) { 445 if (setShapeInfo(U.get(), Shape)) 446 pushInstruction(U.get(), WorkList); 447 } 448 } 449 // After we discovered new shape info for new instructions in the 450 // worklist, we use their users as seeds for the next round of forward 451 // propagation. 452 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 453 for (User *U : WorkList[I]->users()) 454 if (isa<Instruction>(U) && V != U) 455 NewWorkList.push_back(cast<Instruction>(U)); 456 } 457 return NewWorkList; 458 } 459 460 bool Visit() { 461 if (EnableShapePropagation) { 462 SmallVector<Instruction *, 32> WorkList; 463 464 // Initially only the shape of matrix intrinsics is known. 465 // Initialize the work list with ops carrying shape information. 466 for (BasicBlock &BB : Func) 467 for (Instruction &Inst : BB) { 468 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 469 if (!II) 470 continue; 471 472 switch (II->getIntrinsicID()) { 473 case Intrinsic::matrix_multiply: 474 case Intrinsic::matrix_transpose: 475 case Intrinsic::matrix_columnwise_load: 476 case Intrinsic::matrix_columnwise_store: 477 WorkList.push_back(&Inst); 478 break; 479 default: 480 break; 481 } 482 } 483 // Propagate shapes until nothing changes any longer. 484 while (!WorkList.empty()) { 485 WorkList = propagateShapeForward(WorkList); 486 WorkList = propagateShapeBackward(WorkList); 487 } 488 } 489 490 ReversePostOrderTraversal<Function *> RPOT(&Func); 491 bool Changed = false; 492 for (auto *BB : RPOT) { 493 for (Instruction &Inst : make_early_inc_range(*BB)) { 494 IRBuilder<> Builder(&Inst); 495 496 if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) 497 Changed |= VisitCallInst(CInst); 498 499 Value *Op1; 500 Value *Op2; 501 if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) 502 Changed |= VisitBinaryOperator(BinOp); 503 if (match(&Inst, m_Load(m_Value(Op1)))) 504 Changed |= VisitLoad(&Inst, Op1, Builder); 505 else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 506 Changed |= VisitStore(&Inst, Op1, Op2, Builder); 507 } 508 } 509 510 for (Instruction *Inst : reverse(ToRemove)) 511 Inst->eraseFromParent(); 512 513 return Changed; 514 } 515 516 LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, 517 IRBuilder<> Builder) { 518 unsigned Align = DL.getABITypeAlignment(EltType); 519 return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load"); 520 } 521 522 StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, 523 Type *EltType, IRBuilder<> Builder) { 524 unsigned Align = DL.getABITypeAlignment(EltType); 525 return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); 526 } 527 528 529 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 530 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 531 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 532 Type *EltPtrType = PointerType::get(EltType, AS); 533 return Builder.CreatePointerCast(BasePtr, EltPtrType); 534 } 535 536 /// Replace intrinsic calls 537 bool VisitCallInst(CallInst *Inst) { 538 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 539 return false; 540 541 switch (Inst->getCalledFunction()->getIntrinsicID()) { 542 case Intrinsic::matrix_multiply: 543 LowerMultiply(Inst); 544 break; 545 case Intrinsic::matrix_transpose: 546 LowerTranspose(Inst); 547 break; 548 case Intrinsic::matrix_columnwise_load: 549 LowerColumnwiseLoad(Inst); 550 break; 551 case Intrinsic::matrix_columnwise_store: 552 LowerColumnwiseStore(Inst); 553 break; 554 default: 555 return false; 556 } 557 return true; 558 } 559 560 void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, 561 ShapeInfo Shape) { 562 IRBuilder<> Builder(Inst); 563 auto VType = cast<VectorType>(Inst->getType()); 564 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 565 ColumnMatrixTy Result; 566 // Distance between start of one column and the start of the next 567 for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { 568 Value *GEP = 569 computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, 570 VType->getElementType(), Builder); 571 Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); 572 Result.addColumn(Column); 573 } 574 575 finalizeLowering(Inst, Result, Builder); 576 } 577 578 /// Lowers llvm.matrix.columnwise.load. 579 /// 580 /// The intrinsic loads a matrix from memory using a stride between columns. 581 void LowerColumnwiseLoad(CallInst *Inst) { 582 Value *Ptr = Inst->getArgOperand(0); 583 Value *Stride = Inst->getArgOperand(1); 584 LowerLoad(Inst, Ptr, Stride, 585 {Inst->getArgOperand(2), Inst->getArgOperand(3)}); 586 } 587 588 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, 589 ShapeInfo Shape) { 590 IRBuilder<> Builder(Inst); 591 auto VType = cast<VectorType>(Matrix->getType()); 592 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 593 auto LM = getMatrix(Matrix, Shape, Builder); 594 for (auto C : enumerate(LM.columns())) { 595 Value *GEP = 596 computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, 597 Shape.NumRows, VType->getElementType(), Builder); 598 createColumnStore(C.value(), GEP, VType->getElementType(), Builder); 599 } 600 601 ToRemove.push_back(Inst); 602 } 603 604 /// Lowers llvm.matrix.columnwise.store. 605 /// 606 /// The intrinsic store a matrix back memory using a stride between columns. 607 void LowerColumnwiseStore(CallInst *Inst) { 608 Value *Matrix = Inst->getArgOperand(0); 609 Value *Ptr = Inst->getArgOperand(1); 610 Value *Stride = Inst->getArgOperand(2); 611 LowerStore(Inst, Matrix, Ptr, Stride, 612 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 613 } 614 615 /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from 616 /// the matrix \p LM represented as a vector of column vectors. 617 Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, 618 unsigned NumElts, IRBuilder<> Builder) { 619 Value *Col = LM.getColumn(J); 620 Value *Undef = UndefValue::get(Col->getType()); 621 Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); 622 return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); 623 } 624 625 // Set elements I..I+NumElts-1 to Block 626 Value *insertVector(Value *Col, unsigned I, Value *Block, 627 IRBuilder<> Builder) { 628 629 // First, bring Block to the same size as Col 630 unsigned BlockNumElts = 631 cast<VectorType>(Block->getType())->getNumElements(); 632 unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); 633 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 634 635 Value *ExtendMask = 636 createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); 637 Value *Undef = UndefValue::get(Block->getType()); 638 Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); 639 640 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 641 // 8, 4, 5, 6 642 SmallVector<Constant *, 16> Mask; 643 unsigned i; 644 for (i = 0; i < I; i++) 645 Mask.push_back(Builder.getInt32(i)); 646 647 unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); 648 for (; i < I + BlockNumElts; i++) 649 Mask.push_back(Builder.getInt32(i - I + VecNumElts)); 650 651 for (; i < VecNumElts; i++) 652 Mask.push_back(Builder.getInt32(i)); 653 654 Value *MaskVal = ConstantVector::get(Mask); 655 656 return Builder.CreateShuffleVector(Col, Block, MaskVal); 657 } 658 659 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 660 IRBuilder<> &Builder, bool AllowContraction) { 661 662 if (!Sum) 663 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 664 665 if (UseFPOp) { 666 if (AllowContraction) { 667 // Use fmuladd for floating point operations and let the backend decide 668 // if that's profitable. 669 Value *FMulAdd = Intrinsic::getDeclaration( 670 Func.getParent(), Intrinsic::fmuladd, A->getType()); 671 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 672 } 673 Value *Mul = Builder.CreateFMul(A, B); 674 return Builder.CreateFAdd(Sum, Mul); 675 } 676 677 Value *Mul = Builder.CreateMul(A, B); 678 return Builder.CreateAdd(Sum, Mul); 679 } 680 681 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 682 /// users with shape information, there's nothing to do: the will use the 683 /// cached value when they are lowered. For other users, \p Matrix is 684 /// flattened and the uses are updated to use it. Also marks \p Inst for 685 /// deletion. 686 void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, 687 IRBuilder<> &Builder) { 688 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 689 690 ToRemove.push_back(Inst); 691 Value *Flattened = nullptr; 692 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 693 Use &U = *I++; 694 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 695 if (!Flattened) 696 Flattened = Matrix.embedInVector(Builder); 697 U.set(Flattened); 698 } 699 } 700 } 701 702 /// Lowers llvm.matrix.multiply. 703 void LowerMultiply(CallInst *MatMul) { 704 IRBuilder<> Builder(MatMul); 705 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 706 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 707 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 708 709 const ColumnMatrixTy &Lhs = 710 getMatrix(MatMul->getArgOperand(0), LShape, Builder); 711 const ColumnMatrixTy &Rhs = 712 getMatrix(MatMul->getArgOperand(1), RShape, Builder); 713 714 const unsigned R = LShape.NumRows; 715 const unsigned M = LShape.NumColumns; 716 const unsigned C = RShape.NumColumns; 717 assert(M == RShape.NumRows); 718 719 // Initialize the output 720 ColumnMatrixTy Result; 721 for (unsigned J = 0; J < C; ++J) 722 Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); 723 724 const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / 725 EltType->getPrimitiveSizeInBits(), 726 uint64_t(1)); 727 728 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 729 MatMul->hasAllowContract()); 730 // Multiply columns from the first operand with scalars from the second 731 // operand. Then move along the K axes and accumulate the columns. With 732 // this the adds can be vectorized without reassociation. 733 for (unsigned J = 0; J < C; ++J) { 734 unsigned BlockSize = VF; 735 for (unsigned I = 0; I < R; I += BlockSize) { 736 // Gradually lower the vectorization factor to cover the remainder. 737 while (I + BlockSize > R) 738 BlockSize /= 2; 739 740 Value *Sum = nullptr; 741 for (unsigned K = 0; K < M; ++K) { 742 Value *L = extractVector(Lhs, I, K, BlockSize, Builder); 743 Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); 744 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 745 Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), 746 Builder, AllowContract); 747 } 748 Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); 749 } 750 } 751 finalizeLowering(MatMul, Result, Builder); 752 } 753 754 /// Lowers llvm.matrix.transpose. 755 void LowerTranspose(CallInst *Inst) { 756 ColumnMatrixTy Result; 757 IRBuilder<> Builder(Inst); 758 Value *InputVal = Inst->getArgOperand(0); 759 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 760 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 761 ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 762 763 for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { 764 // Build a single column vector for this row. First initialize it. 765 Value *ResultColumn = UndefValue::get( 766 VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); 767 768 // Go through the elements of this row and insert it into the resulting 769 // column vector. 770 for (auto C : enumerate(InputMatrix.columns())) { 771 Value *Elt = Builder.CreateExtractElement(C.value(), Row); 772 // We insert at index Column since that is the row index after the 773 // transpose. 774 ResultColumn = 775 Builder.CreateInsertElement(ResultColumn, Elt, C.index()); 776 } 777 Result.addColumn(ResultColumn); 778 } 779 780 finalizeLowering(Inst, Result, Builder); 781 } 782 783 /// Lower load instructions, if shape information is available. 784 bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { 785 auto I = ShapeMap.find(Inst); 786 if (I == ShapeMap.end()) 787 return false; 788 789 LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); 790 return true; 791 } 792 793 bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, 794 IRBuilder<> &Builder) { 795 auto I = ShapeMap.find(StoredVal); 796 if (I == ShapeMap.end()) 797 return false; 798 799 LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); 800 return true; 801 } 802 803 /// Lower binary operators, if shape information is available. 804 bool VisitBinaryOperator(BinaryOperator *Inst) { 805 auto I = ShapeMap.find(Inst); 806 if (I == ShapeMap.end()) 807 return false; 808 809 Value *Lhs = Inst->getOperand(0); 810 Value *Rhs = Inst->getOperand(1); 811 812 IRBuilder<> Builder(Inst); 813 ShapeInfo &Shape = I->second; 814 815 ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); 816 ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); 817 818 // Add each column and store the result back into the opmapping 819 ColumnMatrixTy Result; 820 auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { 821 switch (Inst->getOpcode()) { 822 case Instruction::Add: 823 return Builder.CreateAdd(LHS, RHS); 824 case Instruction::Mul: 825 return Builder.CreateMul(LHS, RHS); 826 case Instruction::Sub: 827 return Builder.CreateSub(LHS, RHS); 828 case Instruction::FAdd: 829 return Builder.CreateFAdd(LHS, RHS); 830 case Instruction::FMul: 831 return Builder.CreateFMul(LHS, RHS); 832 case Instruction::FSub: 833 return Builder.CreateFSub(LHS, RHS); 834 default: 835 llvm_unreachable("Unsupported binary operator for matrix"); 836 } 837 }; 838 for (unsigned C = 0; C < Shape.NumColumns; ++C) 839 Result.addColumn( 840 BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); 841 842 finalizeLowering(Inst, Result, Builder); 843 return true; 844 } 845 }; 846 } // namespace 847 848 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 849 FunctionAnalysisManager &AM) { 850 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 851 LowerMatrixIntrinsics LMT(F, TTI); 852 if (LMT.Visit()) { 853 PreservedAnalyses PA; 854 PA.preserveSet<CFGAnalyses>(); 855 return PA; 856 } 857 return PreservedAnalyses::all(); 858 } 859 860 namespace { 861 862 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 863 public: 864 static char ID; 865 866 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 867 initializeLowerMatrixIntrinsicsLegacyPassPass( 868 *PassRegistry::getPassRegistry()); 869 } 870 871 bool runOnFunction(Function &F) override { 872 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 873 LowerMatrixIntrinsics LMT(F, *TTI); 874 bool C = LMT.Visit(); 875 return C; 876 } 877 878 void getAnalysisUsage(AnalysisUsage &AU) const override { 879 AU.addRequired<TargetTransformInfoWrapperPass>(); 880 AU.setPreservesCFG(); 881 } 882 }; 883 } // namespace 884 885 static const char pass_name[] = "Lower the matrix intrinsics"; 886 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 887 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 888 false, false) 889 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 890 false, false) 891 892 Pass *llvm::createLowerMatrixIntrinsicsPass() { 893 return new LowerMatrixIntrinsicsLegacyPass(); 894 } 895