1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/ADT/ScopeExit.h" 24 #include "llvm/ADT/SmallSet.h" 25 #include "llvm/ADT/SmallVector.h" 26 #include "llvm/ADT/Statistic.h" 27 #include "llvm/Analysis/AliasAnalysis.h" 28 #include "llvm/Analysis/DomTreeUpdater.h" 29 #include "llvm/Analysis/LoopInfo.h" 30 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 31 #include "llvm/Analysis/TargetTransformInfo.h" 32 #include "llvm/Analysis/ValueTracking.h" 33 #include "llvm/Analysis/VectorUtils.h" 34 #include "llvm/IR/CFG.h" 35 #include "llvm/IR/DataLayout.h" 36 #include "llvm/IR/DebugInfoMetadata.h" 37 #include "llvm/IR/DerivedTypes.h" 38 #include "llvm/IR/Function.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/InstrTypes.h" 41 #include "llvm/IR/Instructions.h" 42 #include "llvm/IR/IntrinsicInst.h" 43 #include "llvm/IR/MatrixBuilder.h" 44 #include "llvm/IR/PatternMatch.h" 45 #include "llvm/Support/Alignment.h" 46 #include "llvm/Support/CommandLine.h" 47 #include "llvm/Support/Compiler.h" 48 #include "llvm/Support/Debug.h" 49 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 50 #include "llvm/Transforms/Utils/LoopUtils.h" 51 #include "llvm/Transforms/Utils/MatrixUtils.h" 52 53 #include <cmath> 54 55 using namespace llvm; 56 using namespace PatternMatch; 57 58 #define DEBUG_TYPE "lower-matrix-intrinsics" 59 60 STATISTIC(FlattenedMatrices, "Number of matrix flattenings"); 61 STATISTIC(ReshapedMatrices, "Number of matrix reshapes"); 62 STATISTIC(SplitMatrices, "Number of matrix splits"); 63 64 static cl::opt<bool> 65 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 66 cl::desc("Enable/disable fusing matrix instructions.")); 67 // TODO: Allow and use non-square tiles. 68 static cl::opt<unsigned> TileSize( 69 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 70 cl::desc( 71 "Tile size for matrix instruction fusion using square-shaped tiles.")); 72 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 73 cl::Hidden, 74 cl::desc("Generate loop nest for tiling.")); 75 static cl::opt<bool> ForceFusion( 76 "force-fuse-matrix", cl::init(false), cl::Hidden, 77 cl::desc("Force matrix instruction fusion even if not profitable.")); 78 static cl::opt<bool> AllowContractEnabled( 79 "matrix-allow-contract", cl::init(false), cl::Hidden, 80 cl::desc("Allow the use of FMAs if available and profitable. This may " 81 "result in different results, due to less rounding error.")); 82 83 static cl::opt<bool> 84 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, 85 cl::desc("Enable/disable matrix shape verification."), 86 cl::init(false)); 87 88 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 89 90 static cl::opt<MatrixLayoutTy> MatrixLayout( 91 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 92 cl::desc("Sets the default matrix layout"), 93 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 94 "Use column-major layout"), 95 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 96 "Use row-major layout"))); 97 98 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", 99 cl::init(false)); 100 101 /// Helper function to either return Scope, if it is a subprogram or the 102 /// attached subprogram for a local scope. 103 static DISubprogram *getSubprogram(DIScope *Scope) { 104 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 105 return Subprogram; 106 return cast<DILocalScope>(Scope)->getSubprogram(); 107 } 108 109 /// Return true if V is a splat of a value (which is used when multiplying a 110 /// matrix with a scalar). 111 static bool isSplat(Value *V) { 112 if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) 113 return SV->isZeroEltSplat(); 114 return false; 115 } 116 117 /// Match any mul operation (fp or integer). 118 template <typename LTy, typename RTy> 119 auto m_AnyMul(const LTy &L, const RTy &R) { 120 return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); 121 } 122 123 /// Match any add operation (fp or integer). 124 template <typename LTy, typename RTy> 125 auto m_AnyAdd(const LTy &L, const RTy &R) { 126 return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); 127 } 128 129 namespace { 130 131 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 132 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 133 // assuming \p Stride elements between start two consecutive vectors. 134 // \p Stride must be >= \p NumElements. 135 // For column-major matrixes, the function computes the address of a column 136 // vectors and \p NumElements must be set to the number of elements in a column 137 // (= number of rows of the matrix). For row-major matrixes, the function 138 // computes the address of a row vector and \p NumElements must be set to the 139 // number of elements in a column (= number of columns of the matrix). 140 // 141 // Consider a 4x4 matrix in column-mjaor layout like below 142 // 143 // 0 1 2 3 144 // 0 v_0_0 v_0_1 v_0_2 v_0_3 145 // 1 v_1_0 v_1_1 v_1_2 v_1_3 146 // 2 v_2_0 v_2_1 v_2_2 v_2_3 147 // 3 v_3_0 v_3_1 v_3_2 v_3_3 148 149 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 150 // we need a pointer to the first element of the submatrix as base pointer. 151 // Then we can use computeVectorAddr to compute the addresses for the columns 152 // of the sub-matrix. 153 // 154 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 155 // -> just returns Base 156 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 157 // -> returns Base + (1 * 4) 158 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 159 // -> returns Base + (2 * 4) 160 // 161 // The graphic below illustrates the number of elements in a column (marked 162 // with |) and the number of skipped elements (marked with }). 163 // 164 // v_0_0 v_0_1 {v_0_2 {v_0_3 165 // Base Col 1 Col 2 166 // | | | 167 // v_1_0 |v_1_1 |v_1_2 |v_1_3 168 // v_2_0 |v_2_1 |v_2_2 |v_2_3 169 // v_3_0 {v_3_1 {v_3_2 v_3_3 170 // 171 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 172 unsigned NumElements, Type *EltType, 173 IRBuilder<> &Builder) { 174 175 assert((!isa<ConstantInt>(Stride) || 176 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 177 "Stride must be >= the number of elements in the result vector."); 178 179 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 180 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 181 182 // Get pointer to the start of the selected vector. Skip GEP creation, 183 // if we select vector 0. 184 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 185 VecStart = BasePtr; 186 else 187 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 188 189 return VecStart; 190 } 191 192 namespace { 193 struct ShapeInfo { 194 unsigned NumRows; 195 unsigned NumColumns; 196 197 bool IsColumnMajor; 198 199 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 200 : NumRows(NumRows), NumColumns(NumColumns), 201 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 202 203 ShapeInfo(Value *NumRows, Value *NumColumns) 204 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 205 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 206 207 bool operator==(const ShapeInfo &other) { 208 return NumRows == other.NumRows && NumColumns == other.NumColumns; 209 } 210 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 211 212 /// Returns true if shape-information is defined, meaning both dimensions 213 /// are != 0. 214 operator bool() const { 215 assert(NumRows == 0 || NumColumns != 0); 216 return NumRows != 0; 217 } 218 219 unsigned getStride() const { 220 if (IsColumnMajor) 221 return NumRows; 222 return NumColumns; 223 } 224 225 unsigned getNumVectors() const { 226 if (IsColumnMajor) 227 return NumColumns; 228 return NumRows; 229 } 230 231 /// Returns the transposed shape. 232 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } 233 234 friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI); 235 236 LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; } 237 }; 238 239 raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) { 240 return OS << SI.NumRows << 'x' << SI.NumColumns; 241 } 242 243 } // namespace 244 245 static bool isUniformShape(Value *V) { 246 Instruction *I = dyn_cast<Instruction>(V); 247 if (!I) 248 return true; 249 250 if (I->isBinaryOp()) 251 return true; 252 253 if (auto *Cast = dyn_cast<CastInst>(V)) { 254 switch (Cast->getOpcode()) { 255 case llvm::Instruction::Trunc: 256 case llvm::Instruction::ZExt: 257 case llvm::Instruction::SExt: 258 case llvm::Instruction::FPToUI: 259 case llvm::Instruction::FPToSI: 260 case llvm::Instruction::UIToFP: 261 case llvm::Instruction::SIToFP: 262 case llvm::Instruction::FPTrunc: 263 case llvm::Instruction::FPExt: 264 return true; 265 case llvm::Instruction::AddrSpaceCast: 266 case CastInst::PtrToInt: 267 case CastInst::IntToPtr: 268 return false; 269 case CastInst::BitCast: { 270 if (auto *SrcVTy = dyn_cast<FixedVectorType>(Cast->getSrcTy())) 271 if (auto *DestVTy = dyn_cast<FixedVectorType>(Cast->getDestTy())) 272 return SrcVTy->getNumElements() == DestVTy->getNumElements(); 273 return false; 274 } 275 case llvm::Instruction::CastOpsEnd: 276 llvm_unreachable("not an actual cast op"); 277 } 278 llvm_unreachable("unhandled cast opcode"); 279 } 280 281 if (auto *II = dyn_cast<IntrinsicInst>(V)) 282 switch (II->getIntrinsicID()) { 283 case Intrinsic::abs: 284 case Intrinsic::fabs: 285 return true; 286 default: 287 return false; 288 } 289 290 switch (I->getOpcode()) { 291 case Instruction::PHI: 292 case Instruction::FNeg: 293 return true; 294 default: 295 return false; 296 } 297 } 298 299 /// Return the ShapeInfo for the result of \p I, it it can be determined. 300 static std::optional<ShapeInfo> 301 computeShapeInfoForInst(Instruction *I, 302 const DenseMap<Value *, ShapeInfo> &ShapeMap) { 303 Value *M; 304 Value *N; 305 Value *K; 306 if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>( 307 m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K)))) 308 return ShapeInfo(M, K); 309 if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M), 310 m_Value(N)))) { 311 // Flip dimensions. 312 return ShapeInfo(N, M); 313 } 314 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>( 315 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M), 316 m_Value(N)))) 317 return ShapeInfo(N, M); 318 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>( 319 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N)))) 320 return ShapeInfo(M, N); 321 Value *MatrixA; 322 if (match(I, m_Store(m_Value(MatrixA), m_Value()))) { 323 auto OpShape = ShapeMap.find(MatrixA); 324 if (OpShape != ShapeMap.end()) 325 return OpShape->second; 326 } 327 328 if (isUniformShape(I) || isa<SelectInst>(I)) { 329 auto Ops = I->operands(); 330 auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops; 331 // Find the first operand that has a known shape and use that. 332 for (auto &Op : ShapedOps) { 333 auto OpShape = ShapeMap.find(Op.get()); 334 if (OpShape != ShapeMap.end()) 335 return OpShape->second; 336 } 337 } 338 return std::nullopt; 339 } 340 341 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 342 /// 343 /// Currently, the lowering for each matrix intrinsic is done as follows: 344 /// 1. Propagate the shape information from intrinsics to connected 345 /// instructions. 346 /// 2. Lower instructions with shape information (assuming column-major layout). 347 /// The lowering works similarly using row-major layout. 348 /// 2.1. Get column vectors for each argument. If we already lowered the 349 /// definition of an argument, use the produced column vectors directly. 350 /// If not, split the operand vector containing an embedded matrix into 351 /// a set of column vectors, 352 /// 2.2. Lower the instruction in terms of column major operations, which 353 /// yields a set of column vectors containing result matrix. Note that we 354 /// lower all instructions that have shape information. Besides the 355 /// intrinsics, this includes stores for example. 356 /// 2.3. Update uses of the lowered instruction. If we have shape information 357 /// for a user, there is nothing to do, as we will look up the result 358 /// column matrix when lowering the user. For other uses, we embed the 359 /// result matrix in a flat vector and update the use. 360 /// 2.4. Cache the result column matrix for the instruction we lowered 361 /// 3. After we lowered all instructions in a function, remove the now 362 /// obsolete instructions. 363 /// 364 class LowerMatrixIntrinsics { 365 Function &Func; 366 const DataLayout &DL; 367 const TargetTransformInfo &TTI; 368 FunctionAnalysisManager *AM; 369 AliasAnalysis *AA = nullptr; 370 DominatorTree *DT = nullptr; 371 LoopInfo *LI = nullptr; 372 OptimizationRemarkEmitter *ORE = nullptr; 373 374 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 375 struct OpInfoTy { 376 /// Number of stores emitted to generate this matrix. 377 unsigned NumStores = 0; 378 /// Number of loads emitted to generate this matrix. 379 unsigned NumLoads = 0; 380 /// Number of compute operations emitted to generate this matrix. 381 unsigned NumComputeOps = 0; 382 /// Most of the time transposes can be fused with matrix multiplies or can 383 /// be folded away via algebraic simplifications. This is the number of 384 /// transposes that we failed to make "free" via such optimizations. 385 unsigned NumExposedTransposes = 0; 386 387 OpInfoTy &operator+=(const OpInfoTy &RHS) { 388 NumStores += RHS.NumStores; 389 NumLoads += RHS.NumLoads; 390 NumComputeOps += RHS.NumComputeOps; 391 NumExposedTransposes += RHS.NumExposedTransposes; 392 return *this; 393 } 394 }; 395 396 /// Wrapper class representing a matrix as a set of vectors, either in row or 397 /// column major layout. All vectors must have the same vector type. 398 class MatrixTy { 399 SmallVector<Value *, 16> Vectors; 400 401 OpInfoTy OpInfo; 402 403 bool IsColumnMajor = true; 404 405 public: 406 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 407 MatrixTy(ArrayRef<Value *> Vectors) 408 : Vectors(Vectors), 409 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 410 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 411 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 412 413 unsigned D = isColumnMajor() ? NumColumns : NumRows; 414 for (unsigned J = 0; J < D; ++J) 415 addVector(PoisonValue::get(FixedVectorType::get( 416 EltTy, isColumnMajor() ? NumRows : NumColumns))); 417 } 418 419 Value *getVector(unsigned i) const { return Vectors[i]; } 420 Value *getColumn(unsigned i) const { 421 assert(isColumnMajor() && "only supported for column-major matrixes"); 422 return Vectors[i]; 423 } 424 Value *getRow(unsigned i) const { 425 assert(!isColumnMajor() && "only supported for row-major matrixes"); 426 return Vectors[i]; 427 } 428 429 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 430 431 Type *getElementType() const { return getVectorTy()->getElementType(); } 432 433 unsigned getNumVectors() const { 434 if (isColumnMajor()) 435 return getNumColumns(); 436 return getNumRows(); 437 } 438 439 unsigned getNumColumns() const { 440 if (isColumnMajor()) 441 return Vectors.size(); 442 else { 443 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 444 return getVectorTy()->getNumElements(); 445 } 446 } 447 unsigned getNumRows() const { 448 if (isColumnMajor()) { 449 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 450 return getVectorTy()->getNumElements(); 451 } else 452 return Vectors.size(); 453 } 454 455 void addVector(Value *V) { Vectors.push_back(V); } 456 FixedVectorType *getColumnTy() { 457 assert(isColumnMajor() && "only supported for column-major matrixes"); 458 return getVectorTy(); 459 } 460 461 FixedVectorType *getVectorTy() const { 462 return cast<FixedVectorType>(Vectors[0]->getType()); 463 } 464 465 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 466 assert(isColumnMajor() && 467 "columns() only supported for column-major matrixes"); 468 return make_range(Vectors.begin(), Vectors.end()); 469 } 470 471 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 472 return make_range(Vectors.begin(), Vectors.end()); 473 } 474 475 /// Embed the vectors of the matrix into a flat vector by concatenating 476 /// them. 477 Value *embedInVector(IRBuilder<> &Builder) const { 478 return Vectors.size() == 1 ? Vectors[0] 479 : concatenateVectors(Builder, Vectors); 480 } 481 482 MatrixTy &addNumLoads(unsigned N) { 483 OpInfo.NumLoads += N; 484 return *this; 485 } 486 487 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 488 489 MatrixTy &addNumStores(unsigned N) { 490 OpInfo.NumStores += N; 491 return *this; 492 } 493 494 MatrixTy &addNumExposedTransposes(unsigned N) { 495 OpInfo.NumExposedTransposes += N; 496 return *this; 497 } 498 499 MatrixTy &addNumComputeOps(unsigned N) { 500 OpInfo.NumComputeOps += N; 501 return *this; 502 } 503 504 unsigned getNumStores() const { return OpInfo.NumStores; } 505 unsigned getNumLoads() const { return OpInfo.NumLoads; } 506 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 507 508 const OpInfoTy &getOpInfo() const { return OpInfo; } 509 510 bool isColumnMajor() const { return IsColumnMajor; } 511 512 unsigned getStride() const { 513 if (isColumnMajor()) 514 return getNumRows(); 515 return getNumColumns(); 516 } 517 518 ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; } 519 520 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 521 /// matrix is column-major, the result vector is extracted from a column 522 /// vector, otherwise from a row vector. 523 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 524 IRBuilder<> &Builder) const { 525 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 526 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 527 NumElts && 528 "Extracted vector will contain poison values"); 529 return Builder.CreateShuffleVector( 530 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 531 "block"); 532 } 533 }; 534 535 /// Maps instructions to their shape information. The shape information 536 /// describes the shape to be used while lowering. This matches the shape of 537 /// the result value of the instruction, with the only exceptions being store 538 /// instructions and the matrix_column_major_store intrinsics. For those, the 539 /// shape information indicates that those instructions should be lowered 540 /// using shape information as well. Note that extra care is needed when 541 /// erasing or RAUW'ing a value that is present in ShapeMap. If the 542 /// replacement is also a matrix operation, use 543 /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to 544 /// ShapeMap. We don't use ValueMap, as there are also cases where we do not 545 /// want to add shape information for a replacement instruction. When directly 546 /// erasing a value with an entry in ShapeMap, use 547 /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated 548 /// accordingly. 549 DenseMap<Value *, ShapeInfo> ShapeMap; 550 551 /// List of instructions to remove. While lowering, we are not replacing all 552 /// users of a lowered instruction, if shape information is available and 553 /// those need to be removed after we finished lowering. 554 SmallVector<Instruction *, 16> ToRemove; 555 556 /// Map from instructions to their produced column matrix. 557 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 558 559 private: 560 static FastMathFlags getFastMathFlags(Instruction *Inst) { 561 FastMathFlags FMF; 562 563 if (isa<FPMathOperator>(*Inst)) 564 FMF = Inst->getFastMathFlags(); 565 566 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 567 568 return FMF; 569 } 570 571 public: 572 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 573 FunctionAnalysisManager *AM) 574 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {} 575 576 unsigned getNumOps(Type *VT) { 577 assert(isa<FixedVectorType>(VT) && "Expected vector type"); 578 return getNumOps(VT->getScalarType(), 579 cast<FixedVectorType>(VT)->getNumElements()); 580 } 581 582 /// Is this the minimal version executed in the backend pipelines. 583 bool isMinimal() const { 584 return !DT; 585 } 586 587 /// Return the estimated number of vector ops required for an operation on 588 /// \p VT * N. 589 unsigned getNumOps(Type *ST, unsigned N) { 590 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / 591 double(TTI.getRegisterBitWidth( 592 TargetTransformInfo::RGK_FixedWidthVector) 593 .getFixedValue())); 594 } 595 596 /// Return the set of vectors that a matrix value is lowered to. 597 /// 598 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 599 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 600 /// into vectors. 601 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 602 IRBuilder<> &Builder) { 603 FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType()); 604 assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && 605 "The vector size must match the number of matrix elements"); 606 607 // Check if we lowered MatrixVal using shape information. In that case, 608 // return the existing matrix, if it matches the requested shape 609 // information. If there is a mis-match, embed the result in a flat 610 // vector and split it later. 611 auto Found = Inst2ColumnMatrix.find(MatrixVal); 612 if (Found != Inst2ColumnMatrix.end()) { 613 MatrixTy &M = Found->second; 614 // Return the found matrix, if its shape matches the requested shape 615 // information 616 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 617 return M; 618 619 MatrixVal = M.embedInVector(Builder); 620 } 621 622 // Otherwise split MatrixVal. 623 SmallVector<Value *, 16> SplitVecs; 624 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); 625 MaskStart += SI.getStride()) { 626 Value *V = Builder.CreateShuffleVector( 627 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 628 "split"); 629 SplitVecs.push_back(V); 630 } 631 632 if (Instruction *Inst = dyn_cast<Instruction>(MatrixVal)) { 633 if (Found != Inst2ColumnMatrix.end()) { 634 // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles 635 // that embedInVector created. 636 LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape() 637 << " to " << SI << " using at least " 638 << SplitVecs.size() << " shuffles on behalf of:\n" 639 << *Inst << '\n'); 640 ReshapedMatrices++; 641 } else if (!ShapeMap.contains(MatrixVal)) { 642 LLVM_DEBUG( 643 dbgs() 644 << "splitting a " << SI << " matrix with " << SplitVecs.size() 645 << " shuffles beacuse we do not have a shape-aware lowering for " 646 "its def:\n" 647 << *Inst << '\n'); 648 (void)Inst; 649 SplitMatrices++; 650 } else { 651 // The ShapeMap has it, so it's a case where we're being lowered 652 // before the def, and we expect that InstCombine will clean things up 653 // afterward. 654 } 655 } 656 657 return {SplitVecs}; 658 } 659 660 /// If \p V already has a known shape return false. Otherwise set the shape 661 /// for instructions that support it. 662 bool setShapeInfo(Value *V, ShapeInfo Shape) { 663 assert(Shape && "Shape not set"); 664 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 665 return false; 666 667 auto SIter = ShapeMap.find(V); 668 if (SIter != ShapeMap.end()) { 669 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || 670 SIter->second.NumColumns != Shape.NumColumns)) { 671 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" 672 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" 673 << Shape.NumColumns << ") for " << *V << "\n"; 674 report_fatal_error( 675 "Matrix shape verification failed, compilation aborted!"); 676 } 677 678 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 679 << SIter->second.NumRows << " " 680 << SIter->second.NumColumns << " for " << *V << "\n"); 681 return false; 682 } 683 684 ShapeMap.insert({V, Shape}); 685 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 686 << " for " << *V << "\n"); 687 return true; 688 } 689 690 /// Returns true if shape information can be used for \p V. The supported 691 /// instructions must match the instructions that can be lowered by this pass. 692 bool supportsShapeInfo(Value *V) { 693 Instruction *Inst = dyn_cast<Instruction>(V); 694 if (!Inst) 695 return false; 696 697 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 698 if (II) 699 switch (II->getIntrinsicID()) { 700 case Intrinsic::matrix_multiply: 701 case Intrinsic::matrix_transpose: 702 case Intrinsic::matrix_column_major_load: 703 case Intrinsic::matrix_column_major_store: 704 return true; 705 default: 706 return isUniformShape(II); 707 } 708 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) || 709 isa<SelectInst>(V); 710 } 711 712 /// Propagate the shape information of instructions to their users. 713 /// The work list contains instructions for which we can compute the shape, 714 /// either based on the information provided by matrix intrinsics or known 715 /// shapes of operands. 716 SmallVector<Instruction *, 32> 717 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 718 SmallVector<Instruction *, 32> NewWorkList; 719 // Pop an element for which we guaranteed to have at least one of the 720 // operand shapes. Add the shape for this and then add users to the work 721 // list. 722 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 723 while (!WorkList.empty()) { 724 Instruction *Inst = WorkList.pop_back_val(); 725 726 // New entry, set the value and insert operands 727 bool Propagate = false; 728 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap)) 729 Propagate = setShapeInfo(Inst, *SI); 730 731 if (Propagate) { 732 NewWorkList.push_back(Inst); 733 for (auto *User : Inst->users()) 734 if (ShapeMap.count(User) == 0) 735 WorkList.push_back(cast<Instruction>(User)); 736 } 737 } 738 739 return NewWorkList; 740 } 741 742 /// Propagate the shape to operands of instructions with shape information. 743 /// \p Worklist contains the instruction for which we already know the shape. 744 SmallVector<Instruction *, 32> 745 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 746 SmallVector<Instruction *, 32> NewWorkList; 747 748 auto pushInstruction = [](Value *V, 749 SmallVectorImpl<Instruction *> &WorkList) { 750 Instruction *I = dyn_cast<Instruction>(V); 751 if (I) 752 WorkList.push_back(I); 753 }; 754 // Pop an element with known shape. Traverse the operands, if their shape 755 // derives from the result shape and is unknown, add it and add them to the 756 // worklist. 757 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 758 while (!WorkList.empty()) { 759 Value *V = WorkList.pop_back_val(); 760 761 size_t BeforeProcessingV = WorkList.size(); 762 if (!isa<Instruction>(V)) 763 continue; 764 765 Value *MatrixA; 766 Value *MatrixB; 767 Value *M; 768 Value *N; 769 Value *K; 770 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 771 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 772 m_Value(N), m_Value(K)))) { 773 if (setShapeInfo(MatrixA, {M, N})) 774 pushInstruction(MatrixA, WorkList); 775 776 if (setShapeInfo(MatrixB, {N, K})) 777 pushInstruction(MatrixB, WorkList); 778 779 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 780 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 781 // Flip dimensions. 782 if (setShapeInfo(MatrixA, {M, N})) 783 pushInstruction(MatrixA, WorkList); 784 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 785 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 786 m_Value(M), m_Value(N)))) { 787 if (setShapeInfo(MatrixA, {M, N})) { 788 pushInstruction(MatrixA, WorkList); 789 } 790 } else if (isa<LoadInst>(V) || 791 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 792 // Nothing to do, no matrix input. 793 } else if (isa<StoreInst>(V)) { 794 // Nothing to do. We forward-propagated to this so we would just 795 // backward propagate to an instruction with an already known shape. 796 } else if (isUniformShape(V) || isa<SelectInst>(V)) { 797 auto Ops = cast<Instruction>(V)->operands(); 798 auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops; 799 // Propagate to all operands. 800 ShapeInfo Shape = ShapeMap[V]; 801 for (Use &U : ShapedOps) { 802 if (setShapeInfo(U.get(), Shape)) 803 pushInstruction(U.get(), WorkList); 804 } 805 } 806 // After we discovered new shape info for new instructions in the 807 // worklist, we use their users as seeds for the next round of forward 808 // propagation. 809 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 810 for (User *U : WorkList[I]->users()) 811 if (isa<Instruction>(U) && V != U) 812 NewWorkList.push_back(cast<Instruction>(U)); 813 } 814 return NewWorkList; 815 } 816 817 /// (Op0 op Op1)^T -> Op0^T op Op1^T 818 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 819 /// them on both sides of \p Operation. 820 Instruction *distributeTransposes( 821 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, 822 MatrixBuilder &Builder, 823 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> 824 Operation) { 825 Value *T0 = Builder.CreateMatrixTranspose( 826 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); 827 // We are being run after shape prop, add shape for newly created 828 // instructions so that we lower them later. 829 setShapeInfo(T0, Shape0.t()); 830 Value *T1 = Builder.CreateMatrixTranspose( 831 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); 832 setShapeInfo(T1, Shape1.t()); 833 return Operation(T0, Shape0.t(), T1, Shape1.t()); 834 } 835 836 /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst 837 /// itself. 838 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) { 839 ShapeMap.erase(Inst); 840 Inst->eraseFromParent(); 841 } 842 843 /// Erase \p V from \p BB and move \II forward to avoid invalidating 844 /// iterators. 845 void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, 846 BasicBlock &BB) { 847 auto *Inst = cast<Instruction>(V); 848 // Still used, don't erase. 849 if (!Inst->use_empty()) 850 return; 851 if (II != BB.rend() && Inst == &*II) 852 ++II; 853 eraseFromParentAndRemoveFromShapeMap(Inst); 854 } 855 856 /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the 857 /// entry for \p Old and replace all uses of \p Old with \p New. 858 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { 859 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 860 // with New. We should only add New it it supportsShapeInfo so we insert 861 // it conditionally instead. 862 auto S = ShapeMap.find(&Old); 863 if (S != ShapeMap.end()) { 864 ShapeMap.erase(S); 865 if (supportsShapeInfo(New)) 866 ShapeMap.insert({New, S->second}); 867 } 868 Old.replaceAllUsesWith(New); 869 } 870 871 /// Sink a top-level transpose inside matmuls and adds. 872 /// This creates and erases instructions as needed, and returns the newly 873 /// created instruction while updating the iterator to avoid invalidation. If 874 /// this returns nullptr, no new instruction was created. 875 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II, 876 bool &Changed) { 877 BasicBlock &BB = *I.getParent(); 878 IRBuilder<> IB(&I); 879 MatrixBuilder Builder(IB); 880 881 Value *TA, *TAMA, *TAMB; 882 ConstantInt *R, *K, *C; 883 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( 884 m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) 885 return nullptr; 886 887 // Transpose of a transpose is a nop when the shapes match. 888 Value *TATA; 889 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>( 890 m_Value(TATA), m_Specific(C), m_Specific(R)))) { 891 updateShapeAndReplaceAllUsesWith(I, TATA); 892 eraseFromParentAndMove(&I, II, BB); 893 eraseFromParentAndMove(TA, II, BB); 894 Changed = true; 895 return nullptr; 896 } 897 898 // k^T -> k 899 if (isSplat(TA)) { 900 updateShapeAndReplaceAllUsesWith(I, TA); 901 eraseFromParentAndMove(&I, II, BB); 902 Changed = true; 903 return nullptr; 904 } 905 906 // (A * B)^t -> B^t * A^t 907 // RxK KxC CxK KxR 908 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 909 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 910 m_ConstantInt(K), m_ConstantInt(C)))) { 911 auto NewInst = distributeTransposes( 912 TAMB, {K, C}, TAMA, {R, K}, Builder, 913 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 914 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, 915 Shape0.NumColumns, 916 Shape1.NumColumns, "mmul"); 917 }); 918 updateShapeAndReplaceAllUsesWith(I, NewInst); 919 eraseFromParentAndMove(&I, II, BB); 920 eraseFromParentAndMove(TA, II, BB); 921 Changed = true; 922 return NewInst; 923 } 924 925 // Same as above, but with a mul, which occurs when multiplied 926 // with a scalar. 927 // (A * k)^t -> A^t * k 928 // R x C RxC 929 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && 930 (isSplat(TAMA) || isSplat(TAMB))) { 931 IRBuilder<> LocalBuilder(&I); 932 // We know that the transposed operand is of shape RxC. 933 // An when multiplied with a scalar, the shape is preserved. 934 auto NewInst = distributeTransposes( 935 TAMA, {R, C}, TAMB, {R, C}, Builder, 936 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 937 bool IsFP = I.getType()->isFPOrFPVectorTy(); 938 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") 939 : LocalBuilder.CreateMul(T0, T1, "mmul"); 940 auto *Result = cast<Instruction>(Mul); 941 setShapeInfo(Result, Shape0); 942 return Result; 943 }); 944 updateShapeAndReplaceAllUsesWith(I, NewInst); 945 eraseFromParentAndMove(&I, II, BB); 946 eraseFromParentAndMove(TA, II, BB); 947 Changed = true; 948 return NewInst; 949 } 950 951 // (A + B)^t -> A^t + B^t 952 // RxC RxC CxR CxR 953 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { 954 IRBuilder<> LocalBuilder(&I); 955 auto NewInst = distributeTransposes( 956 TAMA, {R, C}, TAMB, {R, C}, Builder, 957 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 958 bool IsFP = I.getType()->isFPOrFPVectorTy(); 959 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") 960 : LocalBuilder.CreateAdd(T0, T1, "madd"); 961 962 auto *Result = cast<Instruction>(Add); 963 setShapeInfo(Result, Shape0); 964 return Result; 965 }); 966 updateShapeAndReplaceAllUsesWith(I, NewInst); 967 eraseFromParentAndMove(&I, II, BB); 968 eraseFromParentAndMove(TA, II, BB); 969 Changed = true; 970 return NewInst; 971 } 972 973 return nullptr; 974 } 975 976 bool liftTranspose(Instruction &I) { 977 // Erase dead Instructions after lifting transposes from binops. 978 auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) { 979 if (T.use_empty()) 980 eraseFromParentAndRemoveFromShapeMap(&T); 981 if (A->use_empty()) 982 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A)); 983 if (A != B && B->use_empty()) 984 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B)); 985 }; 986 987 Value *A, *B, *AT, *BT; 988 ConstantInt *R, *K, *C; 989 // A^t * B ^t -> (B * A)^t 990 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 991 m_Value(A), m_Value(B), m_ConstantInt(R), 992 m_ConstantInt(K), m_ConstantInt(C))) && 993 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 994 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 995 IRBuilder<> IB(&I); 996 MatrixBuilder Builder(IB); 997 Value *M = Builder.CreateMatrixMultiply( 998 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 999 setShapeInfo(M, {C, R}); 1000 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), 1001 R->getZExtValue()); 1002 updateShapeAndReplaceAllUsesWith(I, NewInst); 1003 CleanupBinOp(I, A, B); 1004 return true; 1005 } 1006 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If 1007 // the shape of the second transpose is different, there's a shape conflict 1008 // which gets resolved by picking the shape of the first operand. 1009 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && 1010 match(A, m_Intrinsic<Intrinsic::matrix_transpose>( 1011 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && 1012 match(B, m_Intrinsic<Intrinsic::matrix_transpose>( 1013 m_Value(BT), m_ConstantInt(), m_ConstantInt()))) { 1014 IRBuilder<> Builder(&I); 1015 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd"); 1016 MatrixBuilder MBuilder(Builder); 1017 Instruction *NewInst = MBuilder.CreateMatrixTranspose( 1018 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t"); 1019 updateShapeAndReplaceAllUsesWith(I, NewInst); 1020 assert(computeShapeInfoForInst(NewInst, ShapeMap) == 1021 computeShapeInfoForInst(&I, ShapeMap) && 1022 "Shape of new instruction doesn't match original shape."); 1023 CleanupBinOp(I, A, B); 1024 if (auto *AddI = dyn_cast<Instruction>(Add)) { 1025 setShapeInfo(AddI, {R, C}); 1026 assert( 1027 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) == 1028 ShapeMap[AddI] && 1029 "Shape of updated addition doesn't match cached shape."); 1030 } 1031 return true; 1032 } 1033 return false; 1034 } 1035 1036 /// Try moving transposes in order to fold them away or into multiplies. 1037 bool optimizeTransposes() { 1038 bool Changed = false; 1039 // First sink all transposes inside matmuls and adds, hoping that we end up 1040 // with NN, NT or TN variants. 1041 for (BasicBlock &BB : reverse(Func)) { 1042 for (auto II = BB.rbegin(); II != BB.rend();) { 1043 Instruction &I = *II; 1044 // We may remove II. By default continue on the next/prev instruction. 1045 ++II; 1046 if (Instruction *NewInst = sinkTranspose(I, II, Changed)) 1047 II = std::next(BasicBlock::reverse_iterator(NewInst)); 1048 } 1049 } 1050 1051 // If we have a TT matmul or a TT add, lift the transpose. We may be able 1052 // to fold into consuming multiply or add. 1053 for (BasicBlock &BB : Func) { 1054 for (Instruction &I : llvm::make_early_inc_range(BB)) { 1055 Changed |= liftTranspose(I); 1056 } 1057 } 1058 return Changed; 1059 } 1060 1061 bool Visit() { 1062 SmallVector<Instruction *, 32> WorkList; 1063 1064 // Initially only the shape of matrix intrinsics is known. 1065 // Initialize the work list with ops carrying shape information. 1066 for (BasicBlock &BB : Func) 1067 for (Instruction &Inst : BB) { 1068 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 1069 if (!II) 1070 continue; 1071 1072 switch (II->getIntrinsicID()) { 1073 case Intrinsic::matrix_multiply: 1074 case Intrinsic::matrix_transpose: 1075 case Intrinsic::matrix_column_major_load: 1076 case Intrinsic::matrix_column_major_store: 1077 WorkList.push_back(&Inst); 1078 break; 1079 default: 1080 break; 1081 } 1082 } 1083 1084 // Avoid unnecessary work if there are no matrix intrinsics in the function. 1085 if (WorkList.empty()) 1086 return false; 1087 1088 if (AM) { 1089 ORE = &AM->getResult<OptimizationRemarkEmitterAnalysis>(Func); 1090 AA = &AM->getResult<AAManager>(Func); 1091 DT = &AM->getResult<DominatorTreeAnalysis>(Func); 1092 LI = &AM->getResult<LoopAnalysis>(Func); 1093 } 1094 1095 // Propagate shapes until nothing changes any longer. 1096 while (!WorkList.empty()) { 1097 WorkList = propagateShapeForward(WorkList); 1098 WorkList = propagateShapeBackward(WorkList); 1099 } 1100 1101 bool Changed = false; 1102 if (!isMinimal()) { 1103 Changed |= optimizeTransposes(); 1104 if (PrintAfterTransposeOpt) { 1105 dbgs() << "Dump after matrix transpose optimization:\n"; 1106 Func.print(dbgs()); 1107 } 1108 } 1109 1110 SmallVector<CallInst *, 16> MaybeFusableInsts; 1111 SmallVector<Instruction *, 16> MatrixInsts; 1112 SmallVector<IntrinsicInst *, 16> LifetimeEnds; 1113 1114 // First, collect all instructions with shape information and candidates for 1115 // fusion (currently only matrix multiplies). 1116 ReversePostOrderTraversal<Function *> RPOT(&Func); 1117 for (auto *BB : RPOT) 1118 for (Instruction &I : *BB) { 1119 if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>())) 1120 LifetimeEnds.push_back(cast<IntrinsicInst>(&I)); 1121 if (!ShapeMap.contains(&I)) 1122 continue; 1123 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 1124 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 1125 MatrixInsts.push_back(&I); 1126 } 1127 1128 // Second, try to lower any dot products 1129 SmallPtrSet<Instruction *, 16> FusedInsts; 1130 for (CallInst *CI : MaybeFusableInsts) 1131 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); 1132 1133 // Third, try to fuse candidates. 1134 for (CallInst *CI : MaybeFusableInsts) 1135 if (!FusedInsts.contains(CI)) 1136 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); 1137 1138 Changed |= !FusedInsts.empty(); 1139 1140 // Fourth, pre-process all the PHINode's. The incoming values will be 1141 // assigned later in VisitPHI. 1142 for (Instruction *Inst : MatrixInsts) { 1143 if (FusedInsts.count(Inst)) 1144 continue; 1145 1146 auto *PHI = dyn_cast<PHINode>(Inst); 1147 if (!PHI) 1148 continue; 1149 1150 const ShapeInfo &SI = ShapeMap.at(Inst); 1151 auto *EltTy = cast<FixedVectorType>(PHI->getType())->getElementType(); 1152 MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy); 1153 1154 IRBuilder<> Builder(Inst); 1155 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) 1156 PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(), 1157 PHI->getNumIncomingValues(), 1158 PHI->getName())); 1159 assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?"); 1160 Inst2ColumnMatrix[PHI] = PhiM; 1161 } 1162 1163 // Fifth, lower remaining instructions with shape information. 1164 for (Instruction *Inst : MatrixInsts) { 1165 if (FusedInsts.count(Inst)) 1166 continue; 1167 1168 const ShapeInfo &SI = ShapeMap.at(Inst); 1169 1170 Value *Op1; 1171 Value *Op2; 1172 MatrixTy Result; 1173 IRBuilder<> Builder(Inst); 1174 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 1175 Result = VisitBinaryOperator(BinOp, SI, Builder); 1176 else if (auto *Cast = dyn_cast<CastInst>(Inst)) 1177 Result = VisitCastInstruction(Cast, SI, Builder); 1178 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 1179 Result = VisitUnaryOperator(UnOp, SI, Builder); 1180 else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst)) 1181 Result = VisitIntrinsicInst(Intr, SI, Builder); 1182 else if (auto *Select = dyn_cast<SelectInst>(Inst)) 1183 Result = VisitSelectInst(Select, SI, Builder); 1184 else if (match(Inst, m_Load(m_Value(Op1)))) 1185 Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder); 1186 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 1187 Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder); 1188 else if (auto *PHI = dyn_cast<PHINode>(Inst)) 1189 Result = VisitPHI(PHI, SI, Builder); 1190 else 1191 continue; 1192 1193 finalizeLowering(Inst, Result, Builder); 1194 Changed = true; 1195 } 1196 1197 if (ORE) { 1198 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 1199 RemarkGen.emitRemarks(); 1200 } 1201 1202 // Delete the instructions backwards, as it has a reduced likelihood of 1203 // having to update as many def-use and use-def chains. 1204 // 1205 // Because we add to ToRemove during fusion we can't guarantee that defs 1206 // are before uses. Change uses to poison temporarily as these should get 1207 // removed as well. 1208 // 1209 // For verification, we keep track of where we changed uses to poison in 1210 // PoisonedInsts and then check that we in fact remove them. 1211 SmallSet<Instruction *, 16> PoisonedInsts; 1212 for (auto *Inst : reverse(ToRemove)) { 1213 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1214 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 1215 PoisonedInsts.insert(Poisoned); 1216 U.set(PoisonValue::get(Inst->getType())); 1217 } 1218 Inst->eraseFromParent(); 1219 PoisonedInsts.erase(Inst); 1220 } 1221 if (!PoisonedInsts.empty()) { 1222 // If we didn't remove all poisoned instructions, it's a hard error. 1223 dbgs() << "Poisoned but present instructions:\n"; 1224 for (auto *I : PoisonedInsts) 1225 dbgs() << *I << "\n"; 1226 llvm_unreachable("Poisoned but instruction not removed"); 1227 } 1228 1229 return Changed; 1230 } 1231 1232 /// Replace intrinsic calls. 1233 MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI, 1234 IRBuilder<> &Builder) { 1235 assert(Inst->getCalledFunction() && 1236 Inst->getCalledFunction()->isIntrinsic()); 1237 1238 switch (Inst->getCalledFunction()->getIntrinsicID()) { 1239 case Intrinsic::matrix_multiply: 1240 return LowerMultiply(Inst, Builder); 1241 case Intrinsic::matrix_transpose: 1242 return LowerTranspose(Inst, Builder); 1243 case Intrinsic::matrix_column_major_load: 1244 return LowerColumnMajorLoad(Inst, Builder); 1245 case Intrinsic::matrix_column_major_store: 1246 return LowerColumnMajorStore(Inst, Builder); 1247 case Intrinsic::abs: 1248 case Intrinsic::fabs: { 1249 MatrixTy Result; 1250 MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder); 1251 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1252 1253 for (auto *Vector : M.vectors()) { 1254 switch (Inst->getIntrinsicID()) { 1255 case Intrinsic::abs: 1256 Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector, 1257 Inst->getOperand(1))); 1258 continue; 1259 case Intrinsic::fabs: 1260 Result.addVector( 1261 Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector)); 1262 continue; 1263 default: 1264 llvm_unreachable("unexpected intrinsic"); 1265 } 1266 } 1267 1268 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1269 Result.getNumVectors()); 1270 } 1271 default: 1272 break; 1273 } 1274 llvm_unreachable( 1275 "only intrinsics supporting shape info should be seen here"); 1276 } 1277 1278 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 1279 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 1280 /// ConstantInt, reduce the initial alignment based on the byte offset. For 1281 /// non-ConstantInt strides, return the common alignment of the initial 1282 /// alignment and the element size in bytes. 1283 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 1284 MaybeAlign A) const { 1285 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 1286 if (Idx == 0) 1287 return InitialAlign; 1288 1289 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 1290 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 1291 uint64_t StrideInBytes = 1292 ConstStride->getZExtValue() * ElementSizeInBits / 8; 1293 return commonAlignment(InitialAlign, Idx * StrideInBytes); 1294 } 1295 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 1296 } 1297 1298 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 1299 /// vectors. 1300 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 1301 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 1302 auto *VType = cast<FixedVectorType>(Ty); 1303 Type *EltTy = VType->getElementType(); 1304 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 1305 Value *EltPtr = Ptr; 1306 MatrixTy Result; 1307 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 1308 Value *GEP = computeVectorAddr( 1309 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 1310 Stride, Shape.getStride(), EltTy, Builder); 1311 Value *Vector = Builder.CreateAlignedLoad( 1312 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 1313 IsVolatile, "col.load"); 1314 1315 Result.addVector(Vector); 1316 } 1317 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 1318 Result.getNumVectors()); 1319 } 1320 1321 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 1322 /// starting at \p MatrixPtr[I][J]. 1323 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 1324 ShapeInfo MatrixShape, Value *I, Value *J, 1325 ShapeInfo ResultShape, Type *EltTy, 1326 IRBuilder<> &Builder) { 1327 Value *Offset = Builder.CreateAdd( 1328 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1329 1330 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1331 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1332 ResultShape.NumColumns); 1333 1334 return loadMatrix(TileTy, TileStart, Align, 1335 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1336 ResultShape, Builder); 1337 } 1338 1339 /// Lower a load instruction with shape information. 1340 MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, 1341 Value *Stride, bool IsVolatile, ShapeInfo Shape, 1342 IRBuilder<> &Builder) { 1343 return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape, 1344 Builder); 1345 } 1346 1347 /// Lowers llvm.matrix.column.major.load. 1348 /// 1349 /// The intrinsic loads a matrix from memory using a stride between columns. 1350 MatrixTy LowerColumnMajorLoad(CallInst *Inst, IRBuilder<> &Builder) { 1351 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1352 "Intrinsic only supports column-major layout!"); 1353 Value *Ptr = Inst->getArgOperand(0); 1354 Value *Stride = Inst->getArgOperand(1); 1355 return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1356 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1357 {Inst->getArgOperand(3), Inst->getArgOperand(4)}, Builder); 1358 } 1359 1360 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1361 /// MatrixPtr[I][J]. 1362 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1363 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1364 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1365 Value *Offset = Builder.CreateAdd( 1366 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1367 1368 Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); 1369 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1370 StoreVal.getNumColumns()); 1371 1372 storeMatrix(TileTy, StoreVal, TileStart, MAlign, 1373 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1374 } 1375 1376 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1377 /// vectors. 1378 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1379 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1380 IRBuilder<> &Builder) { 1381 auto *VType = cast<FixedVectorType>(Ty); 1382 Value *EltPtr = Ptr; 1383 for (auto Vec : enumerate(StoreVal.vectors())) { 1384 Value *GEP = computeVectorAddr( 1385 EltPtr, 1386 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1387 Vec.index()), 1388 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1389 Builder.CreateAlignedStore(Vec.value(), GEP, 1390 getAlignForIndex(Vec.index(), Stride, 1391 VType->getElementType(), 1392 MAlign), 1393 IsVolatile); 1394 } 1395 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1396 StoreVal.getNumVectors()); 1397 } 1398 1399 /// Lower a store instruction with shape information. 1400 MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, 1401 MaybeAlign A, Value *Stride, bool IsVolatile, 1402 ShapeInfo Shape, IRBuilder<> &Builder) { 1403 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1404 return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile, 1405 Builder); 1406 } 1407 1408 /// Lowers llvm.matrix.column.major.store. 1409 /// 1410 /// The intrinsic store a matrix back memory using a stride between columns. 1411 MatrixTy LowerColumnMajorStore(CallInst *Inst, IRBuilder<> &Builder) { 1412 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1413 "Intrinsic only supports column-major layout!"); 1414 Value *Matrix = Inst->getArgOperand(0); 1415 Value *Ptr = Inst->getArgOperand(1); 1416 Value *Stride = Inst->getArgOperand(2); 1417 return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1418 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1419 {Inst->getArgOperand(4), Inst->getArgOperand(5)}, 1420 Builder); 1421 } 1422 1423 // Set elements I..I+NumElts-1 to Block 1424 Value *insertVector(Value *Col, unsigned I, Value *Block, 1425 IRBuilder<> &Builder) { 1426 1427 // First, bring Block to the same size as Col 1428 unsigned BlockNumElts = 1429 cast<FixedVectorType>(Block->getType())->getNumElements(); 1430 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1431 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1432 1433 Block = Builder.CreateShuffleVector( 1434 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1435 1436 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1437 // 8, 4, 5, 6 1438 SmallVector<int, 16> Mask; 1439 unsigned i; 1440 for (i = 0; i < I; i++) 1441 Mask.push_back(i); 1442 1443 unsigned VecNumElts = 1444 cast<FixedVectorType>(Col->getType())->getNumElements(); 1445 for (; i < I + BlockNumElts; i++) 1446 Mask.push_back(i - I + VecNumElts); 1447 1448 for (; i < VecNumElts; i++) 1449 Mask.push_back(i); 1450 1451 return Builder.CreateShuffleVector(Col, Block, Mask); 1452 } 1453 1454 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1455 IRBuilder<> &Builder, bool AllowContraction, 1456 unsigned &NumComputeOps) { 1457 NumComputeOps += getNumOps(A->getType()); 1458 if (!Sum) 1459 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1460 1461 if (UseFPOp) { 1462 if (AllowContraction) { 1463 // Use fmuladd for floating point operations and let the backend decide 1464 // if that's profitable. 1465 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(), 1466 {A, B, Sum}); 1467 } 1468 NumComputeOps += getNumOps(A->getType()); 1469 Value *Mul = Builder.CreateFMul(A, B); 1470 return Builder.CreateFAdd(Sum, Mul); 1471 } 1472 1473 NumComputeOps += getNumOps(A->getType()); 1474 Value *Mul = Builder.CreateMul(A, B); 1475 return Builder.CreateAdd(Sum, Mul); 1476 } 1477 1478 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1479 /// users with shape information, there's nothing to do: they will use the 1480 /// cached value when they are lowered. For other users, \p Matrix is 1481 /// flattened and the uses are updated to use it. Also marks \p Inst for 1482 /// deletion. 1483 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1484 IRBuilder<> &Builder) { 1485 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1486 (void)inserted; 1487 assert((inserted.second || isa<PHINode>(Inst)) && 1488 "multiple matrix lowering mapping"); 1489 1490 ToRemove.push_back(Inst); 1491 Value *Flattened = nullptr; 1492 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1493 if (ShapeMap.contains(U.getUser())) 1494 continue; 1495 1496 if (!Flattened) { 1497 Flattened = Matrix.embedInVector(Builder); 1498 LLVM_DEBUG( 1499 if (Instruction *User = dyn_cast<Instruction>(U.getUser())) dbgs() 1500 << "flattening a " << Matrix.shape() << " matrix:\n" 1501 << *Inst 1502 << "\nbecause we do not have a shape-aware lowering for its " 1503 "user:\n" 1504 << *User << '\n';); 1505 FlattenedMatrices++; 1506 } 1507 U.set(Flattened); 1508 } 1509 } 1510 1511 /// Special case for MatMul lowering. Prevents scalar loads of row-major 1512 /// vectors Lowers to vector reduction add instead of sequential add if 1513 /// reassocation is enabled. 1514 void lowerDotProduct(CallInst *MatMul, 1515 SmallPtrSet<Instruction *, 16> &FusedInsts, 1516 FastMathFlags FMF) { 1517 if (FusedInsts.contains(MatMul) || 1518 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1519 return; 1520 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1521 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1522 1523 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product 1524 return; 1525 1526 Value *LHS = MatMul->getArgOperand(0); 1527 Value *RHS = MatMul->getArgOperand(1); 1528 1529 Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType(); 1530 bool IsIntVec = ElementType->isIntegerTy(); 1531 1532 // Floating point reductions require reassocation. 1533 if (!IsIntVec && !FMF.allowReassoc()) 1534 return; 1535 1536 auto CanBeFlattened = [](Value *Op) { 1537 if (match(Op, m_BinOp())) 1538 return true; 1539 return match( 1540 Op, m_OneUse(m_CombineOr( 1541 m_Load(m_Value()), 1542 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), 1543 m_Intrinsic<Intrinsic::matrix_column_major_load>( 1544 m_Value(), m_SpecificInt(1)))))); 1545 }; 1546 // Returns the cost benefit of using \p Op with the dot product lowering. If 1547 // the returned cost is < 0, the argument is cheaper to use in the 1548 // dot-product lowering. 1549 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { 1550 if (!ShapeMap.contains(Op)) 1551 return InstructionCost::getInvalid(); 1552 1553 if (!isa<Instruction>(Op)) 1554 return InstructionCost(0); 1555 1556 FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); 1557 Type *EltTy = VecTy->getElementType(); 1558 1559 if (!CanBeFlattened(Op)) { 1560 InstructionCost EmbedCost(0); 1561 // Roughly estimate the cost for embedding the columns into a vector. 1562 for (unsigned I = 1; I < N; ++I) 1563 EmbedCost += TTI.getShuffleCost( 1564 TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1565 FixedVectorType::get(EltTy, 1), {}, TTI::TCK_RecipThroughput); 1566 return EmbedCost; 1567 } 1568 1569 if (match(Op, m_BinOp()) && ShapeMap.contains(Op)) { 1570 InstructionCost OriginalCost = 1571 TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), 1572 EltTy) * 1573 N; 1574 InstructionCost NewCost = TTI.getArithmeticInstrCost( 1575 cast<Instruction>(Op)->getOpcode(), VecTy); 1576 return NewCost - OriginalCost; 1577 } 1578 1579 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { 1580 // The transpose can be skipped for the dot product lowering, roughly 1581 // estimate the savings as the cost of embedding the columns in a 1582 // vector. 1583 InstructionCost EmbedCost(0); 1584 for (unsigned I = 1; I < N; ++I) 1585 EmbedCost -= TTI.getShuffleCost( 1586 TTI::SK_Splice, FixedVectorType::get(EltTy, 1), 1587 FixedVectorType::get(EltTy, 1), {}, TTI::TCK_RecipThroughput); 1588 return EmbedCost; 1589 } 1590 1591 // Costs for loads. 1592 if (N == 1) 1593 return InstructionCost(0); 1594 1595 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - 1596 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); 1597 }; 1598 1599 // Iterate over LHS and operations feeding LHS and check if it is profitable 1600 // to flatten the visited ops. For each op, we compute the difference 1601 // between the flattened and matrix versions. 1602 SmallPtrSet<Value *, 4> Seen; 1603 SmallVector<Value *> WorkList; 1604 SmallVector<Value *> ToFlatten; 1605 WorkList.push_back(LHS); 1606 InstructionCost LHSCost(0); 1607 while (!WorkList.empty()) { 1608 Value *Op = WorkList.pop_back_val(); 1609 if (!Seen.insert(Op).second) 1610 continue; 1611 1612 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns); 1613 if (OpCost + LHSCost >= LHSCost) 1614 continue; 1615 1616 LHSCost += OpCost; 1617 ToFlatten.push_back(Op); 1618 if (auto *I = dyn_cast<Instruction>(Op)) 1619 WorkList.append(I->op_begin(), I->op_end()); 1620 } 1621 1622 // We compare the costs of a vector.reduce.add to sequential add. 1623 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; 1624 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; 1625 InstructionCost ReductionCost = 1626 TTI.getArithmeticReductionCost( 1627 AddOpCode, cast<FixedVectorType>(LHS->getType()), 1628 IsIntVec ? std::nullopt : std::optional(FMF)) + 1629 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); 1630 InstructionCost SequentialAddCost = 1631 TTI.getArithmeticInstrCost(AddOpCode, ElementType) * 1632 (LShape.NumColumns - 1) + 1633 TTI.getArithmeticInstrCost(MulOpCode, ElementType) * 1634 (LShape.NumColumns); 1635 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) 1636 return; 1637 1638 FusedInsts.insert(MatMul); 1639 IRBuilder<> Builder(MatMul); 1640 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, 1641 this](Value *Op) { 1642 // Matmul must be the only user of loads because we don't use LowerLoad 1643 // for row vectors (LowerLoad results in scalar loads and shufflevectors 1644 // instead of single vector load). 1645 if (!CanBeFlattened(Op)) 1646 return; 1647 1648 if (match(Op, m_BinOp())) { 1649 auto It = ShapeMap.find(Op); 1650 if (It != ShapeMap.end()) { 1651 It->second = It->second.t(); 1652 return; 1653 } 1654 } 1655 1656 FusedInsts.insert(cast<Instruction>(Op)); 1657 // If vector uses the builtin load, lower to a LoadInst 1658 Value *Arg; 1659 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( 1660 m_Value(Arg)))) { 1661 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); 1662 Op->replaceAllUsesWith(NewLoad); 1663 eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op)); 1664 return; 1665 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( 1666 m_Value(Arg)))) { 1667 ToRemove.push_back(cast<Instruction>(Op)); 1668 Op->replaceAllUsesWith(Arg); 1669 return; 1670 } 1671 }; 1672 1673 for (auto *V : ToFlatten) 1674 FlattenArg(V); 1675 1676 LHS = MatMul->getArgOperand(0); 1677 1678 // Insert mul/fmul and llvm.vector.reduce.fadd 1679 Value *Mul = 1680 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); 1681 1682 Value *Result; 1683 if (IsIntVec) 1684 Result = Builder.CreateAddReduce(Mul); 1685 else { 1686 Result = Builder.CreateFAddReduce( 1687 ConstantFP::get( 1688 cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0), 1689 Mul); 1690 cast<Instruction>(Result)->setFastMathFlags(FMF); 1691 } 1692 1693 // pack scalar back into a matrix and then replace matmul inst 1694 Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), 1695 Result, uint64_t(0)); 1696 MatMul->replaceAllUsesWith(Result); 1697 FusedInsts.insert(MatMul); 1698 ToRemove.push_back(MatMul); 1699 } 1700 1701 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1702 /// addition. 1703 /// 1704 /// We can fold a transpose into the operand that is used to extract scalars. 1705 /// This is the first operands with row-major and the second with 1706 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1707 /// operand is transposed. 1708 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1709 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1710 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1711 const unsigned VF = std::max<unsigned>( 1712 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1713 .getFixedValue() / 1714 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1715 1U); 1716 unsigned R = Result.getNumRows(); 1717 unsigned C = Result.getNumColumns(); 1718 unsigned M = A.getNumColumns(); 1719 1720 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1721 assert(A.isColumnMajor() == B.isColumnMajor() && 1722 Result.isColumnMajor() == A.isColumnMajor() && 1723 "operands must agree on matrix layout"); 1724 unsigned NumComputeOps = 0; 1725 1726 Builder.setFastMathFlags(FMF); 1727 1728 if (A.isColumnMajor()) { 1729 // Multiply columns from the first operand with scalars from the second 1730 // operand. Then move along the K axes and accumulate the columns. With 1731 // this the adds can be vectorized without reassociation. 1732 for (unsigned J = 0; J < C; ++J) { 1733 unsigned BlockSize = VF; 1734 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1735 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1736 1737 for (unsigned I = 0; I < R; I += BlockSize) { 1738 // Gradually lower the vectorization factor to cover the remainder. 1739 while (I + BlockSize > R) 1740 BlockSize /= 2; 1741 1742 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1743 : nullptr; 1744 for (unsigned K = 0; K < M; ++K) { 1745 Value *L = A.extractVector(I, K, BlockSize, Builder); 1746 Value *RH = Builder.CreateExtractElement( 1747 B.getColumn(IsScalarMatrixTransposed ? K : J), 1748 IsScalarMatrixTransposed ? J : K); 1749 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1750 Sum = 1751 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1752 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1753 } 1754 Result.setVector(J, 1755 insertVector(Result.getVector(J), I, Sum, Builder)); 1756 } 1757 } 1758 } else { 1759 // Multiply rows from the second operand with scalars from the first 1760 // operand. Then move along the K axes and accumulate the rows. With this 1761 // the adds can be vectorized without reassociation. 1762 for (unsigned I = 0; I < R; ++I) { 1763 unsigned BlockSize = VF; 1764 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1765 for (unsigned J = 0; J < C; J += BlockSize) { 1766 // Gradually lower the vectorization factor to cover the remainder. 1767 while (J + BlockSize > C) 1768 BlockSize /= 2; 1769 1770 Value *Sum = nullptr; 1771 for (unsigned K = 0; K < M; ++K) { 1772 Value *R = B.extractVector(K, J, BlockSize, Builder); 1773 Value *LH = Builder.CreateExtractElement( 1774 A.getVector(IsScalarMatrixTransposed ? K : I), 1775 IsScalarMatrixTransposed ? I : K); 1776 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1777 Sum = 1778 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1779 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1780 } 1781 Result.setVector(I, 1782 insertVector(Result.getVector(I), J, Sum, Builder)); 1783 } 1784 } 1785 } 1786 Result.addNumComputeOps(NumComputeOps); 1787 } 1788 1789 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1790 /// copying it to a new location. This new or otherwise the original location 1791 /// is returned. 1792 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1793 CallInst *MatMul) { 1794 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1795 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1796 1797 // If we can statically determine noalias we're good. 1798 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1799 return Load->getPointerOperand(); 1800 1801 // Create code to check if the memory locations of the Load and Store 1802 // overlap and if they do, copy Load's operand to a new buffer. 1803 1804 // First, create new blocks for 2n part of the check and the copy. 1805 BasicBlock *Check0 = MatMul->getParent(); 1806 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1807 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1808 // as we adjust Check0 and Check1's branches. 1809 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1810 for (BasicBlock *Succ : successors(Check0)) 1811 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1812 1813 BasicBlock *Check1 = 1814 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1815 nullptr, "alias_cont"); 1816 BasicBlock *Copy = 1817 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1818 nullptr, "copy"); 1819 BasicBlock *Fusion = 1820 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1821 nullptr, "no_alias"); 1822 1823 // Check if the loaded memory location begins before the end of the store 1824 // location. If the condition holds, they might overlap, otherwise they are 1825 // guaranteed to not overlap. 1826 IRBuilder<> Builder(MatMul); 1827 Check0->getTerminator()->eraseFromParent(); 1828 Builder.SetInsertPoint(Check0); 1829 Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout()); 1830 Value *StoreBegin = Builder.CreatePtrToInt( 1831 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1832 Value *StoreEnd = Builder.CreateAdd( 1833 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1834 "store.end", true, true); 1835 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1836 IntPtrTy, "load.begin"); 1837 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1838 Fusion); 1839 1840 // Check if the store begins before the end of the load location. If the 1841 // condition holds, they alias, otherwise they are guaranteed to not 1842 // overlap. 1843 Check1->getTerminator()->eraseFromParent(); 1844 Builder.SetInsertPoint(Check1, Check1->begin()); 1845 Value *LoadEnd = Builder.CreateAdd( 1846 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1847 "load.end", true, true); 1848 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1849 Fusion); 1850 1851 // Copy load operand to new alloca. 1852 Builder.SetInsertPoint(Copy, Copy->begin()); 1853 auto *VT = cast<FixedVectorType>(Load->getType()); 1854 // Use an array type for the alloca, to avoid potentially huge alignment 1855 // requirements for large vector types. 1856 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1857 AllocaInst *Alloca = 1858 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1859 1860 Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), 1861 Load->getAlign(), LoadLoc.Size.getValue()); 1862 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1863 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1864 PHI->addIncoming(Load->getPointerOperand(), Check0); 1865 PHI->addIncoming(Load->getPointerOperand(), Check1); 1866 PHI->addIncoming(Alloca, Copy); 1867 1868 // Adjust DT. 1869 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1870 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1871 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1872 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1873 DT->applyUpdates(DTUpdates); 1874 return PHI; 1875 } 1876 1877 bool isFusionProfitable(CallInst *MatMul) { 1878 if (ForceFusion) 1879 return true; 1880 1881 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1882 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1883 1884 const unsigned R = LShape.NumRows; 1885 const unsigned C = RShape.NumColumns; 1886 const unsigned M = LShape.NumColumns; 1887 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType(); 1888 1889 const unsigned VF = std::max<unsigned>( 1890 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1891 .getFixedValue() / 1892 EltType->getPrimitiveSizeInBits().getFixedValue(), 1893 1U); 1894 1895 // Cost model for tiling 1896 // 1897 // For tiling to be beneficial, we need reuse either along the R or 1898 // the C axis. We vectorize along the R axis so that means at least 1899 // 3 elements. 1900 // TODO: Also consider cost of copying if operands alias. 1901 if (R <= VF && C == 1) 1902 return false; 1903 // Then we need enough elements to exceed the number of vector 1904 // registers we have. Note that this is an oversimplification since 1905 // fusing also takes some extra loads which may exceed the number of 1906 // reloads necessary. 1907 unsigned Op0Regs = (R + VF - 1) / VF * M; 1908 unsigned Op1Regs = (M + VF - 1) / VF * C; 1909 return Op0Regs + Op1Regs > 1910 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1911 } 1912 1913 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1914 MatrixTy Res; 1915 auto *ColumType = FixedVectorType::get(EltType, R); 1916 for (unsigned I = 0; I < C; ++I) 1917 Res.addVector(ConstantAggregateZero::get(ColumType)); 1918 return Res; 1919 } 1920 1921 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1922 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1923 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType(); 1924 1925 // Create the main tiling loop nest. 1926 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1927 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1928 Instruction *InsertI = cast<Instruction>(MatMul); 1929 BasicBlock *Start = InsertI->getParent(); 1930 BasicBlock *End = 1931 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1932 IRBuilder<> Builder(MatMul); 1933 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1934 1935 Type *TileVecTy = 1936 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1937 MatrixTy TileResult; 1938 // Insert in the inner loop header. 1939 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1940 // Create PHI nodes for the result columns to accumulate across iterations. 1941 SmallVector<PHINode *, 4> ColumnPhis; 1942 for (unsigned I = 0; I < TileSize; I++) { 1943 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1944 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1945 TI.RowLoop.Header->getSingleSuccessor()); 1946 TileResult.addVector(Phi); 1947 ColumnPhis.push_back(Phi); 1948 } 1949 1950 // Insert in the inner loop body, which computes 1951 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1952 Builder.SetInsertPoint(InnerBody->getTerminator()); 1953 // Load tiles of the operands. 1954 MatrixTy A = 1955 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1956 {TileSize, TileSize}, EltType, Builder); 1957 MatrixTy B = 1958 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1959 {TileSize, TileSize}, EltType, Builder); 1960 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1961 getFastMathFlags(MatMul)); 1962 // Store result after the inner loop is done. 1963 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1964 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1965 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1966 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1967 1968 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1969 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1970 1971 // Force unrolling of a few iterations of the inner loop, to make sure there 1972 // is enough work per iteration. 1973 // FIXME: The unroller should make this decision directly instead, but 1974 // currently the cost-model is not up to the task. 1975 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1976 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1977 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1978 } 1979 1980 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1981 StoreInst *Store, 1982 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1983 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1984 "Tiling only supported for column-major matrixes at the moment!"); 1985 if (!isFusionProfitable(MatMul)) 1986 return; 1987 1988 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1989 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1990 1991 const unsigned R = LShape.NumRows; 1992 const unsigned C = RShape.NumColumns; 1993 const unsigned M = LShape.NumColumns; 1994 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType(); 1995 1996 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1997 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1998 Value *CPtr = Store->getPointerOperand(); 1999 2000 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 2001 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 2002 else { 2003 IRBuilder<> Builder(Store); 2004 for (unsigned J = 0; J < C; J += TileSize) 2005 for (unsigned I = 0; I < R; I += TileSize) { 2006 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 2007 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 2008 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 2009 2010 for (unsigned K = 0; K < M; K += TileSize) { 2011 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 2012 MatrixTy A = 2013 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 2014 LShape, Builder.getInt64(I), Builder.getInt64(K), 2015 {TileR, TileM}, EltType, Builder); 2016 MatrixTy B = 2017 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 2018 RShape, Builder.getInt64(K), Builder.getInt64(J), 2019 {TileM, TileC}, EltType, Builder); 2020 emitMatrixMultiply(Res, A, B, Builder, true, false, 2021 getFastMathFlags(MatMul)); 2022 } 2023 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 2024 Builder.getInt64(I), Builder.getInt64(J), EltType, 2025 Builder); 2026 } 2027 } 2028 2029 // Mark eliminated instructions as fused and remove them. 2030 FusedInsts.insert(Store); 2031 FusedInsts.insert(MatMul); 2032 eraseFromParentAndRemoveFromShapeMap(Store); 2033 eraseFromParentAndRemoveFromShapeMap(MatMul); 2034 if (LoadOp0->use_empty()) { 2035 FusedInsts.insert(LoadOp0); 2036 eraseFromParentAndRemoveFromShapeMap(LoadOp0); 2037 } 2038 if (LoadOp1 != LoadOp0 && LoadOp1->use_empty()) { 2039 FusedInsts.insert(LoadOp1); 2040 eraseFromParentAndRemoveFromShapeMap(LoadOp1); 2041 } 2042 } 2043 2044 /// Try to lower matrix multiply chains by fusing operations. 2045 /// 2046 /// Call finalizeLowering on lowered instructions. Instructions that are 2047 /// completely eliminated by fusion are added to \p FusedInsts. 2048 void 2049 LowerMatrixMultiplyFused(CallInst *MatMul, 2050 SmallPtrSetImpl<Instruction *> &FusedInsts, 2051 SmallVector<IntrinsicInst *, 16> &LifetimeEnds) { 2052 if (!FuseMatrix || !DT) 2053 return; 2054 2055 assert(AA && LI && "Analyses should be available"); 2056 2057 Value *A = MatMul->getArgOperand(0); 2058 Value *B = MatMul->getArgOperand(1); 2059 2060 // We can fold the transpose into the operand that is used to fetch scalars. 2061 Value *T; 2062 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 2063 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 2064 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 2065 IRBuilder<> Builder(MatMul); 2066 auto *EltType = 2067 cast<FixedVectorType>(MatMul->getType())->getElementType(); 2068 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 2069 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 2070 const unsigned R = LShape.NumRows; 2071 const unsigned M = LShape.NumColumns; 2072 const unsigned C = RShape.NumColumns; 2073 2074 MatrixTy MA; 2075 MatrixTy MB; 2076 2077 Value *Transpose; 2078 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 2079 MA = getMatrix(A, ShapeInfo(R, M), Builder); 2080 MB = getMatrix(T, ShapeInfo(C, M), Builder); 2081 Transpose = B; 2082 } else { 2083 MA = getMatrix(T, ShapeInfo(R, M), Builder); 2084 MB = getMatrix(B, ShapeInfo(C, M), Builder); 2085 Transpose = A; 2086 } 2087 2088 // Initialize the output 2089 MatrixTy Result(R, C, EltType); 2090 2091 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 2092 getFastMathFlags(MatMul)); 2093 2094 FusedInsts.insert(MatMul); 2095 if (Transpose->hasOneUse()) { 2096 FusedInsts.insert(cast<Instruction>(Transpose)); 2097 ToRemove.push_back(cast<Instruction>(Transpose)); 2098 // TODO: add a fake entry for the folded instruction so that this is 2099 // included in the expression in the remark. 2100 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 2101 } 2102 finalizeLowering(MatMul, Result, Builder); 2103 return; 2104 } 2105 2106 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 2107 return; 2108 2109 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 2110 // since the single store user will be lowered as part of this. 2111 auto *LoadOp0 = dyn_cast<LoadInst>(A); 2112 auto *LoadOp1 = dyn_cast<LoadInst>(B); 2113 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 2114 if (LoadOp0 && LoadOp1 && Store) { 2115 // The store address must dominate the MatMul instruction, otherwise 2116 // we create invalid IR. 2117 SetVector<Value *> WorkList; 2118 WorkList.insert(Store->getOperand(1)); 2119 SmallVector<Instruction *> ToHoist; 2120 for (unsigned I = 0; I != WorkList.size(); ++I) { 2121 Value *Current = WorkList[I]; 2122 auto *CurrI = dyn_cast<Instruction>(Current); 2123 if (!CurrI) 2124 continue; 2125 if (isa<PHINode>(CurrI)) 2126 return; 2127 if (DT->dominates(CurrI, MatMul)) 2128 continue; 2129 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 2130 return; 2131 ToHoist.push_back(CurrI); 2132 WorkList.insert_range(CurrI->operands()); 2133 } 2134 2135 sort(ToHoist, [this](Instruction *A, Instruction *B) { 2136 return DT->dominates(A, B); 2137 }); 2138 for (Instruction *I : ToHoist) 2139 I->moveBefore(MatMul->getIterator()); 2140 2141 // Deal with lifetime.end calls that might be between Load0/Load1 and the 2142 // store. To avoid introducing loads to dead objects (i.e. after the 2143 // lifetime has been termined by @llvm.lifetime.end), either sink them 2144 // after the store if in the same block, or remove the lifetime.end marker 2145 // otherwise. This might pessimize further optimizations, by extending the 2146 // lifetime of the object until the function returns, but should be 2147 // conservatively correct. 2148 MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0); 2149 MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1); 2150 BasicBlock *StoreParent = Store->getParent(); 2151 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent && 2152 LoadOp1->getParent() == StoreParent; 2153 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) { 2154 IntrinsicInst *End = LifetimeEnds[Idx]; 2155 auto Inc = make_scope_exit([&Idx]() { Idx++; }); 2156 // If the lifetime.end is guaranteed to be before the loads or after the 2157 // store, it won't interfere with fusion. 2158 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1)) 2159 continue; 2160 if (DT->dominates(Store, End)) 2161 continue; 2162 // If all fusable ops are in the same block and the lifetime.end is in a 2163 // different block, it won't interfere with fusion. 2164 if (FusableOpsInSameBlock && End->getParent() != StoreParent) 2165 continue; 2166 2167 // If the loads don't alias the lifetime.end, it won't interfere with 2168 // fusion. 2169 MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr); 2170 if (!EndLoc.Ptr) 2171 continue; 2172 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc)) 2173 continue; 2174 2175 // If both lifetime.end and the store are in the same block, extend the 2176 // lifetime until after the store, so the new lifetime covers the loads 2177 // we introduce later. 2178 if (End->getParent() == StoreParent) { 2179 End->moveAfter(Store); 2180 continue; 2181 } 2182 2183 // Otherwise remove the conflicting lifetime.end marker. 2184 ToRemove.push_back(End); 2185 std::swap(LifetimeEnds[Idx], LifetimeEnds.back()); 2186 LifetimeEnds.pop_back(); 2187 Inc.release(); 2188 } 2189 2190 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 2191 return; 2192 } 2193 } 2194 2195 /// Lowers llvm.matrix.multiply. 2196 MatrixTy LowerMultiply(CallInst *MatMul, IRBuilder<> &Builder) { 2197 auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType(); 2198 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 2199 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 2200 2201 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 2202 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 2203 assert(Lhs.getElementType() == Rhs.getElementType() && 2204 "Matrix multiply argument element types do not match."); 2205 2206 const unsigned R = LShape.NumRows; 2207 const unsigned C = RShape.NumColumns; 2208 assert(LShape.NumColumns == RShape.NumRows); 2209 2210 // Initialize the output 2211 MatrixTy Result(R, C, EltType); 2212 assert(Lhs.getElementType() == Result.getElementType() && 2213 "Matrix multiply result element type does not match arguments."); 2214 2215 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 2216 getFastMathFlags(MatMul)); 2217 return Result; 2218 } 2219 2220 /// Lowers llvm.matrix.transpose. 2221 MatrixTy LowerTranspose(CallInst *Inst, IRBuilder<> &Builder) { 2222 MatrixTy Result; 2223 Value *InputVal = Inst->getArgOperand(0); 2224 FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType()); 2225 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 2226 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 2227 2228 const unsigned NewNumVecs = 2229 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 2230 const unsigned NewNumElts = 2231 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 2232 2233 for (unsigned I = 0; I < NewNumVecs; ++I) { 2234 // Build a single result vector. First initialize it. 2235 Value *ResultVector = PoisonValue::get( 2236 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 2237 // Go through the old elements and insert it into the resulting vector. 2238 for (auto J : enumerate(InputMatrix.vectors())) { 2239 Value *Elt = Builder.CreateExtractElement(J.value(), I); 2240 // Row and column indices are transposed. 2241 ResultVector = 2242 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 2243 } 2244 Result.addVector(ResultVector); 2245 } 2246 2247 // TODO: Improve estimate of operations needed for transposes. Currently we 2248 // just count the insertelement/extractelement instructions, but do not 2249 // account for later simplifications/combines. 2250 return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 2251 .addNumExposedTransposes(1); 2252 } 2253 2254 /// Lower load instructions. 2255 MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, 2256 IRBuilder<> &Builder) { 2257 return LowerLoad(Inst, Ptr, Inst->getAlign(), 2258 Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, 2259 Builder); 2260 } 2261 2262 MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, 2263 Value *Ptr, IRBuilder<> &Builder) { 2264 return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 2265 Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, 2266 Builder); 2267 } 2268 2269 MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) { 2270 auto BlockIP = Inst->getParent()->getFirstInsertionPt(); 2271 Builder.SetInsertPoint(BlockIP); 2272 MatrixTy PhiM = getMatrix(Inst, SI, Builder); 2273 2274 for (auto [IncomingV, IncomingB] : 2275 llvm::zip_equal(Inst->incoming_values(), Inst->blocks())) { 2276 // getMatrix() may insert some instructions to help with reshaping. The 2277 // safest place for those is at the top of the block after the rest of the 2278 // PHI's. Even better, if we can put it in the incoming block. 2279 Builder.SetInsertPoint(BlockIP); 2280 if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV)) 2281 if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef()) 2282 Builder.SetInsertPoint(*MaybeIP); 2283 2284 MatrixTy OpM = getMatrix(IncomingV, SI, Builder); 2285 2286 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) { 2287 PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI)); 2288 NewPHI->addIncoming(OpM.getVector(VI), IncomingB); 2289 } 2290 } 2291 2292 // finalizeLowering() may also insert instructions in some cases. The safe 2293 // place for those is at the end of the initial block of PHIs. 2294 Builder.SetInsertPoint(BlockIP); 2295 return PhiM; 2296 } 2297 2298 /// Lower binary operators. 2299 MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI, 2300 IRBuilder<> &Builder) { 2301 Value *Lhs = Inst->getOperand(0); 2302 Value *Rhs = Inst->getOperand(1); 2303 2304 MatrixTy Result; 2305 MatrixTy A = getMatrix(Lhs, SI, Builder); 2306 MatrixTy B = getMatrix(Rhs, SI, Builder); 2307 assert(A.isColumnMajor() == B.isColumnMajor() && 2308 Result.isColumnMajor() == A.isColumnMajor() && 2309 "operands must agree on matrix layout"); 2310 2311 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2312 2313 for (auto [AV, BV] : llvm::zip_equal(A.vectors(), B.vectors())) 2314 Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), AV, BV)); 2315 2316 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2317 Result.getNumVectors()); 2318 } 2319 2320 /// Lower unary operators. 2321 MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI, 2322 IRBuilder<> &Builder) { 2323 Value *Op = Inst->getOperand(0); 2324 2325 MatrixTy Result; 2326 MatrixTy M = getMatrix(Op, SI, Builder); 2327 2328 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2329 2330 // Helper to perform unary op on vectors. 2331 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 2332 switch (Inst->getOpcode()) { 2333 case Instruction::FNeg: 2334 return Builder.CreateFNeg(Op); 2335 default: 2336 llvm_unreachable("Unsupported unary operator for matrix"); 2337 } 2338 }; 2339 2340 for (auto *Vector : M.vectors()) 2341 Result.addVector(BuildVectorOp(Vector)); 2342 2343 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2344 Result.getNumVectors()); 2345 } 2346 2347 /// Lower cast instructions. 2348 MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape, 2349 IRBuilder<> &Builder) { 2350 Value *Op = Inst->getOperand(0); 2351 2352 MatrixTy Result; 2353 MatrixTy M = getMatrix(Op, Shape, Builder); 2354 2355 Builder.setFastMathFlags(getFastMathFlags(Inst)); 2356 2357 auto *OrigVTy = cast<VectorType>(Inst->getType()); 2358 auto *NewVTy = VectorType::get(OrigVTy->getElementType(), 2359 ElementCount::getFixed(M.getStride())); 2360 2361 for (auto *Vector : M.vectors()) 2362 Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy)); 2363 2364 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2365 Result.getNumVectors()); 2366 } 2367 2368 /// Lower selects. 2369 MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape, 2370 IRBuilder<> &Builder) { 2371 Value *Cond = Inst->getOperand(0); 2372 Value *OpA = Inst->getOperand(1); 2373 Value *OpB = Inst->getOperand(2); 2374 2375 MatrixTy Result; 2376 MatrixTy A = getMatrix(OpA, Shape, Builder); 2377 MatrixTy B = getMatrix(OpB, Shape, Builder); 2378 2379 SmallVector<Value*> CondV; 2380 if (isa<FixedVectorType>(Cond->getType())) { 2381 MatrixTy C = getMatrix(Cond, Shape, Builder); 2382 llvm::copy(C.vectors(), std::back_inserter(CondV)); 2383 } else { 2384 CondV.resize(A.getNumVectors()); 2385 llvm::fill(CondV, Cond); 2386 } 2387 2388 for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors())) 2389 Result.addVector(Builder.CreateSelect(CV, AV, BV)); 2390 2391 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 2392 Result.getNumVectors()); 2393 } 2394 2395 /// Helper to linearize a matrix expression tree into a string. Currently 2396 /// matrix expressions are linarized by starting at an expression leaf and 2397 /// linearizing bottom up. 2398 struct ExprLinearizer { 2399 unsigned LengthToBreak = 100; 2400 std::string Str; 2401 raw_string_ostream Stream; 2402 unsigned LineLength = 0; 2403 const DataLayout &DL; 2404 2405 /// Mapping from instructions to matrixes. It is used to identify 2406 /// matrix instructions. 2407 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2408 2409 /// Mapping from values to the leaves of all expressions that the value is 2410 /// part of. 2411 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 2412 2413 /// Set of matrix expressions in the scope of a given DISubprogram. 2414 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 2415 2416 /// Leaf node of the expression to linearize. 2417 Value *Leaf; 2418 2419 /// Used to keep track of sub-expressions that get reused while linearizing 2420 /// the expression. Re-used sub-expressions are marked as (reused). 2421 SmallPtrSet<Value *, 8> ReusedExprs; 2422 2423 ExprLinearizer(const DataLayout &DL, 2424 const MapVector<Value *, MatrixTy> &Inst2Matrix, 2425 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2426 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2427 Value *Leaf) 2428 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 2429 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 2430 2431 void indent(unsigned N) { 2432 LineLength += N; 2433 for (unsigned i = 0; i < N; i++) 2434 Stream << " "; 2435 } 2436 2437 void lineBreak() { 2438 Stream << "\n"; 2439 LineLength = 0; 2440 } 2441 2442 void maybeIndent(unsigned Indent) { 2443 if (LineLength >= LengthToBreak) 2444 lineBreak(); 2445 2446 if (LineLength == 0) 2447 indent(Indent); 2448 } 2449 2450 void write(StringRef S) { 2451 LineLength += S.size(); 2452 Stream << S; 2453 } 2454 2455 Value *getUnderlyingObjectThroughLoads(Value *V) { 2456 if (Value *Ptr = getPointerOperand(V)) 2457 return getUnderlyingObjectThroughLoads(Ptr); 2458 else if (V->getType()->isPointerTy()) 2459 return getUnderlyingObject(V); 2460 return V; 2461 } 2462 2463 /// Returns true if \p V is a matrix value in the given subprogram. 2464 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 2465 2466 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to 2467 /// \p SS. 2468 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 2469 auto M = Inst2Matrix.find(V); 2470 if (M == Inst2Matrix.end()) 2471 SS << "unknown"; 2472 else { 2473 SS << M->second.getNumRows(); 2474 SS << "x"; 2475 SS << M->second.getNumColumns(); 2476 } 2477 } 2478 2479 /// Write the called function name. Handles calls to llvm.matrix.* 2480 /// specially: we write the name, followed by the dimensions of the input 2481 /// matrixes, followed by the scalar type name. 2482 void writeFnName(CallInst *CI) { 2483 if (!CI->getCalledFunction()) 2484 write("<no called fn>"); 2485 else { 2486 StringRef Name = CI->getCalledFunction()->getName(); 2487 if (!Name.starts_with("llvm.matrix")) { 2488 write(Name); 2489 return; 2490 } 2491 auto *II = cast<IntrinsicInst>(CI); 2492 write(Intrinsic::getBaseName(II->getIntrinsicID()) 2493 .drop_front(StringRef("llvm.matrix.").size())); 2494 write("."); 2495 std::string Tmp; 2496 raw_string_ostream SS(Tmp); 2497 2498 switch (II->getIntrinsicID()) { 2499 case Intrinsic::matrix_multiply: 2500 prettyPrintMatrixType(II->getOperand(0), SS); 2501 SS << "."; 2502 prettyPrintMatrixType(II->getOperand(1), SS); 2503 SS << "." << *II->getType()->getScalarType(); 2504 break; 2505 case Intrinsic::matrix_transpose: 2506 prettyPrintMatrixType(II->getOperand(0), SS); 2507 SS << "." << *II->getType()->getScalarType(); 2508 break; 2509 case Intrinsic::matrix_column_major_load: 2510 prettyPrintMatrixType(II, SS); 2511 SS << "." << *II->getType()->getScalarType(); 2512 break; 2513 case Intrinsic::matrix_column_major_store: 2514 prettyPrintMatrixType(II->getOperand(0), SS); 2515 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 2516 break; 2517 default: 2518 llvm_unreachable("Unhandled case"); 2519 } 2520 write(Tmp); 2521 } 2522 } 2523 2524 unsigned getNumShapeArgs(CallInst *CI) const { 2525 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 2526 switch (II->getIntrinsicID()) { 2527 case Intrinsic::matrix_multiply: 2528 return 3; 2529 case Intrinsic::matrix_transpose: 2530 return 2; 2531 case Intrinsic::matrix_column_major_load: 2532 case Intrinsic::matrix_column_major_store: 2533 return 3; 2534 default: 2535 return 0; 2536 } 2537 } 2538 return 0; 2539 } 2540 2541 /// Special printing for values: for pointers, we print if they refer to an 2542 /// (function) external address or a stack address, for other values we 2543 /// either print the constant or "scalar"/"matrix" for other values. 2544 void write(Value *V) { 2545 V = getUnderlyingObjectThroughLoads(V); 2546 if (V->getType()->isPointerTy()) { 2547 if (isa<AllocaInst>(V)) { 2548 Stream << "stack addr"; 2549 LineLength += StringRef("stack addr").size(); 2550 } else { 2551 Stream << "addr"; 2552 LineLength += StringRef("addr").size(); 2553 } 2554 if (!V->getName().empty()) { 2555 Stream << " %" << V->getName() << ""; 2556 LineLength += V->getName().size() + 2; 2557 } 2558 return; 2559 } 2560 2561 std::string Tmp; 2562 raw_string_ostream TmpStream(Tmp); 2563 2564 if (auto *CI = dyn_cast<ConstantInt>(V)) 2565 TmpStream << CI->getValue(); 2566 else if (isa<Constant>(V)) 2567 TmpStream << "constant"; 2568 else { 2569 if (isMatrix(V)) 2570 TmpStream << "matrix"; 2571 else 2572 TmpStream << "scalar"; 2573 } 2574 Tmp = std::string(StringRef(Tmp).trim()); 2575 LineLength += Tmp.size(); 2576 Stream << Tmp; 2577 } 2578 2579 /// Linearize expression \p Expr starting at an indentation of \p Indent. 2580 /// Expressions that are re-used multiple times are prefixed with (reused) 2581 /// at the re-used root instruction. 2582 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 2583 bool ParentShared) { 2584 auto *I = cast<Instruction>(Expr); 2585 maybeIndent(Indent); 2586 SmallVector<Value *, 8> Ops; 2587 2588 // Is Expr shared with other expression leaves? 2589 bool ExprShared = false; 2590 2591 // Deal with shared subtrees. Mark them as shared, if required. 2592 if (!ParentShared) { 2593 auto SI = Shared.find(Expr); 2594 assert(SI != Shared.end() && SI->second.count(Leaf)); 2595 2596 for (Value *S : SI->second) { 2597 if (S == Leaf) 2598 continue; 2599 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2600 write("shared with remark at line " + std::to_string(DL.getLine()) + 2601 " column " + std::to_string(DL.getCol()) + " ("); 2602 } 2603 ExprShared = SI->second.size() > 1; 2604 } 2605 2606 bool Reused = !ReusedExprs.insert(Expr).second; 2607 if (Reused && !ParentReused) 2608 write("(reused) "); 2609 2610 if (auto *CI = dyn_cast<CallInst>(I)) { 2611 writeFnName(CI); 2612 2613 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2614 } else if (isa<BitCastInst>(Expr)) { 2615 // Special case bitcasts, which are used to materialize matrixes from 2616 // non-matrix ops. 2617 write("matrix"); 2618 return; 2619 } else { 2620 Ops.append(I->value_op_begin(), I->value_op_end()); 2621 write(I->getOpcodeName()); 2622 } 2623 2624 write("("); 2625 2626 unsigned NumOpsToBreak = 1; 2627 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2628 NumOpsToBreak = 2; 2629 2630 for (Value *Op : Ops) { 2631 if (Ops.size() > NumOpsToBreak) 2632 lineBreak(); 2633 2634 maybeIndent(Indent + 1); 2635 if (isMatrix(Op)) 2636 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2637 else 2638 write(Op); 2639 if (Op != Ops.back()) 2640 write(", "); 2641 } 2642 2643 write(")"); 2644 } 2645 2646 const std::string &getResult() { 2647 return Str; 2648 } 2649 }; 2650 2651 /// Generate remarks for matrix operations in a function. To generate remarks 2652 /// for matrix expressions, the following approach is used: 2653 /// 1. Use the inlined-at debug information to group matrix operations to the 2654 /// DISubprograms they are contained in. 2655 /// 2. Collect leaves of matrix expressions (done in 2656 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2657 // mapping. Leaves are lowered matrix instructions without other matrix 2658 // users (like stores) in the current subprogram. 2659 /// 3. For each leaf, create a remark containing a linearizied version of the 2660 /// matrix expression. The expression is linearized by a recursive 2661 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2662 /// that multiple leaves can share sub-expressions. Shared subexpressions 2663 /// are explicitly marked as shared(). 2664 struct RemarkGenerator { 2665 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2666 OptimizationRemarkEmitter &ORE; 2667 Function &Func; 2668 const DataLayout &DL; 2669 2670 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2671 OptimizationRemarkEmitter &ORE, Function &Func) 2672 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2673 DL(Func.getDataLayout()) {} 2674 2675 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2676 /// instructions in Inst2Matrix returning void or without any users in 2677 /// \p ExprsInSubprogram. Currently that should only include stores. 2678 SmallVector<Value *, 4> 2679 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2680 SmallVector<Value *, 4> Leaves; 2681 for (auto *Expr : ExprsInSubprogram) 2682 if (Expr->getType()->isVoidTy() || 2683 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2684 return ExprsInSubprogram.count(U); 2685 })) 2686 Leaves.push_back(Expr); 2687 return Leaves; 2688 } 2689 2690 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2691 /// to all visited expressions in \p Shared. Limit the matrix operations to 2692 /// the ones in \p ExprsInSubprogram. 2693 void collectSharedInfo(Value *Leaf, Value *V, 2694 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2695 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2696 2697 if (!ExprsInSubprogram.count(V)) 2698 return; 2699 2700 Shared[V].insert(Leaf); 2701 2702 for (Value *Op : cast<Instruction>(V)->operand_values()) 2703 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2704 } 2705 2706 /// Calculate the number of exclusive and shared op counts for expression 2707 /// starting at \p V. Expressions used multiple times are counted once. 2708 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2709 std::pair<OpInfoTy, OpInfoTy> 2710 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2711 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2712 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2713 if (!ExprsInSubprogram.count(Root)) 2714 return {}; 2715 2716 // Already counted this expression. Stop. 2717 if (!ReusedExprs.insert(Root).second) 2718 return {}; 2719 2720 OpInfoTy SharedCount; 2721 OpInfoTy Count; 2722 2723 auto I = Shared.find(Root); 2724 auto CM = Inst2Matrix.find(Root); 2725 if (I->second.size() == 1) 2726 Count = CM->second.getOpInfo(); 2727 else 2728 SharedCount = CM->second.getOpInfo(); 2729 2730 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2731 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2732 Count += C.first; 2733 SharedCount += C.second; 2734 } 2735 return {Count, SharedCount}; 2736 } 2737 2738 void emitRemarks() { 2739 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2740 return; 2741 2742 // Map matrix operations to their containting subprograms, by traversing 2743 // the inlinedAt chain. If the function does not have a DISubprogram, we 2744 // only map them to the containing function. 2745 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2746 for (const auto &KV : Inst2Matrix) { 2747 if (Func.getSubprogram()) { 2748 auto *I = cast<Instruction>(KV.first); 2749 DILocation *Context = I->getDebugLoc(); 2750 while (Context) { 2751 Subprog2Exprs[getSubprogram(Context->getScope())].push_back( 2752 KV.first); 2753 Context = DebugLoc(Context).getInlinedAt(); 2754 } 2755 } else { 2756 Subprog2Exprs[nullptr].push_back(KV.first); 2757 } 2758 } 2759 for (auto &KV : Subprog2Exprs) { 2760 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2761 KV.second.end()); 2762 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2763 2764 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2765 for (Value *Leaf : Leaves) 2766 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2767 2768 // Generate remarks for each leaf. 2769 for (auto *L : Leaves) { 2770 2771 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2772 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2773 while (Context) { 2774 if (getSubprogram(Context->getScope()) == KV.first) { 2775 Loc = Context; 2776 break; 2777 } 2778 Context = DebugLoc(Context).getInlinedAt(); 2779 } 2780 2781 SmallPtrSet<Value *, 8> ReusedExprs; 2782 OpInfoTy Counts, SharedCounts; 2783 std::tie(Counts, SharedCounts) = 2784 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2785 2786 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2787 cast<Instruction>(L)->getParent()); 2788 2789 Rem << "Lowered with "; 2790 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2791 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2792 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2793 << " compute ops, " 2794 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2795 << " exposed transposes"; 2796 2797 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2798 SharedCounts.NumComputeOps > 0) { 2799 Rem << ",\nadditionally " 2800 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2801 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2802 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2803 << " compute ops" 2804 << " are shared with other expressions"; 2805 } 2806 2807 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2808 ORE.emit(Rem); 2809 } 2810 } 2811 } 2812 2813 std::string 2814 linearize(Value *L, 2815 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2816 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2817 const DataLayout &DL) { 2818 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2819 Lin.linearizeExpr(L, 0, false, false); 2820 return Lin.getResult(); 2821 } 2822 }; 2823 }; 2824 } // namespace 2825 2826 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2827 FunctionAnalysisManager &AM) { 2828 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2829 2830 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM); 2831 if (LMT.Visit()) { 2832 PreservedAnalyses PA; 2833 if (!Minimal) { 2834 PA.preserve<LoopAnalysis>(); 2835 PA.preserve<DominatorTreeAnalysis>(); 2836 } 2837 return PA; 2838 } 2839 return PreservedAnalyses::all(); 2840 } 2841 2842 void LowerMatrixIntrinsicsPass::printPipeline( 2843 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2844 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2845 OS, MapClassName2PassName); 2846 OS << '<'; 2847 if (Minimal) 2848 OS << "minimal"; 2849 OS << '>'; 2850 } 2851