1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "llvm/Analysis/AliasAnalysis.h" 24 #include "llvm/Analysis/DomTreeUpdater.h" 25 #include "llvm/Analysis/LoopInfo.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/MatrixBuilder.h" 38 #include "llvm/IR/PatternMatch.h" 39 #include "llvm/InitializePasses.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Alignment.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Scalar.h" 45 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46 #include "llvm/Transforms/Utils/LoopUtils.h" 47 #include "llvm/Transforms/Utils/MatrixUtils.h" 48 49 #include <cmath> 50 51 using namespace llvm; 52 using namespace PatternMatch; 53 54 #define DEBUG_TYPE "lower-matrix-intrinsics" 55 56 static cl::opt<bool> 57 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 58 cl::desc("Enable/disable fusing matrix instructions.")); 59 // TODO: Allow and use non-square tiles. 60 static cl::opt<unsigned> TileSize( 61 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 62 cl::desc( 63 "Tile size for matrix instruction fusion using square-shaped tiles.")); 64 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 65 cl::Hidden, 66 cl::desc("Generate loop nest for tiling.")); 67 static cl::opt<bool> ForceFusion( 68 "force-fuse-matrix", cl::init(false), cl::Hidden, 69 cl::desc("Force matrix instruction fusion even if not profitable.")); 70 static cl::opt<bool> AllowContractEnabled( 71 "matrix-allow-contract", cl::init(false), cl::Hidden, 72 cl::desc("Allow the use of FMAs if available and profitable. This may " 73 "result in different results, due to less rounding error.")); 74 75 static cl::opt<bool> 76 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, 77 cl::desc("Enable/disable matrix shape verification."), 78 cl::init(false)); 79 80 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 81 82 static cl::opt<MatrixLayoutTy> MatrixLayout( 83 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 84 cl::desc("Sets the default matrix layout"), 85 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 86 "Use column-major layout"), 87 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 88 "Use row-major layout"))); 89 90 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", 91 cl::init(false)); 92 93 /// Helper function to either return Scope, if it is a subprogram or the 94 /// attached subprogram for a local scope. 95 static DISubprogram *getSubprogram(DIScope *Scope) { 96 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 97 return Subprogram; 98 return cast<DILocalScope>(Scope)->getSubprogram(); 99 } 100 101 /// Erase \p V from \p BB and move \II forward to avoid invalidating 102 /// iterators. 103 static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, 104 BasicBlock &BB) { 105 auto *Inst = cast<Instruction>(V); 106 // Still used, don't erase. 107 if (!Inst->use_empty()) 108 return; 109 if (II != BB.rend() && Inst == &*II) 110 ++II; 111 Inst->eraseFromParent(); 112 } 113 114 /// Return true if V is a splat of a value (which is used when multiplying a 115 /// matrix with a scalar). 116 static bool isSplat(Value *V) { 117 if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) 118 return SV->isZeroEltSplat(); 119 return false; 120 } 121 122 /// Match any mul operation (fp or integer). 123 template <typename LTy, typename RTy> 124 auto m_AnyMul(const LTy &L, const RTy &R) { 125 return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); 126 } 127 128 /// Match any add operation (fp or integer). 129 template <typename LTy, typename RTy> 130 auto m_AnyAdd(const LTy &L, const RTy &R) { 131 return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); 132 } 133 134 namespace { 135 136 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 137 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 138 // assuming \p Stride elements between start two consecutive vectors. 139 // \p Stride must be >= \p NumElements. 140 // For column-major matrixes, the function computes the address of a column 141 // vectors and \p NumElements must be set to the number of elements in a column 142 // (= number of rows of the matrix). For row-major matrixes, the function 143 // computes the address of a row vector and \p NumElements must be set to the 144 // number of elements in a column (= number of columns of the matrix). 145 // 146 // Consider a 4x4 matrix in column-mjaor layout like below 147 // 148 // 0 1 2 3 149 // 0 v_0_0 v_0_1 v_0_2 v_0_3 150 // 1 v_1_0 v_1_1 v_1_2 v_1_3 151 // 2 v_2_0 v_2_1 v_2_2 v_2_3 152 // 3 v_3_0 v_3_1 v_3_2 v_3_3 153 154 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 155 // we need a pointer to the first element of the submatrix as base pointer. 156 // Then we can use computeVectorAddr to compute the addresses for the columns 157 // of the sub-matrix. 158 // 159 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 160 // -> just returns Base 161 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 162 // -> returns Base + (1 * 4) 163 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 164 // -> returns Base + (2 * 4) 165 // 166 // The graphic below illustrates the number of elements in a column (marked 167 // with |) and the number of skipped elements (marked with }). 168 // 169 // v_0_0 v_0_1 {v_0_2 {v_0_3 170 // Base Col 1 Col 2 171 // | | | 172 // v_1_0 |v_1_1 |v_1_2 |v_1_3 173 // v_2_0 |v_2_1 |v_2_2 |v_2_3 174 // v_3_0 {v_3_1 {v_3_2 v_3_3 175 // 176 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 177 unsigned NumElements, Type *EltType, 178 IRBuilder<> &Builder) { 179 180 assert((!isa<ConstantInt>(Stride) || 181 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 182 "Stride must be >= the number of elements in the result vector."); 183 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 184 185 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 186 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 187 188 // Get pointer to the start of the selected vector. Skip GEP creation, 189 // if we select vector 0. 190 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 191 VecStart = BasePtr; 192 else 193 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 194 195 // Cast elementwise vector start pointer to a pointer to a vector 196 // (EltType x NumElements)*. 197 auto *VecType = FixedVectorType::get(EltType, NumElements); 198 Type *VecPtrType = PointerType::get(VecType, AS); 199 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 200 } 201 202 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 203 /// 204 /// Currently, the lowering for each matrix intrinsic is done as follows: 205 /// 1. Propagate the shape information from intrinsics to connected 206 /// instructions. 207 /// 2. Lower instructions with shape information (assuming column-major layout). 208 /// The lowering works similarly using row-major layout. 209 /// 2.1. Get column vectors for each argument. If we already lowered the 210 /// definition of an argument, use the produced column vectors directly. 211 /// If not, split the operand vector containing an embedded matrix into 212 /// a set of column vectors, 213 /// 2.2. Lower the instruction in terms of column major operations, which 214 /// yields a set of column vectors containing result matrix. Note that we 215 /// lower all instructions that have shape information. Besides the 216 /// intrinsics, this includes stores for example. 217 /// 2.3. Update uses of the lowered instruction. If we have shape information 218 /// for a user, there is nothing to do, as we will look up the result 219 /// column matrix when lowering the user. For other uses, we embed the 220 /// result matrix in a flat vector and update the use. 221 /// 2.4. Cache the result column matrix for the instruction we lowered 222 /// 3. After we lowered all instructions in a function, remove the now 223 /// obsolete instructions. 224 /// 225 class LowerMatrixIntrinsics { 226 Function &Func; 227 const DataLayout &DL; 228 const TargetTransformInfo &TTI; 229 AliasAnalysis *AA; 230 DominatorTree *DT; 231 LoopInfo *LI; 232 OptimizationRemarkEmitter *ORE; 233 234 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 235 struct OpInfoTy { 236 /// Number of stores emitted to generate this matrix. 237 unsigned NumStores = 0; 238 /// Number of loads emitted to generate this matrix. 239 unsigned NumLoads = 0; 240 /// Number of compute operations emitted to generate this matrix. 241 unsigned NumComputeOps = 0; 242 /// Most of the time transposes can be fused with matrix multiplies or can 243 /// be folded away via algebraic simplifications. This is the number of 244 /// transposes that we failed to make "free" via such optimizations. 245 unsigned NumExposedTransposes = 0; 246 247 OpInfoTy &operator+=(const OpInfoTy &RHS) { 248 NumStores += RHS.NumStores; 249 NumLoads += RHS.NumLoads; 250 NumComputeOps += RHS.NumComputeOps; 251 NumExposedTransposes += RHS.NumExposedTransposes; 252 return *this; 253 } 254 }; 255 256 /// Wrapper class representing a matrix as a set of vectors, either in row or 257 /// column major layout. All vectors must have the same vector type. 258 class MatrixTy { 259 SmallVector<Value *, 16> Vectors; 260 261 OpInfoTy OpInfo; 262 263 bool IsColumnMajor = true; 264 265 public: 266 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 267 MatrixTy(ArrayRef<Value *> Vectors) 268 : Vectors(Vectors.begin(), Vectors.end()), 269 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 270 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 271 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 272 273 unsigned D = isColumnMajor() ? NumColumns : NumRows; 274 for (unsigned J = 0; J < D; ++J) 275 addVector(PoisonValue::get(FixedVectorType::get( 276 EltTy, isColumnMajor() ? NumRows : NumColumns))); 277 } 278 279 Value *getVector(unsigned i) const { return Vectors[i]; } 280 Value *getColumn(unsigned i) const { 281 assert(isColumnMajor() && "only supported for column-major matrixes"); 282 return Vectors[i]; 283 } 284 Value *getRow(unsigned i) const { 285 assert(!isColumnMajor() && "only supported for row-major matrixes"); 286 return Vectors[i]; 287 } 288 289 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 290 291 Type *getElementType() const { return getVectorTy()->getElementType(); } 292 293 unsigned getNumVectors() const { 294 if (isColumnMajor()) 295 return getNumColumns(); 296 return getNumRows(); 297 } 298 299 unsigned getNumColumns() const { 300 if (isColumnMajor()) 301 return Vectors.size(); 302 else { 303 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 304 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 305 } 306 } 307 unsigned getNumRows() const { 308 if (isColumnMajor()) { 309 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 310 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 311 } else 312 return Vectors.size(); 313 } 314 315 void addVector(Value *V) { Vectors.push_back(V); } 316 VectorType *getColumnTy() { 317 assert(isColumnMajor() && "only supported for column-major matrixes"); 318 return getVectorTy(); 319 } 320 321 VectorType *getVectorTy() const { 322 return cast<VectorType>(Vectors[0]->getType()); 323 } 324 325 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 326 assert(isColumnMajor() && 327 "columns() only supported for column-major matrixes"); 328 return make_range(Vectors.begin(), Vectors.end()); 329 } 330 331 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 332 return make_range(Vectors.begin(), Vectors.end()); 333 } 334 335 /// Embed the vectors of the matrix into a flat vector by concatenating 336 /// them. 337 Value *embedInVector(IRBuilder<> &Builder) const { 338 return Vectors.size() == 1 ? Vectors[0] 339 : concatenateVectors(Builder, Vectors); 340 } 341 342 MatrixTy &addNumLoads(unsigned N) { 343 OpInfo.NumLoads += N; 344 return *this; 345 } 346 347 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 348 349 MatrixTy &addNumStores(unsigned N) { 350 OpInfo.NumStores += N; 351 return *this; 352 } 353 354 MatrixTy &addNumExposedTransposes(unsigned N) { 355 OpInfo.NumExposedTransposes += N; 356 return *this; 357 } 358 359 MatrixTy &addNumComputeOps(unsigned N) { 360 OpInfo.NumComputeOps += N; 361 return *this; 362 } 363 364 unsigned getNumStores() const { return OpInfo.NumStores; } 365 unsigned getNumLoads() const { return OpInfo.NumLoads; } 366 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 367 368 const OpInfoTy &getOpInfo() const { return OpInfo; } 369 370 bool isColumnMajor() const { return IsColumnMajor; } 371 372 unsigned getStride() const { 373 if (isColumnMajor()) 374 return getNumRows(); 375 return getNumColumns(); 376 } 377 378 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 379 /// matrix is column-major, the result vector is extracted from a column 380 /// vector, otherwise from a row vector. 381 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 382 IRBuilder<> &Builder) const { 383 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 384 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 385 NumElts && 386 "Extracted vector will contain poison values"); 387 return Builder.CreateShuffleVector( 388 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 389 "block"); 390 } 391 }; 392 393 struct ShapeInfo { 394 unsigned NumRows; 395 unsigned NumColumns; 396 397 bool IsColumnMajor; 398 399 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 400 : NumRows(NumRows), NumColumns(NumColumns), 401 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 402 403 ShapeInfo(Value *NumRows, Value *NumColumns) 404 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 405 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 406 407 bool operator==(const ShapeInfo &other) { 408 return NumRows == other.NumRows && NumColumns == other.NumColumns; 409 } 410 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 411 412 /// Returns true if shape-information is defined, meaning both dimensions 413 /// are != 0. 414 operator bool() const { 415 assert(NumRows == 0 || NumColumns != 0); 416 return NumRows != 0; 417 } 418 419 unsigned getStride() const { 420 if (IsColumnMajor) 421 return NumRows; 422 return NumColumns; 423 } 424 425 unsigned getNumVectors() const { 426 if (IsColumnMajor) 427 return NumColumns; 428 return NumRows; 429 } 430 431 /// Returns the transposed shape. 432 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } 433 }; 434 435 /// Maps instructions to their shape information. The shape information 436 /// describes the shape to be used while lowering. This matches the shape of 437 /// the result value of the instruction, with the only exceptions being store 438 /// instructions and the matrix_column_major_store intrinsics. For those, the 439 /// shape information indicates that those instructions should be lowered 440 /// using shape information as well. A ValueMap is used so that when 441 /// sub-passes like optimizeTransposes performs RAUW the map stays 442 /// up-to-date. 443 ValueMap<Value *, ShapeInfo> ShapeMap; 444 445 /// List of instructions to remove. While lowering, we are not replacing all 446 /// users of a lowered instruction, if shape information is available and 447 /// those need to be removed after we finished lowering. 448 SmallVector<Instruction *, 16> ToRemove; 449 450 /// Map from instructions to their produced column matrix. 451 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 452 453 private: 454 static FastMathFlags getFastMathFlags(Instruction *Inst) { 455 FastMathFlags FMF; 456 457 if (isa<FPMathOperator>(*Inst)) 458 FMF = Inst->getFastMathFlags(); 459 460 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 461 462 return FMF; 463 } 464 465 public: 466 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 467 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 468 OptimizationRemarkEmitter *ORE) 469 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 470 LI(LI), ORE(ORE) {} 471 472 unsigned getNumOps(Type *VT) { 473 assert(isa<VectorType>(VT) && "Expected vector type"); 474 return getNumOps(VT->getScalarType(), 475 cast<FixedVectorType>(VT)->getNumElements()); 476 } 477 478 /// Is this the minimal version executed in the backend pipelines. 479 bool isMinimal() const { 480 return !DT; 481 } 482 483 /// Return the estimated number of vector ops required for an operation on 484 /// \p VT * N. 485 unsigned getNumOps(Type *ST, unsigned N) { 486 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / 487 double(TTI.getRegisterBitWidth( 488 TargetTransformInfo::RGK_FixedWidthVector) 489 .getFixedValue())); 490 } 491 492 /// Return the set of vectors that a matrix value is lowered to. 493 /// 494 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 495 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 496 /// into vectors. 497 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 498 IRBuilder<> &Builder) { 499 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 500 assert(VType && "MatrixVal must be a vector type"); 501 assert(cast<FixedVectorType>(VType)->getNumElements() == 502 SI.NumRows * SI.NumColumns && 503 "The vector size must match the number of matrix elements"); 504 505 // Check if we lowered MatrixVal using shape information. In that case, 506 // return the existing matrix, if it matches the requested shape 507 // information. If there is a mis-match, embed the result in a flat 508 // vector and split it later. 509 auto Found = Inst2ColumnMatrix.find(MatrixVal); 510 if (Found != Inst2ColumnMatrix.end()) { 511 MatrixTy &M = Found->second; 512 // Return the found matrix, if its shape matches the requested shape 513 // information 514 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 515 return M; 516 517 MatrixVal = M.embedInVector(Builder); 518 } 519 520 // Otherwise split MatrixVal. 521 SmallVector<Value *, 16> SplitVecs; 522 for (unsigned MaskStart = 0; 523 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 524 MaskStart += SI.getStride()) { 525 Value *V = Builder.CreateShuffleVector( 526 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 527 "split"); 528 SplitVecs.push_back(V); 529 } 530 531 return {SplitVecs}; 532 } 533 534 /// If \p V already has a known shape return false. Otherwise set the shape 535 /// for instructions that support it. 536 bool setShapeInfo(Value *V, ShapeInfo Shape) { 537 assert(Shape && "Shape not set"); 538 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 539 return false; 540 541 auto SIter = ShapeMap.find(V); 542 if (SIter != ShapeMap.end()) { 543 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || 544 SIter->second.NumColumns != Shape.NumColumns)) { 545 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" 546 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" 547 << Shape.NumColumns << ") for " << *V << "\n"; 548 report_fatal_error( 549 "Matrix shape verification failed, compilation aborted!"); 550 } 551 552 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 553 << SIter->second.NumRows << " " 554 << SIter->second.NumColumns << " for " << *V << "\n"); 555 return false; 556 } 557 558 ShapeMap.insert({V, Shape}); 559 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 560 << " for " << *V << "\n"); 561 return true; 562 } 563 564 bool isUniformShape(Value *V) { 565 Instruction *I = dyn_cast<Instruction>(V); 566 if (!I) 567 return true; 568 569 switch (I->getOpcode()) { 570 case Instruction::FAdd: 571 case Instruction::FSub: 572 case Instruction::FMul: // Scalar multiply. 573 case Instruction::FNeg: 574 case Instruction::Add: 575 case Instruction::Mul: 576 case Instruction::Sub: 577 return true; 578 default: 579 return false; 580 } 581 } 582 583 /// Returns true if shape information can be used for \p V. The supported 584 /// instructions must match the instructions that can be lowered by this pass. 585 bool supportsShapeInfo(Value *V) { 586 Instruction *Inst = dyn_cast<Instruction>(V); 587 if (!Inst) 588 return false; 589 590 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 591 if (II) 592 switch (II->getIntrinsicID()) { 593 case Intrinsic::matrix_multiply: 594 case Intrinsic::matrix_transpose: 595 case Intrinsic::matrix_column_major_load: 596 case Intrinsic::matrix_column_major_store: 597 return true; 598 default: 599 return false; 600 } 601 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 602 } 603 604 /// Propagate the shape information of instructions to their users. 605 /// The work list contains instructions for which we can compute the shape, 606 /// either based on the information provided by matrix intrinsics or known 607 /// shapes of operands. 608 SmallVector<Instruction *, 32> 609 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 610 SmallVector<Instruction *, 32> NewWorkList; 611 // Pop an element for which we guaranteed to have at least one of the 612 // operand shapes. Add the shape for this and then add users to the work 613 // list. 614 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 615 while (!WorkList.empty()) { 616 Instruction *Inst = WorkList.pop_back_val(); 617 618 // New entry, set the value and insert operands 619 bool Propagate = false; 620 621 Value *MatrixA; 622 Value *MatrixB; 623 Value *M; 624 Value *N; 625 Value *K; 626 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 627 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 628 m_Value(N), m_Value(K)))) { 629 Propagate = setShapeInfo(Inst, {M, K}); 630 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 631 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 632 // Flip dimensions. 633 Propagate = setShapeInfo(Inst, {N, M}); 634 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 635 m_Value(MatrixA), m_Value(), m_Value(), 636 m_Value(), m_Value(M), m_Value(N)))) { 637 Propagate = setShapeInfo(Inst, {N, M}); 638 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 639 m_Value(), m_Value(), m_Value(), m_Value(M), 640 m_Value(N)))) { 641 Propagate = setShapeInfo(Inst, {M, N}); 642 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 643 auto OpShape = ShapeMap.find(MatrixA); 644 if (OpShape != ShapeMap.end()) 645 setShapeInfo(Inst, OpShape->second); 646 continue; 647 } else if (isUniformShape(Inst)) { 648 // Find the first operand that has a known shape and use that. 649 for (auto &Op : Inst->operands()) { 650 auto OpShape = ShapeMap.find(Op.get()); 651 if (OpShape != ShapeMap.end()) { 652 Propagate |= setShapeInfo(Inst, OpShape->second); 653 break; 654 } 655 } 656 } 657 658 if (Propagate) { 659 NewWorkList.push_back(Inst); 660 for (auto *User : Inst->users()) 661 if (ShapeMap.count(User) == 0) 662 WorkList.push_back(cast<Instruction>(User)); 663 } 664 } 665 666 return NewWorkList; 667 } 668 669 /// Propagate the shape to operands of instructions with shape information. 670 /// \p Worklist contains the instruction for which we already know the shape. 671 SmallVector<Instruction *, 32> 672 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 673 SmallVector<Instruction *, 32> NewWorkList; 674 675 auto pushInstruction = [](Value *V, 676 SmallVectorImpl<Instruction *> &WorkList) { 677 Instruction *I = dyn_cast<Instruction>(V); 678 if (I) 679 WorkList.push_back(I); 680 }; 681 // Pop an element with known shape. Traverse the operands, if their shape 682 // derives from the result shape and is unknown, add it and add them to the 683 // worklist. 684 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 685 while (!WorkList.empty()) { 686 Value *V = WorkList.pop_back_val(); 687 688 size_t BeforeProcessingV = WorkList.size(); 689 if (!isa<Instruction>(V)) 690 continue; 691 692 Value *MatrixA; 693 Value *MatrixB; 694 Value *M; 695 Value *N; 696 Value *K; 697 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 698 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 699 m_Value(N), m_Value(K)))) { 700 if (setShapeInfo(MatrixA, {M, N})) 701 pushInstruction(MatrixA, WorkList); 702 703 if (setShapeInfo(MatrixB, {N, K})) 704 pushInstruction(MatrixB, WorkList); 705 706 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 707 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 708 // Flip dimensions. 709 if (setShapeInfo(MatrixA, {M, N})) 710 pushInstruction(MatrixA, WorkList); 711 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 712 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 713 m_Value(M), m_Value(N)))) { 714 if (setShapeInfo(MatrixA, {M, N})) { 715 pushInstruction(MatrixA, WorkList); 716 } 717 } else if (isa<LoadInst>(V) || 718 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 719 // Nothing to do, no matrix input. 720 } else if (isa<StoreInst>(V)) { 721 // Nothing to do. We forward-propagated to this so we would just 722 // backward propagate to an instruction with an already known shape. 723 } else if (isUniformShape(V)) { 724 // Propagate to all operands. 725 ShapeInfo Shape = ShapeMap[V]; 726 for (Use &U : cast<Instruction>(V)->operands()) { 727 if (setShapeInfo(U.get(), Shape)) 728 pushInstruction(U.get(), WorkList); 729 } 730 } 731 // After we discovered new shape info for new instructions in the 732 // worklist, we use their users as seeds for the next round of forward 733 // propagation. 734 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 735 for (User *U : WorkList[I]->users()) 736 if (isa<Instruction>(U) && V != U) 737 NewWorkList.push_back(cast<Instruction>(U)); 738 } 739 return NewWorkList; 740 } 741 742 /// (Op0 op Op1)^T -> Op0^T op Op1^T 743 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 744 /// them on both sides of \p Operation. 745 Instruction *distributeTransposes( 746 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, 747 MatrixBuilder &Builder, 748 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> 749 Operation) { 750 Value *T0 = Builder.CreateMatrixTranspose( 751 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); 752 // We are being run after shape prop, add shape for newly created 753 // instructions so that we lower them later. 754 setShapeInfo(T0, Shape0.t()); 755 Value *T1 = Builder.CreateMatrixTranspose( 756 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); 757 setShapeInfo(T1, Shape1.t()); 758 return Operation(T0, Shape0.t(), T1, Shape1.t()); 759 } 760 761 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { 762 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 763 // with New. We should only add New it it supportsShapeInfo so we insert 764 // it conditionally instead. 765 auto S = ShapeMap.find(&Old); 766 if (S != ShapeMap.end()) { 767 ShapeMap.erase(S); 768 if (supportsShapeInfo(New)) 769 ShapeMap.insert({New, S->second}); 770 } 771 Old.replaceAllUsesWith(New); 772 } 773 774 /// Sink a top-level transpose inside matmuls and adds. 775 /// This creates and erases instructions as needed, and returns the newly 776 /// created instruction while updating the iterator to avoid invalidation. If 777 /// this returns nullptr, no new instruction was created. 778 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { 779 BasicBlock &BB = *I.getParent(); 780 IRBuilder<> IB(&I); 781 MatrixBuilder Builder(IB); 782 783 Value *TA, *TAMA, *TAMB; 784 ConstantInt *R, *K, *C; 785 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( 786 m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) 787 return nullptr; 788 789 // Transpose of a transpose is a nop 790 Value *TATA; 791 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 792 updateShapeAndReplaceAllUsesWith(I, TATA); 793 eraseFromParentAndMove(&I, II, BB); 794 eraseFromParentAndMove(TA, II, BB); 795 return nullptr; 796 } 797 798 // k^T -> k 799 if (isSplat(TA)) { 800 updateShapeAndReplaceAllUsesWith(I, TA); 801 eraseFromParentAndMove(&I, II, BB); 802 return nullptr; 803 } 804 805 // (A * B)^t -> B^t * A^t 806 // RxK KxC CxK KxR 807 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 808 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 809 m_ConstantInt(K), m_ConstantInt(C)))) { 810 auto NewInst = distributeTransposes( 811 TAMB, {K, C}, TAMA, {R, K}, Builder, 812 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 813 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, 814 Shape0.NumColumns, 815 Shape1.NumColumns, "mmul"); 816 }); 817 updateShapeAndReplaceAllUsesWith(I, NewInst); 818 eraseFromParentAndMove(&I, II, BB); 819 eraseFromParentAndMove(TA, II, BB); 820 return NewInst; 821 } 822 823 // Same as above, but with a mul, which occurs when multiplied 824 // with a scalar. 825 // (A * k)^t -> A^t * k 826 // R x C RxC 827 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && 828 (isSplat(TAMA) || isSplat(TAMB))) { 829 IRBuilder<> LocalBuilder(&I); 830 // We know that the transposed operand is of shape RxC. 831 // An when multiplied with a scalar, the shape is preserved. 832 auto NewInst = distributeTransposes( 833 TAMA, {R, C}, TAMB, {R, C}, Builder, 834 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 835 bool IsFP = I.getType()->isFPOrFPVectorTy(); 836 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") 837 : LocalBuilder.CreateMul(T0, T1, "mmul"); 838 auto *Result = cast<Instruction>(Mul); 839 setShapeInfo(Result, Shape0); 840 return Result; 841 }); 842 updateShapeAndReplaceAllUsesWith(I, NewInst); 843 eraseFromParentAndMove(&I, II, BB); 844 eraseFromParentAndMove(TA, II, BB); 845 return NewInst; 846 } 847 848 // (A + B)^t -> A^t + B^t 849 // RxC RxC CxR CxR 850 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { 851 IRBuilder<> LocalBuilder(&I); 852 auto NewInst = distributeTransposes( 853 TAMA, {R, C}, TAMB, {R, C}, Builder, 854 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 855 bool IsFP = I.getType()->isFPOrFPVectorTy(); 856 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") 857 : LocalBuilder.CreateAdd(T0, T1, "madd"); 858 859 auto *Result = cast<Instruction>(Add); 860 setShapeInfo(Result, Shape0); 861 return Result; 862 }); 863 updateShapeAndReplaceAllUsesWith(I, NewInst); 864 eraseFromParentAndMove(&I, II, BB); 865 eraseFromParentAndMove(TA, II, BB); 866 return NewInst; 867 } 868 869 return nullptr; 870 } 871 872 void liftTranspose(Instruction &I) { 873 // Erase dead Instructions after lifting transposes from binops. 874 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { 875 if (T.use_empty()) 876 T.eraseFromParent(); 877 if (A->use_empty()) 878 cast<Instruction>(A)->eraseFromParent(); 879 if (A != B && B->use_empty()) 880 cast<Instruction>(B)->eraseFromParent(); 881 }; 882 883 Value *A, *B, *AT, *BT; 884 ConstantInt *R, *K, *C; 885 // A^t * B ^t -> (B * A)^t 886 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 887 m_Value(A), m_Value(B), m_ConstantInt(R), 888 m_ConstantInt(K), m_ConstantInt(C))) && 889 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 890 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 891 IRBuilder<> IB(&I); 892 MatrixBuilder Builder(IB); 893 Value *M = Builder.CreateMatrixMultiply( 894 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 895 setShapeInfo(M, {C, R}); 896 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), 897 R->getZExtValue()); 898 updateShapeAndReplaceAllUsesWith(I, NewInst); 899 CleanupBinOp(I, A, B); 900 } 901 // A^t + B ^t -> (A + B)^t 902 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && 903 match(A, m_Intrinsic<Intrinsic::matrix_transpose>( 904 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && 905 match(B, m_Intrinsic<Intrinsic::matrix_transpose>( 906 m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) { 907 IRBuilder<> Builder(&I); 908 Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); 909 setShapeInfo(Add, {C, R}); 910 MatrixBuilder MBuilder(Builder); 911 Instruction *NewInst = MBuilder.CreateMatrixTranspose( 912 Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t"); 913 updateShapeAndReplaceAllUsesWith(I, NewInst); 914 CleanupBinOp(I, A, B); 915 } 916 } 917 918 /// Try moving transposes in order to fold them away or into multiplies. 919 void optimizeTransposes() { 920 // First sink all transposes inside matmuls and adds, hoping that we end up 921 // with NN, NT or TN variants. 922 for (BasicBlock &BB : reverse(Func)) { 923 for (auto II = BB.rbegin(); II != BB.rend();) { 924 Instruction &I = *II; 925 // We may remove II. By default continue on the next/prev instruction. 926 ++II; 927 if (Instruction *NewInst = sinkTranspose(I, II)) 928 II = std::next(BasicBlock::reverse_iterator(NewInst)); 929 } 930 } 931 932 // If we have a TT matmul or a TT add, lift the transpose. We may be able 933 // to fold into consuming multiply or add. 934 for (BasicBlock &BB : Func) { 935 for (Instruction &I : llvm::make_early_inc_range(BB)) { 936 liftTranspose(I); 937 } 938 } 939 } 940 941 bool Visit() { 942 SmallVector<Instruction *, 32> WorkList; 943 944 // Initially only the shape of matrix intrinsics is known. 945 // Initialize the work list with ops carrying shape information. 946 for (BasicBlock &BB : Func) 947 for (Instruction &Inst : BB) { 948 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 949 if (!II) 950 continue; 951 952 switch (II->getIntrinsicID()) { 953 case Intrinsic::matrix_multiply: 954 case Intrinsic::matrix_transpose: 955 case Intrinsic::matrix_column_major_load: 956 case Intrinsic::matrix_column_major_store: 957 WorkList.push_back(&Inst); 958 break; 959 default: 960 break; 961 } 962 } 963 964 // Avoid unnecessary work if there are no matrix intrinsics in the function. 965 if (WorkList.empty()) 966 return false; 967 968 // Propagate shapes until nothing changes any longer. 969 while (!WorkList.empty()) { 970 WorkList = propagateShapeForward(WorkList); 971 WorkList = propagateShapeBackward(WorkList); 972 } 973 974 if (!isMinimal()) { 975 optimizeTransposes(); 976 if (PrintAfterTransposeOpt) { 977 dbgs() << "Dump after matrix transpose optimization:\n"; 978 Func.print(dbgs()); 979 } 980 } 981 982 bool Changed = false; 983 SmallVector<CallInst *, 16> MaybeFusableInsts; 984 SmallVector<Instruction *, 16> MatrixInsts; 985 986 // First, collect all instructions with shape information and candidates for 987 // fusion (currently only matrix multiplies). 988 ReversePostOrderTraversal<Function *> RPOT(&Func); 989 for (auto *BB : RPOT) 990 for (Instruction &I : *BB) { 991 if (ShapeMap.find(&I) == ShapeMap.end()) 992 continue; 993 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 994 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 995 MatrixInsts.push_back(&I); 996 } 997 998 // Second, try to lower any dot products 999 SmallPtrSet<Instruction *, 16> FusedInsts; 1000 for (CallInst *CI : MaybeFusableInsts) 1001 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); 1002 1003 // Third, try to fuse candidates. 1004 for (CallInst *CI : MaybeFusableInsts) 1005 LowerMatrixMultiplyFused(CI, FusedInsts); 1006 1007 Changed = !FusedInsts.empty(); 1008 1009 // Fourth, lower remaining instructions with shape information. 1010 for (Instruction *Inst : MatrixInsts) { 1011 if (FusedInsts.count(Inst)) 1012 continue; 1013 1014 IRBuilder<> Builder(Inst); 1015 1016 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 1017 Changed |= VisitCallInst(CInst); 1018 1019 Value *Op1; 1020 Value *Op2; 1021 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 1022 Changed |= VisitBinaryOperator(BinOp); 1023 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 1024 Changed |= VisitUnaryOperator(UnOp); 1025 if (match(Inst, m_Load(m_Value(Op1)))) 1026 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 1027 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 1028 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 1029 } 1030 1031 if (ORE) { 1032 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 1033 RemarkGen.emitRemarks(); 1034 } 1035 1036 // Delete the instructions backwards, as it has a reduced likelihood of 1037 // having to update as many def-use and use-def chains. 1038 // 1039 // Because we add to ToRemove during fusion we can't guarantee that defs 1040 // are before uses. Change uses to poison temporarily as these should get 1041 // removed as well. 1042 // 1043 // For verification, we keep track of where we changed uses to poison in 1044 // PoisonedInsts and then check that we in fact remove them. 1045 SmallSet<Instruction *, 16> PoisonedInsts; 1046 for (auto *Inst : reverse(ToRemove)) { 1047 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1048 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 1049 PoisonedInsts.insert(Poisoned); 1050 U.set(PoisonValue::get(Inst->getType())); 1051 } 1052 Inst->eraseFromParent(); 1053 PoisonedInsts.erase(Inst); 1054 } 1055 if (!PoisonedInsts.empty()) { 1056 // If we didn't remove all poisoned instructions, it's a hard error. 1057 dbgs() << "Poisoned but present instructions:\n"; 1058 for (auto *I : PoisonedInsts) 1059 dbgs() << *I << "\n"; 1060 llvm_unreachable("Poisoned but instruction not removed"); 1061 } 1062 1063 return Changed; 1064 } 1065 1066 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 1067 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 1068 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 1069 Type *EltPtrType = PointerType::get(EltType, AS); 1070 return Builder.CreatePointerCast(BasePtr, EltPtrType); 1071 } 1072 1073 /// Replace intrinsic calls 1074 bool VisitCallInst(CallInst *Inst) { 1075 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 1076 return false; 1077 1078 switch (Inst->getCalledFunction()->getIntrinsicID()) { 1079 case Intrinsic::matrix_multiply: 1080 LowerMultiply(Inst); 1081 break; 1082 case Intrinsic::matrix_transpose: 1083 LowerTranspose(Inst); 1084 break; 1085 case Intrinsic::matrix_column_major_load: 1086 LowerColumnMajorLoad(Inst); 1087 break; 1088 case Intrinsic::matrix_column_major_store: 1089 LowerColumnMajorStore(Inst); 1090 break; 1091 default: 1092 return false; 1093 } 1094 return true; 1095 } 1096 1097 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 1098 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 1099 /// ConstantInt, reduce the initial alignment based on the byte offset. For 1100 /// non-ConstantInt strides, return the common alignment of the initial 1101 /// alignment and the element size in bytes. 1102 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 1103 MaybeAlign A) const { 1104 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 1105 if (Idx == 0) 1106 return InitialAlign; 1107 1108 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 1109 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 1110 uint64_t StrideInBytes = 1111 ConstStride->getZExtValue() * ElementSizeInBits / 8; 1112 return commonAlignment(InitialAlign, Idx * StrideInBytes); 1113 } 1114 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 1115 } 1116 1117 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 1118 /// vectors. 1119 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 1120 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 1121 auto *VType = cast<VectorType>(Ty); 1122 Type *EltTy = VType->getElementType(); 1123 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 1124 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 1125 MatrixTy Result; 1126 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 1127 Value *GEP = computeVectorAddr( 1128 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 1129 Stride, Shape.getStride(), EltTy, Builder); 1130 Value *Vector = Builder.CreateAlignedLoad( 1131 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 1132 IsVolatile, "col.load"); 1133 1134 Result.addVector(Vector); 1135 } 1136 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 1137 Result.getNumVectors()); 1138 } 1139 1140 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 1141 /// starting at \p MatrixPtr[I][J]. 1142 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 1143 ShapeInfo MatrixShape, Value *I, Value *J, 1144 ShapeInfo ResultShape, Type *EltTy, 1145 IRBuilder<> &Builder) { 1146 1147 Value *Offset = Builder.CreateAdd( 1148 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1149 1150 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1151 Value *EltPtr = 1152 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1153 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1154 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1155 ResultShape.NumColumns); 1156 Type *TilePtrTy = PointerType::get(TileTy, AS); 1157 Value *TilePtr = 1158 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1159 1160 return loadMatrix(TileTy, TilePtr, Align, 1161 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1162 ResultShape, Builder); 1163 } 1164 1165 /// Lower a load instruction with shape information. 1166 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 1167 bool IsVolatile, ShapeInfo Shape) { 1168 IRBuilder<> Builder(Inst); 1169 finalizeLowering(Inst, 1170 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 1171 Shape, Builder), 1172 Builder); 1173 } 1174 1175 /// Lowers llvm.matrix.column.major.load. 1176 /// 1177 /// The intrinsic loads a matrix from memory using a stride between columns. 1178 void LowerColumnMajorLoad(CallInst *Inst) { 1179 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1180 "Intrinsic only supports column-major layout!"); 1181 Value *Ptr = Inst->getArgOperand(0); 1182 Value *Stride = Inst->getArgOperand(1); 1183 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1184 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1185 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1186 } 1187 1188 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1189 /// MatrixPtr[I][J]. 1190 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1191 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1192 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1193 Value *Offset = Builder.CreateAdd( 1194 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1195 1196 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1197 Value *EltPtr = 1198 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1199 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1200 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1201 StoreVal.getNumColumns()); 1202 Type *TilePtrTy = PointerType::get(TileTy, AS); 1203 Value *TilePtr = 1204 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1205 1206 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 1207 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1208 } 1209 1210 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1211 /// vectors. 1212 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1213 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1214 IRBuilder<> &Builder) { 1215 auto VType = cast<VectorType>(Ty); 1216 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 1217 for (auto Vec : enumerate(StoreVal.vectors())) { 1218 Value *GEP = computeVectorAddr( 1219 EltPtr, 1220 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1221 Vec.index()), 1222 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1223 Builder.CreateAlignedStore(Vec.value(), GEP, 1224 getAlignForIndex(Vec.index(), Stride, 1225 VType->getElementType(), 1226 MAlign), 1227 IsVolatile); 1228 } 1229 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1230 StoreVal.getNumVectors()); 1231 } 1232 1233 /// Lower a store instruction with shape information. 1234 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1235 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1236 IRBuilder<> Builder(Inst); 1237 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1238 finalizeLowering(Inst, 1239 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1240 IsVolatile, Builder), 1241 Builder); 1242 } 1243 1244 /// Lowers llvm.matrix.column.major.store. 1245 /// 1246 /// The intrinsic store a matrix back memory using a stride between columns. 1247 void LowerColumnMajorStore(CallInst *Inst) { 1248 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1249 "Intrinsic only supports column-major layout!"); 1250 Value *Matrix = Inst->getArgOperand(0); 1251 Value *Ptr = Inst->getArgOperand(1); 1252 Value *Stride = Inst->getArgOperand(2); 1253 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1254 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1255 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1256 } 1257 1258 // Set elements I..I+NumElts-1 to Block 1259 Value *insertVector(Value *Col, unsigned I, Value *Block, 1260 IRBuilder<> &Builder) { 1261 1262 // First, bring Block to the same size as Col 1263 unsigned BlockNumElts = 1264 cast<FixedVectorType>(Block->getType())->getNumElements(); 1265 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1266 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1267 1268 Block = Builder.CreateShuffleVector( 1269 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1270 1271 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1272 // 8, 4, 5, 6 1273 SmallVector<int, 16> Mask; 1274 unsigned i; 1275 for (i = 0; i < I; i++) 1276 Mask.push_back(i); 1277 1278 unsigned VecNumElts = 1279 cast<FixedVectorType>(Col->getType())->getNumElements(); 1280 for (; i < I + BlockNumElts; i++) 1281 Mask.push_back(i - I + VecNumElts); 1282 1283 for (; i < VecNumElts; i++) 1284 Mask.push_back(i); 1285 1286 return Builder.CreateShuffleVector(Col, Block, Mask); 1287 } 1288 1289 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1290 IRBuilder<> &Builder, bool AllowContraction, 1291 unsigned &NumComputeOps) { 1292 NumComputeOps += getNumOps(A->getType()); 1293 if (!Sum) 1294 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1295 1296 if (UseFPOp) { 1297 if (AllowContraction) { 1298 // Use fmuladd for floating point operations and let the backend decide 1299 // if that's profitable. 1300 Function *FMulAdd = Intrinsic::getDeclaration( 1301 Func.getParent(), Intrinsic::fmuladd, A->getType()); 1302 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1303 } 1304 NumComputeOps += getNumOps(A->getType()); 1305 Value *Mul = Builder.CreateFMul(A, B); 1306 return Builder.CreateFAdd(Sum, Mul); 1307 } 1308 1309 NumComputeOps += getNumOps(A->getType()); 1310 Value *Mul = Builder.CreateMul(A, B); 1311 return Builder.CreateAdd(Sum, Mul); 1312 } 1313 1314 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1315 /// users with shape information, there's nothing to do: they will use the 1316 /// cached value when they are lowered. For other users, \p Matrix is 1317 /// flattened and the uses are updated to use it. Also marks \p Inst for 1318 /// deletion. 1319 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1320 IRBuilder<> &Builder) { 1321 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1322 (void)inserted; 1323 assert(inserted.second && "multiple matrix lowering mapping"); 1324 1325 ToRemove.push_back(Inst); 1326 Value *Flattened = nullptr; 1327 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1328 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1329 if (!Flattened) 1330 Flattened = Matrix.embedInVector(Builder); 1331 U.set(Flattened); 1332 } 1333 } 1334 } 1335 1336 /// Special case for MatMul lowering. Prevents scalar loads of row-major 1337 /// vectors Lowers to vector reduction add instead of sequential add if 1338 /// reassocation is enabled. 1339 void lowerDotProduct(CallInst *MatMul, 1340 SmallPtrSet<Instruction *, 16> &FusedInsts, 1341 FastMathFlags FMF) { 1342 if (FusedInsts.contains(MatMul) || 1343 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1344 return; 1345 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1346 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1347 1348 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product 1349 return; 1350 1351 Value *LHS = MatMul->getArgOperand(0); 1352 Value *RHS = MatMul->getArgOperand(1); 1353 1354 Type *ElementType = cast<VectorType>(LHS->getType())->getElementType(); 1355 bool IsIntVec = ElementType->isIntegerTy(); 1356 1357 // Floating point reductions require reassocation. 1358 if (!IsIntVec && !FMF.allowReassoc()) 1359 return; 1360 1361 auto CanBeFlattened = [this](Value *Op) { 1362 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) 1363 return true; 1364 return match( 1365 Op, m_OneUse(m_CombineOr( 1366 m_Load(m_Value()), 1367 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), 1368 m_Intrinsic<Intrinsic::matrix_column_major_load>( 1369 m_Value(), m_SpecificInt(1)))))); 1370 }; 1371 // Returns the cost benefit of using \p Op with the dot product lowering. If 1372 // the returned cost is < 0, the argument is cheaper to use in the 1373 // dot-product lowering. 1374 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { 1375 if (!isa<Instruction>(Op)) 1376 return InstructionCost(0); 1377 1378 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); 1379 Type *EltTy = VecTy->getElementType(); 1380 1381 if (!CanBeFlattened(Op)) { 1382 InstructionCost EmbedCost(0); 1383 // Roughly estimate the cost for embedding the columns into a vector. 1384 for (unsigned I = 1; I < N; ++I) 1385 EmbedCost -= 1386 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1387 std::nullopt, TTI::TCK_RecipThroughput); 1388 return EmbedCost; 1389 } 1390 1391 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 1392 InstructionCost OriginalCost = 1393 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), 1394 EltTy) * 1395 N; 1396 InstructionCost NewCost = TTI.getArithmeticInstrCost( 1397 cast<Instruction>(Op)->getOpcode(), VecTy); 1398 return NewCost - OriginalCost; 1399 } 1400 1401 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { 1402 // The transpose can be skipped for the dot product lowering, roughly 1403 // estimate the savings as the cost of embedding the columns in a 1404 // vector. 1405 InstructionCost EmbedCost(0); 1406 for (unsigned I = 1; I < N; ++I) 1407 EmbedCost += 1408 TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1409 std::nullopt, TTI::TCK_RecipThroughput); 1410 return EmbedCost; 1411 } 1412 1413 // Costs for loads. 1414 if (N == 1) 1415 return InstructionCost(0); 1416 1417 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - 1418 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); 1419 }; 1420 auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); 1421 1422 // We compare the costs of a vector.reduce.add to sequential add. 1423 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; 1424 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; 1425 InstructionCost ReductionCost = 1426 TTI.getArithmeticReductionCost( 1427 AddOpCode, cast<VectorType>(LHS->getType()), 1428 IsIntVec ? std::nullopt : std::optional(FMF)) + 1429 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); 1430 InstructionCost SequentialAddCost = 1431 TTI.getArithmeticInstrCost(AddOpCode, ElementType) * 1432 (LShape.NumColumns - 1) + 1433 TTI.getArithmeticInstrCost(MulOpCode, ElementType) * 1434 (LShape.NumColumns); 1435 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) 1436 return; 1437 1438 FusedInsts.insert(MatMul); 1439 IRBuilder<> Builder(MatMul); 1440 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, 1441 this](Value *Op) -> Value * { 1442 // Matmul must be the only user of loads because we don't use LowerLoad 1443 // for row vectors (LowerLoad results in scalar loads and shufflevectors 1444 // instead of single vector load). 1445 if (!CanBeFlattened(Op)) 1446 return Op; 1447 1448 if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { 1449 ShapeMap[Op] = ShapeMap[Op].t(); 1450 return Op; 1451 } 1452 1453 FusedInsts.insert(cast<Instruction>(Op)); 1454 // If vector uses the builtin load, lower to a LoadInst 1455 Value *Arg; 1456 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( 1457 m_Value(Arg)))) { 1458 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); 1459 Op->replaceAllUsesWith(NewLoad); 1460 cast<Instruction>(Op)->eraseFromParent(); 1461 return NewLoad; 1462 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( 1463 m_Value(Arg)))) { 1464 ToRemove.push_back(cast<Instruction>(Op)); 1465 return Arg; 1466 } 1467 1468 return Op; 1469 }; 1470 LHS = FlattenArg(LHS); 1471 1472 // Insert mul/fmul and llvm.vector.reduce.fadd 1473 Value *Mul = 1474 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); 1475 1476 Value *Result; 1477 if (IsIntVec) 1478 Result = Builder.CreateAddReduce(Mul); 1479 else { 1480 Result = Builder.CreateFAddReduce( 1481 ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(), 1482 0.0), 1483 Mul); 1484 cast<Instruction>(Result)->setFastMathFlags(FMF); 1485 } 1486 1487 // pack scalar back into a matrix and then replace matmul inst 1488 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), 1489 Result, uint64_t(0)); 1490 MatMul->replaceAllUsesWith(Result); 1491 FusedInsts.insert(MatMul); 1492 ToRemove.push_back(MatMul); 1493 } 1494 1495 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1496 /// addition. 1497 /// 1498 /// We can fold a transpose into the operand that is used to extract scalars. 1499 /// This is the first operands with row-major and the second with 1500 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1501 /// operand is transposed. 1502 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1503 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1504 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1505 const unsigned VF = std::max<unsigned>( 1506 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1507 .getFixedValue() / 1508 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1509 1U); 1510 unsigned R = Result.getNumRows(); 1511 unsigned C = Result.getNumColumns(); 1512 unsigned M = A.getNumColumns(); 1513 1514 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1515 assert(A.isColumnMajor() == B.isColumnMajor() && 1516 Result.isColumnMajor() == A.isColumnMajor() && 1517 "operands must agree on matrix layout"); 1518 unsigned NumComputeOps = 0; 1519 1520 Builder.setFastMathFlags(FMF); 1521 1522 if (A.isColumnMajor()) { 1523 // Multiply columns from the first operand with scalars from the second 1524 // operand. Then move along the K axes and accumulate the columns. With 1525 // this the adds can be vectorized without reassociation. 1526 for (unsigned J = 0; J < C; ++J) { 1527 unsigned BlockSize = VF; 1528 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1529 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1530 1531 for (unsigned I = 0; I < R; I += BlockSize) { 1532 // Gradually lower the vectorization factor to cover the remainder. 1533 while (I + BlockSize > R) 1534 BlockSize /= 2; 1535 1536 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1537 : nullptr; 1538 for (unsigned K = 0; K < M; ++K) { 1539 Value *L = A.extractVector(I, K, BlockSize, Builder); 1540 Value *RH = Builder.CreateExtractElement( 1541 B.getColumn(IsScalarMatrixTransposed ? K : J), 1542 IsScalarMatrixTransposed ? J : K); 1543 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1544 Sum = 1545 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1546 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1547 } 1548 Result.setVector(J, 1549 insertVector(Result.getVector(J), I, Sum, Builder)); 1550 } 1551 } 1552 } else { 1553 // Multiply rows from the second operand with scalars from the first 1554 // operand. Then move along the K axes and accumulate the rows. With this 1555 // the adds can be vectorized without reassociation. 1556 for (unsigned I = 0; I < R; ++I) { 1557 unsigned BlockSize = VF; 1558 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1559 for (unsigned J = 0; J < C; J += BlockSize) { 1560 // Gradually lower the vectorization factor to cover the remainder. 1561 while (J + BlockSize > C) 1562 BlockSize /= 2; 1563 1564 Value *Sum = nullptr; 1565 for (unsigned K = 0; K < M; ++K) { 1566 Value *R = B.extractVector(K, J, BlockSize, Builder); 1567 Value *LH = Builder.CreateExtractElement( 1568 A.getVector(IsScalarMatrixTransposed ? K : I), 1569 IsScalarMatrixTransposed ? I : K); 1570 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1571 Sum = 1572 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1573 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1574 } 1575 Result.setVector(I, 1576 insertVector(Result.getVector(I), J, Sum, Builder)); 1577 } 1578 } 1579 } 1580 Result.addNumComputeOps(NumComputeOps); 1581 } 1582 1583 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1584 /// copying it to a new location. This new or otherwise the original location 1585 /// is returned. 1586 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1587 CallInst *MatMul) { 1588 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1589 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1590 1591 // If we can statically determine noalias we're good. 1592 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1593 return Load->getPointerOperand(); 1594 1595 // Create code to check if the memory locations of the Load and Store 1596 // overlap and if they do, copy Load's operand to a new buffer. 1597 1598 // First, create new blocks for 2n part of the check and the copy. 1599 BasicBlock *Check0 = MatMul->getParent(); 1600 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1601 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1602 // as we adjust Check0 and Check1's branches. 1603 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1604 for (BasicBlock *Succ : successors(Check0)) 1605 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1606 1607 BasicBlock *Check1 = 1608 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1609 nullptr, "alias_cont"); 1610 BasicBlock *Copy = 1611 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1612 nullptr, "copy"); 1613 BasicBlock *Fusion = 1614 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1615 nullptr, "no_alias"); 1616 1617 // Check if the loaded memory location begins before the end of the store 1618 // location. If the condition holds, they might overlap, otherwise they are 1619 // guaranteed to not overlap. 1620 IRBuilder<> Builder(MatMul); 1621 Check0->getTerminator()->eraseFromParent(); 1622 Builder.SetInsertPoint(Check0); 1623 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1624 Value *StoreBegin = Builder.CreatePtrToInt( 1625 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1626 Value *StoreEnd = Builder.CreateAdd( 1627 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1628 "store.end", true, true); 1629 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1630 IntPtrTy, "load.begin"); 1631 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1632 Fusion); 1633 1634 // Check if the store begins before the end of the load location. If the 1635 // condition holds, they alias, otherwise they are guaranteed to not 1636 // overlap. 1637 Check1->getTerminator()->eraseFromParent(); 1638 Builder.SetInsertPoint(Check1, Check1->begin()); 1639 Value *LoadEnd = Builder.CreateAdd( 1640 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1641 "load.end", true, true); 1642 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1643 Fusion); 1644 1645 // Copy load operand to new alloca. 1646 Builder.SetInsertPoint(Copy, Copy->begin()); 1647 auto *VT = cast<FixedVectorType>(Load->getType()); 1648 // Use an array type for the alloca, to avoid potentially huge alignment 1649 // requirements for large vector types. 1650 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1651 AllocaInst *Alloca = 1652 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1653 1654 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), 1655 Load->getAlign(), LoadLoc.Size.getValue()); 1656 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1657 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1658 PHI->addIncoming(Load->getPointerOperand(), Check0); 1659 PHI->addIncoming(Load->getPointerOperand(), Check1); 1660 PHI->addIncoming(Alloca, Copy); 1661 1662 // Adjust DT. 1663 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1664 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1665 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1666 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1667 DT->applyUpdates(DTUpdates); 1668 return PHI; 1669 } 1670 1671 bool isFusionProfitable(CallInst *MatMul) { 1672 if (ForceFusion) 1673 return true; 1674 1675 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1676 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1677 1678 const unsigned R = LShape.NumRows; 1679 const unsigned C = RShape.NumColumns; 1680 const unsigned M = LShape.NumColumns; 1681 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1682 1683 const unsigned VF = std::max<unsigned>( 1684 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1685 .getFixedValue() / 1686 EltType->getPrimitiveSizeInBits().getFixedValue(), 1687 1U); 1688 1689 // Cost model for tiling 1690 // 1691 // For tiling to be beneficial, we need reuse either along the R or 1692 // the C axis. We vectorize along the R axis so that means at least 1693 // 3 elements. 1694 // TODO: Also consider cost of copying if operands alias. 1695 if (R <= VF && C == 1) 1696 return false; 1697 // Then we need enough elements to exceed the number of vector 1698 // registers we have. Note that this is an oversimplification since 1699 // fusing also takes some extra loads which may exceed the number of 1700 // reloads necessary. 1701 unsigned Op0Regs = (R + VF - 1) / VF * M; 1702 unsigned Op1Regs = (M + VF - 1) / VF * C; 1703 return Op0Regs + Op1Regs > 1704 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1705 } 1706 1707 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1708 MatrixTy Res; 1709 auto *ColumType = FixedVectorType::get(EltType, R); 1710 for (unsigned I = 0; I < C; ++I) 1711 Res.addVector(ConstantAggregateZero::get(ColumType)); 1712 return Res; 1713 } 1714 1715 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1716 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1717 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1718 1719 // Create the main tiling loop nest. 1720 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1721 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1722 Instruction *InsertI = cast<Instruction>(MatMul); 1723 BasicBlock *Start = InsertI->getParent(); 1724 BasicBlock *End = 1725 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1726 IRBuilder<> Builder(MatMul); 1727 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1728 1729 Type *TileVecTy = 1730 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1731 MatrixTy TileResult; 1732 // Insert in the inner loop header. 1733 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1734 // Create PHI nodes for the result columns to accumulate across iterations. 1735 SmallVector<PHINode *, 4> ColumnPhis; 1736 for (unsigned I = 0; I < TileSize; I++) { 1737 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1738 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1739 TI.RowLoop.Header->getSingleSuccessor()); 1740 TileResult.addVector(Phi); 1741 ColumnPhis.push_back(Phi); 1742 } 1743 1744 // Insert in the inner loop body, which computes 1745 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1746 Builder.SetInsertPoint(InnerBody->getTerminator()); 1747 // Load tiles of the operands. 1748 MatrixTy A = 1749 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1750 {TileSize, TileSize}, EltType, Builder); 1751 MatrixTy B = 1752 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1753 {TileSize, TileSize}, EltType, Builder); 1754 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1755 getFastMathFlags(MatMul)); 1756 // Store result after the inner loop is done. 1757 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1758 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1759 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1760 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1761 1762 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1763 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1764 1765 // Force unrolling of a few iterations of the inner loop, to make sure there 1766 // is enough work per iteration. 1767 // FIXME: The unroller should make this decision directly instead, but 1768 // currently the cost-model is not up to the task. 1769 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1770 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1771 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1772 } 1773 1774 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1775 StoreInst *Store, 1776 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1777 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1778 "Tiling only supported for column-major matrixes at the moment!"); 1779 if (!isFusionProfitable(MatMul)) 1780 return; 1781 1782 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1783 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1784 1785 const unsigned R = LShape.NumRows; 1786 const unsigned C = RShape.NumColumns; 1787 const unsigned M = LShape.NumColumns; 1788 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1789 1790 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1791 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1792 Value *CPtr = Store->getPointerOperand(); 1793 1794 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1795 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1796 else { 1797 IRBuilder<> Builder(Store); 1798 for (unsigned J = 0; J < C; J += TileSize) 1799 for (unsigned I = 0; I < R; I += TileSize) { 1800 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1801 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1802 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1803 1804 for (unsigned K = 0; K < M; K += TileSize) { 1805 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1806 MatrixTy A = 1807 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1808 LShape, Builder.getInt64(I), Builder.getInt64(K), 1809 {TileR, TileM}, EltType, Builder); 1810 MatrixTy B = 1811 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1812 RShape, Builder.getInt64(K), Builder.getInt64(J), 1813 {TileM, TileC}, EltType, Builder); 1814 emitMatrixMultiply(Res, A, B, Builder, true, false, 1815 getFastMathFlags(MatMul)); 1816 } 1817 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1818 Builder.getInt64(I), Builder.getInt64(J), EltType, 1819 Builder); 1820 } 1821 } 1822 1823 // Mark eliminated instructions as fused and remove them. 1824 FusedInsts.insert(Store); 1825 FusedInsts.insert(MatMul); 1826 Store->eraseFromParent(); 1827 MatMul->eraseFromParent(); 1828 if (LoadOp0->hasNUses(0)) { 1829 FusedInsts.insert(LoadOp0); 1830 LoadOp0->eraseFromParent(); 1831 } 1832 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 1833 FusedInsts.insert(LoadOp1); 1834 LoadOp1->eraseFromParent(); 1835 } 1836 } 1837 1838 /// Try to lower matrix multiply chains by fusing operations. 1839 /// 1840 /// Call finalizeLowering on lowered instructions. Instructions that are 1841 /// completely eliminated by fusion are added to \p FusedInsts. 1842 void LowerMatrixMultiplyFused(CallInst *MatMul, 1843 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1844 if (!FuseMatrix || !DT) 1845 return; 1846 1847 assert(AA && LI && "Analyses should be available"); 1848 1849 Value *A = MatMul->getArgOperand(0); 1850 Value *B = MatMul->getArgOperand(1); 1851 1852 // We can fold the transpose into the operand that is used to fetch scalars. 1853 Value *T; 1854 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1855 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1856 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1857 IRBuilder<> Builder(MatMul); 1858 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1859 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1860 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1861 const unsigned R = LShape.NumRows; 1862 const unsigned M = LShape.NumColumns; 1863 const unsigned C = RShape.NumColumns; 1864 1865 MatrixTy MA; 1866 MatrixTy MB; 1867 1868 Value *Transpose; 1869 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1870 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1871 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1872 Transpose = B; 1873 } else { 1874 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1875 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1876 Transpose = A; 1877 } 1878 1879 // Initialize the output 1880 MatrixTy Result(R, C, EltType); 1881 1882 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1883 getFastMathFlags(MatMul)); 1884 1885 FusedInsts.insert(MatMul); 1886 if (Transpose->hasOneUse()) { 1887 FusedInsts.insert(cast<Instruction>(Transpose)); 1888 ToRemove.push_back(cast<Instruction>(Transpose)); 1889 // TODO: add a fake entry for the folded instruction so that this is 1890 // included in the expression in the remark. 1891 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1892 } 1893 finalizeLowering(MatMul, Result, Builder); 1894 return; 1895 } 1896 1897 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1898 return; 1899 1900 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1901 // since the single store user will be lowered as part of this. 1902 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1903 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1904 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1905 if (LoadOp0 && LoadOp1 && Store) { 1906 // The store address must dominate the MatMul instruction, otherwise 1907 // we create invalid IR. 1908 SetVector<Value *> WorkList; 1909 WorkList.insert(Store->getOperand(1)); 1910 SmallVector<Instruction *> ToHoist; 1911 for (unsigned I = 0; I != WorkList.size(); ++I) { 1912 Value *Current = WorkList[I]; 1913 auto *CurrI = dyn_cast<Instruction>(Current); 1914 if (!CurrI) 1915 continue; 1916 if (isa<PHINode>(CurrI)) 1917 return; 1918 if (DT->dominates(CurrI, MatMul)) 1919 continue; 1920 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1921 return; 1922 ToHoist.push_back(CurrI); 1923 WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1924 } 1925 1926 sort(ToHoist, [this](Instruction *A, Instruction *B) { 1927 return DT->dominates(A, B); 1928 }); 1929 for (Instruction *I : ToHoist) 1930 I->moveBefore(MatMul); 1931 1932 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1933 return; 1934 } 1935 } 1936 1937 /// Lowers llvm.matrix.multiply. 1938 void LowerMultiply(CallInst *MatMul) { 1939 IRBuilder<> Builder(MatMul); 1940 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1941 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1942 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1943 1944 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1945 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1946 assert(Lhs.getElementType() == Rhs.getElementType() && 1947 "Matrix multiply argument element types do not match."); 1948 1949 const unsigned R = LShape.NumRows; 1950 const unsigned C = RShape.NumColumns; 1951 assert(LShape.NumColumns == RShape.NumRows); 1952 1953 // Initialize the output 1954 MatrixTy Result(R, C, EltType); 1955 assert(Lhs.getElementType() == Result.getElementType() && 1956 "Matrix multiply result element type does not match arguments."); 1957 1958 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1959 getFastMathFlags(MatMul)); 1960 finalizeLowering(MatMul, Result, Builder); 1961 } 1962 1963 /// Lowers llvm.matrix.transpose. 1964 void LowerTranspose(CallInst *Inst) { 1965 MatrixTy Result; 1966 IRBuilder<> Builder(Inst); 1967 Value *InputVal = Inst->getArgOperand(0); 1968 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1969 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1970 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1971 1972 const unsigned NewNumVecs = 1973 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1974 const unsigned NewNumElts = 1975 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1976 1977 for (unsigned I = 0; I < NewNumVecs; ++I) { 1978 // Build a single result vector. First initialize it. 1979 Value *ResultVector = PoisonValue::get( 1980 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1981 // Go through the old elements and insert it into the resulting vector. 1982 for (auto J : enumerate(InputMatrix.vectors())) { 1983 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1984 // Row and column indices are transposed. 1985 ResultVector = 1986 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1987 } 1988 Result.addVector(ResultVector); 1989 } 1990 1991 // TODO: Improve estimate of operations needed for transposes. Currently we 1992 // just count the insertelement/extractelement instructions, but do not 1993 // account for later simplifications/combines. 1994 finalizeLowering( 1995 Inst, 1996 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1997 .addNumExposedTransposes(1), 1998 Builder); 1999 } 2000 2001 /// Lower load instructions, if shape information is available. 2002 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 2003 auto I = ShapeMap.find(Inst); 2004 if (I == ShapeMap.end()) 2005 return false; 2006 2007 LowerLoad(Inst, Ptr, Inst->getAlign(), 2008 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 2009 I->second); 2010 return true; 2011 } 2012 2013 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 2014 IRBuilder<> &Builder) { 2015 auto I = ShapeMap.find(StoredVal); 2016 if (I == ShapeMap.end()) 2017 return false; 2018 2019 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 2020 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 2021 I->second); 2022 return true; 2023 } 2024 2025 /// Lower binary operators, if shape information is available. 2026 bool VisitBinaryOperator(BinaryOperator *Inst) { 2027 auto I = ShapeMap.find(Inst); 2028 if (I == ShapeMap.end()) 2029 return false; 2030 2031 Value *Lhs = Inst->getOperand(0); 2032 Value *Rhs = Inst->getOperand(1); 2033 2034 IRBuilder<> Builder(Inst); 2035 ShapeInfo &Shape = I->second; 2036 2037 MatrixTy Result; 2038 MatrixTy A = getMatrix(Lhs, Shape, Builder); 2039 MatrixTy B = getMatrix(Rhs, Shape, Builder); 2040 assert(A.isColumnMajor() == B.isColumnMajor() && 2041 Result.isColumnMajor() == A.isColumnMajor() && 2042 "operands must agree on matrix layout"); 2043 2044 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2045 2046 // Helper to perform binary op on vectors. 2047 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 2048 switch (Inst->getOpcode()) { 2049 case Instruction::Add: 2050 return Builder.CreateAdd(LHS, RHS); 2051 case Instruction::Mul: 2052 return Builder.CreateMul(LHS, RHS); 2053 case Instruction::Sub: 2054 return Builder.CreateSub(LHS, RHS); 2055 case Instruction::FAdd: 2056 return Builder.CreateFAdd(LHS, RHS); 2057 case Instruction::FMul: 2058 return Builder.CreateFMul(LHS, RHS); 2059 case Instruction::FSub: 2060 return Builder.CreateFSub(LHS, RHS); 2061 default: 2062 llvm_unreachable("Unsupported binary operator for matrix"); 2063 } 2064 }; 2065 2066 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2067 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 2068 2069 finalizeLowering(Inst, 2070 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2071 Result.getNumVectors()), 2072 Builder); 2073 return true; 2074 } 2075 2076 /// Lower unary operators, if shape information is available. 2077 bool VisitUnaryOperator(UnaryOperator *Inst) { 2078 auto I = ShapeMap.find(Inst); 2079 if (I == ShapeMap.end()) 2080 return false; 2081 2082 Value *Op = Inst->getOperand(0); 2083 2084 IRBuilder<> Builder(Inst); 2085 ShapeInfo &Shape = I->second; 2086 2087 MatrixTy Result; 2088 MatrixTy M = getMatrix(Op, Shape, Builder); 2089 2090 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2091 2092 // Helper to perform unary op on vectors. 2093 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 2094 switch (Inst->getOpcode()) { 2095 case Instruction::FNeg: 2096 return Builder.CreateFNeg(Op); 2097 default: 2098 llvm_unreachable("Unsupported unary operator for matrix"); 2099 } 2100 }; 2101 2102 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 2103 Result.addVector(BuildVectorOp(M.getVector(I))); 2104 2105 finalizeLowering(Inst, 2106 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2107 Result.getNumVectors()), 2108 Builder); 2109 return true; 2110 } 2111 2112 /// Helper to linearize a matrix expression tree into a string. Currently 2113 /// matrix expressions are linarized by starting at an expression leaf and 2114 /// linearizing bottom up. 2115 struct ExprLinearizer { 2116 unsigned LengthToBreak = 100; 2117 std::string Str; 2118 raw_string_ostream Stream; 2119 unsigned LineLength = 0; 2120 const DataLayout &DL; 2121 2122 /// Mapping from instructions to matrixes. It is used to identify 2123 /// matrix instructions. 2124 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2125 2126 /// Mapping from values to the leaves of all expressions that the value is 2127 /// part of. 2128 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 2129 2130 /// Set of matrix expressions in the scope of a given DISubprogram. 2131 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 2132 2133 /// Leaf node of the expression to linearize. 2134 Value *Leaf; 2135 2136 /// Used to keep track of sub-expressions that get reused while linearizing 2137 /// the expression. Re-used sub-expressions are marked as (reused). 2138 SmallPtrSet<Value *, 8> ReusedExprs; 2139 2140 ExprLinearizer(const DataLayout &DL, 2141 const MapVector<Value *, MatrixTy> &Inst2Matrix, 2142 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2143 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2144 Value *Leaf) 2145 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 2146 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 2147 2148 void indent(unsigned N) { 2149 LineLength += N; 2150 for (unsigned i = 0; i < N; i++) 2151 Stream << " "; 2152 } 2153 2154 void lineBreak() { 2155 Stream << "\n"; 2156 LineLength = 0; 2157 } 2158 2159 void maybeIndent(unsigned Indent) { 2160 if (LineLength >= LengthToBreak) 2161 lineBreak(); 2162 2163 if (LineLength == 0) 2164 indent(Indent); 2165 } 2166 2167 void write(StringRef S) { 2168 LineLength += S.size(); 2169 Stream << S; 2170 } 2171 2172 Value *getUnderlyingObjectThroughLoads(Value *V) { 2173 if (Value *Ptr = getPointerOperand(V)) 2174 return getUnderlyingObjectThroughLoads(Ptr); 2175 else if (V->getType()->isPointerTy()) 2176 return getUnderlyingObject(V); 2177 return V; 2178 } 2179 2180 /// Returns true if \p V is a matrix value in the given subprogram. 2181 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 2182 2183 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 2184 /// \p SS. 2185 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 2186 auto M = Inst2Matrix.find(V); 2187 if (M == Inst2Matrix.end()) 2188 SS << "unknown"; 2189 else { 2190 SS << M->second.getNumRows(); 2191 SS << "x"; 2192 SS << M->second.getNumColumns(); 2193 } 2194 } 2195 2196 /// Write the called function name. Handles calls to llvm.matrix.* 2197 /// specially: we write the name, followed by the dimensions of the input 2198 /// matrixes, followed by the scalar type name. 2199 void writeFnName(CallInst *CI) { 2200 if (!CI->getCalledFunction()) 2201 write("<no called fn>"); 2202 else { 2203 StringRef Name = CI->getCalledFunction()->getName(); 2204 if (!Name.startswith("llvm.matrix")) { 2205 write(Name); 2206 return; 2207 } 2208 auto *II = cast<IntrinsicInst>(CI); 2209 write(Intrinsic::getBaseName(II->getIntrinsicID()) 2210 .drop_front(StringRef("llvm.matrix.").size())); 2211 write("."); 2212 std::string Tmp; 2213 raw_string_ostream SS(Tmp); 2214 2215 switch (II->getIntrinsicID()) { 2216 case Intrinsic::matrix_multiply: 2217 prettyPrintMatrixType(II->getOperand(0), SS); 2218 SS << "."; 2219 prettyPrintMatrixType(II->getOperand(1), SS); 2220 SS << "." << *II->getType()->getScalarType(); 2221 break; 2222 case Intrinsic::matrix_transpose: 2223 prettyPrintMatrixType(II->getOperand(0), SS); 2224 SS << "." << *II->getType()->getScalarType(); 2225 break; 2226 case Intrinsic::matrix_column_major_load: 2227 prettyPrintMatrixType(II, SS); 2228 SS << "." << *II->getType()->getScalarType(); 2229 break; 2230 case Intrinsic::matrix_column_major_store: 2231 prettyPrintMatrixType(II->getOperand(0), SS); 2232 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 2233 break; 2234 default: 2235 llvm_unreachable("Unhandled case"); 2236 } 2237 SS.flush(); 2238 write(Tmp); 2239 } 2240 } 2241 2242 unsigned getNumShapeArgs(CallInst *CI) const { 2243 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 2244 switch (II->getIntrinsicID()) { 2245 case Intrinsic::matrix_multiply: 2246 return 3; 2247 case Intrinsic::matrix_transpose: 2248 return 2; 2249 case Intrinsic::matrix_column_major_load: 2250 case Intrinsic::matrix_column_major_store: 2251 return 3; 2252 default: 2253 return 0; 2254 } 2255 } 2256 return 0; 2257 } 2258 2259 /// Special printing for values: for pointers, we print if they refer to an 2260 /// (function) external address or a stack address, for other values we 2261 /// either print the constant or "scalar"/"matrix" for other values. 2262 void write(Value *V) { 2263 V = getUnderlyingObjectThroughLoads(V); 2264 if (V->getType()->isPointerTy()) { 2265 if (isa<AllocaInst>(V)) { 2266 Stream << "stack addr"; 2267 LineLength += StringRef("stack addr").size(); 2268 } else { 2269 Stream << "addr"; 2270 LineLength += StringRef("addr").size(); 2271 } 2272 if (!V->getName().empty()) { 2273 Stream << " %" << V->getName() << ""; 2274 LineLength += V->getName().size() + 2; 2275 } 2276 return; 2277 } 2278 2279 std::string Tmp; 2280 raw_string_ostream TmpStream(Tmp); 2281 2282 if (auto *CI = dyn_cast<ConstantInt>(V)) 2283 TmpStream << CI->getValue(); 2284 else if (isa<Constant>(V)) 2285 TmpStream << "constant"; 2286 else { 2287 if (isMatrix(V)) 2288 TmpStream << "matrix"; 2289 else 2290 TmpStream << "scalar"; 2291 } 2292 TmpStream.flush(); 2293 Tmp = std::string(StringRef(Tmp).trim()); 2294 LineLength += Tmp.size(); 2295 Stream << Tmp; 2296 } 2297 2298 /// Linearize expression \p Expr starting at an indentation of \p Indent. 2299 /// Expressions that are re-used multiple times are prefixed with (reused) 2300 /// at the re-used root instruction. 2301 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 2302 bool ParentShared) { 2303 auto *I = cast<Instruction>(Expr); 2304 maybeIndent(Indent); 2305 SmallVector<Value *, 8> Ops; 2306 2307 // Is Expr shared with other expression leaves? 2308 bool ExprShared = false; 2309 2310 // Deal with shared subtrees. Mark them as shared, if required. 2311 if (!ParentShared) { 2312 auto SI = Shared.find(Expr); 2313 assert(SI != Shared.end() && SI->second.count(Leaf)); 2314 2315 for (Value *S : SI->second) { 2316 if (S == Leaf) 2317 continue; 2318 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2319 write("shared with remark at line " + std::to_string(DL.getLine()) + 2320 " column " + std::to_string(DL.getCol()) + " ("); 2321 } 2322 ExprShared = SI->second.size() > 1; 2323 } 2324 2325 bool Reused = !ReusedExprs.insert(Expr).second; 2326 if (Reused && !ParentReused) 2327 write("(reused) "); 2328 2329 if (auto *CI = dyn_cast<CallInst>(I)) { 2330 writeFnName(CI); 2331 2332 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2333 } else if (isa<BitCastInst>(Expr)) { 2334 // Special case bitcasts, which are used to materialize matrixes from 2335 // non-matrix ops. 2336 write("matrix"); 2337 return; 2338 } else { 2339 Ops.append(I->value_op_begin(), I->value_op_end()); 2340 write(std::string(I->getOpcodeName())); 2341 } 2342 2343 write(std::string("(")); 2344 2345 unsigned NumOpsToBreak = 1; 2346 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2347 NumOpsToBreak = 2; 2348 2349 for (Value *Op : Ops) { 2350 if (Ops.size() > NumOpsToBreak) 2351 lineBreak(); 2352 2353 maybeIndent(Indent + 1); 2354 if (isMatrix(Op)) 2355 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2356 else 2357 write(Op); 2358 if (Op != Ops.back()) 2359 write(", "); 2360 } 2361 2362 write(")"); 2363 } 2364 2365 const std::string &getResult() { 2366 Stream.flush(); 2367 return Str; 2368 } 2369 }; 2370 2371 /// Generate remarks for matrix operations in a function. To generate remarks 2372 /// for matrix expressions, the following approach is used: 2373 /// 1. Use the inlined-at debug information to group matrix operations to the 2374 /// DISubprograms they are contained in. 2375 /// 2. Collect leaves of matrix expressions (done in 2376 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2377 // mapping. Leaves are lowered matrix instructions without other matrix 2378 // users (like stores) in the current subprogram. 2379 /// 3. For each leaf, create a remark containing a linearizied version of the 2380 /// matrix expression. The expression is linearized by a recursive 2381 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2382 /// that multiple leaves can share sub-expressions. Shared subexpressions 2383 /// are explicitly marked as shared(). 2384 struct RemarkGenerator { 2385 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2386 OptimizationRemarkEmitter &ORE; 2387 Function &Func; 2388 const DataLayout &DL; 2389 2390 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2391 OptimizationRemarkEmitter &ORE, Function &Func) 2392 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2393 DL(Func.getParent()->getDataLayout()) {} 2394 2395 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2396 /// instructions in Inst2Matrix returning void or without any users in 2397 /// \p ExprsInSubprogram. Currently that should only include stores. 2398 SmallVector<Value *, 4> 2399 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2400 SmallVector<Value *, 4> Leaves; 2401 for (auto *Expr : ExprsInSubprogram) 2402 if (Expr->getType()->isVoidTy() || 2403 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2404 return ExprsInSubprogram.count(U); 2405 })) 2406 Leaves.push_back(Expr); 2407 return Leaves; 2408 } 2409 2410 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2411 /// to all visited expressions in \p Shared. Limit the matrix operations to 2412 /// the ones in \p ExprsInSubprogram. 2413 void collectSharedInfo(Value *Leaf, Value *V, 2414 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2415 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2416 2417 if (!ExprsInSubprogram.count(V)) 2418 return; 2419 2420 auto I = Shared.insert({V, {}}); 2421 I.first->second.insert(Leaf); 2422 2423 for (Value *Op : cast<Instruction>(V)->operand_values()) 2424 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2425 } 2426 2427 /// Calculate the number of exclusive and shared op counts for expression 2428 /// starting at \p V. Expressions used multiple times are counted once. 2429 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2430 std::pair<OpInfoTy, OpInfoTy> 2431 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2432 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2433 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2434 if (!ExprsInSubprogram.count(Root)) 2435 return {}; 2436 2437 // Already counted this expression. Stop. 2438 if (!ReusedExprs.insert(Root).second) 2439 return {}; 2440 2441 OpInfoTy SharedCount; 2442 OpInfoTy Count; 2443 2444 auto I = Shared.find(Root); 2445 auto CM = Inst2Matrix.find(Root); 2446 if (I->second.size() == 1) 2447 Count = CM->second.getOpInfo(); 2448 else 2449 SharedCount = CM->second.getOpInfo(); 2450 2451 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2452 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2453 Count += C.first; 2454 SharedCount += C.second; 2455 } 2456 return {Count, SharedCount}; 2457 } 2458 2459 void emitRemarks() { 2460 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2461 return; 2462 2463 // Map matrix operations to their containting subprograms, by traversing 2464 // the inlinedAt chain. If the function does not have a DISubprogram, we 2465 // only map them to the containing function. 2466 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2467 for (const auto &KV : Inst2Matrix) { 2468 if (Func.getSubprogram()) { 2469 auto *I = cast<Instruction>(KV.first); 2470 DILocation *Context = I->getDebugLoc(); 2471 while (Context) { 2472 auto I = 2473 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 2474 I.first->second.push_back(KV.first); 2475 Context = DebugLoc(Context).getInlinedAt(); 2476 } 2477 } else { 2478 auto I = Subprog2Exprs.insert({nullptr, {}}); 2479 I.first->second.push_back(KV.first); 2480 } 2481 } 2482 for (auto &KV : Subprog2Exprs) { 2483 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2484 KV.second.end()); 2485 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2486 2487 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2488 for (Value *Leaf : Leaves) 2489 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2490 2491 // Generate remarks for each leaf. 2492 for (auto *L : Leaves) { 2493 2494 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2495 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2496 while (Context) { 2497 if (getSubprogram(Context->getScope()) == KV.first) { 2498 Loc = Context; 2499 break; 2500 } 2501 Context = DebugLoc(Context).getInlinedAt(); 2502 } 2503 2504 SmallPtrSet<Value *, 8> ReusedExprs; 2505 OpInfoTy Counts, SharedCounts; 2506 std::tie(Counts, SharedCounts) = 2507 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2508 2509 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2510 cast<Instruction>(L)->getParent()); 2511 2512 Rem << "Lowered with "; 2513 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2514 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2515 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2516 << " compute ops, " 2517 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2518 << " exposed transposes"; 2519 2520 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2521 SharedCounts.NumComputeOps > 0) { 2522 Rem << ",\nadditionally " 2523 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2524 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2525 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2526 << " compute ops" 2527 << " are shared with other expressions"; 2528 } 2529 2530 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2531 ORE.emit(Rem); 2532 } 2533 } 2534 } 2535 2536 std::string 2537 linearize(Value *L, 2538 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2539 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2540 const DataLayout &DL) { 2541 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2542 Lin.linearizeExpr(L, 0, false, false); 2543 return Lin.getResult(); 2544 } 2545 }; 2546 }; 2547 } // namespace 2548 2549 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2550 FunctionAnalysisManager &AM) { 2551 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2552 OptimizationRemarkEmitter *ORE = nullptr; 2553 AAResults *AA = nullptr; 2554 DominatorTree *DT = nullptr; 2555 LoopInfo *LI = nullptr; 2556 2557 if (!Minimal) { 2558 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2559 AA = &AM.getResult<AAManager>(F); 2560 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2561 LI = &AM.getResult<LoopAnalysis>(F); 2562 } 2563 2564 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2565 if (LMT.Visit()) { 2566 PreservedAnalyses PA; 2567 if (!Minimal) { 2568 PA.preserve<LoopAnalysis>(); 2569 PA.preserve<DominatorTreeAnalysis>(); 2570 } 2571 return PA; 2572 } 2573 return PreservedAnalyses::all(); 2574 } 2575 2576 void LowerMatrixIntrinsicsPass::printPipeline( 2577 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2578 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2579 OS, MapClassName2PassName); 2580 OS << '<'; 2581 if (Minimal) 2582 OS << "minimal"; 2583 OS << '>'; 2584 } 2585