1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/ScopeExit.h" 23 #include "llvm/ADT/SmallSet.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Analysis/AliasAnalysis.h" 26 #include "llvm/Analysis/DomTreeUpdater.h" 27 #include "llvm/Analysis/LoopInfo.h" 28 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 29 #include "llvm/Analysis/TargetTransformInfo.h" 30 #include "llvm/Analysis/ValueTracking.h" 31 #include "llvm/Analysis/VectorUtils.h" 32 #include "llvm/IR/CFG.h" 33 #include "llvm/IR/DataLayout.h" 34 #include "llvm/IR/DebugInfoMetadata.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/IRBuilder.h" 37 #include "llvm/IR/Instructions.h" 38 #include "llvm/IR/IntrinsicInst.h" 39 #include "llvm/IR/MatrixBuilder.h" 40 #include "llvm/IR/PatternMatch.h" 41 #include "llvm/Support/Alignment.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 #include "llvm/Transforms/Utils/LoopUtils.h" 46 #include "llvm/Transforms/Utils/MatrixUtils.h" 47 48 #include <cmath> 49 50 using namespace llvm; 51 using namespace PatternMatch; 52 53 #define DEBUG_TYPE "lower-matrix-intrinsics" 54 55 static cl::opt<bool> 56 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 57 cl::desc("Enable/disable fusing matrix instructions.")); 58 // TODO: Allow and use non-square tiles. 59 static cl::opt<unsigned> TileSize( 60 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 61 cl::desc( 62 "Tile size for matrix instruction fusion using square-shaped tiles.")); 63 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 64 cl::Hidden, 65 cl::desc("Generate loop nest for tiling.")); 66 static cl::opt<bool> ForceFusion( 67 "force-fuse-matrix", cl::init(false), cl::Hidden, 68 cl::desc("Force matrix instruction fusion even if not profitable.")); 69 static cl::opt<bool> AllowContractEnabled( 70 "matrix-allow-contract", cl::init(false), cl::Hidden, 71 cl::desc("Allow the use of FMAs if available and profitable. This may " 72 "result in different results, due to less rounding error.")); 73 74 static cl::opt<bool> 75 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, 76 cl::desc("Enable/disable matrix shape verification."), 77 cl::init(false)); 78 79 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 80 81 static cl::opt<MatrixLayoutTy> MatrixLayout( 82 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 83 cl::desc("Sets the default matrix layout"), 84 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 85 "Use column-major layout"), 86 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 87 "Use row-major layout"))); 88 89 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", 90 cl::init(false)); 91 92 /// Helper function to either return Scope, if it is a subprogram or the 93 /// attached subprogram for a local scope. 94 static DISubprogram *getSubprogram(DIScope *Scope) { 95 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 96 return Subprogram; 97 return cast<DILocalScope>(Scope)->getSubprogram(); 98 } 99 100 /// Erase \p V from \p BB and move \II forward to avoid invalidating 101 /// iterators. 102 static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, 103 BasicBlock &BB) { 104 auto *Inst = cast<Instruction>(V); 105 // Still used, don't erase. 106 if (!Inst->use_empty()) 107 return; 108 if (II != BB.rend() && Inst == &*II) 109 ++II; 110 Inst->eraseFromParent(); 111 } 112 113 /// Return true if V is a splat of a value (which is used when multiplying a 114 /// matrix with a scalar). 115 static bool isSplat(Value *V) { 116 if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) 117 return SV->isZeroEltSplat(); 118 return false; 119 } 120 121 /// Match any mul operation (fp or integer). 122 template <typename LTy, typename RTy> 123 auto m_AnyMul(const LTy &L, const RTy &R) { 124 return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); 125 } 126 127 /// Match any add operation (fp or integer). 128 template <typename LTy, typename RTy> 129 auto m_AnyAdd(const LTy &L, const RTy &R) { 130 return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); 131 } 132 133 namespace { 134 135 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 136 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 137 // assuming \p Stride elements between start two consecutive vectors. 138 // \p Stride must be >= \p NumElements. 139 // For column-major matrixes, the function computes the address of a column 140 // vectors and \p NumElements must be set to the number of elements in a column 141 // (= number of rows of the matrix). For row-major matrixes, the function 142 // computes the address of a row vector and \p NumElements must be set to the 143 // number of elements in a column (= number of columns of the matrix). 144 // 145 // Consider a 4x4 matrix in column-mjaor layout like below 146 // 147 // 0 1 2 3 148 // 0 v_0_0 v_0_1 v_0_2 v_0_3 149 // 1 v_1_0 v_1_1 v_1_2 v_1_3 150 // 2 v_2_0 v_2_1 v_2_2 v_2_3 151 // 3 v_3_0 v_3_1 v_3_2 v_3_3 152 153 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 154 // we need a pointer to the first element of the submatrix as base pointer. 155 // Then we can use computeVectorAddr to compute the addresses for the columns 156 // of the sub-matrix. 157 // 158 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 159 // -> just returns Base 160 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 161 // -> returns Base + (1 * 4) 162 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 163 // -> returns Base + (2 * 4) 164 // 165 // The graphic below illustrates the number of elements in a column (marked 166 // with |) and the number of skipped elements (marked with }). 167 // 168 // v_0_0 v_0_1 {v_0_2 {v_0_3 169 // Base Col 1 Col 2 170 // | | | 171 // v_1_0 |v_1_1 |v_1_2 |v_1_3 172 // v_2_0 |v_2_1 |v_2_2 |v_2_3 173 // v_3_0 {v_3_1 {v_3_2 v_3_3 174 // 175 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 176 unsigned NumElements, Type *EltType, 177 IRBuilder<> &Builder) { 178 179 assert((!isa<ConstantInt>(Stride) || 180 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 181 "Stride must be >= the number of elements in the result vector."); 182 183 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 184 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 185 186 // Get pointer to the start of the selected vector. Skip GEP creation, 187 // if we select vector 0. 188 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 189 VecStart = BasePtr; 190 else 191 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 192 193 return VecStart; 194 } 195 196 namespace { 197 struct ShapeInfo { 198 unsigned NumRows; 199 unsigned NumColumns; 200 201 bool IsColumnMajor; 202 203 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 204 : NumRows(NumRows), NumColumns(NumColumns), 205 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 206 207 ShapeInfo(Value *NumRows, Value *NumColumns) 208 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 209 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 210 211 bool operator==(const ShapeInfo &other) { 212 return NumRows == other.NumRows && NumColumns == other.NumColumns; 213 } 214 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 215 216 /// Returns true if shape-information is defined, meaning both dimensions 217 /// are != 0. 218 operator bool() const { 219 assert(NumRows == 0 || NumColumns != 0); 220 return NumRows != 0; 221 } 222 223 unsigned getStride() const { 224 if (IsColumnMajor) 225 return NumRows; 226 return NumColumns; 227 } 228 229 unsigned getNumVectors() const { 230 if (IsColumnMajor) 231 return NumColumns; 232 return NumRows; 233 } 234 235 /// Returns the transposed shape. 236 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } 237 }; 238 } // namespace 239 240 static bool isUniformShape(Value *V) { 241 Instruction *I = dyn_cast<Instruction>(V); 242 if (!I) 243 return true; 244 245 switch (I->getOpcode()) { 246 case Instruction::FAdd: 247 case Instruction::FSub: 248 case Instruction::FMul: // Scalar multiply. 249 case Instruction::FNeg: 250 case Instruction::Add: 251 case Instruction::Mul: 252 case Instruction::Sub: 253 return true; 254 default: 255 return false; 256 } 257 } 258 259 /// Return the ShapeInfo for the result of \p I, it it can be determined. 260 static std::optional<ShapeInfo> 261 computeShapeInfoForInst(Instruction *I, 262 const ValueMap<Value *, ShapeInfo> &ShapeMap) { 263 Value *M; 264 Value *N; 265 Value *K; 266 if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>( 267 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K)))) 268 return ShapeInfo(M, K); 269 if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M), 270 m_Value(N)))) { 271 // Flip dimensions. 272 return ShapeInfo(N, M); 273 } 274 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>( 275 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M), 276 m_Value(N)))) 277 return ShapeInfo(N, M); 278 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>( 279 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N)))) 280 return ShapeInfo(M, N); 281 Value *MatrixA; 282 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) { 283 auto OpShape = ShapeMap.find(MatrixA); 284 if (OpShape != ShapeMap.end()) 285 return OpShape->second; 286 } 287 288 if (isUniformShape(I)) { 289 // Find the first operand that has a known shape and use that. 290 for (auto &Op : I->operands()) { 291 auto OpShape = ShapeMap.find(Op.get()); 292 if (OpShape != ShapeMap.end()) 293 return OpShape->second; 294 } 295 } 296 return std::nullopt; 297 } 298 299 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 300 /// 301 /// Currently, the lowering for each matrix intrinsic is done as follows: 302 /// 1. Propagate the shape information from intrinsics to connected 303 /// instructions. 304 /// 2. Lower instructions with shape information (assuming column-major layout). 305 /// The lowering works similarly using row-major layout. 306 /// 2.1. Get column vectors for each argument. If we already lowered the 307 /// definition of an argument, use the produced column vectors directly. 308 /// If not, split the operand vector containing an embedded matrix into 309 /// a set of column vectors, 310 /// 2.2. Lower the instruction in terms of column major operations, which 311 /// yields a set of column vectors containing result matrix. Note that we 312 /// lower all instructions that have shape information. Besides the 313 /// intrinsics, this includes stores for example. 314 /// 2.3. Update uses of the lowered instruction. If we have shape information 315 /// for a user, there is nothing to do, as we will look up the result 316 /// column matrix when lowering the user. For other uses, we embed the 317 /// result matrix in a flat vector and update the use. 318 /// 2.4. Cache the result column matrix for the instruction we lowered 319 /// 3. After we lowered all instructions in a function, remove the now 320 /// obsolete instructions. 321 /// 322 class LowerMatrixIntrinsics { 323 Function &Func; 324 const DataLayout &DL; 325 const TargetTransformInfo &TTI; 326 AliasAnalysis *AA; 327 DominatorTree *DT; 328 LoopInfo *LI; 329 OptimizationRemarkEmitter *ORE; 330 331 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 332 struct OpInfoTy { 333 /// Number of stores emitted to generate this matrix. 334 unsigned NumStores = 0; 335 /// Number of loads emitted to generate this matrix. 336 unsigned NumLoads = 0; 337 /// Number of compute operations emitted to generate this matrix. 338 unsigned NumComputeOps = 0; 339 /// Most of the time transposes can be fused with matrix multiplies or can 340 /// be folded away via algebraic simplifications. This is the number of 341 /// transposes that we failed to make "free" via such optimizations. 342 unsigned NumExposedTransposes = 0; 343 344 OpInfoTy &operator+=(const OpInfoTy &RHS) { 345 NumStores += RHS.NumStores; 346 NumLoads += RHS.NumLoads; 347 NumComputeOps += RHS.NumComputeOps; 348 NumExposedTransposes += RHS.NumExposedTransposes; 349 return *this; 350 } 351 }; 352 353 /// Wrapper class representing a matrix as a set of vectors, either in row or 354 /// column major layout. All vectors must have the same vector type. 355 class MatrixTy { 356 SmallVector<Value *, 16> Vectors; 357 358 OpInfoTy OpInfo; 359 360 bool IsColumnMajor = true; 361 362 public: 363 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 364 MatrixTy(ArrayRef<Value *> Vectors) 365 : Vectors(Vectors.begin(), Vectors.end()), 366 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 367 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 368 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 369 370 unsigned D = isColumnMajor() ? NumColumns : NumRows; 371 for (unsigned J = 0; J < D; ++J) 372 addVector(PoisonValue::get(FixedVectorType::get( 373 EltTy, isColumnMajor() ? NumRows : NumColumns))); 374 } 375 376 Value *getVector(unsigned i) const { return Vectors[i]; } 377 Value *getColumn(unsigned i) const { 378 assert(isColumnMajor() && "only supported for column-major matrixes"); 379 return Vectors[i]; 380 } 381 Value *getRow(unsigned i) const { 382 assert(!isColumnMajor() && "only supported for row-major matrixes"); 383 return Vectors[i]; 384 } 385 386 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 387 388 Type *getElementType() const { return getVectorTy()->getElementType(); } 389 390 unsigned getNumVectors() const { 391 if (isColumnMajor()) 392 return getNumColumns(); 393 return getNumRows(); 394 } 395 396 unsigned getNumColumns() const { 397 if (isColumnMajor()) 398 return Vectors.size(); 399 else { 400 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 401 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 402 } 403 } 404 unsigned getNumRows() const { 405 if (isColumnMajor()) { 406 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 407 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 408 } else 409 return Vectors.size(); 410 } 411 412 void addVector(Value *V) { Vectors.push_back(V); } 413 VectorType *getColumnTy() { 414 assert(isColumnMajor() && "only supported for column-major matrixes"); 415 return getVectorTy(); 416 } 417 418 VectorType *getVectorTy() const { 419 return cast<VectorType>(Vectors[0]->getType()); 420 } 421 422 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 423 assert(isColumnMajor() && 424 "columns() only supported for column-major matrixes"); 425 return make_range(Vectors.begin(), Vectors.end()); 426 } 427 428 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 429 return make_range(Vectors.begin(), Vectors.end()); 430 } 431 432 /// Embed the vectors of the matrix into a flat vector by concatenating 433 /// them. 434 Value *embedInVector(IRBuilder<> &Builder) const { 435 return Vectors.size() == 1 ? Vectors[0] 436 : concatenateVectors(Builder, Vectors); 437 } 438 439 MatrixTy &addNumLoads(unsigned N) { 440 OpInfo.NumLoads += N; 441 return *this; 442 } 443 444 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 445 446 MatrixTy &addNumStores(unsigned N) { 447 OpInfo.NumStores += N; 448 return *this; 449 } 450 451 MatrixTy &addNumExposedTransposes(unsigned N) { 452 OpInfo.NumExposedTransposes += N; 453 return *this; 454 } 455 456 MatrixTy &addNumComputeOps(unsigned N) { 457 OpInfo.NumComputeOps += N; 458 return *this; 459 } 460 461 unsigned getNumStores() const { return OpInfo.NumStores; } 462 unsigned getNumLoads() const { return OpInfo.NumLoads; } 463 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 464 465 const OpInfoTy &getOpInfo() const { return OpInfo; } 466 467 bool isColumnMajor() const { return IsColumnMajor; } 468 469 unsigned getStride() const { 470 if (isColumnMajor()) 471 return getNumRows(); 472 return getNumColumns(); 473 } 474 475 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 476 /// matrix is column-major, the result vector is extracted from a column 477 /// vector, otherwise from a row vector. 478 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 479 IRBuilder<> &Builder) const { 480 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 481 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 482 NumElts && 483 "Extracted vector will contain poison values"); 484 return Builder.CreateShuffleVector( 485 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 486 "block"); 487 } 488 }; 489 490 /// Maps instructions to their shape information. The shape information 491 /// describes the shape to be used while lowering. This matches the shape of 492 /// the result value of the instruction, with the only exceptions being store 493 /// instructions and the matrix_column_major_store intrinsics. For those, the 494 /// shape information indicates that those instructions should be lowered 495 /// using shape information as well. A ValueMap is used so that when 496 /// sub-passes like optimizeTransposes performs RAUW the map stays 497 /// up-to-date. 498 ValueMap<Value *, ShapeInfo> ShapeMap; 499 500 /// List of instructions to remove. While lowering, we are not replacing all 501 /// users of a lowered instruction, if shape information is available and 502 /// those need to be removed after we finished lowering. 503 SmallVector<Instruction *, 16> ToRemove; 504 505 /// Map from instructions to their produced column matrix. 506 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 507 508 private: 509 static FastMathFlags getFastMathFlags(Instruction *Inst) { 510 FastMathFlags FMF; 511 512 if (isa<FPMathOperator>(*Inst)) 513 FMF = Inst->getFastMathFlags(); 514 515 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 516 517 return FMF; 518 } 519 520 public: 521 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 522 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 523 OptimizationRemarkEmitter *ORE) 524 : Func(F), DL(F.getDataLayout()), TTI(TTI), AA(AA), DT(DT), 525 LI(LI), ORE(ORE) {} 526 527 unsigned getNumOps(Type *VT) { 528 assert(isa<VectorType>(VT) && "Expected vector type"); 529 return getNumOps(VT->getScalarType(), 530 cast<FixedVectorType>(VT)->getNumElements()); 531 } 532 533 /// Is this the minimal version executed in the backend pipelines. 534 bool isMinimal() const { 535 return !DT; 536 } 537 538 /// Return the estimated number of vector ops required for an operation on 539 /// \p VT * N. 540 unsigned getNumOps(Type *ST, unsigned N) { 541 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / 542 double(TTI.getRegisterBitWidth( 543 TargetTransformInfo::RGK_FixedWidthVector) 544 .getFixedValue())); 545 } 546 547 /// Return the set of vectors that a matrix value is lowered to. 548 /// 549 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 550 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 551 /// into vectors. 552 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 553 IRBuilder<> &Builder) { 554 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 555 assert(VType && "MatrixVal must be a vector type"); 556 assert(cast<FixedVectorType>(VType)->getNumElements() == 557 SI.NumRows * SI.NumColumns && 558 "The vector size must match the number of matrix elements"); 559 560 // Check if we lowered MatrixVal using shape information. In that case, 561 // return the existing matrix, if it matches the requested shape 562 // information. If there is a mis-match, embed the result in a flat 563 // vector and split it later. 564 auto Found = Inst2ColumnMatrix.find(MatrixVal); 565 if (Found != Inst2ColumnMatrix.end()) { 566 MatrixTy &M = Found->second; 567 // Return the found matrix, if its shape matches the requested shape 568 // information 569 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 570 return M; 571 572 MatrixVal = M.embedInVector(Builder); 573 } 574 575 // Otherwise split MatrixVal. 576 SmallVector<Value *, 16> SplitVecs; 577 for (unsigned MaskStart = 0; 578 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 579 MaskStart += SI.getStride()) { 580 Value *V = Builder.CreateShuffleVector( 581 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 582 "split"); 583 SplitVecs.push_back(V); 584 } 585 586 return {SplitVecs}; 587 } 588 589 /// If \p V already has a known shape return false. Otherwise set the shape 590 /// for instructions that support it. 591 bool setShapeInfo(Value *V, ShapeInfo Shape) { 592 assert(Shape && "Shape not set"); 593 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 594 return false; 595 596 auto SIter = ShapeMap.find(V); 597 if (SIter != ShapeMap.end()) { 598 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || 599 SIter->second.NumColumns != Shape.NumColumns)) { 600 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" 601 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" 602 << Shape.NumColumns << ") for " << *V << "\n"; 603 report_fatal_error( 604 "Matrix shape verification failed, compilation aborted!"); 605 } 606 607 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 608 << SIter->second.NumRows << " " 609 << SIter->second.NumColumns << " for " << *V << "\n"); 610 return false; 611 } 612 613 ShapeMap.insert({V, Shape}); 614 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 615 << " for " << *V << "\n"); 616 return true; 617 } 618 619 /// Returns true if shape information can be used for \p V. The supported 620 /// instructions must match the instructions that can be lowered by this pass. 621 bool supportsShapeInfo(Value *V) { 622 Instruction *Inst = dyn_cast<Instruction>(V); 623 if (!Inst) 624 return false; 625 626 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 627 if (II) 628 switch (II->getIntrinsicID()) { 629 case Intrinsic::matrix_multiply: 630 case Intrinsic::matrix_transpose: 631 case Intrinsic::matrix_column_major_load: 632 case Intrinsic::matrix_column_major_store: 633 return true; 634 default: 635 return false; 636 } 637 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 638 } 639 640 /// Propagate the shape information of instructions to their users. 641 /// The work list contains instructions for which we can compute the shape, 642 /// either based on the information provided by matrix intrinsics or known 643 /// shapes of operands. 644 SmallVector<Instruction *, 32> 645 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 646 SmallVector<Instruction *, 32> NewWorkList; 647 // Pop an element for which we guaranteed to have at least one of the 648 // operand shapes. Add the shape for this and then add users to the work 649 // list. 650 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 651 while (!WorkList.empty()) { 652 Instruction *Inst = WorkList.pop_back_val(); 653 654 // New entry, set the value and insert operands 655 bool Propagate = false; 656 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap)) 657 Propagate = setShapeInfo(Inst, *SI); 658 659 if (Propagate) { 660 NewWorkList.push_back(Inst); 661 for (auto *User : Inst->users()) 662 if (ShapeMap.count(User) == 0) 663 WorkList.push_back(cast<Instruction>(User)); 664 } 665 } 666 667 return NewWorkList; 668 } 669 670 /// Propagate the shape to operands of instructions with shape information. 671 /// \p Worklist contains the instruction for which we already know the shape. 672 SmallVector<Instruction *, 32> 673 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 674 SmallVector<Instruction *, 32> NewWorkList; 675 676 auto pushInstruction = [](Value *V, 677 SmallVectorImpl<Instruction *> &WorkList) { 678 Instruction *I = dyn_cast<Instruction>(V); 679 if (I) 680 WorkList.push_back(I); 681 }; 682 // Pop an element with known shape. Traverse the operands, if their shape 683 // derives from the result shape and is unknown, add it and add them to the 684 // worklist. 685 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 686 while (!WorkList.empty()) { 687 Value *V = WorkList.pop_back_val(); 688 689 size_t BeforeProcessingV = WorkList.size(); 690 if (!isa<Instruction>(V)) 691 continue; 692 693 Value *MatrixA; 694 Value *MatrixB; 695 Value *M; 696 Value *N; 697 Value *K; 698 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 699 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 700 m_Value(N), m_Value(K)))) { 701 if (setShapeInfo(MatrixA, {M, N})) 702 pushInstruction(MatrixA, WorkList); 703 704 if (setShapeInfo(MatrixB, {N, K})) 705 pushInstruction(MatrixB, WorkList); 706 707 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 708 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 709 // Flip dimensions. 710 if (setShapeInfo(MatrixA, {M, N})) 711 pushInstruction(MatrixA, WorkList); 712 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 713 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 714 m_Value(M), m_Value(N)))) { 715 if (setShapeInfo(MatrixA, {M, N})) { 716 pushInstruction(MatrixA, WorkList); 717 } 718 } else if (isa<LoadInst>(V) || 719 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 720 // Nothing to do, no matrix input. 721 } else if (isa<StoreInst>(V)) { 722 // Nothing to do. We forward-propagated to this so we would just 723 // backward propagate to an instruction with an already known shape. 724 } else if (isUniformShape(V)) { 725 // Propagate to all operands. 726 ShapeInfo Shape = ShapeMap[V]; 727 for (Use &U : cast<Instruction>(V)->operands()) { 728 if (setShapeInfo(U.get(), Shape)) 729 pushInstruction(U.get(), WorkList); 730 } 731 } 732 // After we discovered new shape info for new instructions in the 733 // worklist, we use their users as seeds for the next round of forward 734 // propagation. 735 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 736 for (User *U : WorkList[I]->users()) 737 if (isa<Instruction>(U) && V != U) 738 NewWorkList.push_back(cast<Instruction>(U)); 739 } 740 return NewWorkList; 741 } 742 743 /// (Op0 op Op1)^T -> Op0^T op Op1^T 744 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 745 /// them on both sides of \p Operation. 746 Instruction *distributeTransposes( 747 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, 748 MatrixBuilder &Builder, 749 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> 750 Operation) { 751 Value *T0 = Builder.CreateMatrixTranspose( 752 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); 753 // We are being run after shape prop, add shape for newly created 754 // instructions so that we lower them later. 755 setShapeInfo(T0, Shape0.t()); 756 Value *T1 = Builder.CreateMatrixTranspose( 757 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); 758 setShapeInfo(T1, Shape1.t()); 759 return Operation(T0, Shape0.t(), T1, Shape1.t()); 760 } 761 762 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { 763 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 764 // with New. We should only add New it it supportsShapeInfo so we insert 765 // it conditionally instead. 766 auto S = ShapeMap.find(&Old); 767 if (S != ShapeMap.end()) { 768 ShapeMap.erase(S); 769 if (supportsShapeInfo(New)) 770 ShapeMap.insert({New, S->second}); 771 } 772 Old.replaceAllUsesWith(New); 773 } 774 775 /// Sink a top-level transpose inside matmuls and adds. 776 /// This creates and erases instructions as needed, and returns the newly 777 /// created instruction while updating the iterator to avoid invalidation. If 778 /// this returns nullptr, no new instruction was created. 779 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { 780 BasicBlock &BB = *I.getParent(); 781 IRBuilder<> IB(&I); 782 MatrixBuilder Builder(IB); 783 784 Value *TA, *TAMA, *TAMB; 785 ConstantInt *R, *K, *C; 786 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( 787 m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) 788 return nullptr; 789 790 // Transpose of a transpose is a nop 791 Value *TATA; 792 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 793 updateShapeAndReplaceAllUsesWith(I, TATA); 794 eraseFromParentAndMove(&I, II, BB); 795 eraseFromParentAndMove(TA, II, BB); 796 return nullptr; 797 } 798 799 // k^T -> k 800 if (isSplat(TA)) { 801 updateShapeAndReplaceAllUsesWith(I, TA); 802 eraseFromParentAndMove(&I, II, BB); 803 return nullptr; 804 } 805 806 // (A * B)^t -> B^t * A^t 807 // RxK KxC CxK KxR 808 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 809 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 810 m_ConstantInt(K), m_ConstantInt(C)))) { 811 auto NewInst = distributeTransposes( 812 TAMB, {K, C}, TAMA, {R, K}, Builder, 813 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 814 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, 815 Shape0.NumColumns, 816 Shape1.NumColumns, "mmul"); 817 }); 818 updateShapeAndReplaceAllUsesWith(I, NewInst); 819 eraseFromParentAndMove(&I, II, BB); 820 eraseFromParentAndMove(TA, II, BB); 821 return NewInst; 822 } 823 824 // Same as above, but with a mul, which occurs when multiplied 825 // with a scalar. 826 // (A * k)^t -> A^t * k 827 // R x C RxC 828 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && 829 (isSplat(TAMA) || isSplat(TAMB))) { 830 IRBuilder<> LocalBuilder(&I); 831 // We know that the transposed operand is of shape RxC. 832 // An when multiplied with a scalar, the shape is preserved. 833 auto NewInst = distributeTransposes( 834 TAMA, {R, C}, TAMB, {R, C}, Builder, 835 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 836 bool IsFP = I.getType()->isFPOrFPVectorTy(); 837 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") 838 : LocalBuilder.CreateMul(T0, T1, "mmul"); 839 auto *Result = cast<Instruction>(Mul); 840 setShapeInfo(Result, Shape0); 841 return Result; 842 }); 843 updateShapeAndReplaceAllUsesWith(I, NewInst); 844 eraseFromParentAndMove(&I, II, BB); 845 eraseFromParentAndMove(TA, II, BB); 846 return NewInst; 847 } 848 849 // (A + B)^t -> A^t + B^t 850 // RxC RxC CxR CxR 851 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { 852 IRBuilder<> LocalBuilder(&I); 853 auto NewInst = distributeTransposes( 854 TAMA, {R, C}, TAMB, {R, C}, Builder, 855 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 856 bool IsFP = I.getType()->isFPOrFPVectorTy(); 857 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") 858 : LocalBuilder.CreateAdd(T0, T1, "madd"); 859 860 auto *Result = cast<Instruction>(Add); 861 setShapeInfo(Result, Shape0); 862 return Result; 863 }); 864 updateShapeAndReplaceAllUsesWith(I, NewInst); 865 eraseFromParentAndMove(&I, II, BB); 866 eraseFromParentAndMove(TA, II, BB); 867 return NewInst; 868 } 869 870 return nullptr; 871 } 872 873 void liftTranspose(Instruction &I) { 874 // Erase dead Instructions after lifting transposes from binops. 875 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { 876 if (T.use_empty()) 877 T.eraseFromParent(); 878 if (A->use_empty()) 879 cast<Instruction>(A)->eraseFromParent(); 880 if (A != B && B->use_empty()) 881 cast<Instruction>(B)->eraseFromParent(); 882 }; 883 884 Value *A, *B, *AT, *BT; 885 ConstantInt *R, *K, *C; 886 // A^t * B ^t -> (B * A)^t 887 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 888 m_Value(A), m_Value(B), m_ConstantInt(R), 889 m_ConstantInt(K), m_ConstantInt(C))) && 890 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 891 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 892 IRBuilder<> IB(&I); 893 MatrixBuilder Builder(IB); 894 Value *M = Builder.CreateMatrixMultiply( 895 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 896 setShapeInfo(M, {C, R}); 897 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), 898 R->getZExtValue()); 899 updateShapeAndReplaceAllUsesWith(I, NewInst); 900 CleanupBinOp(I, A, B); 901 } 902 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If 903 // the shape of the second transpose is different, there's a shape conflict 904 // which gets resolved by picking the shape of the first operand. 905 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && 906 match(A, m_Intrinsic<Intrinsic::matrix_transpose>( 907 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && 908 match(B, m_Intrinsic<Intrinsic::matrix_transpose>( 909 m_Value(BT), m_ConstantInt(), m_ConstantInt()))) { 910 IRBuilder<> Builder(&I); 911 auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); 912 setShapeInfo(Add, {R, C}); 913 MatrixBuilder MBuilder(Builder); 914 Instruction *NewInst = MBuilder.CreateMatrixTranspose( 915 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t"); 916 updateShapeAndReplaceAllUsesWith(I, NewInst); 917 assert(computeShapeInfoForInst(NewInst, ShapeMap) == 918 computeShapeInfoForInst(&I, ShapeMap) && 919 "Shape of new instruction doesn't match original shape."); 920 CleanupBinOp(I, A, B); 921 assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) == 922 ShapeMap[Add] && 923 "Shape of updated addition doesn't match cached shape."); 924 } 925 } 926 927 /// Try moving transposes in order to fold them away or into multiplies. 928 void optimizeTransposes() { 929 // First sink all transposes inside matmuls and adds, hoping that we end up 930 // with NN, NT or TN variants. 931 for (BasicBlock &BB : reverse(Func)) { 932 for (auto II = BB.rbegin(); II != BB.rend();) { 933 Instruction &I = *II; 934 // We may remove II. By default continue on the next/prev instruction. 935 ++II; 936 if (Instruction *NewInst = sinkTranspose(I, II)) 937 II = std::next(BasicBlock::reverse_iterator(NewInst)); 938 } 939 } 940 941 // If we have a TT matmul or a TT add, lift the transpose. We may be able 942 // to fold into consuming multiply or add. 943 for (BasicBlock &BB : Func) { 944 for (Instruction &I : llvm::make_early_inc_range(BB)) { 945 liftTranspose(I); 946 } 947 } 948 } 949 950 bool Visit() { 951 SmallVector<Instruction *, 32> WorkList; 952 953 // Initially only the shape of matrix intrinsics is known. 954 // Initialize the work list with ops carrying shape information. 955 for (BasicBlock &BB : Func) 956 for (Instruction &Inst : BB) { 957 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 958 if (!II) 959 continue; 960 961 switch (II->getIntrinsicID()) { 962 case Intrinsic::matrix_multiply: 963 case Intrinsic::matrix_transpose: 964 case Intrinsic::matrix_column_major_load: 965 case Intrinsic::matrix_column_major_store: 966 WorkList.push_back(&Inst); 967 break; 968 default: 969 break; 970 } 971 } 972 973 // Avoid unnecessary work if there are no matrix intrinsics in the function. 974 if (WorkList.empty()) 975 return false; 976 977 // Propagate shapes until nothing changes any longer. 978 while (!WorkList.empty()) { 979 WorkList = propagateShapeForward(WorkList); 980 WorkList = propagateShapeBackward(WorkList); 981 } 982 983 if (!isMinimal()) { 984 optimizeTransposes(); 985 if (PrintAfterTransposeOpt) { 986 dbgs() << "Dump after matrix transpose optimization:\n"; 987 Func.print(dbgs()); 988 } 989 } 990 991 bool Changed = false; 992 SmallVector<CallInst *, 16> MaybeFusableInsts; 993 SmallVector<Instruction *, 16> MatrixInsts; 994 SmallVector<IntrinsicInst *, 16> LifetimeEnds; 995 996 // First, collect all instructions with shape information and candidates for 997 // fusion (currently only matrix multiplies). 998 ReversePostOrderTraversal<Function *> RPOT(&Func); 999 for (auto *BB : RPOT) 1000 for (Instruction &I : *BB) { 1001 if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>())) 1002 LifetimeEnds.push_back(cast<IntrinsicInst>(&I)); 1003 if (ShapeMap.find(&I) == ShapeMap.end()) 1004 continue; 1005 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 1006 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 1007 MatrixInsts.push_back(&I); 1008 } 1009 1010 // Second, try to lower any dot products 1011 SmallPtrSet<Instruction *, 16> FusedInsts; 1012 for (CallInst *CI : MaybeFusableInsts) 1013 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); 1014 1015 // Third, try to fuse candidates. 1016 for (CallInst *CI : MaybeFusableInsts) 1017 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); 1018 1019 Changed = !FusedInsts.empty(); 1020 1021 // Fourth, lower remaining instructions with shape information. 1022 for (Instruction *Inst : MatrixInsts) { 1023 if (FusedInsts.count(Inst)) 1024 continue; 1025 1026 IRBuilder<> Builder(Inst); 1027 1028 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 1029 Changed |= VisitCallInst(CInst); 1030 1031 Value *Op1; 1032 Value *Op2; 1033 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 1034 Changed |= VisitBinaryOperator(BinOp); 1035 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 1036 Changed |= VisitUnaryOperator(UnOp); 1037 if (match(Inst, m_Load(m_Value(Op1)))) 1038 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 1039 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 1040 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 1041 } 1042 1043 if (ORE) { 1044 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 1045 RemarkGen.emitRemarks(); 1046 } 1047 1048 // Delete the instructions backwards, as it has a reduced likelihood of 1049 // having to update as many def-use and use-def chains. 1050 // 1051 // Because we add to ToRemove during fusion we can't guarantee that defs 1052 // are before uses. Change uses to poison temporarily as these should get 1053 // removed as well. 1054 // 1055 // For verification, we keep track of where we changed uses to poison in 1056 // PoisonedInsts and then check that we in fact remove them. 1057 SmallSet<Instruction *, 16> PoisonedInsts; 1058 for (auto *Inst : reverse(ToRemove)) { 1059 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1060 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 1061 PoisonedInsts.insert(Poisoned); 1062 U.set(PoisonValue::get(Inst->getType())); 1063 } 1064 Inst->eraseFromParent(); 1065 PoisonedInsts.erase(Inst); 1066 } 1067 if (!PoisonedInsts.empty()) { 1068 // If we didn't remove all poisoned instructions, it's a hard error. 1069 dbgs() << "Poisoned but present instructions:\n"; 1070 for (auto *I : PoisonedInsts) 1071 dbgs() << *I << "\n"; 1072 llvm_unreachable("Poisoned but instruction not removed"); 1073 } 1074 1075 return Changed; 1076 } 1077 1078 /// Replace intrinsic calls 1079 bool VisitCallInst(CallInst *Inst) { 1080 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 1081 return false; 1082 1083 switch (Inst->getCalledFunction()->getIntrinsicID()) { 1084 case Intrinsic::matrix_multiply: 1085 LowerMultiply(Inst); 1086 break; 1087 case Intrinsic::matrix_transpose: 1088 LowerTranspose(Inst); 1089 break; 1090 case Intrinsic::matrix_column_major_load: 1091 LowerColumnMajorLoad(Inst); 1092 break; 1093 case Intrinsic::matrix_column_major_store: 1094 LowerColumnMajorStore(Inst); 1095 break; 1096 default: 1097 return false; 1098 } 1099 return true; 1100 } 1101 1102 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 1103 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 1104 /// ConstantInt, reduce the initial alignment based on the byte offset. For 1105 /// non-ConstantInt strides, return the common alignment of the initial 1106 /// alignment and the element size in bytes. 1107 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 1108 MaybeAlign A) const { 1109 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 1110 if (Idx == 0) 1111 return InitialAlign; 1112 1113 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 1114 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 1115 uint64_t StrideInBytes = 1116 ConstStride->getZExtValue() * ElementSizeInBits / 8; 1117 return commonAlignment(InitialAlign, Idx * StrideInBytes); 1118 } 1119 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 1120 } 1121 1122 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 1123 /// vectors. 1124 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 1125 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 1126 auto *VType = cast<VectorType>(Ty); 1127 Type *EltTy = VType->getElementType(); 1128 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 1129 Value *EltPtr = Ptr; 1130 MatrixTy Result; 1131 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 1132 Value *GEP = computeVectorAddr( 1133 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 1134 Stride, Shape.getStride(), EltTy, Builder); 1135 Value *Vector = Builder.CreateAlignedLoad( 1136 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 1137 IsVolatile, "col.load"); 1138 1139 Result.addVector(Vector); 1140 } 1141 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 1142 Result.getNumVectors()); 1143 } 1144 1145 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 1146 /// starting at \p MatrixPtr[I][J]. 1147 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 1148 ShapeInfo MatrixShape, Value *I, Value *J, 1149 ShapeInfo ResultShape, Type *EltTy, 1150 IRBuilder<> &Builder) { 1151 1152 Value *Offset = Builder.CreateAdd( 1153 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1154 1155 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1156 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1157 ResultShape.NumColumns); 1158 1159 return loadMatrix(TileTy, TileStart, Align, 1160 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1161 ResultShape, Builder); 1162 } 1163 1164 /// Lower a load instruction with shape information. 1165 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 1166 bool IsVolatile, ShapeInfo Shape) { 1167 IRBuilder<> Builder(Inst); 1168 finalizeLowering(Inst, 1169 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 1170 Shape, Builder), 1171 Builder); 1172 } 1173 1174 /// Lowers llvm.matrix.column.major.load. 1175 /// 1176 /// The intrinsic loads a matrix from memory using a stride between columns. 1177 void LowerColumnMajorLoad(CallInst *Inst) { 1178 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1179 "Intrinsic only supports column-major layout!"); 1180 Value *Ptr = Inst->getArgOperand(0); 1181 Value *Stride = Inst->getArgOperand(1); 1182 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1183 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1184 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1185 } 1186 1187 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1188 /// MatrixPtr[I][J]. 1189 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1190 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1191 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1192 Value *Offset = Builder.CreateAdd( 1193 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1194 1195 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1196 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1197 StoreVal.getNumColumns()); 1198 1199 storeMatrix(TileTy, StoreVal, TileStart, MAlign, 1200 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1201 } 1202 1203 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1204 /// vectors. 1205 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1206 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1207 IRBuilder<> &Builder) { 1208 auto VType = cast<VectorType>(Ty); 1209 Value *EltPtr = Ptr; 1210 for (auto Vec : enumerate(StoreVal.vectors())) { 1211 Value *GEP = computeVectorAddr( 1212 EltPtr, 1213 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1214 Vec.index()), 1215 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1216 Builder.CreateAlignedStore(Vec.value(), GEP, 1217 getAlignForIndex(Vec.index(), Stride, 1218 VType->getElementType(), 1219 MAlign), 1220 IsVolatile); 1221 } 1222 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1223 StoreVal.getNumVectors()); 1224 } 1225 1226 /// Lower a store instruction with shape information. 1227 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1228 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1229 IRBuilder<> Builder(Inst); 1230 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1231 finalizeLowering(Inst, 1232 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1233 IsVolatile, Builder), 1234 Builder); 1235 } 1236 1237 /// Lowers llvm.matrix.column.major.store. 1238 /// 1239 /// The intrinsic store a matrix back memory using a stride between columns. 1240 void LowerColumnMajorStore(CallInst *Inst) { 1241 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1242 "Intrinsic only supports column-major layout!"); 1243 Value *Matrix = Inst->getArgOperand(0); 1244 Value *Ptr = Inst->getArgOperand(1); 1245 Value *Stride = Inst->getArgOperand(2); 1246 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1247 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1248 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1249 } 1250 1251 // Set elements I..I+NumElts-1 to Block 1252 Value *insertVector(Value *Col, unsigned I, Value *Block, 1253 IRBuilder<> &Builder) { 1254 1255 // First, bring Block to the same size as Col 1256 unsigned BlockNumElts = 1257 cast<FixedVectorType>(Block->getType())->getNumElements(); 1258 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1259 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1260 1261 Block = Builder.CreateShuffleVector( 1262 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1263 1264 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1265 // 8, 4, 5, 6 1266 SmallVector<int, 16> Mask; 1267 unsigned i; 1268 for (i = 0; i < I; i++) 1269 Mask.push_back(i); 1270 1271 unsigned VecNumElts = 1272 cast<FixedVectorType>(Col->getType())->getNumElements(); 1273 for (; i < I + BlockNumElts; i++) 1274 Mask.push_back(i - I + VecNumElts); 1275 1276 for (; i < VecNumElts; i++) 1277 Mask.push_back(i); 1278 1279 return Builder.CreateShuffleVector(Col, Block, Mask); 1280 } 1281 1282 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1283 IRBuilder<> &Builder, bool AllowContraction, 1284 unsigned &NumComputeOps) { 1285 NumComputeOps += getNumOps(A->getType()); 1286 if (!Sum) 1287 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1288 1289 if (UseFPOp) { 1290 if (AllowContraction) { 1291 // Use fmuladd for floating point operations and let the backend decide 1292 // if that's profitable. 1293 Function *FMulAdd = Intrinsic::getDeclaration( 1294 Func.getParent(), Intrinsic::fmuladd, A->getType()); 1295 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1296 } 1297 NumComputeOps += getNumOps(A->getType()); 1298 Value *Mul = Builder.CreateFMul(A, B); 1299 return Builder.CreateFAdd(Sum, Mul); 1300 } 1301 1302 NumComputeOps += getNumOps(A->getType()); 1303 Value *Mul = Builder.CreateMul(A, B); 1304 return Builder.CreateAdd(Sum, Mul); 1305 } 1306 1307 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1308 /// users with shape information, there's nothing to do: they will use the 1309 /// cached value when they are lowered. For other users, \p Matrix is 1310 /// flattened and the uses are updated to use it. Also marks \p Inst for 1311 /// deletion. 1312 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1313 IRBuilder<> &Builder) { 1314 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1315 (void)inserted; 1316 assert(inserted.second && "multiple matrix lowering mapping"); 1317 1318 ToRemove.push_back(Inst); 1319 Value *Flattened = nullptr; 1320 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1321 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1322 if (!Flattened) 1323 Flattened = Matrix.embedInVector(Builder); 1324 U.set(Flattened); 1325 } 1326 } 1327 } 1328 1329 /// Special case for MatMul lowering. Prevents scalar loads of row-major 1330 /// vectors Lowers to vector reduction add instead of sequential add if 1331 /// reassocation is enabled. 1332 void lowerDotProduct(CallInst *MatMul, 1333 SmallPtrSet<Instruction *, 16> &FusedInsts, 1334 FastMathFlags FMF) { 1335 if (FusedInsts.contains(MatMul) || 1336 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1337 return; 1338 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1339 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1340 1341 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product 1342 return; 1343 1344 Value *LHS = MatMul->getArgOperand(0); 1345 Value *RHS = MatMul->getArgOperand(1); 1346 1347 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType(); 1348 bool IsIntVec = ElementType->isIntegerTy(); 1349 1350 // Floating point reductions require reassocation. 1351 if (!IsIntVec && !FMF.allowReassoc()) 1352 return; 1353 1354 auto CanBeFlattened = [](Value *Op) { 1355 if (match(Op, m_BinOp())) 1356 return true; 1357 return match( 1358 Op, m_OneUse(m_CombineOr( 1359 m_Load(m_Value()), 1360 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), 1361 m_Intrinsic<Intrinsic::matrix_column_major_load>( 1362 m_Value(), m_SpecificInt(1)))))); 1363 }; 1364 // Returns the cost benefit of using \p Op with the dot product lowering. If 1365 // the returned cost is < 0, the argument is cheaper to use in the 1366 // dot-product lowering. 1367 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { 1368 if (ShapeMap.find(Op) == ShapeMap.end()) 1369 return InstructionCost::getInvalid(); 1370 1371 if (!isa<Instruction>(Op)) 1372 return InstructionCost(0); 1373 1374 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); 1375 Type *EltTy = VecTy->getElementType(); 1376 1377 if (!CanBeFlattened(Op)) { 1378 InstructionCost EmbedCost(0); 1379 // Roughly estimate the cost for embedding the columns into a vector. 1380 for (unsigned I = 1; I < N; ++I) 1381 EmbedCost += 1382 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1383 std::nullopt, TTI::TCK_RecipThroughput); 1384 return EmbedCost; 1385 } 1386 1387 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 1388 InstructionCost OriginalCost = 1389 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), 1390 EltTy) * 1391 N; 1392 InstructionCost NewCost = TTI.getArithmeticInstrCost( 1393 cast<Instruction>(Op)->getOpcode(), VecTy); 1394 return NewCost - OriginalCost; 1395 } 1396 1397 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { 1398 // The transpose can be skipped for the dot product lowering, roughly 1399 // estimate the savings as the cost of embedding the columns in a 1400 // vector. 1401 InstructionCost EmbedCost(0); 1402 for (unsigned I = 1; I < N; ++I) 1403 EmbedCost -= 1404 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1405 std::nullopt, TTI::TCK_RecipThroughput); 1406 return EmbedCost; 1407 } 1408 1409 // Costs for loads. 1410 if (N == 1) 1411 return InstructionCost(0); 1412 1413 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - 1414 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); 1415 }; 1416 1417 // Iterate over LHS and operations feeding LHS and check if it is profitable 1418 // to flatten the visited ops. For each op, we compute the difference 1419 // between the flattened and matrix versions. 1420 SmallPtrSet<Value *, 4> Seen; 1421 SmallVector<Value *> WorkList; 1422 SmallVector<Value *> ToFlatten; 1423 WorkList.push_back(LHS); 1424 InstructionCost LHSCost(0); 1425 while (!WorkList.empty()) { 1426 Value *Op = WorkList.pop_back_val(); 1427 if (!Seen.insert(Op).second) 1428 continue; 1429 1430 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns); 1431 if (OpCost + LHSCost >= LHSCost) 1432 continue; 1433 1434 LHSCost += OpCost; 1435 ToFlatten.push_back(Op); 1436 if (auto *I = dyn_cast<Instruction>(Op)) 1437 WorkList.append(I->op_begin(), I->op_end()); 1438 } 1439 1440 // We compare the costs of a vector.reduce.add to sequential add. 1441 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; 1442 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; 1443 InstructionCost ReductionCost = 1444 TTI.getArithmeticReductionCost( 1445 AddOpCode, cast<VectorType>(LHS->getType()), 1446 IsIntVec ? std::nullopt : std::optional(FMF)) + 1447 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); 1448 InstructionCost SequentialAddCost = 1449 TTI.getArithmeticInstrCost(AddOpCode, ElementType) * 1450 (LShape.NumColumns - 1) + 1451 TTI.getArithmeticInstrCost(MulOpCode, ElementType) * 1452 (LShape.NumColumns); 1453 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) 1454 return; 1455 1456 FusedInsts.insert(MatMul); 1457 IRBuilder<> Builder(MatMul); 1458 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, 1459 this](Value *Op) { 1460 // Matmul must be the only user of loads because we don't use LowerLoad 1461 // for row vectors (LowerLoad results in scalar loads and shufflevectors 1462 // instead of single vector load). 1463 if (!CanBeFlattened(Op)) 1464 return; 1465 1466 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 1467 ShapeMap[Op] = ShapeMap[Op].t(); 1468 return; 1469 } 1470 1471 FusedInsts.insert(cast<Instruction>(Op)); 1472 // If vector uses the builtin load, lower to a LoadInst 1473 Value *Arg; 1474 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( 1475 m_Value(Arg)))) { 1476 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); 1477 Op->replaceAllUsesWith(NewLoad); 1478 cast<Instruction>(Op)->eraseFromParent(); 1479 return; 1480 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( 1481 m_Value(Arg)))) { 1482 ToRemove.push_back(cast<Instruction>(Op)); 1483 Op->replaceAllUsesWith(Arg); 1484 return; 1485 } 1486 }; 1487 1488 for (auto *V : ToFlatten) 1489 FlattenArg(V); 1490 1491 LHS = MatMul->getArgOperand(0); 1492 1493 // Insert mul/fmul and llvm.vector.reduce.fadd 1494 Value *Mul = 1495 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); 1496 1497 Value *Result; 1498 if (IsIntVec) 1499 Result = Builder.CreateAddReduce(Mul); 1500 else { 1501 Result = Builder.CreateFAddReduce( 1502 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(), 1503 0.0), 1504 Mul); 1505 cast<Instruction>(Result)->setFastMathFlags(FMF); 1506 } 1507 1508 // pack scalar back into a matrix and then replace matmul inst 1509 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), 1510 Result, uint64_t(0)); 1511 MatMul->replaceAllUsesWith(Result); 1512 FusedInsts.insert(MatMul); 1513 ToRemove.push_back(MatMul); 1514 } 1515 1516 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1517 /// addition. 1518 /// 1519 /// We can fold a transpose into the operand that is used to extract scalars. 1520 /// This is the first operands with row-major and the second with 1521 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1522 /// operand is transposed. 1523 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1524 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1525 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1526 const unsigned VF = std::max<unsigned>( 1527 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1528 .getFixedValue() / 1529 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1530 1U); 1531 unsigned R = Result.getNumRows(); 1532 unsigned C = Result.getNumColumns(); 1533 unsigned M = A.getNumColumns(); 1534 1535 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1536 assert(A.isColumnMajor() == B.isColumnMajor() && 1537 Result.isColumnMajor() == A.isColumnMajor() && 1538 "operands must agree on matrix layout"); 1539 unsigned NumComputeOps = 0; 1540 1541 Builder.setFastMathFlags(FMF); 1542 1543 if (A.isColumnMajor()) { 1544 // Multiply columns from the first operand with scalars from the second 1545 // operand. Then move along the K axes and accumulate the columns. With 1546 // this the adds can be vectorized without reassociation. 1547 for (unsigned J = 0; J < C; ++J) { 1548 unsigned BlockSize = VF; 1549 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1550 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1551 1552 for (unsigned I = 0; I < R; I += BlockSize) { 1553 // Gradually lower the vectorization factor to cover the remainder. 1554 while (I + BlockSize > R) 1555 BlockSize /= 2; 1556 1557 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1558 : nullptr; 1559 for (unsigned K = 0; K < M; ++K) { 1560 Value *L = A.extractVector(I, K, BlockSize, Builder); 1561 Value *RH = Builder.CreateExtractElement( 1562 B.getColumn(IsScalarMatrixTransposed ? K : J), 1563 IsScalarMatrixTransposed ? J : K); 1564 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1565 Sum = 1566 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1567 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1568 } 1569 Result.setVector(J, 1570 insertVector(Result.getVector(J), I, Sum, Builder)); 1571 } 1572 } 1573 } else { 1574 // Multiply rows from the second operand with scalars from the first 1575 // operand. Then move along the K axes and accumulate the rows. With this 1576 // the adds can be vectorized without reassociation. 1577 for (unsigned I = 0; I < R; ++I) { 1578 unsigned BlockSize = VF; 1579 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1580 for (unsigned J = 0; J < C; J += BlockSize) { 1581 // Gradually lower the vectorization factor to cover the remainder. 1582 while (J + BlockSize > C) 1583 BlockSize /= 2; 1584 1585 Value *Sum = nullptr; 1586 for (unsigned K = 0; K < M; ++K) { 1587 Value *R = B.extractVector(K, J, BlockSize, Builder); 1588 Value *LH = Builder.CreateExtractElement( 1589 A.getVector(IsScalarMatrixTransposed ? K : I), 1590 IsScalarMatrixTransposed ? I : K); 1591 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1592 Sum = 1593 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1594 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1595 } 1596 Result.setVector(I, 1597 insertVector(Result.getVector(I), J, Sum, Builder)); 1598 } 1599 } 1600 } 1601 Result.addNumComputeOps(NumComputeOps); 1602 } 1603 1604 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1605 /// copying it to a new location. This new or otherwise the original location 1606 /// is returned. 1607 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1608 CallInst *MatMul) { 1609 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1610 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1611 1612 // If we can statically determine noalias we're good. 1613 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1614 return Load->getPointerOperand(); 1615 1616 // Create code to check if the memory locations of the Load and Store 1617 // overlap and if they do, copy Load's operand to a new buffer. 1618 1619 // First, create new blocks for 2n part of the check and the copy. 1620 BasicBlock *Check0 = MatMul->getParent(); 1621 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1622 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1623 // as we adjust Check0 and Check1's branches. 1624 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1625 for (BasicBlock *Succ : successors(Check0)) 1626 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1627 1628 BasicBlock *Check1 = 1629 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1630 nullptr, "alias_cont"); 1631 BasicBlock *Copy = 1632 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1633 nullptr, "copy"); 1634 BasicBlock *Fusion = 1635 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1636 nullptr, "no_alias"); 1637 1638 // Check if the loaded memory location begins before the end of the store 1639 // location. If the condition holds, they might overlap, otherwise they are 1640 // guaranteed to not overlap. 1641 IRBuilder<> Builder(MatMul); 1642 Check0->getTerminator()->eraseFromParent(); 1643 Builder.SetInsertPoint(Check0); 1644 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout()); 1645 Value *StoreBegin = Builder.CreatePtrToInt( 1646 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1647 Value *StoreEnd = Builder.CreateAdd( 1648 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1649 "store.end", true, true); 1650 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1651 IntPtrTy, "load.begin"); 1652 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1653 Fusion); 1654 1655 // Check if the store begins before the end of the load location. If the 1656 // condition holds, they alias, otherwise they are guaranteed to not 1657 // overlap. 1658 Check1->getTerminator()->eraseFromParent(); 1659 Builder.SetInsertPoint(Check1, Check1->begin()); 1660 Value *LoadEnd = Builder.CreateAdd( 1661 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1662 "load.end", true, true); 1663 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1664 Fusion); 1665 1666 // Copy load operand to new alloca. 1667 Builder.SetInsertPoint(Copy, Copy->begin()); 1668 auto *VT = cast<FixedVectorType>(Load->getType()); 1669 // Use an array type for the alloca, to avoid potentially huge alignment 1670 // requirements for large vector types. 1671 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1672 AllocaInst *Alloca = 1673 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1674 1675 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), 1676 Load->getAlign(), LoadLoc.Size.getValue()); 1677 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1678 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1679 PHI->addIncoming(Load->getPointerOperand(), Check0); 1680 PHI->addIncoming(Load->getPointerOperand(), Check1); 1681 PHI->addIncoming(Alloca, Copy); 1682 1683 // Adjust DT. 1684 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1685 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1686 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1687 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1688 DT->applyUpdates(DTUpdates); 1689 return PHI; 1690 } 1691 1692 bool isFusionProfitable(CallInst *MatMul) { 1693 if (ForceFusion) 1694 return true; 1695 1696 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1697 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1698 1699 const unsigned R = LShape.NumRows; 1700 const unsigned C = RShape.NumColumns; 1701 const unsigned M = LShape.NumColumns; 1702 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1703 1704 const unsigned VF = std::max<unsigned>( 1705 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1706 .getFixedValue() / 1707 EltType->getPrimitiveSizeInBits().getFixedValue(), 1708 1U); 1709 1710 // Cost model for tiling 1711 // 1712 // For tiling to be beneficial, we need reuse either along the R or 1713 // the C axis. We vectorize along the R axis so that means at least 1714 // 3 elements. 1715 // TODO: Also consider cost of copying if operands alias. 1716 if (R <= VF && C == 1) 1717 return false; 1718 // Then we need enough elements to exceed the number of vector 1719 // registers we have. Note that this is an oversimplification since 1720 // fusing also takes some extra loads which may exceed the number of 1721 // reloads necessary. 1722 unsigned Op0Regs = (R + VF - 1) / VF * M; 1723 unsigned Op1Regs = (M + VF - 1) / VF * C; 1724 return Op0Regs + Op1Regs > 1725 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1726 } 1727 1728 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1729 MatrixTy Res; 1730 auto *ColumType = FixedVectorType::get(EltType, R); 1731 for (unsigned I = 0; I < C; ++I) 1732 Res.addVector(ConstantAggregateZero::get(ColumType)); 1733 return Res; 1734 } 1735 1736 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1737 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1738 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1739 1740 // Create the main tiling loop nest. 1741 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1742 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1743 Instruction *InsertI = cast<Instruction>(MatMul); 1744 BasicBlock *Start = InsertI->getParent(); 1745 BasicBlock *End = 1746 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1747 IRBuilder<> Builder(MatMul); 1748 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1749 1750 Type *TileVecTy = 1751 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1752 MatrixTy TileResult; 1753 // Insert in the inner loop header. 1754 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1755 // Create PHI nodes for the result columns to accumulate across iterations. 1756 SmallVector<PHINode *, 4> ColumnPhis; 1757 for (unsigned I = 0; I < TileSize; I++) { 1758 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1759 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1760 TI.RowLoop.Header->getSingleSuccessor()); 1761 TileResult.addVector(Phi); 1762 ColumnPhis.push_back(Phi); 1763 } 1764 1765 // Insert in the inner loop body, which computes 1766 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1767 Builder.SetInsertPoint(InnerBody->getTerminator()); 1768 // Load tiles of the operands. 1769 MatrixTy A = 1770 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1771 {TileSize, TileSize}, EltType, Builder); 1772 MatrixTy B = 1773 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1774 {TileSize, TileSize}, EltType, Builder); 1775 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1776 getFastMathFlags(MatMul)); 1777 // Store result after the inner loop is done. 1778 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1779 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1780 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1781 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1782 1783 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1784 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1785 1786 // Force unrolling of a few iterations of the inner loop, to make sure there 1787 // is enough work per iteration. 1788 // FIXME: The unroller should make this decision directly instead, but 1789 // currently the cost-model is not up to the task. 1790 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1791 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1792 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1793 } 1794 1795 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1796 StoreInst *Store, 1797 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1798 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1799 "Tiling only supported for column-major matrixes at the moment!"); 1800 if (!isFusionProfitable(MatMul)) 1801 return; 1802 1803 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1804 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1805 1806 const unsigned R = LShape.NumRows; 1807 const unsigned C = RShape.NumColumns; 1808 const unsigned M = LShape.NumColumns; 1809 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1810 1811 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1812 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1813 Value *CPtr = Store->getPointerOperand(); 1814 1815 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1816 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1817 else { 1818 IRBuilder<> Builder(Store); 1819 for (unsigned J = 0; J < C; J += TileSize) 1820 for (unsigned I = 0; I < R; I += TileSize) { 1821 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1822 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1823 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1824 1825 for (unsigned K = 0; K < M; K += TileSize) { 1826 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1827 MatrixTy A = 1828 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1829 LShape, Builder.getInt64(I), Builder.getInt64(K), 1830 {TileR, TileM}, EltType, Builder); 1831 MatrixTy B = 1832 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1833 RShape, Builder.getInt64(K), Builder.getInt64(J), 1834 {TileM, TileC}, EltType, Builder); 1835 emitMatrixMultiply(Res, A, B, Builder, true, false, 1836 getFastMathFlags(MatMul)); 1837 } 1838 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1839 Builder.getInt64(I), Builder.getInt64(J), EltType, 1840 Builder); 1841 } 1842 } 1843 1844 // Mark eliminated instructions as fused and remove them. 1845 FusedInsts.insert(Store); 1846 FusedInsts.insert(MatMul); 1847 Store->eraseFromParent(); 1848 MatMul->eraseFromParent(); 1849 if (LoadOp0->hasNUses(0)) { 1850 FusedInsts.insert(LoadOp0); 1851 LoadOp0->eraseFromParent(); 1852 } 1853 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 1854 FusedInsts.insert(LoadOp1); 1855 LoadOp1->eraseFromParent(); 1856 } 1857 } 1858 1859 /// Try to lower matrix multiply chains by fusing operations. 1860 /// 1861 /// Call finalizeLowering on lowered instructions. Instructions that are 1862 /// completely eliminated by fusion are added to \p FusedInsts. 1863 void 1864 LowerMatrixMultiplyFused(CallInst *MatMul, 1865 SmallPtrSetImpl<Instruction *> &FusedInsts, 1866 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) { 1867 if (!FuseMatrix || !DT) 1868 return; 1869 1870 assert(AA && LI && "Analyses should be available"); 1871 1872 Value *A = MatMul->getArgOperand(0); 1873 Value *B = MatMul->getArgOperand(1); 1874 1875 // We can fold the transpose into the operand that is used to fetch scalars. 1876 Value *T; 1877 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1878 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1879 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1880 IRBuilder<> Builder(MatMul); 1881 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1882 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1883 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1884 const unsigned R = LShape.NumRows; 1885 const unsigned M = LShape.NumColumns; 1886 const unsigned C = RShape.NumColumns; 1887 1888 MatrixTy MA; 1889 MatrixTy MB; 1890 1891 Value *Transpose; 1892 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1893 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1894 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1895 Transpose = B; 1896 } else { 1897 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1898 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1899 Transpose = A; 1900 } 1901 1902 // Initialize the output 1903 MatrixTy Result(R, C, EltType); 1904 1905 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1906 getFastMathFlags(MatMul)); 1907 1908 FusedInsts.insert(MatMul); 1909 if (Transpose->hasOneUse()) { 1910 FusedInsts.insert(cast<Instruction>(Transpose)); 1911 ToRemove.push_back(cast<Instruction>(Transpose)); 1912 // TODO: add a fake entry for the folded instruction so that this is 1913 // included in the expression in the remark. 1914 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1915 } 1916 finalizeLowering(MatMul, Result, Builder); 1917 return; 1918 } 1919 1920 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1921 return; 1922 1923 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1924 // since the single store user will be lowered as part of this. 1925 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1926 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1927 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1928 if (LoadOp0 && LoadOp1 && Store) { 1929 // The store address must dominate the MatMul instruction, otherwise 1930 // we create invalid IR. 1931 SetVector<Value *> WorkList; 1932 WorkList.insert(Store->getOperand(1)); 1933 SmallVector<Instruction *> ToHoist; 1934 for (unsigned I = 0; I != WorkList.size(); ++I) { 1935 Value *Current = WorkList[I]; 1936 auto *CurrI = dyn_cast<Instruction>(Current); 1937 if (!CurrI) 1938 continue; 1939 if (isa<PHINode>(CurrI)) 1940 return; 1941 if (DT->dominates(CurrI, MatMul)) 1942 continue; 1943 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1944 return; 1945 ToHoist.push_back(CurrI); 1946 WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1947 } 1948 1949 sort(ToHoist, [this](Instruction *A, Instruction *B) { 1950 return DT->dominates(A, B); 1951 }); 1952 for (Instruction *I : ToHoist) 1953 I->moveBefore(MatMul); 1954 1955 // Deal with lifetime.end calls that might be between Load0/Load1 and the 1956 // store. To avoid introducing loads to dead objects (i.e. after the 1957 // lifetime has been termined by @llvm.lifetime.end), either sink them 1958 // after the store if in the same block, or remove the lifetime.end marker 1959 // otherwise. This might pessimize further optimizations, by extending the 1960 // lifetime of the object until the function returns, but should be 1961 // conservatively correct. 1962 MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0); 1963 MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1); 1964 BasicBlock *StoreParent = Store->getParent(); 1965 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent && 1966 LoadOp1->getParent() == StoreParent; 1967 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) { 1968 IntrinsicInst *End = LifetimeEnds[Idx]; 1969 auto Inc = make_scope_exit([&Idx]() { Idx++; }); 1970 // If the lifetime.end is guaranteed to be before the loads or after the 1971 // store, it won't interfere with fusion. 1972 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1)) 1973 continue; 1974 if (DT->dominates(Store, End)) 1975 continue; 1976 // If all fusable ops are in the same block and the lifetime.end is in a 1977 // different block, it won't interfere with fusion. 1978 if (FusableOpsInSameBlock && End->getParent() != StoreParent) 1979 continue; 1980 1981 // If the loads don't alias the lifetime.end, it won't interfere with 1982 // fusion. 1983 MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr); 1984 if (!EndLoc.Ptr) 1985 continue; 1986 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc)) 1987 continue; 1988 1989 // If both lifetime.end and the store are in the same block, extend the 1990 // lifetime until after the store, so the new lifetime covers the loads 1991 // we introduce later. 1992 if (End->getParent() == StoreParent) { 1993 End->moveAfter(Store); 1994 continue; 1995 } 1996 1997 // Otherwise remove the conflicting lifetime.end marker. 1998 ToRemove.push_back(End); 1999 std::swap(LifetimeEnds[Idx], LifetimeEnds.back()); 2000 LifetimeEnds.pop_back(); 2001 Inc.release(); 2002 } 2003 2004 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 2005 return; 2006 } 2007 } 2008 2009 /// Lowers llvm.matrix.multiply. 2010 void LowerMultiply(CallInst *MatMul) { 2011 IRBuilder<> Builder(MatMul); 2012 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 2013 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 2014 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 2015 2016 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 2017 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 2018 assert(Lhs.getElementType() == Rhs.getElementType() && 2019 "Matrix multiply argument element types do not match."); 2020 2021 const unsigned R = LShape.NumRows; 2022 const unsigned C = RShape.NumColumns; 2023 assert(LShape.NumColumns == RShape.NumRows); 2024 2025 // Initialize the output 2026 MatrixTy Result(R, C, EltType); 2027 assert(Lhs.getElementType() == Result.getElementType() && 2028 "Matrix multiply result element type does not match arguments."); 2029 2030 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 2031 getFastMathFlags(MatMul)); 2032 finalizeLowering(MatMul, Result, Builder); 2033 } 2034 2035 /// Lowers llvm.matrix.transpose. 2036 void LowerTranspose(CallInst *Inst) { 2037 MatrixTy Result; 2038 IRBuilder<> Builder(Inst); 2039 Value *InputVal = Inst->getArgOperand(0); 2040 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 2041 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 2042 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 2043 2044 const unsigned NewNumVecs = 2045 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 2046 const unsigned NewNumElts = 2047 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 2048 2049 for (unsigned I = 0; I < NewNumVecs; ++I) { 2050 // Build a single result vector. First initialize it. 2051 Value *ResultVector = PoisonValue::get( 2052 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 2053 // Go through the old elements and insert it into the resulting vector. 2054 for (auto J : enumerate(InputMatrix.vectors())) { 2055 Value *Elt = Builder.CreateExtractElement(J.value(), I); 2056 // Row and column indices are transposed. 2057 ResultVector = 2058 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 2059 } 2060 Result.addVector(ResultVector); 2061 } 2062 2063 // TODO: Improve estimate of operations needed for transposes. Currently we 2064 // just count the insertelement/extractelement instructions, but do not 2065 // account for later simplifications/combines. 2066 finalizeLowering( 2067 Inst, 2068 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 2069 .addNumExposedTransposes(1), 2070 Builder); 2071 } 2072 2073 /// Lower load instructions, if shape information is available. 2074 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 2075 auto I = ShapeMap.find(Inst); 2076 if (I == ShapeMap.end()) 2077 return false; 2078 2079 LowerLoad(Inst, Ptr, Inst->getAlign(), 2080 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 2081 I->second); 2082 return true; 2083 } 2084 2085 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 2086 IRBuilder<> &Builder) { 2087 auto I = ShapeMap.find(StoredVal); 2088 if (I == ShapeMap.end()) 2089 return false; 2090 2091 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 2092 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 2093 I->second); 2094 return true; 2095 } 2096 2097 /// Lower binary operators, if shape information is available. 2098 bool VisitBinaryOperator(BinaryOperator *Inst) { 2099 auto I = ShapeMap.find(Inst); 2100 if (I == ShapeMap.end()) 2101 return false; 2102 2103 Value *Lhs = Inst->getOperand(0); 2104 Value *Rhs = Inst->getOperand(1); 2105 2106 IRBuilder<> Builder(Inst); 2107 ShapeInfo &Shape = I->second; 2108 2109 MatrixTy Result; 2110 MatrixTy A = getMatrix(Lhs, Shape, Builder); 2111 MatrixTy B = getMatrix(Rhs, Shape, Builder); 2112 assert(A.isColumnMajor() == B.isColumnMajor() && 2113 Result.isColumnMajor() == A.isColumnMajor() && 2114 "operands must agree on matrix layout"); 2115 2116 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2117 2118 // Helper to perform binary op on vectors. 2119 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 2120 switch (Inst->getOpcode()) { 2121 case Instruction::Add: 2122 return Builder.CreateAdd(LHS, RHS); 2123 case Instruction::Mul: 2124 return Builder.CreateMul(LHS, RHS); 2125 case Instruction::Sub: 2126 return Builder.CreateSub(LHS, RHS); 2127 case Instruction::FAdd: 2128 return Builder.CreateFAdd(LHS, RHS); 2129 case Instruction::FMul: 2130 return Builder.CreateFMul(LHS, RHS); 2131 case Instruction::FSub: 2132 return Builder.CreateFSub(LHS, RHS); 2133 default: 2134 llvm_unreachable("Unsupported binary operator for matrix"); 2135 } 2136 }; 2137 2138 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2139 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 2140 2141 finalizeLowering(Inst, 2142 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2143 Result.getNumVectors()), 2144 Builder); 2145 return true; 2146 } 2147 2148 /// Lower unary operators, if shape information is available. 2149 bool VisitUnaryOperator(UnaryOperator *Inst) { 2150 auto I = ShapeMap.find(Inst); 2151 if (I == ShapeMap.end()) 2152 return false; 2153 2154 Value *Op = Inst->getOperand(0); 2155 2156 IRBuilder<> Builder(Inst); 2157 ShapeInfo &Shape = I->second; 2158 2159 MatrixTy Result; 2160 MatrixTy M = getMatrix(Op, Shape, Builder); 2161 2162 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2163 2164 // Helper to perform unary op on vectors. 2165 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 2166 switch (Inst->getOpcode()) { 2167 case Instruction::FNeg: 2168 return Builder.CreateFNeg(Op); 2169 default: 2170 llvm_unreachable("Unsupported unary operator for matrix"); 2171 } 2172 }; 2173 2174 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2175 Result.addVector(BuildVectorOp(M.getVector(I))); 2176 2177 finalizeLowering(Inst, 2178 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2179 Result.getNumVectors()), 2180 Builder); 2181 return true; 2182 } 2183 2184 /// Helper to linearize a matrix expression tree into a string. Currently 2185 /// matrix expressions are linarized by starting at an expression leaf and 2186 /// linearizing bottom up. 2187 struct ExprLinearizer { 2188 unsigned LengthToBreak = 100; 2189 std::string Str; 2190 raw_string_ostream Stream; 2191 unsigned LineLength = 0; 2192 const DataLayout &DL; 2193 2194 /// Mapping from instructions to matrixes. It is used to identify 2195 /// matrix instructions. 2196 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2197 2198 /// Mapping from values to the leaves of all expressions that the value is 2199 /// part of. 2200 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 2201 2202 /// Set of matrix expressions in the scope of a given DISubprogram. 2203 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 2204 2205 /// Leaf node of the expression to linearize. 2206 Value *Leaf; 2207 2208 /// Used to keep track of sub-expressions that get reused while linearizing 2209 /// the expression. Re-used sub-expressions are marked as (reused). 2210 SmallPtrSet<Value *, 8> ReusedExprs; 2211 2212 ExprLinearizer(const DataLayout &DL, 2213 const MapVector<Value *, MatrixTy> &Inst2Matrix, 2214 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2215 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2216 Value *Leaf) 2217 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 2218 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 2219 2220 void indent(unsigned N) { 2221 LineLength += N; 2222 for (unsigned i = 0; i < N; i++) 2223 Stream << " "; 2224 } 2225 2226 void lineBreak() { 2227 Stream << "\n"; 2228 LineLength = 0; 2229 } 2230 2231 void maybeIndent(unsigned Indent) { 2232 if (LineLength >= LengthToBreak) 2233 lineBreak(); 2234 2235 if (LineLength == 0) 2236 indent(Indent); 2237 } 2238 2239 void write(StringRef S) { 2240 LineLength += S.size(); 2241 Stream << S; 2242 } 2243 2244 Value *getUnderlyingObjectThroughLoads(Value *V) { 2245 if (Value *Ptr = getPointerOperand(V)) 2246 return getUnderlyingObjectThroughLoads(Ptr); 2247 else if (V->getType()->isPointerTy()) 2248 return getUnderlyingObject(V); 2249 return V; 2250 } 2251 2252 /// Returns true if \p V is a matrix value in the given subprogram. 2253 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 2254 2255 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to 2256 /// \p SS. 2257 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 2258 auto M = Inst2Matrix.find(V); 2259 if (M == Inst2Matrix.end()) 2260 SS << "unknown"; 2261 else { 2262 SS << M->second.getNumRows(); 2263 SS << "x"; 2264 SS << M->second.getNumColumns(); 2265 } 2266 } 2267 2268 /// Write the called function name. Handles calls to llvm.matrix.* 2269 /// specially: we write the name, followed by the dimensions of the input 2270 /// matrixes, followed by the scalar type name. 2271 void writeFnName(CallInst *CI) { 2272 if (!CI->getCalledFunction()) 2273 write("<no called fn>"); 2274 else { 2275 StringRef Name = CI->getCalledFunction()->getName(); 2276 if (!Name.starts_with("llvm.matrix")) { 2277 write(Name); 2278 return; 2279 } 2280 auto *II = cast<IntrinsicInst>(CI); 2281 write(Intrinsic::getBaseName(II->getIntrinsicID()) 2282 .drop_front(StringRef("llvm.matrix.").size())); 2283 write("."); 2284 std::string Tmp; 2285 raw_string_ostream SS(Tmp); 2286 2287 switch (II->getIntrinsicID()) { 2288 case Intrinsic::matrix_multiply: 2289 prettyPrintMatrixType(II->getOperand(0), SS); 2290 SS << "."; 2291 prettyPrintMatrixType(II->getOperand(1), SS); 2292 SS << "." << *II->getType()->getScalarType(); 2293 break; 2294 case Intrinsic::matrix_transpose: 2295 prettyPrintMatrixType(II->getOperand(0), SS); 2296 SS << "." << *II->getType()->getScalarType(); 2297 break; 2298 case Intrinsic::matrix_column_major_load: 2299 prettyPrintMatrixType(II, SS); 2300 SS << "." << *II->getType()->getScalarType(); 2301 break; 2302 case Intrinsic::matrix_column_major_store: 2303 prettyPrintMatrixType(II->getOperand(0), SS); 2304 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 2305 break; 2306 default: 2307 llvm_unreachable("Unhandled case"); 2308 } 2309 SS.flush(); 2310 write(Tmp); 2311 } 2312 } 2313 2314 unsigned getNumShapeArgs(CallInst *CI) const { 2315 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 2316 switch (II->getIntrinsicID()) { 2317 case Intrinsic::matrix_multiply: 2318 return 3; 2319 case Intrinsic::matrix_transpose: 2320 return 2; 2321 case Intrinsic::matrix_column_major_load: 2322 case Intrinsic::matrix_column_major_store: 2323 return 3; 2324 default: 2325 return 0; 2326 } 2327 } 2328 return 0; 2329 } 2330 2331 /// Special printing for values: for pointers, we print if they refer to an 2332 /// (function) external address or a stack address, for other values we 2333 /// either print the constant or "scalar"/"matrix" for other values. 2334 void write(Value *V) { 2335 V = getUnderlyingObjectThroughLoads(V); 2336 if (V->getType()->isPointerTy()) { 2337 if (isa<AllocaInst>(V)) { 2338 Stream << "stack addr"; 2339 LineLength += StringRef("stack addr").size(); 2340 } else { 2341 Stream << "addr"; 2342 LineLength += StringRef("addr").size(); 2343 } 2344 if (!V->getName().empty()) { 2345 Stream << " %" << V->getName() << ""; 2346 LineLength += V->getName().size() + 2; 2347 } 2348 return; 2349 } 2350 2351 std::string Tmp; 2352 raw_string_ostream TmpStream(Tmp); 2353 2354 if (auto *CI = dyn_cast<ConstantInt>(V)) 2355 TmpStream << CI->getValue(); 2356 else if (isa<Constant>(V)) 2357 TmpStream << "constant"; 2358 else { 2359 if (isMatrix(V)) 2360 TmpStream << "matrix"; 2361 else 2362 TmpStream << "scalar"; 2363 } 2364 TmpStream.flush(); 2365 Tmp = std::string(StringRef(Tmp).trim()); 2366 LineLength += Tmp.size(); 2367 Stream << Tmp; 2368 } 2369 2370 /// Linearize expression \p Expr starting at an indentation of \p Indent. 2371 /// Expressions that are re-used multiple times are prefixed with (reused) 2372 /// at the re-used root instruction. 2373 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 2374 bool ParentShared) { 2375 auto *I = cast<Instruction>(Expr); 2376 maybeIndent(Indent); 2377 SmallVector<Value *, 8> Ops; 2378 2379 // Is Expr shared with other expression leaves? 2380 bool ExprShared = false; 2381 2382 // Deal with shared subtrees. Mark them as shared, if required. 2383 if (!ParentShared) { 2384 auto SI = Shared.find(Expr); 2385 assert(SI != Shared.end() && SI->second.count(Leaf)); 2386 2387 for (Value *S : SI->second) { 2388 if (S == Leaf) 2389 continue; 2390 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2391 write("shared with remark at line " + std::to_string(DL.getLine()) + 2392 " column " + std::to_string(DL.getCol()) + " ("); 2393 } 2394 ExprShared = SI->second.size() > 1; 2395 } 2396 2397 bool Reused = !ReusedExprs.insert(Expr).second; 2398 if (Reused && !ParentReused) 2399 write("(reused) "); 2400 2401 if (auto *CI = dyn_cast<CallInst>(I)) { 2402 writeFnName(CI); 2403 2404 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2405 } else if (isa<BitCastInst>(Expr)) { 2406 // Special case bitcasts, which are used to materialize matrixes from 2407 // non-matrix ops. 2408 write("matrix"); 2409 return; 2410 } else { 2411 Ops.append(I->value_op_begin(), I->value_op_end()); 2412 write(std::string(I->getOpcodeName())); 2413 } 2414 2415 write(std::string("(")); 2416 2417 unsigned NumOpsToBreak = 1; 2418 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2419 NumOpsToBreak = 2; 2420 2421 for (Value *Op : Ops) { 2422 if (Ops.size() > NumOpsToBreak) 2423 lineBreak(); 2424 2425 maybeIndent(Indent + 1); 2426 if (isMatrix(Op)) 2427 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2428 else 2429 write(Op); 2430 if (Op != Ops.back()) 2431 write(", "); 2432 } 2433 2434 write(")"); 2435 } 2436 2437 const std::string &getResult() { 2438 Stream.flush(); 2439 return Str; 2440 } 2441 }; 2442 2443 /// Generate remarks for matrix operations in a function. To generate remarks 2444 /// for matrix expressions, the following approach is used: 2445 /// 1. Use the inlined-at debug information to group matrix operations to the 2446 /// DISubprograms they are contained in. 2447 /// 2. Collect leaves of matrix expressions (done in 2448 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2449 // mapping. Leaves are lowered matrix instructions without other matrix 2450 // users (like stores) in the current subprogram. 2451 /// 3. For each leaf, create a remark containing a linearizied version of the 2452 /// matrix expression. The expression is linearized by a recursive 2453 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2454 /// that multiple leaves can share sub-expressions. Shared subexpressions 2455 /// are explicitly marked as shared(). 2456 struct RemarkGenerator { 2457 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2458 OptimizationRemarkEmitter &ORE; 2459 Function &Func; 2460 const DataLayout &DL; 2461 2462 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2463 OptimizationRemarkEmitter &ORE, Function &Func) 2464 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2465 DL(Func.getDataLayout()) {} 2466 2467 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2468 /// instructions in Inst2Matrix returning void or without any users in 2469 /// \p ExprsInSubprogram. Currently that should only include stores. 2470 SmallVector<Value *, 4> 2471 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2472 SmallVector<Value *, 4> Leaves; 2473 for (auto *Expr : ExprsInSubprogram) 2474 if (Expr->getType()->isVoidTy() || 2475 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2476 return ExprsInSubprogram.count(U); 2477 })) 2478 Leaves.push_back(Expr); 2479 return Leaves; 2480 } 2481 2482 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2483 /// to all visited expressions in \p Shared. Limit the matrix operations to 2484 /// the ones in \p ExprsInSubprogram. 2485 void collectSharedInfo(Value *Leaf, Value *V, 2486 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2487 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2488 2489 if (!ExprsInSubprogram.count(V)) 2490 return; 2491 2492 auto I = Shared.insert({V, {}}); 2493 I.first->second.insert(Leaf); 2494 2495 for (Value *Op : cast<Instruction>(V)->operand_values()) 2496 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2497 } 2498 2499 /// Calculate the number of exclusive and shared op counts for expression 2500 /// starting at \p V. Expressions used multiple times are counted once. 2501 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2502 std::pair<OpInfoTy, OpInfoTy> 2503 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2504 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2505 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2506 if (!ExprsInSubprogram.count(Root)) 2507 return {}; 2508 2509 // Already counted this expression. Stop. 2510 if (!ReusedExprs.insert(Root).second) 2511 return {}; 2512 2513 OpInfoTy SharedCount; 2514 OpInfoTy Count; 2515 2516 auto I = Shared.find(Root); 2517 auto CM = Inst2Matrix.find(Root); 2518 if (I->second.size() == 1) 2519 Count = CM->second.getOpInfo(); 2520 else 2521 SharedCount = CM->second.getOpInfo(); 2522 2523 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2524 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2525 Count += C.first; 2526 SharedCount += C.second; 2527 } 2528 return {Count, SharedCount}; 2529 } 2530 2531 void emitRemarks() { 2532 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2533 return; 2534 2535 // Map matrix operations to their containting subprograms, by traversing 2536 // the inlinedAt chain. If the function does not have a DISubprogram, we 2537 // only map them to the containing function. 2538 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2539 for (const auto &KV : Inst2Matrix) { 2540 if (Func.getSubprogram()) { 2541 auto *I = cast<Instruction>(KV.first); 2542 DILocation *Context = I->getDebugLoc(); 2543 while (Context) { 2544 auto I = 2545 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 2546 I.first->second.push_back(KV.first); 2547 Context = DebugLoc(Context).getInlinedAt(); 2548 } 2549 } else { 2550 auto I = Subprog2Exprs.insert({nullptr, {}}); 2551 I.first->second.push_back(KV.first); 2552 } 2553 } 2554 for (auto &KV : Subprog2Exprs) { 2555 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2556 KV.second.end()); 2557 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2558 2559 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2560 for (Value *Leaf : Leaves) 2561 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2562 2563 // Generate remarks for each leaf. 2564 for (auto *L : Leaves) { 2565 2566 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2567 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2568 while (Context) { 2569 if (getSubprogram(Context->getScope()) == KV.first) { 2570 Loc = Context; 2571 break; 2572 } 2573 Context = DebugLoc(Context).getInlinedAt(); 2574 } 2575 2576 SmallPtrSet<Value *, 8> ReusedExprs; 2577 OpInfoTy Counts, SharedCounts; 2578 std::tie(Counts, SharedCounts) = 2579 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2580 2581 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2582 cast<Instruction>(L)->getParent()); 2583 2584 Rem << "Lowered with "; 2585 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2586 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2587 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2588 << " compute ops, " 2589 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2590 << " exposed transposes"; 2591 2592 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2593 SharedCounts.NumComputeOps > 0) { 2594 Rem << ",\nadditionally " 2595 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2596 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2597 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2598 << " compute ops" 2599 << " are shared with other expressions"; 2600 } 2601 2602 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2603 ORE.emit(Rem); 2604 } 2605 } 2606 } 2607 2608 std::string 2609 linearize(Value *L, 2610 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2611 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2612 const DataLayout &DL) { 2613 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2614 Lin.linearizeExpr(L, 0, false, false); 2615 return Lin.getResult(); 2616 } 2617 }; 2618 }; 2619 } // namespace 2620 2621 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2622 FunctionAnalysisManager &AM) { 2623 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2624 OptimizationRemarkEmitter *ORE = nullptr; 2625 AAResults *AA = nullptr; 2626 DominatorTree *DT = nullptr; 2627 LoopInfo *LI = nullptr; 2628 2629 if (!Minimal) { 2630 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2631 AA = &AM.getResult<AAManager>(F); 2632 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2633 LI = &AM.getResult<LoopAnalysis>(F); 2634 } 2635 2636 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2637 if (LMT.Visit()) { 2638 PreservedAnalyses PA; 2639 if (!Minimal) { 2640 PA.preserve<LoopAnalysis>(); 2641 PA.preserve<DominatorTreeAnalysis>(); 2642 } 2643 return PA; 2644 } 2645 return PreservedAnalyses::all(); 2646 } 2647 2648 void LowerMatrixIntrinsicsPass::printPipeline( 2649 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2650 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2651 OS, MapClassName2PassName); 2652 OS << '<'; 2653 if (Minimal) 2654 OS << "minimal"; 2655 OS << '>'; 2656 } 2657