1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/SmallSet.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/LoopInfo.h" 27 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 28 #include "llvm/Analysis/TargetTransformInfo.h" 29 #include "llvm/Analysis/ValueTracking.h" 30 #include "llvm/Analysis/VectorUtils.h" 31 #include "llvm/IR/CFG.h" 32 #include "llvm/IR/DataLayout.h" 33 #include "llvm/IR/DebugInfoMetadata.h" 34 #include "llvm/IR/Function.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/Instructions.h" 37 #include "llvm/IR/IntrinsicInst.h" 38 #include "llvm/IR/MatrixBuilder.h" 39 #include "llvm/IR/PatternMatch.h" 40 #include "llvm/Support/Alignment.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/Debug.h" 43 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 44 #include "llvm/Transforms/Utils/LoopUtils.h" 45 #include "llvm/Transforms/Utils/MatrixUtils.h" 46 47 #include <cmath> 48 49 using namespace llvm; 50 using namespace PatternMatch; 51 52 #define DEBUG_TYPE "lower-matrix-intrinsics" 53 54 static cl::opt<bool> 55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 56 cl::desc("Enable/disable fusing matrix instructions.")); 57 // TODO: Allow and use non-square tiles. 58 static cl::opt<unsigned> TileSize( 59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 60 cl::desc( 61 "Tile size for matrix instruction fusion using square-shaped tiles.")); 62 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 63 cl::Hidden, 64 cl::desc("Generate loop nest for tiling.")); 65 static cl::opt<bool> ForceFusion( 66 "force-fuse-matrix", cl::init(false), cl::Hidden, 67 cl::desc("Force matrix instruction fusion even if not profitable.")); 68 static cl::opt<bool> AllowContractEnabled( 69 "matrix-allow-contract", cl::init(false), cl::Hidden, 70 cl::desc("Allow the use of FMAs if available and profitable. This may " 71 "result in different results, due to less rounding error.")); 72 73 static cl::opt<bool> 74 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, 75 cl::desc("Enable/disable matrix shape verification."), 76 cl::init(false)); 77 78 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 79 80 static cl::opt<MatrixLayoutTy> MatrixLayout( 81 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 82 cl::desc("Sets the default matrix layout"), 83 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 84 "Use column-major layout"), 85 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 86 "Use row-major layout"))); 87 88 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", 89 cl::init(false)); 90 91 /// Helper function to either return Scope, if it is a subprogram or the 92 /// attached subprogram for a local scope. 93 static DISubprogram *getSubprogram(DIScope *Scope) { 94 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 95 return Subprogram; 96 return cast<DILocalScope>(Scope)->getSubprogram(); 97 } 98 99 /// Erase \p V from \p BB and move \II forward to avoid invalidating 100 /// iterators. 101 static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, 102 BasicBlock &BB) { 103 auto *Inst = cast<Instruction>(V); 104 // Still used, don't erase. 105 if (!Inst->use_empty()) 106 return; 107 if (II != BB.rend() && Inst == &*II) 108 ++II; 109 Inst->eraseFromParent(); 110 } 111 112 /// Return true if V is a splat of a value (which is used when multiplying a 113 /// matrix with a scalar). 114 static bool isSplat(Value *V) { 115 if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) 116 return SV->isZeroEltSplat(); 117 return false; 118 } 119 120 /// Match any mul operation (fp or integer). 121 template <typename LTy, typename RTy> 122 auto m_AnyMul(const LTy &L, const RTy &R) { 123 return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); 124 } 125 126 /// Match any add operation (fp or integer). 127 template <typename LTy, typename RTy> 128 auto m_AnyAdd(const LTy &L, const RTy &R) { 129 return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); 130 } 131 132 namespace { 133 134 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 135 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 136 // assuming \p Stride elements between start two consecutive vectors. 137 // \p Stride must be >= \p NumElements. 138 // For column-major matrixes, the function computes the address of a column 139 // vectors and \p NumElements must be set to the number of elements in a column 140 // (= number of rows of the matrix). For row-major matrixes, the function 141 // computes the address of a row vector and \p NumElements must be set to the 142 // number of elements in a column (= number of columns of the matrix). 143 // 144 // Consider a 4x4 matrix in column-mjaor layout like below 145 // 146 // 0 1 2 3 147 // 0 v_0_0 v_0_1 v_0_2 v_0_3 148 // 1 v_1_0 v_1_1 v_1_2 v_1_3 149 // 2 v_2_0 v_2_1 v_2_2 v_2_3 150 // 3 v_3_0 v_3_1 v_3_2 v_3_3 151 152 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 153 // we need a pointer to the first element of the submatrix as base pointer. 154 // Then we can use computeVectorAddr to compute the addresses for the columns 155 // of the sub-matrix. 156 // 157 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 158 // -> just returns Base 159 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 160 // -> returns Base + (1 * 4) 161 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 162 // -> returns Base + (2 * 4) 163 // 164 // The graphic below illustrates the number of elements in a column (marked 165 // with |) and the number of skipped elements (marked with }). 166 // 167 // v_0_0 v_0_1 {v_0_2 {v_0_3 168 // Base Col 1 Col 2 169 // | | | 170 // v_1_0 |v_1_1 |v_1_2 |v_1_3 171 // v_2_0 |v_2_1 |v_2_2 |v_2_3 172 // v_3_0 {v_3_1 {v_3_2 v_3_3 173 // 174 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 175 unsigned NumElements, Type *EltType, 176 IRBuilder<> &Builder) { 177 178 assert((!isa<ConstantInt>(Stride) || 179 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 180 "Stride must be >= the number of elements in the result vector."); 181 182 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 183 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 184 185 // Get pointer to the start of the selected vector. Skip GEP creation, 186 // if we select vector 0. 187 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 188 VecStart = BasePtr; 189 else 190 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 191 192 return VecStart; 193 } 194 195 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 196 /// 197 /// Currently, the lowering for each matrix intrinsic is done as follows: 198 /// 1. Propagate the shape information from intrinsics to connected 199 /// instructions. 200 /// 2. Lower instructions with shape information (assuming column-major layout). 201 /// The lowering works similarly using row-major layout. 202 /// 2.1. Get column vectors for each argument. If we already lowered the 203 /// definition of an argument, use the produced column vectors directly. 204 /// If not, split the operand vector containing an embedded matrix into 205 /// a set of column vectors, 206 /// 2.2. Lower the instruction in terms of column major operations, which 207 /// yields a set of column vectors containing result matrix. Note that we 208 /// lower all instructions that have shape information. Besides the 209 /// intrinsics, this includes stores for example. 210 /// 2.3. Update uses of the lowered instruction. If we have shape information 211 /// for a user, there is nothing to do, as we will look up the result 212 /// column matrix when lowering the user. For other uses, we embed the 213 /// result matrix in a flat vector and update the use. 214 /// 2.4. Cache the result column matrix for the instruction we lowered 215 /// 3. After we lowered all instructions in a function, remove the now 216 /// obsolete instructions. 217 /// 218 class LowerMatrixIntrinsics { 219 Function &Func; 220 const DataLayout &DL; 221 const TargetTransformInfo &TTI; 222 AliasAnalysis *AA; 223 DominatorTree *DT; 224 LoopInfo *LI; 225 OptimizationRemarkEmitter *ORE; 226 227 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 228 struct OpInfoTy { 229 /// Number of stores emitted to generate this matrix. 230 unsigned NumStores = 0; 231 /// Number of loads emitted to generate this matrix. 232 unsigned NumLoads = 0; 233 /// Number of compute operations emitted to generate this matrix. 234 unsigned NumComputeOps = 0; 235 /// Most of the time transposes can be fused with matrix multiplies or can 236 /// be folded away via algebraic simplifications. This is the number of 237 /// transposes that we failed to make "free" via such optimizations. 238 unsigned NumExposedTransposes = 0; 239 240 OpInfoTy &operator+=(const OpInfoTy &RHS) { 241 NumStores += RHS.NumStores; 242 NumLoads += RHS.NumLoads; 243 NumComputeOps += RHS.NumComputeOps; 244 NumExposedTransposes += RHS.NumExposedTransposes; 245 return *this; 246 } 247 }; 248 249 /// Wrapper class representing a matrix as a set of vectors, either in row or 250 /// column major layout. All vectors must have the same vector type. 251 class MatrixTy { 252 SmallVector<Value *, 16> Vectors; 253 254 OpInfoTy OpInfo; 255 256 bool IsColumnMajor = true; 257 258 public: 259 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 260 MatrixTy(ArrayRef<Value *> Vectors) 261 : Vectors(Vectors.begin(), Vectors.end()), 262 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 263 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 264 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 265 266 unsigned D = isColumnMajor() ? NumColumns : NumRows; 267 for (unsigned J = 0; J < D; ++J) 268 addVector(PoisonValue::get(FixedVectorType::get( 269 EltTy, isColumnMajor() ? NumRows : NumColumns))); 270 } 271 272 Value *getVector(unsigned i) const { return Vectors[i]; } 273 Value *getColumn(unsigned i) const { 274 assert(isColumnMajor() && "only supported for column-major matrixes"); 275 return Vectors[i]; 276 } 277 Value *getRow(unsigned i) const { 278 assert(!isColumnMajor() && "only supported for row-major matrixes"); 279 return Vectors[i]; 280 } 281 282 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 283 284 Type *getElementType() const { return getVectorTy()->getElementType(); } 285 286 unsigned getNumVectors() const { 287 if (isColumnMajor()) 288 return getNumColumns(); 289 return getNumRows(); 290 } 291 292 unsigned getNumColumns() const { 293 if (isColumnMajor()) 294 return Vectors.size(); 295 else { 296 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 297 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 298 } 299 } 300 unsigned getNumRows() const { 301 if (isColumnMajor()) { 302 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 303 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 304 } else 305 return Vectors.size(); 306 } 307 308 void addVector(Value *V) { Vectors.push_back(V); } 309 VectorType *getColumnTy() { 310 assert(isColumnMajor() && "only supported for column-major matrixes"); 311 return getVectorTy(); 312 } 313 314 VectorType *getVectorTy() const { 315 return cast<VectorType>(Vectors[0]->getType()); 316 } 317 318 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 319 assert(isColumnMajor() && 320 "columns() only supported for column-major matrixes"); 321 return make_range(Vectors.begin(), Vectors.end()); 322 } 323 324 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 325 return make_range(Vectors.begin(), Vectors.end()); 326 } 327 328 /// Embed the vectors of the matrix into a flat vector by concatenating 329 /// them. 330 Value *embedInVector(IRBuilder<> &Builder) const { 331 return Vectors.size() == 1 ? Vectors[0] 332 : concatenateVectors(Builder, Vectors); 333 } 334 335 MatrixTy &addNumLoads(unsigned N) { 336 OpInfo.NumLoads += N; 337 return *this; 338 } 339 340 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 341 342 MatrixTy &addNumStores(unsigned N) { 343 OpInfo.NumStores += N; 344 return *this; 345 } 346 347 MatrixTy &addNumExposedTransposes(unsigned N) { 348 OpInfo.NumExposedTransposes += N; 349 return *this; 350 } 351 352 MatrixTy &addNumComputeOps(unsigned N) { 353 OpInfo.NumComputeOps += N; 354 return *this; 355 } 356 357 unsigned getNumStores() const { return OpInfo.NumStores; } 358 unsigned getNumLoads() const { return OpInfo.NumLoads; } 359 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 360 361 const OpInfoTy &getOpInfo() const { return OpInfo; } 362 363 bool isColumnMajor() const { return IsColumnMajor; } 364 365 unsigned getStride() const { 366 if (isColumnMajor()) 367 return getNumRows(); 368 return getNumColumns(); 369 } 370 371 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 372 /// matrix is column-major, the result vector is extracted from a column 373 /// vector, otherwise from a row vector. 374 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 375 IRBuilder<> &Builder) const { 376 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 377 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 378 NumElts && 379 "Extracted vector will contain poison values"); 380 return Builder.CreateShuffleVector( 381 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 382 "block"); 383 } 384 }; 385 386 struct ShapeInfo { 387 unsigned NumRows; 388 unsigned NumColumns; 389 390 bool IsColumnMajor; 391 392 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 393 : NumRows(NumRows), NumColumns(NumColumns), 394 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 395 396 ShapeInfo(Value *NumRows, Value *NumColumns) 397 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 398 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 399 400 bool operator==(const ShapeInfo &other) { 401 return NumRows == other.NumRows && NumColumns == other.NumColumns; 402 } 403 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 404 405 /// Returns true if shape-information is defined, meaning both dimensions 406 /// are != 0. 407 operator bool() const { 408 assert(NumRows == 0 || NumColumns != 0); 409 return NumRows != 0; 410 } 411 412 unsigned getStride() const { 413 if (IsColumnMajor) 414 return NumRows; 415 return NumColumns; 416 } 417 418 unsigned getNumVectors() const { 419 if (IsColumnMajor) 420 return NumColumns; 421 return NumRows; 422 } 423 424 /// Returns the transposed shape. 425 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } 426 }; 427 428 /// Maps instructions to their shape information. The shape information 429 /// describes the shape to be used while lowering. This matches the shape of 430 /// the result value of the instruction, with the only exceptions being store 431 /// instructions and the matrix_column_major_store intrinsics. For those, the 432 /// shape information indicates that those instructions should be lowered 433 /// using shape information as well. A ValueMap is used so that when 434 /// sub-passes like optimizeTransposes performs RAUW the map stays 435 /// up-to-date. 436 ValueMap<Value *, ShapeInfo> ShapeMap; 437 438 /// List of instructions to remove. While lowering, we are not replacing all 439 /// users of a lowered instruction, if shape information is available and 440 /// those need to be removed after we finished lowering. 441 SmallVector<Instruction *, 16> ToRemove; 442 443 /// Map from instructions to their produced column matrix. 444 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 445 446 private: 447 static FastMathFlags getFastMathFlags(Instruction *Inst) { 448 FastMathFlags FMF; 449 450 if (isa<FPMathOperator>(*Inst)) 451 FMF = Inst->getFastMathFlags(); 452 453 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 454 455 return FMF; 456 } 457 458 public: 459 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 460 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 461 OptimizationRemarkEmitter *ORE) 462 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 463 LI(LI), ORE(ORE) {} 464 465 unsigned getNumOps(Type *VT) { 466 assert(isa<VectorType>(VT) && "Expected vector type"); 467 return getNumOps(VT->getScalarType(), 468 cast<FixedVectorType>(VT)->getNumElements()); 469 } 470 471 /// Is this the minimal version executed in the backend pipelines. 472 bool isMinimal() const { 473 return !DT; 474 } 475 476 /// Return the estimated number of vector ops required for an operation on 477 /// \p VT * N. 478 unsigned getNumOps(Type *ST, unsigned N) { 479 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / 480 double(TTI.getRegisterBitWidth( 481 TargetTransformInfo::RGK_FixedWidthVector) 482 .getFixedValue())); 483 } 484 485 /// Return the set of vectors that a matrix value is lowered to. 486 /// 487 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 488 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 489 /// into vectors. 490 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 491 IRBuilder<> &Builder) { 492 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 493 assert(VType && "MatrixVal must be a vector type"); 494 assert(cast<FixedVectorType>(VType)->getNumElements() == 495 SI.NumRows * SI.NumColumns && 496 "The vector size must match the number of matrix elements"); 497 498 // Check if we lowered MatrixVal using shape information. In that case, 499 // return the existing matrix, if it matches the requested shape 500 // information. If there is a mis-match, embed the result in a flat 501 // vector and split it later. 502 auto Found = Inst2ColumnMatrix.find(MatrixVal); 503 if (Found != Inst2ColumnMatrix.end()) { 504 MatrixTy &M = Found->second; 505 // Return the found matrix, if its shape matches the requested shape 506 // information 507 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 508 return M; 509 510 MatrixVal = M.embedInVector(Builder); 511 } 512 513 // Otherwise split MatrixVal. 514 SmallVector<Value *, 16> SplitVecs; 515 for (unsigned MaskStart = 0; 516 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 517 MaskStart += SI.getStride()) { 518 Value *V = Builder.CreateShuffleVector( 519 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 520 "split"); 521 SplitVecs.push_back(V); 522 } 523 524 return {SplitVecs}; 525 } 526 527 /// If \p V already has a known shape return false. Otherwise set the shape 528 /// for instructions that support it. 529 bool setShapeInfo(Value *V, ShapeInfo Shape) { 530 assert(Shape && "Shape not set"); 531 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 532 return false; 533 534 auto SIter = ShapeMap.find(V); 535 if (SIter != ShapeMap.end()) { 536 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || 537 SIter->second.NumColumns != Shape.NumColumns)) { 538 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" 539 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" 540 << Shape.NumColumns << ") for " << *V << "\n"; 541 report_fatal_error( 542 "Matrix shape verification failed, compilation aborted!"); 543 } 544 545 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 546 << SIter->second.NumRows << " " 547 << SIter->second.NumColumns << " for " << *V << "\n"); 548 return false; 549 } 550 551 ShapeMap.insert({V, Shape}); 552 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 553 << " for " << *V << "\n"); 554 return true; 555 } 556 557 bool isUniformShape(Value *V) { 558 Instruction *I = dyn_cast<Instruction>(V); 559 if (!I) 560 return true; 561 562 switch (I->getOpcode()) { 563 case Instruction::FAdd: 564 case Instruction::FSub: 565 case Instruction::FMul: // Scalar multiply. 566 case Instruction::FNeg: 567 case Instruction::Add: 568 case Instruction::Mul: 569 case Instruction::Sub: 570 return true; 571 default: 572 return false; 573 } 574 } 575 576 /// Returns true if shape information can be used for \p V. The supported 577 /// instructions must match the instructions that can be lowered by this pass. 578 bool supportsShapeInfo(Value *V) { 579 Instruction *Inst = dyn_cast<Instruction>(V); 580 if (!Inst) 581 return false; 582 583 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 584 if (II) 585 switch (II->getIntrinsicID()) { 586 case Intrinsic::matrix_multiply: 587 case Intrinsic::matrix_transpose: 588 case Intrinsic::matrix_column_major_load: 589 case Intrinsic::matrix_column_major_store: 590 return true; 591 default: 592 return false; 593 } 594 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 595 } 596 597 /// Propagate the shape information of instructions to their users. 598 /// The work list contains instructions for which we can compute the shape, 599 /// either based on the information provided by matrix intrinsics or known 600 /// shapes of operands. 601 SmallVector<Instruction *, 32> 602 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 603 SmallVector<Instruction *, 32> NewWorkList; 604 // Pop an element for which we guaranteed to have at least one of the 605 // operand shapes. Add the shape for this and then add users to the work 606 // list. 607 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 608 while (!WorkList.empty()) { 609 Instruction *Inst = WorkList.pop_back_val(); 610 611 // New entry, set the value and insert operands 612 bool Propagate = false; 613 614 Value *MatrixA; 615 Value *MatrixB; 616 Value *M; 617 Value *N; 618 Value *K; 619 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 620 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 621 m_Value(N), m_Value(K)))) { 622 Propagate = setShapeInfo(Inst, {M, K}); 623 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 624 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 625 // Flip dimensions. 626 Propagate = setShapeInfo(Inst, {N, M}); 627 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 628 m_Value(MatrixA), m_Value(), m_Value(), 629 m_Value(), m_Value(M), m_Value(N)))) { 630 Propagate = setShapeInfo(Inst, {N, M}); 631 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 632 m_Value(), m_Value(), m_Value(), m_Value(M), 633 m_Value(N)))) { 634 Propagate = setShapeInfo(Inst, {M, N}); 635 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 636 auto OpShape = ShapeMap.find(MatrixA); 637 if (OpShape != ShapeMap.end()) 638 setShapeInfo(Inst, OpShape->second); 639 continue; 640 } else if (isUniformShape(Inst)) { 641 // Find the first operand that has a known shape and use that. 642 for (auto &Op : Inst->operands()) { 643 auto OpShape = ShapeMap.find(Op.get()); 644 if (OpShape != ShapeMap.end()) { 645 Propagate |= setShapeInfo(Inst, OpShape->second); 646 break; 647 } 648 } 649 } 650 651 if (Propagate) { 652 NewWorkList.push_back(Inst); 653 for (auto *User : Inst->users()) 654 if (ShapeMap.count(User) == 0) 655 WorkList.push_back(cast<Instruction>(User)); 656 } 657 } 658 659 return NewWorkList; 660 } 661 662 /// Propagate the shape to operands of instructions with shape information. 663 /// \p Worklist contains the instruction for which we already know the shape. 664 SmallVector<Instruction *, 32> 665 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 666 SmallVector<Instruction *, 32> NewWorkList; 667 668 auto pushInstruction = [](Value *V, 669 SmallVectorImpl<Instruction *> &WorkList) { 670 Instruction *I = dyn_cast<Instruction>(V); 671 if (I) 672 WorkList.push_back(I); 673 }; 674 // Pop an element with known shape. Traverse the operands, if their shape 675 // derives from the result shape and is unknown, add it and add them to the 676 // worklist. 677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 678 while (!WorkList.empty()) { 679 Value *V = WorkList.pop_back_val(); 680 681 size_t BeforeProcessingV = WorkList.size(); 682 if (!isa<Instruction>(V)) 683 continue; 684 685 Value *MatrixA; 686 Value *MatrixB; 687 Value *M; 688 Value *N; 689 Value *K; 690 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 691 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 692 m_Value(N), m_Value(K)))) { 693 if (setShapeInfo(MatrixA, {M, N})) 694 pushInstruction(MatrixA, WorkList); 695 696 if (setShapeInfo(MatrixB, {N, K})) 697 pushInstruction(MatrixB, WorkList); 698 699 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 700 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 701 // Flip dimensions. 702 if (setShapeInfo(MatrixA, {M, N})) 703 pushInstruction(MatrixA, WorkList); 704 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 705 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 706 m_Value(M), m_Value(N)))) { 707 if (setShapeInfo(MatrixA, {M, N})) { 708 pushInstruction(MatrixA, WorkList); 709 } 710 } else if (isa<LoadInst>(V) || 711 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 712 // Nothing to do, no matrix input. 713 } else if (isa<StoreInst>(V)) { 714 // Nothing to do. We forward-propagated to this so we would just 715 // backward propagate to an instruction with an already known shape. 716 } else if (isUniformShape(V)) { 717 // Propagate to all operands. 718 ShapeInfo Shape = ShapeMap[V]; 719 for (Use &U : cast<Instruction>(V)->operands()) { 720 if (setShapeInfo(U.get(), Shape)) 721 pushInstruction(U.get(), WorkList); 722 } 723 } 724 // After we discovered new shape info for new instructions in the 725 // worklist, we use their users as seeds for the next round of forward 726 // propagation. 727 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 728 for (User *U : WorkList[I]->users()) 729 if (isa<Instruction>(U) && V != U) 730 NewWorkList.push_back(cast<Instruction>(U)); 731 } 732 return NewWorkList; 733 } 734 735 /// (Op0 op Op1)^T -> Op0^T op Op1^T 736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 737 /// them on both sides of \p Operation. 738 Instruction *distributeTransposes( 739 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, 740 MatrixBuilder &Builder, 741 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> 742 Operation) { 743 Value *T0 = Builder.CreateMatrixTranspose( 744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); 745 // We are being run after shape prop, add shape for newly created 746 // instructions so that we lower them later. 747 setShapeInfo(T0, Shape0.t()); 748 Value *T1 = Builder.CreateMatrixTranspose( 749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); 750 setShapeInfo(T1, Shape1.t()); 751 return Operation(T0, Shape0.t(), T1, Shape1.t()); 752 } 753 754 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { 755 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 756 // with New. We should only add New it it supportsShapeInfo so we insert 757 // it conditionally instead. 758 auto S = ShapeMap.find(&Old); 759 if (S != ShapeMap.end()) { 760 ShapeMap.erase(S); 761 if (supportsShapeInfo(New)) 762 ShapeMap.insert({New, S->second}); 763 } 764 Old.replaceAllUsesWith(New); 765 } 766 767 /// Sink a top-level transpose inside matmuls and adds. 768 /// This creates and erases instructions as needed, and returns the newly 769 /// created instruction while updating the iterator to avoid invalidation. If 770 /// this returns nullptr, no new instruction was created. 771 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { 772 BasicBlock &BB = *I.getParent(); 773 IRBuilder<> IB(&I); 774 MatrixBuilder Builder(IB); 775 776 Value *TA, *TAMA, *TAMB; 777 ConstantInt *R, *K, *C; 778 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( 779 m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) 780 return nullptr; 781 782 // Transpose of a transpose is a nop 783 Value *TATA; 784 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 785 updateShapeAndReplaceAllUsesWith(I, TATA); 786 eraseFromParentAndMove(&I, II, BB); 787 eraseFromParentAndMove(TA, II, BB); 788 return nullptr; 789 } 790 791 // k^T -> k 792 if (isSplat(TA)) { 793 updateShapeAndReplaceAllUsesWith(I, TA); 794 eraseFromParentAndMove(&I, II, BB); 795 return nullptr; 796 } 797 798 // (A * B)^t -> B^t * A^t 799 // RxK KxC CxK KxR 800 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 801 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 802 m_ConstantInt(K), m_ConstantInt(C)))) { 803 auto NewInst = distributeTransposes( 804 TAMB, {K, C}, TAMA, {R, K}, Builder, 805 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 806 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, 807 Shape0.NumColumns, 808 Shape1.NumColumns, "mmul"); 809 }); 810 updateShapeAndReplaceAllUsesWith(I, NewInst); 811 eraseFromParentAndMove(&I, II, BB); 812 eraseFromParentAndMove(TA, II, BB); 813 return NewInst; 814 } 815 816 // Same as above, but with a mul, which occurs when multiplied 817 // with a scalar. 818 // (A * k)^t -> A^t * k 819 // R x C RxC 820 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && 821 (isSplat(TAMA) || isSplat(TAMB))) { 822 IRBuilder<> LocalBuilder(&I); 823 // We know that the transposed operand is of shape RxC. 824 // An when multiplied with a scalar, the shape is preserved. 825 auto NewInst = distributeTransposes( 826 TAMA, {R, C}, TAMB, {R, C}, Builder, 827 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 828 bool IsFP = I.getType()->isFPOrFPVectorTy(); 829 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") 830 : LocalBuilder.CreateMul(T0, T1, "mmul"); 831 auto *Result = cast<Instruction>(Mul); 832 setShapeInfo(Result, Shape0); 833 return Result; 834 }); 835 updateShapeAndReplaceAllUsesWith(I, NewInst); 836 eraseFromParentAndMove(&I, II, BB); 837 eraseFromParentAndMove(TA, II, BB); 838 return NewInst; 839 } 840 841 // (A + B)^t -> A^t + B^t 842 // RxC RxC CxR CxR 843 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { 844 IRBuilder<> LocalBuilder(&I); 845 auto NewInst = distributeTransposes( 846 TAMA, {R, C}, TAMB, {R, C}, Builder, 847 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 848 bool IsFP = I.getType()->isFPOrFPVectorTy(); 849 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") 850 : LocalBuilder.CreateAdd(T0, T1, "madd"); 851 852 auto *Result = cast<Instruction>(Add); 853 setShapeInfo(Result, Shape0); 854 return Result; 855 }); 856 updateShapeAndReplaceAllUsesWith(I, NewInst); 857 eraseFromParentAndMove(&I, II, BB); 858 eraseFromParentAndMove(TA, II, BB); 859 return NewInst; 860 } 861 862 return nullptr; 863 } 864 865 void liftTranspose(Instruction &I) { 866 // Erase dead Instructions after lifting transposes from binops. 867 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { 868 if (T.use_empty()) 869 T.eraseFromParent(); 870 if (A->use_empty()) 871 cast<Instruction>(A)->eraseFromParent(); 872 if (A != B && B->use_empty()) 873 cast<Instruction>(B)->eraseFromParent(); 874 }; 875 876 Value *A, *B, *AT, *BT; 877 ConstantInt *R, *K, *C; 878 // A^t * B ^t -> (B * A)^t 879 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 880 m_Value(A), m_Value(B), m_ConstantInt(R), 881 m_ConstantInt(K), m_ConstantInt(C))) && 882 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 883 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 884 IRBuilder<> IB(&I); 885 MatrixBuilder Builder(IB); 886 Value *M = Builder.CreateMatrixMultiply( 887 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 888 setShapeInfo(M, {C, R}); 889 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), 890 R->getZExtValue()); 891 updateShapeAndReplaceAllUsesWith(I, NewInst); 892 CleanupBinOp(I, A, B); 893 } 894 // A^t + B ^t -> (A + B)^t 895 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && 896 match(A, m_Intrinsic<Intrinsic::matrix_transpose>( 897 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && 898 match(B, m_Intrinsic<Intrinsic::matrix_transpose>( 899 m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) { 900 IRBuilder<> Builder(&I); 901 Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); 902 setShapeInfo(Add, {C, R}); 903 MatrixBuilder MBuilder(Builder); 904 Instruction *NewInst = MBuilder.CreateMatrixTranspose( 905 Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t"); 906 updateShapeAndReplaceAllUsesWith(I, NewInst); 907 CleanupBinOp(I, A, B); 908 } 909 } 910 911 /// Try moving transposes in order to fold them away or into multiplies. 912 void optimizeTransposes() { 913 // First sink all transposes inside matmuls and adds, hoping that we end up 914 // with NN, NT or TN variants. 915 for (BasicBlock &BB : reverse(Func)) { 916 for (auto II = BB.rbegin(); II != BB.rend();) { 917 Instruction &I = *II; 918 // We may remove II. By default continue on the next/prev instruction. 919 ++II; 920 if (Instruction *NewInst = sinkTranspose(I, II)) 921 II = std::next(BasicBlock::reverse_iterator(NewInst)); 922 } 923 } 924 925 // If we have a TT matmul or a TT add, lift the transpose. We may be able 926 // to fold into consuming multiply or add. 927 for (BasicBlock &BB : Func) { 928 for (Instruction &I : llvm::make_early_inc_range(BB)) { 929 liftTranspose(I); 930 } 931 } 932 } 933 934 bool Visit() { 935 SmallVector<Instruction *, 32> WorkList; 936 937 // Initially only the shape of matrix intrinsics is known. 938 // Initialize the work list with ops carrying shape information. 939 for (BasicBlock &BB : Func) 940 for (Instruction &Inst : BB) { 941 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 942 if (!II) 943 continue; 944 945 switch (II->getIntrinsicID()) { 946 case Intrinsic::matrix_multiply: 947 case Intrinsic::matrix_transpose: 948 case Intrinsic::matrix_column_major_load: 949 case Intrinsic::matrix_column_major_store: 950 WorkList.push_back(&Inst); 951 break; 952 default: 953 break; 954 } 955 } 956 957 // Avoid unnecessary work if there are no matrix intrinsics in the function. 958 if (WorkList.empty()) 959 return false; 960 961 // Propagate shapes until nothing changes any longer. 962 while (!WorkList.empty()) { 963 WorkList = propagateShapeForward(WorkList); 964 WorkList = propagateShapeBackward(WorkList); 965 } 966 967 if (!isMinimal()) { 968 optimizeTransposes(); 969 if (PrintAfterTransposeOpt) { 970 dbgs() << "Dump after matrix transpose optimization:\n"; 971 Func.print(dbgs()); 972 } 973 } 974 975 bool Changed = false; 976 SmallVector<CallInst *, 16> MaybeFusableInsts; 977 SmallVector<Instruction *, 16> MatrixInsts; 978 979 // First, collect all instructions with shape information and candidates for 980 // fusion (currently only matrix multiplies). 981 ReversePostOrderTraversal<Function *> RPOT(&Func); 982 for (auto *BB : RPOT) 983 for (Instruction &I : *BB) { 984 if (ShapeMap.find(&I) == ShapeMap.end()) 985 continue; 986 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 987 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 988 MatrixInsts.push_back(&I); 989 } 990 991 // Second, try to lower any dot products 992 SmallPtrSet<Instruction *, 16> FusedInsts; 993 for (CallInst *CI : MaybeFusableInsts) 994 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); 995 996 // Third, try to fuse candidates. 997 for (CallInst *CI : MaybeFusableInsts) 998 LowerMatrixMultiplyFused(CI, FusedInsts); 999 1000 Changed = !FusedInsts.empty(); 1001 1002 // Fourth, lower remaining instructions with shape information. 1003 for (Instruction *Inst : MatrixInsts) { 1004 if (FusedInsts.count(Inst)) 1005 continue; 1006 1007 IRBuilder<> Builder(Inst); 1008 1009 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 1010 Changed |= VisitCallInst(CInst); 1011 1012 Value *Op1; 1013 Value *Op2; 1014 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 1015 Changed |= VisitBinaryOperator(BinOp); 1016 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 1017 Changed |= VisitUnaryOperator(UnOp); 1018 if (match(Inst, m_Load(m_Value(Op1)))) 1019 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 1020 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 1021 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 1022 } 1023 1024 if (ORE) { 1025 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 1026 RemarkGen.emitRemarks(); 1027 } 1028 1029 // Delete the instructions backwards, as it has a reduced likelihood of 1030 // having to update as many def-use and use-def chains. 1031 // 1032 // Because we add to ToRemove during fusion we can't guarantee that defs 1033 // are before uses. Change uses to poison temporarily as these should get 1034 // removed as well. 1035 // 1036 // For verification, we keep track of where we changed uses to poison in 1037 // PoisonedInsts and then check that we in fact remove them. 1038 SmallSet<Instruction *, 16> PoisonedInsts; 1039 for (auto *Inst : reverse(ToRemove)) { 1040 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1041 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 1042 PoisonedInsts.insert(Poisoned); 1043 U.set(PoisonValue::get(Inst->getType())); 1044 } 1045 Inst->eraseFromParent(); 1046 PoisonedInsts.erase(Inst); 1047 } 1048 if (!PoisonedInsts.empty()) { 1049 // If we didn't remove all poisoned instructions, it's a hard error. 1050 dbgs() << "Poisoned but present instructions:\n"; 1051 for (auto *I : PoisonedInsts) 1052 dbgs() << *I << "\n"; 1053 llvm_unreachable("Poisoned but instruction not removed"); 1054 } 1055 1056 return Changed; 1057 } 1058 1059 /// Replace intrinsic calls 1060 bool VisitCallInst(CallInst *Inst) { 1061 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 1062 return false; 1063 1064 switch (Inst->getCalledFunction()->getIntrinsicID()) { 1065 case Intrinsic::matrix_multiply: 1066 LowerMultiply(Inst); 1067 break; 1068 case Intrinsic::matrix_transpose: 1069 LowerTranspose(Inst); 1070 break; 1071 case Intrinsic::matrix_column_major_load: 1072 LowerColumnMajorLoad(Inst); 1073 break; 1074 case Intrinsic::matrix_column_major_store: 1075 LowerColumnMajorStore(Inst); 1076 break; 1077 default: 1078 return false; 1079 } 1080 return true; 1081 } 1082 1083 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 1084 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 1085 /// ConstantInt, reduce the initial alignment based on the byte offset. For 1086 /// non-ConstantInt strides, return the common alignment of the initial 1087 /// alignment and the element size in bytes. 1088 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 1089 MaybeAlign A) const { 1090 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 1091 if (Idx == 0) 1092 return InitialAlign; 1093 1094 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 1095 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 1096 uint64_t StrideInBytes = 1097 ConstStride->getZExtValue() * ElementSizeInBits / 8; 1098 return commonAlignment(InitialAlign, Idx * StrideInBytes); 1099 } 1100 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 1101 } 1102 1103 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 1104 /// vectors. 1105 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 1106 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 1107 auto *VType = cast<VectorType>(Ty); 1108 Type *EltTy = VType->getElementType(); 1109 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 1110 Value *EltPtr = Ptr; 1111 MatrixTy Result; 1112 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 1113 Value *GEP = computeVectorAddr( 1114 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 1115 Stride, Shape.getStride(), EltTy, Builder); 1116 Value *Vector = Builder.CreateAlignedLoad( 1117 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 1118 IsVolatile, "col.load"); 1119 1120 Result.addVector(Vector); 1121 } 1122 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 1123 Result.getNumVectors()); 1124 } 1125 1126 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 1127 /// starting at \p MatrixPtr[I][J]. 1128 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 1129 ShapeInfo MatrixShape, Value *I, Value *J, 1130 ShapeInfo ResultShape, Type *EltTy, 1131 IRBuilder<> &Builder) { 1132 1133 Value *Offset = Builder.CreateAdd( 1134 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1135 1136 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1137 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1138 ResultShape.NumColumns); 1139 1140 return loadMatrix(TileTy, TileStart, Align, 1141 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1142 ResultShape, Builder); 1143 } 1144 1145 /// Lower a load instruction with shape information. 1146 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 1147 bool IsVolatile, ShapeInfo Shape) { 1148 IRBuilder<> Builder(Inst); 1149 finalizeLowering(Inst, 1150 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 1151 Shape, Builder), 1152 Builder); 1153 } 1154 1155 /// Lowers llvm.matrix.column.major.load. 1156 /// 1157 /// The intrinsic loads a matrix from memory using a stride between columns. 1158 void LowerColumnMajorLoad(CallInst *Inst) { 1159 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1160 "Intrinsic only supports column-major layout!"); 1161 Value *Ptr = Inst->getArgOperand(0); 1162 Value *Stride = Inst->getArgOperand(1); 1163 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1164 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1165 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1166 } 1167 1168 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1169 /// MatrixPtr[I][J]. 1170 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1171 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1172 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1173 Value *Offset = Builder.CreateAdd( 1174 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1175 1176 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1177 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1178 StoreVal.getNumColumns()); 1179 1180 storeMatrix(TileTy, StoreVal, TileStart, MAlign, 1181 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1182 } 1183 1184 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1185 /// vectors. 1186 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1187 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1188 IRBuilder<> &Builder) { 1189 auto VType = cast<VectorType>(Ty); 1190 Value *EltPtr = Ptr; 1191 for (auto Vec : enumerate(StoreVal.vectors())) { 1192 Value *GEP = computeVectorAddr( 1193 EltPtr, 1194 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1195 Vec.index()), 1196 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1197 Builder.CreateAlignedStore(Vec.value(), GEP, 1198 getAlignForIndex(Vec.index(), Stride, 1199 VType->getElementType(), 1200 MAlign), 1201 IsVolatile); 1202 } 1203 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1204 StoreVal.getNumVectors()); 1205 } 1206 1207 /// Lower a store instruction with shape information. 1208 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1209 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1210 IRBuilder<> Builder(Inst); 1211 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1212 finalizeLowering(Inst, 1213 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1214 IsVolatile, Builder), 1215 Builder); 1216 } 1217 1218 /// Lowers llvm.matrix.column.major.store. 1219 /// 1220 /// The intrinsic store a matrix back memory using a stride between columns. 1221 void LowerColumnMajorStore(CallInst *Inst) { 1222 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1223 "Intrinsic only supports column-major layout!"); 1224 Value *Matrix = Inst->getArgOperand(0); 1225 Value *Ptr = Inst->getArgOperand(1); 1226 Value *Stride = Inst->getArgOperand(2); 1227 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1228 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1229 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1230 } 1231 1232 // Set elements I..I+NumElts-1 to Block 1233 Value *insertVector(Value *Col, unsigned I, Value *Block, 1234 IRBuilder<> &Builder) { 1235 1236 // First, bring Block to the same size as Col 1237 unsigned BlockNumElts = 1238 cast<FixedVectorType>(Block->getType())->getNumElements(); 1239 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1240 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1241 1242 Block = Builder.CreateShuffleVector( 1243 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1244 1245 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1246 // 8, 4, 5, 6 1247 SmallVector<int, 16> Mask; 1248 unsigned i; 1249 for (i = 0; i < I; i++) 1250 Mask.push_back(i); 1251 1252 unsigned VecNumElts = 1253 cast<FixedVectorType>(Col->getType())->getNumElements(); 1254 for (; i < I + BlockNumElts; i++) 1255 Mask.push_back(i - I + VecNumElts); 1256 1257 for (; i < VecNumElts; i++) 1258 Mask.push_back(i); 1259 1260 return Builder.CreateShuffleVector(Col, Block, Mask); 1261 } 1262 1263 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1264 IRBuilder<> &Builder, bool AllowContraction, 1265 unsigned &NumComputeOps) { 1266 NumComputeOps += getNumOps(A->getType()); 1267 if (!Sum) 1268 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1269 1270 if (UseFPOp) { 1271 if (AllowContraction) { 1272 // Use fmuladd for floating point operations and let the backend decide 1273 // if that's profitable. 1274 Function *FMulAdd = Intrinsic::getDeclaration( 1275 Func.getParent(), Intrinsic::fmuladd, A->getType()); 1276 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1277 } 1278 NumComputeOps += getNumOps(A->getType()); 1279 Value *Mul = Builder.CreateFMul(A, B); 1280 return Builder.CreateFAdd(Sum, Mul); 1281 } 1282 1283 NumComputeOps += getNumOps(A->getType()); 1284 Value *Mul = Builder.CreateMul(A, B); 1285 return Builder.CreateAdd(Sum, Mul); 1286 } 1287 1288 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1289 /// users with shape information, there's nothing to do: they will use the 1290 /// cached value when they are lowered. For other users, \p Matrix is 1291 /// flattened and the uses are updated to use it. Also marks \p Inst for 1292 /// deletion. 1293 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1294 IRBuilder<> &Builder) { 1295 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1296 (void)inserted; 1297 assert(inserted.second && "multiple matrix lowering mapping"); 1298 1299 ToRemove.push_back(Inst); 1300 Value *Flattened = nullptr; 1301 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1302 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1303 if (!Flattened) 1304 Flattened = Matrix.embedInVector(Builder); 1305 U.set(Flattened); 1306 } 1307 } 1308 } 1309 1310 /// Special case for MatMul lowering. Prevents scalar loads of row-major 1311 /// vectors Lowers to vector reduction add instead of sequential add if 1312 /// reassocation is enabled. 1313 void lowerDotProduct(CallInst *MatMul, 1314 SmallPtrSet<Instruction *, 16> &FusedInsts, 1315 FastMathFlags FMF) { 1316 if (FusedInsts.contains(MatMul) || 1317 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1318 return; 1319 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1320 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1321 1322 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product 1323 return; 1324 1325 Value *LHS = MatMul->getArgOperand(0); 1326 Value *RHS = MatMul->getArgOperand(1); 1327 1328 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType(); 1329 bool IsIntVec = ElementType->isIntegerTy(); 1330 1331 // Floating point reductions require reassocation. 1332 if (!IsIntVec && !FMF.allowReassoc()) 1333 return; 1334 1335 auto CanBeFlattened = [this](Value *Op) { 1336 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) 1337 return true; 1338 return match( 1339 Op, m_OneUse(m_CombineOr( 1340 m_Load(m_Value()), 1341 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), 1342 m_Intrinsic<Intrinsic::matrix_column_major_load>( 1343 m_Value(), m_SpecificInt(1)))))); 1344 }; 1345 // Returns the cost benefit of using \p Op with the dot product lowering. If 1346 // the returned cost is < 0, the argument is cheaper to use in the 1347 // dot-product lowering. 1348 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { 1349 if (!isa<Instruction>(Op)) 1350 return InstructionCost(0); 1351 1352 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); 1353 Type *EltTy = VecTy->getElementType(); 1354 1355 if (!CanBeFlattened(Op)) { 1356 InstructionCost EmbedCost(0); 1357 // Roughly estimate the cost for embedding the columns into a vector. 1358 for (unsigned I = 1; I < N; ++I) 1359 EmbedCost -= 1360 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1361 std::nullopt, TTI::TCK_RecipThroughput); 1362 return EmbedCost; 1363 } 1364 1365 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 1366 InstructionCost OriginalCost = 1367 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), 1368 EltTy) * 1369 N; 1370 InstructionCost NewCost = TTI.getArithmeticInstrCost( 1371 cast<Instruction>(Op)->getOpcode(), VecTy); 1372 return NewCost - OriginalCost; 1373 } 1374 1375 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { 1376 // The transpose can be skipped for the dot product lowering, roughly 1377 // estimate the savings as the cost of embedding the columns in a 1378 // vector. 1379 InstructionCost EmbedCost(0); 1380 for (unsigned I = 1; I < N; ++I) 1381 EmbedCost += 1382 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1383 std::nullopt, TTI::TCK_RecipThroughput); 1384 return EmbedCost; 1385 } 1386 1387 // Costs for loads. 1388 if (N == 1) 1389 return InstructionCost(0); 1390 1391 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - 1392 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); 1393 }; 1394 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); 1395 1396 // We compare the costs of a vector.reduce.add to sequential add. 1397 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; 1398 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; 1399 InstructionCost ReductionCost = 1400 TTI.getArithmeticReductionCost( 1401 AddOpCode, cast<VectorType>(LHS->getType()), 1402 IsIntVec ? std::nullopt : std::optional(FMF)) + 1403 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); 1404 InstructionCost SequentialAddCost = 1405 TTI.getArithmeticInstrCost(AddOpCode, ElementType) * 1406 (LShape.NumColumns - 1) + 1407 TTI.getArithmeticInstrCost(MulOpCode, ElementType) * 1408 (LShape.NumColumns); 1409 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) 1410 return; 1411 1412 FusedInsts.insert(MatMul); 1413 IRBuilder<> Builder(MatMul); 1414 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, 1415 this](Value *Op) -> Value * { 1416 // Matmul must be the only user of loads because we don't use LowerLoad 1417 // for row vectors (LowerLoad results in scalar loads and shufflevectors 1418 // instead of single vector load). 1419 if (!CanBeFlattened(Op)) 1420 return Op; 1421 1422 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 1423 ShapeMap[Op] = ShapeMap[Op].t(); 1424 return Op; 1425 } 1426 1427 FusedInsts.insert(cast<Instruction>(Op)); 1428 // If vector uses the builtin load, lower to a LoadInst 1429 Value *Arg; 1430 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( 1431 m_Value(Arg)))) { 1432 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); 1433 Op->replaceAllUsesWith(NewLoad); 1434 cast<Instruction>(Op)->eraseFromParent(); 1435 return NewLoad; 1436 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( 1437 m_Value(Arg)))) { 1438 ToRemove.push_back(cast<Instruction>(Op)); 1439 return Arg; 1440 } 1441 1442 return Op; 1443 }; 1444 LHS = FlattenArg(LHS); 1445 1446 // Insert mul/fmul and llvm.vector.reduce.fadd 1447 Value *Mul = 1448 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); 1449 1450 Value *Result; 1451 if (IsIntVec) 1452 Result = Builder.CreateAddReduce(Mul); 1453 else { 1454 Result = Builder.CreateFAddReduce( 1455 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(), 1456 0.0), 1457 Mul); 1458 cast<Instruction>(Result)->setFastMathFlags(FMF); 1459 } 1460 1461 // pack scalar back into a matrix and then replace matmul inst 1462 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), 1463 Result, uint64_t(0)); 1464 MatMul->replaceAllUsesWith(Result); 1465 FusedInsts.insert(MatMul); 1466 ToRemove.push_back(MatMul); 1467 } 1468 1469 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1470 /// addition. 1471 /// 1472 /// We can fold a transpose into the operand that is used to extract scalars. 1473 /// This is the first operands with row-major and the second with 1474 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1475 /// operand is transposed. 1476 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1477 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1478 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1479 const unsigned VF = std::max<unsigned>( 1480 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1481 .getFixedValue() / 1482 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1483 1U); 1484 unsigned R = Result.getNumRows(); 1485 unsigned C = Result.getNumColumns(); 1486 unsigned M = A.getNumColumns(); 1487 1488 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1489 assert(A.isColumnMajor() == B.isColumnMajor() && 1490 Result.isColumnMajor() == A.isColumnMajor() && 1491 "operands must agree on matrix layout"); 1492 unsigned NumComputeOps = 0; 1493 1494 Builder.setFastMathFlags(FMF); 1495 1496 if (A.isColumnMajor()) { 1497 // Multiply columns from the first operand with scalars from the second 1498 // operand. Then move along the K axes and accumulate the columns. With 1499 // this the adds can be vectorized without reassociation. 1500 for (unsigned J = 0; J < C; ++J) { 1501 unsigned BlockSize = VF; 1502 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1503 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1504 1505 for (unsigned I = 0; I < R; I += BlockSize) { 1506 // Gradually lower the vectorization factor to cover the remainder. 1507 while (I + BlockSize > R) 1508 BlockSize /= 2; 1509 1510 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1511 : nullptr; 1512 for (unsigned K = 0; K < M; ++K) { 1513 Value *L = A.extractVector(I, K, BlockSize, Builder); 1514 Value *RH = Builder.CreateExtractElement( 1515 B.getColumn(IsScalarMatrixTransposed ? K : J), 1516 IsScalarMatrixTransposed ? J : K); 1517 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1518 Sum = 1519 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1520 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1521 } 1522 Result.setVector(J, 1523 insertVector(Result.getVector(J), I, Sum, Builder)); 1524 } 1525 } 1526 } else { 1527 // Multiply rows from the second operand with scalars from the first 1528 // operand. Then move along the K axes and accumulate the rows. With this 1529 // the adds can be vectorized without reassociation. 1530 for (unsigned I = 0; I < R; ++I) { 1531 unsigned BlockSize = VF; 1532 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1533 for (unsigned J = 0; J < C; J += BlockSize) { 1534 // Gradually lower the vectorization factor to cover the remainder. 1535 while (J + BlockSize > C) 1536 BlockSize /= 2; 1537 1538 Value *Sum = nullptr; 1539 for (unsigned K = 0; K < M; ++K) { 1540 Value *R = B.extractVector(K, J, BlockSize, Builder); 1541 Value *LH = Builder.CreateExtractElement( 1542 A.getVector(IsScalarMatrixTransposed ? K : I), 1543 IsScalarMatrixTransposed ? I : K); 1544 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1545 Sum = 1546 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1547 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1548 } 1549 Result.setVector(I, 1550 insertVector(Result.getVector(I), J, Sum, Builder)); 1551 } 1552 } 1553 } 1554 Result.addNumComputeOps(NumComputeOps); 1555 } 1556 1557 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1558 /// copying it to a new location. This new or otherwise the original location 1559 /// is returned. 1560 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1561 CallInst *MatMul) { 1562 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1563 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1564 1565 // If we can statically determine noalias we're good. 1566 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1567 return Load->getPointerOperand(); 1568 1569 // Create code to check if the memory locations of the Load and Store 1570 // overlap and if they do, copy Load's operand to a new buffer. 1571 1572 // First, create new blocks for 2n part of the check and the copy. 1573 BasicBlock *Check0 = MatMul->getParent(); 1574 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1575 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1576 // as we adjust Check0 and Check1's branches. 1577 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1578 for (BasicBlock *Succ : successors(Check0)) 1579 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1580 1581 BasicBlock *Check1 = 1582 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1583 nullptr, "alias_cont"); 1584 BasicBlock *Copy = 1585 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1586 nullptr, "copy"); 1587 BasicBlock *Fusion = 1588 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1589 nullptr, "no_alias"); 1590 1591 // Check if the loaded memory location begins before the end of the store 1592 // location. If the condition holds, they might overlap, otherwise they are 1593 // guaranteed to not overlap. 1594 IRBuilder<> Builder(MatMul); 1595 Check0->getTerminator()->eraseFromParent(); 1596 Builder.SetInsertPoint(Check0); 1597 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1598 Value *StoreBegin = Builder.CreatePtrToInt( 1599 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1600 Value *StoreEnd = Builder.CreateAdd( 1601 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1602 "store.end", true, true); 1603 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1604 IntPtrTy, "load.begin"); 1605 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1606 Fusion); 1607 1608 // Check if the store begins before the end of the load location. If the 1609 // condition holds, they alias, otherwise they are guaranteed to not 1610 // overlap. 1611 Check1->getTerminator()->eraseFromParent(); 1612 Builder.SetInsertPoint(Check1, Check1->begin()); 1613 Value *LoadEnd = Builder.CreateAdd( 1614 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1615 "load.end", true, true); 1616 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1617 Fusion); 1618 1619 // Copy load operand to new alloca. 1620 Builder.SetInsertPoint(Copy, Copy->begin()); 1621 auto *VT = cast<FixedVectorType>(Load->getType()); 1622 // Use an array type for the alloca, to avoid potentially huge alignment 1623 // requirements for large vector types. 1624 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1625 AllocaInst *Alloca = 1626 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1627 1628 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), 1629 Load->getAlign(), LoadLoc.Size.getValue()); 1630 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1631 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1632 PHI->addIncoming(Load->getPointerOperand(), Check0); 1633 PHI->addIncoming(Load->getPointerOperand(), Check1); 1634 PHI->addIncoming(Alloca, Copy); 1635 1636 // Adjust DT. 1637 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1638 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1639 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1640 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1641 DT->applyUpdates(DTUpdates); 1642 return PHI; 1643 } 1644 1645 bool isFusionProfitable(CallInst *MatMul) { 1646 if (ForceFusion) 1647 return true; 1648 1649 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1650 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1651 1652 const unsigned R = LShape.NumRows; 1653 const unsigned C = RShape.NumColumns; 1654 const unsigned M = LShape.NumColumns; 1655 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1656 1657 const unsigned VF = std::max<unsigned>( 1658 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1659 .getFixedValue() / 1660 EltType->getPrimitiveSizeInBits().getFixedValue(), 1661 1U); 1662 1663 // Cost model for tiling 1664 // 1665 // For tiling to be beneficial, we need reuse either along the R or 1666 // the C axis. We vectorize along the R axis so that means at least 1667 // 3 elements. 1668 // TODO: Also consider cost of copying if operands alias. 1669 if (R <= VF && C == 1) 1670 return false; 1671 // Then we need enough elements to exceed the number of vector 1672 // registers we have. Note that this is an oversimplification since 1673 // fusing also takes some extra loads which may exceed the number of 1674 // reloads necessary. 1675 unsigned Op0Regs = (R + VF - 1) / VF * M; 1676 unsigned Op1Regs = (M + VF - 1) / VF * C; 1677 return Op0Regs + Op1Regs > 1678 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1679 } 1680 1681 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1682 MatrixTy Res; 1683 auto *ColumType = FixedVectorType::get(EltType, R); 1684 for (unsigned I = 0; I < C; ++I) 1685 Res.addVector(ConstantAggregateZero::get(ColumType)); 1686 return Res; 1687 } 1688 1689 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1690 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1691 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1692 1693 // Create the main tiling loop nest. 1694 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1695 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1696 Instruction *InsertI = cast<Instruction>(MatMul); 1697 BasicBlock *Start = InsertI->getParent(); 1698 BasicBlock *End = 1699 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1700 IRBuilder<> Builder(MatMul); 1701 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1702 1703 Type *TileVecTy = 1704 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1705 MatrixTy TileResult; 1706 // Insert in the inner loop header. 1707 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1708 // Create PHI nodes for the result columns to accumulate across iterations. 1709 SmallVector<PHINode *, 4> ColumnPhis; 1710 for (unsigned I = 0; I < TileSize; I++) { 1711 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1712 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1713 TI.RowLoop.Header->getSingleSuccessor()); 1714 TileResult.addVector(Phi); 1715 ColumnPhis.push_back(Phi); 1716 } 1717 1718 // Insert in the inner loop body, which computes 1719 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1720 Builder.SetInsertPoint(InnerBody->getTerminator()); 1721 // Load tiles of the operands. 1722 MatrixTy A = 1723 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1724 {TileSize, TileSize}, EltType, Builder); 1725 MatrixTy B = 1726 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1727 {TileSize, TileSize}, EltType, Builder); 1728 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1729 getFastMathFlags(MatMul)); 1730 // Store result after the inner loop is done. 1731 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1732 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1733 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1734 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1735 1736 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1737 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1738 1739 // Force unrolling of a few iterations of the inner loop, to make sure there 1740 // is enough work per iteration. 1741 // FIXME: The unroller should make this decision directly instead, but 1742 // currently the cost-model is not up to the task. 1743 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1744 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1745 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1746 } 1747 1748 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1749 StoreInst *Store, 1750 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1751 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1752 "Tiling only supported for column-major matrixes at the moment!"); 1753 if (!isFusionProfitable(MatMul)) 1754 return; 1755 1756 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1757 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1758 1759 const unsigned R = LShape.NumRows; 1760 const unsigned C = RShape.NumColumns; 1761 const unsigned M = LShape.NumColumns; 1762 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1763 1764 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1765 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1766 Value *CPtr = Store->getPointerOperand(); 1767 1768 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1769 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1770 else { 1771 IRBuilder<> Builder(Store); 1772 for (unsigned J = 0; J < C; J += TileSize) 1773 for (unsigned I = 0; I < R; I += TileSize) { 1774 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1775 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1776 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1777 1778 for (unsigned K = 0; K < M; K += TileSize) { 1779 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1780 MatrixTy A = 1781 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1782 LShape, Builder.getInt64(I), Builder.getInt64(K), 1783 {TileR, TileM}, EltType, Builder); 1784 MatrixTy B = 1785 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1786 RShape, Builder.getInt64(K), Builder.getInt64(J), 1787 {TileM, TileC}, EltType, Builder); 1788 emitMatrixMultiply(Res, A, B, Builder, true, false, 1789 getFastMathFlags(MatMul)); 1790 } 1791 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1792 Builder.getInt64(I), Builder.getInt64(J), EltType, 1793 Builder); 1794 } 1795 } 1796 1797 // Mark eliminated instructions as fused and remove them. 1798 FusedInsts.insert(Store); 1799 FusedInsts.insert(MatMul); 1800 Store->eraseFromParent(); 1801 MatMul->eraseFromParent(); 1802 if (LoadOp0->hasNUses(0)) { 1803 FusedInsts.insert(LoadOp0); 1804 LoadOp0->eraseFromParent(); 1805 } 1806 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 1807 FusedInsts.insert(LoadOp1); 1808 LoadOp1->eraseFromParent(); 1809 } 1810 } 1811 1812 /// Try to lower matrix multiply chains by fusing operations. 1813 /// 1814 /// Call finalizeLowering on lowered instructions. Instructions that are 1815 /// completely eliminated by fusion are added to \p FusedInsts. 1816 void LowerMatrixMultiplyFused(CallInst *MatMul, 1817 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1818 if (!FuseMatrix || !DT) 1819 return; 1820 1821 assert(AA && LI && "Analyses should be available"); 1822 1823 Value *A = MatMul->getArgOperand(0); 1824 Value *B = MatMul->getArgOperand(1); 1825 1826 // We can fold the transpose into the operand that is used to fetch scalars. 1827 Value *T; 1828 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1829 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1830 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1831 IRBuilder<> Builder(MatMul); 1832 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1833 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1834 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1835 const unsigned R = LShape.NumRows; 1836 const unsigned M = LShape.NumColumns; 1837 const unsigned C = RShape.NumColumns; 1838 1839 MatrixTy MA; 1840 MatrixTy MB; 1841 1842 Value *Transpose; 1843 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1844 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1845 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1846 Transpose = B; 1847 } else { 1848 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1849 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1850 Transpose = A; 1851 } 1852 1853 // Initialize the output 1854 MatrixTy Result(R, C, EltType); 1855 1856 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1857 getFastMathFlags(MatMul)); 1858 1859 FusedInsts.insert(MatMul); 1860 if (Transpose->hasOneUse()) { 1861 FusedInsts.insert(cast<Instruction>(Transpose)); 1862 ToRemove.push_back(cast<Instruction>(Transpose)); 1863 // TODO: add a fake entry for the folded instruction so that this is 1864 // included in the expression in the remark. 1865 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1866 } 1867 finalizeLowering(MatMul, Result, Builder); 1868 return; 1869 } 1870 1871 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1872 return; 1873 1874 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1875 // since the single store user will be lowered as part of this. 1876 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1877 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1878 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1879 if (LoadOp0 && LoadOp1 && Store) { 1880 // The store address must dominate the MatMul instruction, otherwise 1881 // we create invalid IR. 1882 SetVector<Value *> WorkList; 1883 WorkList.insert(Store->getOperand(1)); 1884 SmallVector<Instruction *> ToHoist; 1885 for (unsigned I = 0; I != WorkList.size(); ++I) { 1886 Value *Current = WorkList[I]; 1887 auto *CurrI = dyn_cast<Instruction>(Current); 1888 if (!CurrI) 1889 continue; 1890 if (isa<PHINode>(CurrI)) 1891 return; 1892 if (DT->dominates(CurrI, MatMul)) 1893 continue; 1894 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1895 return; 1896 ToHoist.push_back(CurrI); 1897 WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1898 } 1899 1900 sort(ToHoist, [this](Instruction *A, Instruction *B) { 1901 return DT->dominates(A, B); 1902 }); 1903 for (Instruction *I : ToHoist) 1904 I->moveBefore(MatMul); 1905 1906 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1907 return; 1908 } 1909 } 1910 1911 /// Lowers llvm.matrix.multiply. 1912 void LowerMultiply(CallInst *MatMul) { 1913 IRBuilder<> Builder(MatMul); 1914 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1915 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1916 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1917 1918 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1919 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1920 assert(Lhs.getElementType() == Rhs.getElementType() && 1921 "Matrix multiply argument element types do not match."); 1922 1923 const unsigned R = LShape.NumRows; 1924 const unsigned C = RShape.NumColumns; 1925 assert(LShape.NumColumns == RShape.NumRows); 1926 1927 // Initialize the output 1928 MatrixTy Result(R, C, EltType); 1929 assert(Lhs.getElementType() == Result.getElementType() && 1930 "Matrix multiply result element type does not match arguments."); 1931 1932 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1933 getFastMathFlags(MatMul)); 1934 finalizeLowering(MatMul, Result, Builder); 1935 } 1936 1937 /// Lowers llvm.matrix.transpose. 1938 void LowerTranspose(CallInst *Inst) { 1939 MatrixTy Result; 1940 IRBuilder<> Builder(Inst); 1941 Value *InputVal = Inst->getArgOperand(0); 1942 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1943 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1944 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1945 1946 const unsigned NewNumVecs = 1947 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1948 const unsigned NewNumElts = 1949 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1950 1951 for (unsigned I = 0; I < NewNumVecs; ++I) { 1952 // Build a single result vector. First initialize it. 1953 Value *ResultVector = PoisonValue::get( 1954 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1955 // Go through the old elements and insert it into the resulting vector. 1956 for (auto J : enumerate(InputMatrix.vectors())) { 1957 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1958 // Row and column indices are transposed. 1959 ResultVector = 1960 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1961 } 1962 Result.addVector(ResultVector); 1963 } 1964 1965 // TODO: Improve estimate of operations needed for transposes. Currently we 1966 // just count the insertelement/extractelement instructions, but do not 1967 // account for later simplifications/combines. 1968 finalizeLowering( 1969 Inst, 1970 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1971 .addNumExposedTransposes(1), 1972 Builder); 1973 } 1974 1975 /// Lower load instructions, if shape information is available. 1976 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1977 auto I = ShapeMap.find(Inst); 1978 if (I == ShapeMap.end()) 1979 return false; 1980 1981 LowerLoad(Inst, Ptr, Inst->getAlign(), 1982 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1983 I->second); 1984 return true; 1985 } 1986 1987 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1988 IRBuilder<> &Builder) { 1989 auto I = ShapeMap.find(StoredVal); 1990 if (I == ShapeMap.end()) 1991 return false; 1992 1993 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1994 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1995 I->second); 1996 return true; 1997 } 1998 1999 /// Lower binary operators, if shape information is available. 2000 bool VisitBinaryOperator(BinaryOperator *Inst) { 2001 auto I = ShapeMap.find(Inst); 2002 if (I == ShapeMap.end()) 2003 return false; 2004 2005 Value *Lhs = Inst->getOperand(0); 2006 Value *Rhs = Inst->getOperand(1); 2007 2008 IRBuilder<> Builder(Inst); 2009 ShapeInfo &Shape = I->second; 2010 2011 MatrixTy Result; 2012 MatrixTy A = getMatrix(Lhs, Shape, Builder); 2013 MatrixTy B = getMatrix(Rhs, Shape, Builder); 2014 assert(A.isColumnMajor() == B.isColumnMajor() && 2015 Result.isColumnMajor() == A.isColumnMajor() && 2016 "operands must agree on matrix layout"); 2017 2018 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2019 2020 // Helper to perform binary op on vectors. 2021 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 2022 switch (Inst->getOpcode()) { 2023 case Instruction::Add: 2024 return Builder.CreateAdd(LHS, RHS); 2025 case Instruction::Mul: 2026 return Builder.CreateMul(LHS, RHS); 2027 case Instruction::Sub: 2028 return Builder.CreateSub(LHS, RHS); 2029 case Instruction::FAdd: 2030 return Builder.CreateFAdd(LHS, RHS); 2031 case Instruction::FMul: 2032 return Builder.CreateFMul(LHS, RHS); 2033 case Instruction::FSub: 2034 return Builder.CreateFSub(LHS, RHS); 2035 default: 2036 llvm_unreachable("Unsupported binary operator for matrix"); 2037 } 2038 }; 2039 2040 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2041 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 2042 2043 finalizeLowering(Inst, 2044 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2045 Result.getNumVectors()), 2046 Builder); 2047 return true; 2048 } 2049 2050 /// Lower unary operators, if shape information is available. 2051 bool VisitUnaryOperator(UnaryOperator *Inst) { 2052 auto I = ShapeMap.find(Inst); 2053 if (I == ShapeMap.end()) 2054 return false; 2055 2056 Value *Op = Inst->getOperand(0); 2057 2058 IRBuilder<> Builder(Inst); 2059 ShapeInfo &Shape = I->second; 2060 2061 MatrixTy Result; 2062 MatrixTy M = getMatrix(Op, Shape, Builder); 2063 2064 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2065 2066 // Helper to perform unary op on vectors. 2067 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 2068 switch (Inst->getOpcode()) { 2069 case Instruction::FNeg: 2070 return Builder.CreateFNeg(Op); 2071 default: 2072 llvm_unreachable("Unsupported unary operator for matrix"); 2073 } 2074 }; 2075 2076 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2077 Result.addVector(BuildVectorOp(M.getVector(I))); 2078 2079 finalizeLowering(Inst, 2080 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2081 Result.getNumVectors()), 2082 Builder); 2083 return true; 2084 } 2085 2086 /// Helper to linearize a matrix expression tree into a string. Currently 2087 /// matrix expressions are linarized by starting at an expression leaf and 2088 /// linearizing bottom up. 2089 struct ExprLinearizer { 2090 unsigned LengthToBreak = 100; 2091 std::string Str; 2092 raw_string_ostream Stream; 2093 unsigned LineLength = 0; 2094 const DataLayout &DL; 2095 2096 /// Mapping from instructions to matrixes. It is used to identify 2097 /// matrix instructions. 2098 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2099 2100 /// Mapping from values to the leaves of all expressions that the value is 2101 /// part of. 2102 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 2103 2104 /// Set of matrix expressions in the scope of a given DISubprogram. 2105 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 2106 2107 /// Leaf node of the expression to linearize. 2108 Value *Leaf; 2109 2110 /// Used to keep track of sub-expressions that get reused while linearizing 2111 /// the expression. Re-used sub-expressions are marked as (reused). 2112 SmallPtrSet<Value *, 8> ReusedExprs; 2113 2114 ExprLinearizer(const DataLayout &DL, 2115 const MapVector<Value *, MatrixTy> &Inst2Matrix, 2116 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2117 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2118 Value *Leaf) 2119 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 2120 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 2121 2122 void indent(unsigned N) { 2123 LineLength += N; 2124 for (unsigned i = 0; i < N; i++) 2125 Stream << " "; 2126 } 2127 2128 void lineBreak() { 2129 Stream << "\n"; 2130 LineLength = 0; 2131 } 2132 2133 void maybeIndent(unsigned Indent) { 2134 if (LineLength >= LengthToBreak) 2135 lineBreak(); 2136 2137 if (LineLength == 0) 2138 indent(Indent); 2139 } 2140 2141 void write(StringRef S) { 2142 LineLength += S.size(); 2143 Stream << S; 2144 } 2145 2146 Value *getUnderlyingObjectThroughLoads(Value *V) { 2147 if (Value *Ptr = getPointerOperand(V)) 2148 return getUnderlyingObjectThroughLoads(Ptr); 2149 else if (V->getType()->isPointerTy()) 2150 return getUnderlyingObject(V); 2151 return V; 2152 } 2153 2154 /// Returns true if \p V is a matrix value in the given subprogram. 2155 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 2156 2157 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to 2158 /// \p SS. 2159 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 2160 auto M = Inst2Matrix.find(V); 2161 if (M == Inst2Matrix.end()) 2162 SS << "unknown"; 2163 else { 2164 SS << M->second.getNumRows(); 2165 SS << "x"; 2166 SS << M->second.getNumColumns(); 2167 } 2168 } 2169 2170 /// Write the called function name. Handles calls to llvm.matrix.* 2171 /// specially: we write the name, followed by the dimensions of the input 2172 /// matrixes, followed by the scalar type name. 2173 void writeFnName(CallInst *CI) { 2174 if (!CI->getCalledFunction()) 2175 write("<no called fn>"); 2176 else { 2177 StringRef Name = CI->getCalledFunction()->getName(); 2178 if (!Name.starts_with("llvm.matrix")) { 2179 write(Name); 2180 return; 2181 } 2182 auto *II = cast<IntrinsicInst>(CI); 2183 write(Intrinsic::getBaseName(II->getIntrinsicID()) 2184 .drop_front(StringRef("llvm.matrix.").size())); 2185 write("."); 2186 std::string Tmp; 2187 raw_string_ostream SS(Tmp); 2188 2189 switch (II->getIntrinsicID()) { 2190 case Intrinsic::matrix_multiply: 2191 prettyPrintMatrixType(II->getOperand(0), SS); 2192 SS << "."; 2193 prettyPrintMatrixType(II->getOperand(1), SS); 2194 SS << "." << *II->getType()->getScalarType(); 2195 break; 2196 case Intrinsic::matrix_transpose: 2197 prettyPrintMatrixType(II->getOperand(0), SS); 2198 SS << "." << *II->getType()->getScalarType(); 2199 break; 2200 case Intrinsic::matrix_column_major_load: 2201 prettyPrintMatrixType(II, SS); 2202 SS << "." << *II->getType()->getScalarType(); 2203 break; 2204 case Intrinsic::matrix_column_major_store: 2205 prettyPrintMatrixType(II->getOperand(0), SS); 2206 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 2207 break; 2208 default: 2209 llvm_unreachable("Unhandled case"); 2210 } 2211 SS.flush(); 2212 write(Tmp); 2213 } 2214 } 2215 2216 unsigned getNumShapeArgs(CallInst *CI) const { 2217 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 2218 switch (II->getIntrinsicID()) { 2219 case Intrinsic::matrix_multiply: 2220 return 3; 2221 case Intrinsic::matrix_transpose: 2222 return 2; 2223 case Intrinsic::matrix_column_major_load: 2224 case Intrinsic::matrix_column_major_store: 2225 return 3; 2226 default: 2227 return 0; 2228 } 2229 } 2230 return 0; 2231 } 2232 2233 /// Special printing for values: for pointers, we print if they refer to an 2234 /// (function) external address or a stack address, for other values we 2235 /// either print the constant or "scalar"/"matrix" for other values. 2236 void write(Value *V) { 2237 V = getUnderlyingObjectThroughLoads(V); 2238 if (V->getType()->isPointerTy()) { 2239 if (isa<AllocaInst>(V)) { 2240 Stream << "stack addr"; 2241 LineLength += StringRef("stack addr").size(); 2242 } else { 2243 Stream << "addr"; 2244 LineLength += StringRef("addr").size(); 2245 } 2246 if (!V->getName().empty()) { 2247 Stream << " %" << V->getName() << ""; 2248 LineLength += V->getName().size() + 2; 2249 } 2250 return; 2251 } 2252 2253 std::string Tmp; 2254 raw_string_ostream TmpStream(Tmp); 2255 2256 if (auto *CI = dyn_cast<ConstantInt>(V)) 2257 TmpStream << CI->getValue(); 2258 else if (isa<Constant>(V)) 2259 TmpStream << "constant"; 2260 else { 2261 if (isMatrix(V)) 2262 TmpStream << "matrix"; 2263 else 2264 TmpStream << "scalar"; 2265 } 2266 TmpStream.flush(); 2267 Tmp = std::string(StringRef(Tmp).trim()); 2268 LineLength += Tmp.size(); 2269 Stream << Tmp; 2270 } 2271 2272 /// Linearize expression \p Expr starting at an indentation of \p Indent. 2273 /// Expressions that are re-used multiple times are prefixed with (reused) 2274 /// at the re-used root instruction. 2275 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 2276 bool ParentShared) { 2277 auto *I = cast<Instruction>(Expr); 2278 maybeIndent(Indent); 2279 SmallVector<Value *, 8> Ops; 2280 2281 // Is Expr shared with other expression leaves? 2282 bool ExprShared = false; 2283 2284 // Deal with shared subtrees. Mark them as shared, if required. 2285 if (!ParentShared) { 2286 auto SI = Shared.find(Expr); 2287 assert(SI != Shared.end() && SI->second.count(Leaf)); 2288 2289 for (Value *S : SI->second) { 2290 if (S == Leaf) 2291 continue; 2292 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2293 write("shared with remark at line " + std::to_string(DL.getLine()) + 2294 " column " + std::to_string(DL.getCol()) + " ("); 2295 } 2296 ExprShared = SI->second.size() > 1; 2297 } 2298 2299 bool Reused = !ReusedExprs.insert(Expr).second; 2300 if (Reused && !ParentReused) 2301 write("(reused) "); 2302 2303 if (auto *CI = dyn_cast<CallInst>(I)) { 2304 writeFnName(CI); 2305 2306 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2307 } else if (isa<BitCastInst>(Expr)) { 2308 // Special case bitcasts, which are used to materialize matrixes from 2309 // non-matrix ops. 2310 write("matrix"); 2311 return; 2312 } else { 2313 Ops.append(I->value_op_begin(), I->value_op_end()); 2314 write(std::string(I->getOpcodeName())); 2315 } 2316 2317 write(std::string("(")); 2318 2319 unsigned NumOpsToBreak = 1; 2320 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2321 NumOpsToBreak = 2; 2322 2323 for (Value *Op : Ops) { 2324 if (Ops.size() > NumOpsToBreak) 2325 lineBreak(); 2326 2327 maybeIndent(Indent + 1); 2328 if (isMatrix(Op)) 2329 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2330 else 2331 write(Op); 2332 if (Op != Ops.back()) 2333 write(", "); 2334 } 2335 2336 write(")"); 2337 } 2338 2339 const std::string &getResult() { 2340 Stream.flush(); 2341 return Str; 2342 } 2343 }; 2344 2345 /// Generate remarks for matrix operations in a function. To generate remarks 2346 /// for matrix expressions, the following approach is used: 2347 /// 1. Use the inlined-at debug information to group matrix operations to the 2348 /// DISubprograms they are contained in. 2349 /// 2. Collect leaves of matrix expressions (done in 2350 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2351 // mapping. Leaves are lowered matrix instructions without other matrix 2352 // users (like stores) in the current subprogram. 2353 /// 3. For each leaf, create a remark containing a linearizied version of the 2354 /// matrix expression. The expression is linearized by a recursive 2355 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2356 /// that multiple leaves can share sub-expressions. Shared subexpressions 2357 /// are explicitly marked as shared(). 2358 struct RemarkGenerator { 2359 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2360 OptimizationRemarkEmitter &ORE; 2361 Function &Func; 2362 const DataLayout &DL; 2363 2364 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2365 OptimizationRemarkEmitter &ORE, Function &Func) 2366 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2367 DL(Func.getParent()->getDataLayout()) {} 2368 2369 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2370 /// instructions in Inst2Matrix returning void or without any users in 2371 /// \p ExprsInSubprogram. Currently that should only include stores. 2372 SmallVector<Value *, 4> 2373 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2374 SmallVector<Value *, 4> Leaves; 2375 for (auto *Expr : ExprsInSubprogram) 2376 if (Expr->getType()->isVoidTy() || 2377 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2378 return ExprsInSubprogram.count(U); 2379 })) 2380 Leaves.push_back(Expr); 2381 return Leaves; 2382 } 2383 2384 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2385 /// to all visited expressions in \p Shared. Limit the matrix operations to 2386 /// the ones in \p ExprsInSubprogram. 2387 void collectSharedInfo(Value *Leaf, Value *V, 2388 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2389 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2390 2391 if (!ExprsInSubprogram.count(V)) 2392 return; 2393 2394 auto I = Shared.insert({V, {}}); 2395 I.first->second.insert(Leaf); 2396 2397 for (Value *Op : cast<Instruction>(V)->operand_values()) 2398 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2399 } 2400 2401 /// Calculate the number of exclusive and shared op counts for expression 2402 /// starting at \p V. Expressions used multiple times are counted once. 2403 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2404 std::pair<OpInfoTy, OpInfoTy> 2405 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2406 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2407 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2408 if (!ExprsInSubprogram.count(Root)) 2409 return {}; 2410 2411 // Already counted this expression. Stop. 2412 if (!ReusedExprs.insert(Root).second) 2413 return {}; 2414 2415 OpInfoTy SharedCount; 2416 OpInfoTy Count; 2417 2418 auto I = Shared.find(Root); 2419 auto CM = Inst2Matrix.find(Root); 2420 if (I->second.size() == 1) 2421 Count = CM->second.getOpInfo(); 2422 else 2423 SharedCount = CM->second.getOpInfo(); 2424 2425 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2426 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2427 Count += C.first; 2428 SharedCount += C.second; 2429 } 2430 return {Count, SharedCount}; 2431 } 2432 2433 void emitRemarks() { 2434 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2435 return; 2436 2437 // Map matrix operations to their containting subprograms, by traversing 2438 // the inlinedAt chain. If the function does not have a DISubprogram, we 2439 // only map them to the containing function. 2440 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2441 for (const auto &KV : Inst2Matrix) { 2442 if (Func.getSubprogram()) { 2443 auto *I = cast<Instruction>(KV.first); 2444 DILocation *Context = I->getDebugLoc(); 2445 while (Context) { 2446 auto I = 2447 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 2448 I.first->second.push_back(KV.first); 2449 Context = DebugLoc(Context).getInlinedAt(); 2450 } 2451 } else { 2452 auto I = Subprog2Exprs.insert({nullptr, {}}); 2453 I.first->second.push_back(KV.first); 2454 } 2455 } 2456 for (auto &KV : Subprog2Exprs) { 2457 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2458 KV.second.end()); 2459 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2460 2461 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2462 for (Value *Leaf : Leaves) 2463 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2464 2465 // Generate remarks for each leaf. 2466 for (auto *L : Leaves) { 2467 2468 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2469 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2470 while (Context) { 2471 if (getSubprogram(Context->getScope()) == KV.first) { 2472 Loc = Context; 2473 break; 2474 } 2475 Context = DebugLoc(Context).getInlinedAt(); 2476 } 2477 2478 SmallPtrSet<Value *, 8> ReusedExprs; 2479 OpInfoTy Counts, SharedCounts; 2480 std::tie(Counts, SharedCounts) = 2481 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2482 2483 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2484 cast<Instruction>(L)->getParent()); 2485 2486 Rem << "Lowered with "; 2487 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2488 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2489 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2490 << " compute ops, " 2491 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2492 << " exposed transposes"; 2493 2494 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2495 SharedCounts.NumComputeOps > 0) { 2496 Rem << ",\nadditionally " 2497 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2498 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2499 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2500 << " compute ops" 2501 << " are shared with other expressions"; 2502 } 2503 2504 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2505 ORE.emit(Rem); 2506 } 2507 } 2508 } 2509 2510 std::string 2511 linearize(Value *L, 2512 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2513 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2514 const DataLayout &DL) { 2515 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2516 Lin.linearizeExpr(L, 0, false, false); 2517 return Lin.getResult(); 2518 } 2519 }; 2520 }; 2521 } // namespace 2522 2523 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2524 FunctionAnalysisManager &AM) { 2525 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2526 OptimizationRemarkEmitter *ORE = nullptr; 2527 AAResults *AA = nullptr; 2528 DominatorTree *DT = nullptr; 2529 LoopInfo *LI = nullptr; 2530 2531 if (!Minimal) { 2532 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2533 AA = &AM.getResult<AAManager>(F); 2534 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2535 LI = &AM.getResult<LoopAnalysis>(F); 2536 } 2537 2538 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2539 if (LMT.Visit()) { 2540 PreservedAnalyses PA; 2541 if (!Minimal) { 2542 PA.preserve<LoopAnalysis>(); 2543 PA.preserve<DominatorTreeAnalysis>(); 2544 } 2545 return PA; 2546 } 2547 return PreservedAnalyses::all(); 2548 } 2549 2550 void LowerMatrixIntrinsicsPass::printPipeline( 2551 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2552 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2553 OS, MapClassName2PassName); 2554 OS << '<'; 2555 if (Minimal) 2556 OS << "minimal"; 2557 OS << '>'; 2558 } 2559