1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "llvm/Analysis/AliasAnalysis.h" 24 #include "llvm/Analysis/DomTreeUpdater.h" 25 #include "llvm/Analysis/LoopInfo.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/MatrixBuilder.h" 38 #include "llvm/IR/PatternMatch.h" 39 #include "llvm/InitializePasses.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Alignment.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Scalar.h" 45 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46 #include "llvm/Transforms/Utils/LoopUtils.h" 47 #include "llvm/Transforms/Utils/MatrixUtils.h" 48 49 #include <cmath> 50 51 using namespace llvm; 52 using namespace PatternMatch; 53 54 #define DEBUG_TYPE "lower-matrix-intrinsics" 55 56 static cl::opt<bool> 57 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 58 cl::desc("Enable/disable fusing matrix instructions.")); 59 // TODO: Allow and use non-square tiles. 60 static cl::opt<unsigned> TileSize( 61 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 62 cl::desc( 63 "Tile size for matrix instruction fusion using square-shaped tiles.")); 64 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 65 cl::Hidden, 66 cl::desc("Generate loop nest for tiling.")); 67 static cl::opt<bool> ForceFusion( 68 "force-fuse-matrix", cl::init(false), cl::Hidden, 69 cl::desc("Force matrix instruction fusion even if not profitable.")); 70 static cl::opt<bool> AllowContractEnabled( 71 "matrix-allow-contract", cl::init(false), cl::Hidden, 72 cl::desc("Allow the use of FMAs if available and profitable. This may " 73 "result in different results, due to less rounding error.")); 74 75 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 76 77 static cl::opt<MatrixLayoutTy> MatrixLayout( 78 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 79 cl::desc("Sets the default matrix layout"), 80 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 81 "Use column-major layout"), 82 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 83 "Use row-major layout"))); 84 85 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", 86 cl::init(false)); 87 88 /// Helper function to either return Scope, if it is a subprogram or the 89 /// attached subprogram for a local scope. 90 static DISubprogram *getSubprogram(DIScope *Scope) { 91 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 92 return Subprogram; 93 return cast<DILocalScope>(Scope)->getSubprogram(); 94 } 95 96 /// Erase \p V from \p BB and move \II forward to avoid invalidating 97 /// iterators. 98 static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, 99 BasicBlock &BB) { 100 auto *Inst = cast<Instruction>(V); 101 // Still used, don't erase. 102 if (!Inst->use_empty()) 103 return; 104 if (II != BB.rend() && Inst == &*II) 105 ++II; 106 Inst->eraseFromParent(); 107 } 108 109 /// Return true if V is a splat of a value (which is used when multiplying a 110 /// matrix with a scalar). 111 static bool isSplat(Value *V) { 112 if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) 113 return SV->isZeroEltSplat(); 114 return false; 115 } 116 117 /// Match any mul operation (fp or integer). 118 template <typename LTy, typename RTy> 119 auto m_AnyMul(const LTy &L, const RTy &R) { 120 return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); 121 } 122 123 /// Match any add operation (fp or integer). 124 template <typename LTy, typename RTy> 125 auto m_AnyAdd(const LTy &L, const RTy &R) { 126 return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); 127 } 128 129 namespace { 130 131 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 132 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 133 // assuming \p Stride elements between start two consecutive vectors. 134 // \p Stride must be >= \p NumElements. 135 // For column-major matrixes, the function computes the address of a column 136 // vectors and \p NumElements must be set to the number of elements in a column 137 // (= number of rows of the matrix). For row-major matrixes, the function 138 // computes the address of a row vector and \p NumElements must be set to the 139 // number of elements in a column (= number of columns of the matrix). 140 // 141 // Consider a 4x4 matrix in column-mjaor layout like below 142 // 143 // 0 1 2 3 144 // 0 v_0_0 v_0_1 v_0_2 v_0_3 145 // 1 v_1_0 v_1_1 v_1_2 v_1_3 146 // 2 v_2_0 v_2_1 v_2_2 v_2_3 147 // 3 v_3_0 v_3_1 v_3_2 v_3_3 148 149 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 150 // we need a pointer to the first element of the submatrix as base pointer. 151 // Then we can use computeVectorAddr to compute the addresses for the columns 152 // of the sub-matrix. 153 // 154 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 155 // -> just returns Base 156 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 157 // -> returns Base + (1 * 4) 158 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 159 // -> returns Base + (2 * 4) 160 // 161 // The graphic below illustrates the number of elements in a column (marked 162 // with |) and the number of skipped elements (marked with }). 163 // 164 // v_0_0 v_0_1 {v_0_2 {v_0_3 165 // Base Col 1 Col 2 166 // | | | 167 // v_1_0 |v_1_1 |v_1_2 |v_1_3 168 // v_2_0 |v_2_1 |v_2_2 |v_2_3 169 // v_3_0 {v_3_1 {v_3_2 v_3_3 170 // 171 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 172 unsigned NumElements, Type *EltType, 173 IRBuilder<> &Builder) { 174 175 assert((!isa<ConstantInt>(Stride) || 176 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 177 "Stride must be >= the number of elements in the result vector."); 178 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 179 180 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 181 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 182 183 // Get pointer to the start of the selected vector. Skip GEP creation, 184 // if we select vector 0. 185 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 186 VecStart = BasePtr; 187 else 188 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 189 190 // Cast elementwise vector start pointer to a pointer to a vector 191 // (EltType x NumElements)*. 192 auto *VecType = FixedVectorType::get(EltType, NumElements); 193 Type *VecPtrType = PointerType::get(VecType, AS); 194 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 195 } 196 197 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 198 /// 199 /// Currently, the lowering for each matrix intrinsic is done as follows: 200 /// 1. Propagate the shape information from intrinsics to connected 201 /// instructions. 202 /// 2. Lower instructions with shape information (assuming column-major layout). 203 /// The lowering works similarly using row-major layout. 204 /// 2.1. Get column vectors for each argument. If we already lowered the 205 /// definition of an argument, use the produced column vectors directly. 206 /// If not, split the operand vector containing an embedded matrix into 207 /// a set of column vectors, 208 /// 2.2. Lower the instruction in terms of column major operations, which 209 /// yields a set of column vectors containing result matrix. Note that we 210 /// lower all instructions that have shape information. Besides the 211 /// intrinsics, this includes stores for example. 212 /// 2.3. Update uses of the lowered instruction. If we have shape information 213 /// for a user, there is nothing to do, as we will look up the result 214 /// column matrix when lowering the user. For other uses, we embed the 215 /// result matrix in a flat vector and update the use. 216 /// 2.4. Cache the result column matrix for the instruction we lowered 217 /// 3. After we lowered all instructions in a function, remove the now 218 /// obsolete instructions. 219 /// 220 class LowerMatrixIntrinsics { 221 Function &Func; 222 const DataLayout &DL; 223 const TargetTransformInfo &TTI; 224 AliasAnalysis *AA; 225 DominatorTree *DT; 226 LoopInfo *LI; 227 OptimizationRemarkEmitter *ORE; 228 229 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 230 struct OpInfoTy { 231 /// Number of stores emitted to generate this matrix. 232 unsigned NumStores = 0; 233 /// Number of loads emitted to generate this matrix. 234 unsigned NumLoads = 0; 235 /// Number of compute operations emitted to generate this matrix. 236 unsigned NumComputeOps = 0; 237 /// Most of the time transposes can be fused with matrix multiplies or can 238 /// be folded away via algebraic simplifications. This is the number of 239 /// transposes that we failed to make "free" via such optimizations. 240 unsigned NumExposedTransposes = 0; 241 242 OpInfoTy &operator+=(const OpInfoTy &RHS) { 243 NumStores += RHS.NumStores; 244 NumLoads += RHS.NumLoads; 245 NumComputeOps += RHS.NumComputeOps; 246 NumExposedTransposes += RHS.NumExposedTransposes; 247 return *this; 248 } 249 }; 250 251 /// Wrapper class representing a matrix as a set of vectors, either in row or 252 /// column major layout. All vectors must have the same vector type. 253 class MatrixTy { 254 SmallVector<Value *, 16> Vectors; 255 256 OpInfoTy OpInfo; 257 258 bool IsColumnMajor = true; 259 260 public: 261 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 262 MatrixTy(ArrayRef<Value *> Vectors) 263 : Vectors(Vectors.begin(), Vectors.end()), 264 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 265 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 266 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 267 268 unsigned D = isColumnMajor() ? NumColumns : NumRows; 269 for (unsigned J = 0; J < D; ++J) 270 addVector(UndefValue::get(FixedVectorType::get( 271 EltTy, isColumnMajor() ? NumRows : NumColumns))); 272 } 273 274 Value *getVector(unsigned i) const { return Vectors[i]; } 275 Value *getColumn(unsigned i) const { 276 assert(isColumnMajor() && "only supported for column-major matrixes"); 277 return Vectors[i]; 278 } 279 Value *getRow(unsigned i) const { 280 assert(!isColumnMajor() && "only supported for row-major matrixes"); 281 return Vectors[i]; 282 } 283 284 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 285 286 Type *getElementType() const { return getVectorTy()->getElementType(); } 287 288 unsigned getNumVectors() const { 289 if (isColumnMajor()) 290 return getNumColumns(); 291 return getNumRows(); 292 } 293 294 unsigned getNumColumns() const { 295 if (isColumnMajor()) 296 return Vectors.size(); 297 else { 298 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 299 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 300 } 301 } 302 unsigned getNumRows() const { 303 if (isColumnMajor()) { 304 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 305 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 306 } else 307 return Vectors.size(); 308 } 309 310 void addVector(Value *V) { Vectors.push_back(V); } 311 VectorType *getColumnTy() { 312 assert(isColumnMajor() && "only supported for column-major matrixes"); 313 return getVectorTy(); 314 } 315 316 VectorType *getVectorTy() const { 317 return cast<VectorType>(Vectors[0]->getType()); 318 } 319 320 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 321 assert(isColumnMajor() && 322 "columns() only supported for column-major matrixes"); 323 return make_range(Vectors.begin(), Vectors.end()); 324 } 325 326 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 327 return make_range(Vectors.begin(), Vectors.end()); 328 } 329 330 /// Embed the vectors of the matrix into a flat vector by concatenating 331 /// them. 332 Value *embedInVector(IRBuilder<> &Builder) const { 333 return Vectors.size() == 1 ? Vectors[0] 334 : concatenateVectors(Builder, Vectors); 335 } 336 337 MatrixTy &addNumLoads(unsigned N) { 338 OpInfo.NumLoads += N; 339 return *this; 340 } 341 342 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 343 344 MatrixTy &addNumStores(unsigned N) { 345 OpInfo.NumStores += N; 346 return *this; 347 } 348 349 MatrixTy &addNumExposedTransposes(unsigned N) { 350 OpInfo.NumExposedTransposes += N; 351 return *this; 352 } 353 354 MatrixTy &addNumComputeOps(unsigned N) { 355 OpInfo.NumComputeOps += N; 356 return *this; 357 } 358 359 unsigned getNumStores() const { return OpInfo.NumStores; } 360 unsigned getNumLoads() const { return OpInfo.NumLoads; } 361 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 362 363 const OpInfoTy &getOpInfo() const { return OpInfo; } 364 365 bool isColumnMajor() const { return IsColumnMajor; } 366 367 unsigned getStride() const { 368 if (isColumnMajor()) 369 return getNumRows(); 370 return getNumColumns(); 371 } 372 373 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 374 /// matrix is column-major, the result vector is extracted from a column 375 /// vector, otherwise from a row vector. 376 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 377 IRBuilder<> &Builder) const { 378 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 379 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 380 NumElts && 381 "Extracted vector will contain poison values"); 382 return Builder.CreateShuffleVector( 383 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 384 "block"); 385 } 386 }; 387 388 struct ShapeInfo { 389 unsigned NumRows; 390 unsigned NumColumns; 391 392 bool IsColumnMajor; 393 394 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 395 : NumRows(NumRows), NumColumns(NumColumns), 396 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 397 398 ShapeInfo(Value *NumRows, Value *NumColumns) 399 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 400 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 401 402 bool operator==(const ShapeInfo &other) { 403 return NumRows == other.NumRows && NumColumns == other.NumColumns; 404 } 405 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 406 407 /// Returns true if shape-information is defined, meaning both dimensions 408 /// are != 0. 409 operator bool() const { 410 assert(NumRows == 0 || NumColumns != 0); 411 return NumRows != 0; 412 } 413 414 unsigned getStride() const { 415 if (IsColumnMajor) 416 return NumRows; 417 return NumColumns; 418 } 419 420 unsigned getNumVectors() const { 421 if (IsColumnMajor) 422 return NumColumns; 423 return NumRows; 424 } 425 426 /// Returns the transposed shape. 427 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } 428 }; 429 430 /// Maps instructions to their shape information. The shape information 431 /// describes the shape to be used while lowering. This matches the shape of 432 /// the result value of the instruction, with the only exceptions being store 433 /// instructions and the matrix_column_major_store intrinsics. For those, the 434 /// shape information indicates that those instructions should be lowered 435 /// using shape information as well. A ValueMap is used so that when 436 /// sub-passes like optimizeTransposes performs RAUW the map stays 437 /// up-to-date. 438 ValueMap<Value *, ShapeInfo> ShapeMap; 439 440 /// List of instructions to remove. While lowering, we are not replacing all 441 /// users of a lowered instruction, if shape information is available and 442 /// those need to be removed after we finished lowering. 443 SmallVector<Instruction *, 16> ToRemove; 444 445 /// Map from instructions to their produced column matrix. 446 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 447 448 private: 449 static FastMathFlags getFastMathFlags(Instruction *Inst) { 450 FastMathFlags FMF; 451 452 if (isa<FPMathOperator>(*Inst)) 453 FMF = Inst->getFastMathFlags(); 454 455 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 456 457 return FMF; 458 } 459 460 public: 461 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 462 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 463 OptimizationRemarkEmitter *ORE) 464 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 465 LI(LI), ORE(ORE) {} 466 467 unsigned getNumOps(Type *VT) { 468 assert(isa<VectorType>(VT) && "Expected vector type"); 469 return getNumOps(VT->getScalarType(), 470 cast<FixedVectorType>(VT)->getNumElements()); 471 } 472 473 /// Is this the minimal version executed in the backend pipelines. 474 bool isMinimal() const { 475 return !DT; 476 } 477 478 /// Return the estimated number of vector ops required for an operation on 479 /// \p VT * N. 480 unsigned getNumOps(Type *ST, unsigned N) { 481 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / 482 double(TTI.getRegisterBitWidth( 483 TargetTransformInfo::RGK_FixedWidthVector) 484 .getFixedValue())); 485 } 486 487 /// Return the set of vectors that a matrix value is lowered to. 488 /// 489 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 490 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 491 /// into vectors. 492 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 493 IRBuilder<> &Builder) { 494 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 495 assert(VType && "MatrixVal must be a vector type"); 496 assert(cast<FixedVectorType>(VType)->getNumElements() == 497 SI.NumRows * SI.NumColumns && 498 "The vector size must match the number of matrix elements"); 499 500 // Check if we lowered MatrixVal using shape information. In that case, 501 // return the existing matrix, if it matches the requested shape 502 // information. If there is a mis-match, embed the result in a flat 503 // vector and split it later. 504 auto Found = Inst2ColumnMatrix.find(MatrixVal); 505 if (Found != Inst2ColumnMatrix.end()) { 506 MatrixTy &M = Found->second; 507 // Return the found matrix, if its shape matches the requested shape 508 // information 509 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 510 return M; 511 512 MatrixVal = M.embedInVector(Builder); 513 } 514 515 // Otherwise split MatrixVal. 516 SmallVector<Value *, 16> SplitVecs; 517 for (unsigned MaskStart = 0; 518 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 519 MaskStart += SI.getStride()) { 520 Value *V = Builder.CreateShuffleVector( 521 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 522 "split"); 523 SplitVecs.push_back(V); 524 } 525 526 return {SplitVecs}; 527 } 528 529 /// If \p V already has a known shape return false. Otherwise set the shape 530 /// for instructions that support it. 531 bool setShapeInfo(Value *V, ShapeInfo Shape) { 532 assert(Shape && "Shape not set"); 533 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 534 return false; 535 536 auto SIter = ShapeMap.find(V); 537 if (SIter != ShapeMap.end()) { 538 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 539 << SIter->second.NumRows << " " 540 << SIter->second.NumColumns << " for " << *V << "\n"); 541 return false; 542 } 543 544 ShapeMap.insert({V, Shape}); 545 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 546 << " for " << *V << "\n"); 547 return true; 548 } 549 550 bool isUniformShape(Value *V) { 551 Instruction *I = dyn_cast<Instruction>(V); 552 if (!I) 553 return true; 554 555 switch (I->getOpcode()) { 556 case Instruction::FAdd: 557 case Instruction::FSub: 558 case Instruction::FMul: // Scalar multiply. 559 case Instruction::FNeg: 560 case Instruction::Add: 561 case Instruction::Mul: 562 case Instruction::Sub: 563 return true; 564 default: 565 return false; 566 } 567 } 568 569 /// Returns true if shape information can be used for \p V. The supported 570 /// instructions must match the instructions that can be lowered by this pass. 571 bool supportsShapeInfo(Value *V) { 572 Instruction *Inst = dyn_cast<Instruction>(V); 573 if (!Inst) 574 return false; 575 576 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 577 if (II) 578 switch (II->getIntrinsicID()) { 579 case Intrinsic::matrix_multiply: 580 case Intrinsic::matrix_transpose: 581 case Intrinsic::matrix_column_major_load: 582 case Intrinsic::matrix_column_major_store: 583 return true; 584 default: 585 return false; 586 } 587 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 588 } 589 590 /// Propagate the shape information of instructions to their users. 591 /// The work list contains instructions for which we can compute the shape, 592 /// either based on the information provided by matrix intrinsics or known 593 /// shapes of operands. 594 SmallVector<Instruction *, 32> 595 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 596 SmallVector<Instruction *, 32> NewWorkList; 597 // Pop an element for which we guaranteed to have at least one of the 598 // operand shapes. Add the shape for this and then add users to the work 599 // list. 600 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 601 while (!WorkList.empty()) { 602 Instruction *Inst = WorkList.pop_back_val(); 603 604 // New entry, set the value and insert operands 605 bool Propagate = false; 606 607 Value *MatrixA; 608 Value *MatrixB; 609 Value *M; 610 Value *N; 611 Value *K; 612 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 613 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 614 m_Value(N), m_Value(K)))) { 615 Propagate = setShapeInfo(Inst, {M, K}); 616 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 617 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 618 // Flip dimensions. 619 Propagate = setShapeInfo(Inst, {N, M}); 620 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 621 m_Value(MatrixA), m_Value(), m_Value(), 622 m_Value(), m_Value(M), m_Value(N)))) { 623 Propagate = setShapeInfo(Inst, {N, M}); 624 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 625 m_Value(), m_Value(), m_Value(), m_Value(M), 626 m_Value(N)))) { 627 Propagate = setShapeInfo(Inst, {M, N}); 628 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 629 auto OpShape = ShapeMap.find(MatrixA); 630 if (OpShape != ShapeMap.end()) 631 setShapeInfo(Inst, OpShape->second); 632 continue; 633 } else if (isUniformShape(Inst)) { 634 // Find the first operand that has a known shape and use that. 635 for (auto &Op : Inst->operands()) { 636 auto OpShape = ShapeMap.find(Op.get()); 637 if (OpShape != ShapeMap.end()) { 638 Propagate |= setShapeInfo(Inst, OpShape->second); 639 break; 640 } 641 } 642 } 643 644 if (Propagate) { 645 NewWorkList.push_back(Inst); 646 for (auto *User : Inst->users()) 647 if (ShapeMap.count(User) == 0) 648 WorkList.push_back(cast<Instruction>(User)); 649 } 650 } 651 652 return NewWorkList; 653 } 654 655 /// Propagate the shape to operands of instructions with shape information. 656 /// \p Worklist contains the instruction for which we already know the shape. 657 SmallVector<Instruction *, 32> 658 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 659 SmallVector<Instruction *, 32> NewWorkList; 660 661 auto pushInstruction = [](Value *V, 662 SmallVectorImpl<Instruction *> &WorkList) { 663 Instruction *I = dyn_cast<Instruction>(V); 664 if (I) 665 WorkList.push_back(I); 666 }; 667 // Pop an element with known shape. Traverse the operands, if their shape 668 // derives from the result shape and is unknown, add it and add them to the 669 // worklist. 670 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 671 while (!WorkList.empty()) { 672 Value *V = WorkList.pop_back_val(); 673 674 size_t BeforeProcessingV = WorkList.size(); 675 if (!isa<Instruction>(V)) 676 continue; 677 678 Value *MatrixA; 679 Value *MatrixB; 680 Value *M; 681 Value *N; 682 Value *K; 683 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 684 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 685 m_Value(N), m_Value(K)))) { 686 if (setShapeInfo(MatrixA, {M, N})) 687 pushInstruction(MatrixA, WorkList); 688 689 if (setShapeInfo(MatrixB, {N, K})) 690 pushInstruction(MatrixB, WorkList); 691 692 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 693 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 694 // Flip dimensions. 695 if (setShapeInfo(MatrixA, {M, N})) 696 pushInstruction(MatrixA, WorkList); 697 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 698 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 699 m_Value(M), m_Value(N)))) { 700 if (setShapeInfo(MatrixA, {M, N})) { 701 pushInstruction(MatrixA, WorkList); 702 } 703 } else if (isa<LoadInst>(V) || 704 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 705 // Nothing to do, no matrix input. 706 } else if (isa<StoreInst>(V)) { 707 // Nothing to do. We forward-propagated to this so we would just 708 // backward propagate to an instruction with an already known shape. 709 } else if (isUniformShape(V)) { 710 // Propagate to all operands. 711 ShapeInfo Shape = ShapeMap[V]; 712 for (Use &U : cast<Instruction>(V)->operands()) { 713 if (setShapeInfo(U.get(), Shape)) 714 pushInstruction(U.get(), WorkList); 715 } 716 } 717 // After we discovered new shape info for new instructions in the 718 // worklist, we use their users as seeds for the next round of forward 719 // propagation. 720 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 721 for (User *U : WorkList[I]->users()) 722 if (isa<Instruction>(U) && V != U) 723 NewWorkList.push_back(cast<Instruction>(U)); 724 } 725 return NewWorkList; 726 } 727 728 /// (Op0 op Op1)^T -> Op0^T op Op1^T 729 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use 730 /// them on both sides of \p Operation. 731 Instruction *distributeTransposes( 732 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, 733 MatrixBuilder &Builder, 734 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> 735 Operation) { 736 Value *T0 = Builder.CreateMatrixTranspose( 737 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); 738 // We are being run after shape prop, add shape for newly created 739 // instructions so that we lower them later. 740 setShapeInfo(T0, Shape0.t()); 741 Value *T1 = Builder.CreateMatrixTranspose( 742 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); 743 setShapeInfo(T1, Shape1.t()); 744 return Operation(T0, Shape0.t(), T1, Shape1.t()); 745 } 746 747 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { 748 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 749 // with New. We should only add New it it supportsShapeInfo so we insert 750 // it conditionally instead. 751 auto S = ShapeMap.find(&Old); 752 if (S != ShapeMap.end()) { 753 ShapeMap.erase(S); 754 if (supportsShapeInfo(New)) 755 ShapeMap.insert({New, S->second}); 756 } 757 Old.replaceAllUsesWith(New); 758 } 759 760 /// Sink a top-level transpose inside matmuls and adds. 761 /// This creates and erases instructions as needed, and returns the newly 762 /// created instruction while updating the iterator to avoid invalidation. If 763 /// this returns nullptr, no new instruction was created. 764 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { 765 BasicBlock &BB = *I.getParent(); 766 IRBuilder<> IB(&I); 767 MatrixBuilder Builder(IB); 768 769 Value *TA, *TAMA, *TAMB; 770 ConstantInt *R, *K, *C; 771 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( 772 m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) 773 return nullptr; 774 775 // Transpose of a transpose is a nop 776 Value *TATA; 777 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 778 updateShapeAndReplaceAllUsesWith(I, TATA); 779 eraseFromParentAndMove(&I, II, BB); 780 eraseFromParentAndMove(TA, II, BB); 781 return nullptr; 782 } 783 784 // k^T -> k 785 if (isSplat(TA)) { 786 updateShapeAndReplaceAllUsesWith(I, TA); 787 eraseFromParentAndMove(&I, II, BB); 788 return nullptr; 789 } 790 791 // (A * B)^t -> B^t * A^t 792 // RxK KxC CxK KxR 793 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 794 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 795 m_ConstantInt(K), m_ConstantInt(C)))) { 796 auto NewInst = distributeTransposes( 797 TAMB, {K, C}, TAMA, {R, K}, Builder, 798 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 799 return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, 800 Shape0.NumColumns, 801 Shape1.NumColumns, "mmul"); 802 }); 803 updateShapeAndReplaceAllUsesWith(I, NewInst); 804 eraseFromParentAndMove(&I, II, BB); 805 eraseFromParentAndMove(TA, II, BB); 806 return NewInst; 807 } 808 809 // Same as above, but with a mul, which occurs when multiplied 810 // with a scalar. 811 // (A * k)^t -> A^t * k 812 // R x C RxC 813 if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && 814 (isSplat(TAMA) || isSplat(TAMB))) { 815 IRBuilder<> LocalBuilder(&I); 816 // We know that the transposed operand is of shape RxC. 817 // An when multiplied with a scalar, the shape is preserved. 818 auto NewInst = distributeTransposes( 819 TAMA, {R, C}, TAMB, {R, C}, Builder, 820 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 821 bool IsFP = I.getType()->isFPOrFPVectorTy(); 822 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") 823 : LocalBuilder.CreateMul(T0, T1, "mmul"); 824 auto *Result = cast<Instruction>(Mul); 825 setShapeInfo(Result, Shape0); 826 return Result; 827 }); 828 updateShapeAndReplaceAllUsesWith(I, NewInst); 829 eraseFromParentAndMove(&I, II, BB); 830 eraseFromParentAndMove(TA, II, BB); 831 return NewInst; 832 } 833 834 // (A + B)^t -> A^t + B^t 835 // RxC RxC CxR CxR 836 if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { 837 IRBuilder<> LocalBuilder(&I); 838 auto NewInst = distributeTransposes( 839 TAMA, {R, C}, TAMB, {R, C}, Builder, 840 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { 841 auto *FAdd = 842 cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd")); 843 setShapeInfo(FAdd, Shape0); 844 return FAdd; 845 }); 846 updateShapeAndReplaceAllUsesWith(I, NewInst); 847 eraseFromParentAndMove(&I, II, BB); 848 eraseFromParentAndMove(TA, II, BB); 849 return NewInst; 850 } 851 852 return nullptr; 853 } 854 855 void liftTranspose(Instruction &I) { 856 // Erase dead Instructions after lifting transposes from binops. 857 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { 858 if (T.use_empty()) 859 T.eraseFromParent(); 860 if (A->use_empty()) 861 cast<Instruction>(A)->eraseFromParent(); 862 if (A != B && B->use_empty()) 863 cast<Instruction>(B)->eraseFromParent(); 864 }; 865 866 Value *A, *B, *AT, *BT; 867 ConstantInt *R, *K, *C; 868 // A^t * B ^t -> (B * A)^t 869 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 870 m_Value(A), m_Value(B), m_ConstantInt(R), 871 m_ConstantInt(K), m_ConstantInt(C))) && 872 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 873 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 874 IRBuilder<> IB(&I); 875 MatrixBuilder Builder(IB); 876 Value *M = Builder.CreateMatrixMultiply( 877 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 878 setShapeInfo(M, {C, R}); 879 Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), 880 R->getZExtValue()); 881 updateShapeAndReplaceAllUsesWith(I, NewInst); 882 CleanupBinOp(I, A, B); 883 } 884 // A^t + B ^t -> (A + B)^t 885 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && 886 match(A, m_Intrinsic<Intrinsic::matrix_transpose>( 887 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && 888 match(B, m_Intrinsic<Intrinsic::matrix_transpose>( 889 m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) { 890 IRBuilder<> Builder(&I); 891 Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); 892 setShapeInfo(Add, {C, R}); 893 MatrixBuilder MBuilder(Builder); 894 Instruction *NewInst = MBuilder.CreateMatrixTranspose( 895 Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t"); 896 updateShapeAndReplaceAllUsesWith(I, NewInst); 897 CleanupBinOp(I, A, B); 898 } 899 } 900 901 /// Try moving transposes in order to fold them away or into multiplies. 902 void optimizeTransposes() { 903 // First sink all transposes inside matmuls and adds, hoping that we end up 904 // with NN, NT or TN variants. 905 for (BasicBlock &BB : reverse(Func)) { 906 for (auto II = BB.rbegin(); II != BB.rend();) { 907 Instruction &I = *II; 908 // We may remove II. By default continue on the next/prev instruction. 909 ++II; 910 if (Instruction *NewInst = sinkTranspose(I, II)) 911 II = std::next(BasicBlock::reverse_iterator(NewInst)); 912 } 913 } 914 915 // If we have a TT matmul or a TT add, lift the transpose. We may be able 916 // to fold into consuming multiply or add. 917 for (BasicBlock &BB : Func) { 918 for (Instruction &I : llvm::make_early_inc_range(BB)) { 919 liftTranspose(I); 920 } 921 } 922 } 923 924 bool Visit() { 925 SmallVector<Instruction *, 32> WorkList; 926 927 // Initially only the shape of matrix intrinsics is known. 928 // Initialize the work list with ops carrying shape information. 929 for (BasicBlock &BB : Func) 930 for (Instruction &Inst : BB) { 931 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 932 if (!II) 933 continue; 934 935 switch (II->getIntrinsicID()) { 936 case Intrinsic::matrix_multiply: 937 case Intrinsic::matrix_transpose: 938 case Intrinsic::matrix_column_major_load: 939 case Intrinsic::matrix_column_major_store: 940 WorkList.push_back(&Inst); 941 break; 942 default: 943 break; 944 } 945 } 946 947 // Avoid unnecessary work if there are no matrix intrinsics in the function. 948 if (WorkList.empty()) 949 return false; 950 951 // Propagate shapes until nothing changes any longer. 952 while (!WorkList.empty()) { 953 WorkList = propagateShapeForward(WorkList); 954 WorkList = propagateShapeBackward(WorkList); 955 } 956 957 if (!isMinimal()) { 958 optimizeTransposes(); 959 if (PrintAfterTransposeOpt) { 960 dbgs() << "Dump after matrix transpose optimization:\n"; 961 Func.print(dbgs()); 962 } 963 } 964 965 bool Changed = false; 966 SmallVector<CallInst *, 16> MaybeFusableInsts; 967 SmallVector<Instruction *, 16> MatrixInsts; 968 969 // First, collect all instructions with shape information and candidates for 970 // fusion (currently only matrix multiplies). 971 ReversePostOrderTraversal<Function *> RPOT(&Func); 972 for (auto *BB : RPOT) 973 for (Instruction &I : *BB) { 974 if (ShapeMap.find(&I) == ShapeMap.end()) 975 continue; 976 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 977 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 978 MatrixInsts.push_back(&I); 979 } 980 981 // Second, try to fuse candidates. 982 SmallPtrSet<Instruction *, 16> FusedInsts; 983 for (CallInst *CI : MaybeFusableInsts) 984 LowerMatrixMultiplyFused(CI, FusedInsts); 985 Changed = !FusedInsts.empty(); 986 987 // Third, lower remaining instructions with shape information. 988 for (Instruction *Inst : MatrixInsts) { 989 if (FusedInsts.count(Inst)) 990 continue; 991 992 IRBuilder<> Builder(Inst); 993 994 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 995 Changed |= VisitCallInst(CInst); 996 997 Value *Op1; 998 Value *Op2; 999 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 1000 Changed |= VisitBinaryOperator(BinOp); 1001 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 1002 Changed |= VisitUnaryOperator(UnOp); 1003 if (match(Inst, m_Load(m_Value(Op1)))) 1004 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 1005 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 1006 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 1007 } 1008 1009 if (ORE) { 1010 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 1011 RemarkGen.emitRemarks(); 1012 } 1013 1014 // Delete the instructions backwards, as it has a reduced likelihood of 1015 // having to update as many def-use and use-def chains. 1016 // 1017 // Because we add to ToRemove during fusion we can't guarantee that defs 1018 // are before uses. Change uses to poison temporarily as these should get 1019 // removed as well. 1020 // 1021 // For verification, we keep track of where we changed uses to poison in 1022 // PoisonedInsts and then check that we in fact remove them. 1023 SmallSet<Instruction *, 16> PoisonedInsts; 1024 for (auto *Inst : reverse(ToRemove)) { 1025 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1026 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 1027 PoisonedInsts.insert(Poisoned); 1028 U.set(PoisonValue::get(Inst->getType())); 1029 } 1030 Inst->eraseFromParent(); 1031 PoisonedInsts.erase(Inst); 1032 } 1033 if (!PoisonedInsts.empty()) { 1034 // If we didn't remove all poisoned instructions, it's a hard error. 1035 dbgs() << "Poisoned but present instructions:\n"; 1036 for (auto *I : PoisonedInsts) 1037 dbgs() << *I << "\n"; 1038 llvm_unreachable("Poisoned but instruction not removed"); 1039 } 1040 1041 return Changed; 1042 } 1043 1044 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 1045 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 1046 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 1047 Type *EltPtrType = PointerType::get(EltType, AS); 1048 return Builder.CreatePointerCast(BasePtr, EltPtrType); 1049 } 1050 1051 /// Replace intrinsic calls 1052 bool VisitCallInst(CallInst *Inst) { 1053 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 1054 return false; 1055 1056 switch (Inst->getCalledFunction()->getIntrinsicID()) { 1057 case Intrinsic::matrix_multiply: 1058 LowerMultiply(Inst); 1059 break; 1060 case Intrinsic::matrix_transpose: 1061 LowerTranspose(Inst); 1062 break; 1063 case Intrinsic::matrix_column_major_load: 1064 LowerColumnMajorLoad(Inst); 1065 break; 1066 case Intrinsic::matrix_column_major_store: 1067 LowerColumnMajorStore(Inst); 1068 break; 1069 default: 1070 return false; 1071 } 1072 return true; 1073 } 1074 1075 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 1076 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 1077 /// ConstantInt, reduce the initial alignment based on the byte offset. For 1078 /// non-ConstantInt strides, return the common alignment of the initial 1079 /// alignment and the element size in bytes. 1080 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 1081 MaybeAlign A) const { 1082 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 1083 if (Idx == 0) 1084 return InitialAlign; 1085 1086 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 1087 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 1088 uint64_t StrideInBytes = 1089 ConstStride->getZExtValue() * ElementSizeInBits / 8; 1090 return commonAlignment(InitialAlign, Idx * StrideInBytes); 1091 } 1092 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 1093 } 1094 1095 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 1096 /// vectors. 1097 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 1098 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 1099 auto *VType = cast<VectorType>(Ty); 1100 Type *EltTy = VType->getElementType(); 1101 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 1102 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 1103 MatrixTy Result; 1104 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 1105 Value *GEP = computeVectorAddr( 1106 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 1107 Stride, Shape.getStride(), EltTy, Builder); 1108 Value *Vector = Builder.CreateAlignedLoad( 1109 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 1110 IsVolatile, "col.load"); 1111 1112 Result.addVector(Vector); 1113 } 1114 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 1115 Result.getNumVectors()); 1116 } 1117 1118 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 1119 /// starting at \p MatrixPtr[I][J]. 1120 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 1121 ShapeInfo MatrixShape, Value *I, Value *J, 1122 ShapeInfo ResultShape, Type *EltTy, 1123 IRBuilder<> &Builder) { 1124 1125 Value *Offset = Builder.CreateAdd( 1126 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1127 1128 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1129 Value *EltPtr = 1130 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1131 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1132 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1133 ResultShape.NumColumns); 1134 Type *TilePtrTy = PointerType::get(TileTy, AS); 1135 Value *TilePtr = 1136 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1137 1138 return loadMatrix(TileTy, TilePtr, Align, 1139 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1140 ResultShape, Builder); 1141 } 1142 1143 /// Lower a load instruction with shape information. 1144 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 1145 bool IsVolatile, ShapeInfo Shape) { 1146 IRBuilder<> Builder(Inst); 1147 finalizeLowering(Inst, 1148 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 1149 Shape, Builder), 1150 Builder); 1151 } 1152 1153 /// Lowers llvm.matrix.column.major.load. 1154 /// 1155 /// The intrinsic loads a matrix from memory using a stride between columns. 1156 void LowerColumnMajorLoad(CallInst *Inst) { 1157 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1158 "Intrinsic only supports column-major layout!"); 1159 Value *Ptr = Inst->getArgOperand(0); 1160 Value *Stride = Inst->getArgOperand(1); 1161 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1162 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1163 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1164 } 1165 1166 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1167 /// MatrixPtr[I][J]. 1168 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1169 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1170 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1171 Value *Offset = Builder.CreateAdd( 1172 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1173 1174 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1175 Value *EltPtr = 1176 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1177 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1178 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1179 StoreVal.getNumColumns()); 1180 Type *TilePtrTy = PointerType::get(TileTy, AS); 1181 Value *TilePtr = 1182 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1183 1184 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 1185 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1186 } 1187 1188 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1189 /// vectors. 1190 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1191 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1192 IRBuilder<> &Builder) { 1193 auto VType = cast<VectorType>(Ty); 1194 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 1195 for (auto Vec : enumerate(StoreVal.vectors())) { 1196 Value *GEP = computeVectorAddr( 1197 EltPtr, 1198 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1199 Vec.index()), 1200 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1201 Builder.CreateAlignedStore(Vec.value(), GEP, 1202 getAlignForIndex(Vec.index(), Stride, 1203 VType->getElementType(), 1204 MAlign), 1205 IsVolatile); 1206 } 1207 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1208 StoreVal.getNumVectors()); 1209 } 1210 1211 /// Lower a store instruction with shape information. 1212 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1213 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1214 IRBuilder<> Builder(Inst); 1215 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1216 finalizeLowering(Inst, 1217 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1218 IsVolatile, Builder), 1219 Builder); 1220 } 1221 1222 /// Lowers llvm.matrix.column.major.store. 1223 /// 1224 /// The intrinsic store a matrix back memory using a stride between columns. 1225 void LowerColumnMajorStore(CallInst *Inst) { 1226 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1227 "Intrinsic only supports column-major layout!"); 1228 Value *Matrix = Inst->getArgOperand(0); 1229 Value *Ptr = Inst->getArgOperand(1); 1230 Value *Stride = Inst->getArgOperand(2); 1231 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1232 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1233 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1234 } 1235 1236 // Set elements I..I+NumElts-1 to Block 1237 Value *insertVector(Value *Col, unsigned I, Value *Block, 1238 IRBuilder<> &Builder) { 1239 1240 // First, bring Block to the same size as Col 1241 unsigned BlockNumElts = 1242 cast<FixedVectorType>(Block->getType())->getNumElements(); 1243 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1244 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1245 1246 Block = Builder.CreateShuffleVector( 1247 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1248 1249 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1250 // 8, 4, 5, 6 1251 SmallVector<int, 16> Mask; 1252 unsigned i; 1253 for (i = 0; i < I; i++) 1254 Mask.push_back(i); 1255 1256 unsigned VecNumElts = 1257 cast<FixedVectorType>(Col->getType())->getNumElements(); 1258 for (; i < I + BlockNumElts; i++) 1259 Mask.push_back(i - I + VecNumElts); 1260 1261 for (; i < VecNumElts; i++) 1262 Mask.push_back(i); 1263 1264 return Builder.CreateShuffleVector(Col, Block, Mask); 1265 } 1266 1267 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1268 IRBuilder<> &Builder, bool AllowContraction, 1269 unsigned &NumComputeOps) { 1270 NumComputeOps += getNumOps(A->getType()); 1271 if (!Sum) 1272 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1273 1274 if (UseFPOp) { 1275 if (AllowContraction) { 1276 // Use fmuladd for floating point operations and let the backend decide 1277 // if that's profitable. 1278 Function *FMulAdd = Intrinsic::getDeclaration( 1279 Func.getParent(), Intrinsic::fmuladd, A->getType()); 1280 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1281 } 1282 NumComputeOps += getNumOps(A->getType()); 1283 Value *Mul = Builder.CreateFMul(A, B); 1284 return Builder.CreateFAdd(Sum, Mul); 1285 } 1286 1287 NumComputeOps += getNumOps(A->getType()); 1288 Value *Mul = Builder.CreateMul(A, B); 1289 return Builder.CreateAdd(Sum, Mul); 1290 } 1291 1292 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1293 /// users with shape information, there's nothing to do: they will use the 1294 /// cached value when they are lowered. For other users, \p Matrix is 1295 /// flattened and the uses are updated to use it. Also marks \p Inst for 1296 /// deletion. 1297 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1298 IRBuilder<> &Builder) { 1299 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1300 (void)inserted; 1301 assert(inserted.second && "multiple matrix lowering mapping"); 1302 1303 ToRemove.push_back(Inst); 1304 Value *Flattened = nullptr; 1305 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1306 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1307 if (!Flattened) 1308 Flattened = Matrix.embedInVector(Builder); 1309 U.set(Flattened); 1310 } 1311 } 1312 } 1313 1314 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1315 /// addition. 1316 /// 1317 /// We can fold a transpose into the operand that is used to extract scalars. 1318 /// This is the first operands with row-major and the second with 1319 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1320 /// operand is transposed. 1321 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1322 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1323 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1324 const unsigned VF = std::max<unsigned>( 1325 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1326 .getFixedValue() / 1327 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1328 1U); 1329 unsigned R = Result.getNumRows(); 1330 unsigned C = Result.getNumColumns(); 1331 unsigned M = A.getNumColumns(); 1332 1333 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1334 assert(A.isColumnMajor() == B.isColumnMajor() && 1335 Result.isColumnMajor() == A.isColumnMajor() && 1336 "operands must agree on matrix layout"); 1337 unsigned NumComputeOps = 0; 1338 1339 Builder.setFastMathFlags(FMF); 1340 1341 if (A.isColumnMajor()) { 1342 // Multiply columns from the first operand with scalars from the second 1343 // operand. Then move along the K axes and accumulate the columns. With 1344 // this the adds can be vectorized without reassociation. 1345 for (unsigned J = 0; J < C; ++J) { 1346 unsigned BlockSize = VF; 1347 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1348 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1349 1350 for (unsigned I = 0; I < R; I += BlockSize) { 1351 // Gradually lower the vectorization factor to cover the remainder. 1352 while (I + BlockSize > R) 1353 BlockSize /= 2; 1354 1355 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1356 : nullptr; 1357 for (unsigned K = 0; K < M; ++K) { 1358 Value *L = A.extractVector(I, K, BlockSize, Builder); 1359 Value *RH = Builder.CreateExtractElement( 1360 B.getColumn(IsScalarMatrixTransposed ? K : J), 1361 IsScalarMatrixTransposed ? J : K); 1362 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1363 Sum = 1364 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1365 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1366 } 1367 Result.setVector(J, 1368 insertVector(Result.getVector(J), I, Sum, Builder)); 1369 } 1370 } 1371 } else { 1372 // Multiply rows from the second operand with scalars from the first 1373 // operand. Then move along the K axes and accumulate the rows. With this 1374 // the adds can be vectorized without reassociation. 1375 for (unsigned I = 0; I < R; ++I) { 1376 unsigned BlockSize = VF; 1377 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1378 for (unsigned J = 0; J < C; J += BlockSize) { 1379 // Gradually lower the vectorization factor to cover the remainder. 1380 while (J + BlockSize > C) 1381 BlockSize /= 2; 1382 1383 Value *Sum = nullptr; 1384 for (unsigned K = 0; K < M; ++K) { 1385 Value *R = B.extractVector(K, J, BlockSize, Builder); 1386 Value *LH = Builder.CreateExtractElement( 1387 A.getVector(IsScalarMatrixTransposed ? K : I), 1388 IsScalarMatrixTransposed ? I : K); 1389 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1390 Sum = 1391 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1392 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1393 } 1394 Result.setVector(I, 1395 insertVector(Result.getVector(I), J, Sum, Builder)); 1396 } 1397 } 1398 } 1399 Result.addNumComputeOps(NumComputeOps); 1400 } 1401 1402 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1403 /// copying it to a new location. This new or otherwise the original location 1404 /// is returned. 1405 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1406 CallInst *MatMul) { 1407 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1408 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1409 1410 // If we can statically determine noalias we're good. 1411 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1412 return Load->getPointerOperand(); 1413 1414 // Create code to check if the memory locations of the Load and Store 1415 // overlap and if they do, copy Load's operand to a new buffer. 1416 1417 // First, create new blocks for 2n part of the check and the copy. 1418 BasicBlock *Check0 = MatMul->getParent(); 1419 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1420 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1421 // as we adjust Check0 and Check1's branches. 1422 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1423 for (BasicBlock *Succ : successors(Check0)) 1424 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1425 1426 BasicBlock *Check1 = 1427 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1428 nullptr, "alias_cont"); 1429 BasicBlock *Copy = 1430 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1431 nullptr, "copy"); 1432 BasicBlock *Fusion = 1433 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1434 nullptr, "no_alias"); 1435 1436 // Check if the loaded memory location begins before the end of the store 1437 // location. If the condition holds, they might overlap, otherwise they are 1438 // guaranteed to not overlap. 1439 IRBuilder<> Builder(MatMul); 1440 Check0->getTerminator()->eraseFromParent(); 1441 Builder.SetInsertPoint(Check0); 1442 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1443 Value *StoreBegin = Builder.CreatePtrToInt( 1444 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1445 Value *StoreEnd = Builder.CreateAdd( 1446 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1447 "store.end", true, true); 1448 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1449 IntPtrTy, "load.begin"); 1450 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1451 Fusion); 1452 1453 // Check if the store begins before the end of the load location. If the 1454 // condition holds, they alias, otherwise they are guaranteed to not 1455 // overlap. 1456 Check1->getTerminator()->eraseFromParent(); 1457 Builder.SetInsertPoint(Check1, Check1->begin()); 1458 Value *LoadEnd = Builder.CreateAdd( 1459 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1460 "load.end", true, true); 1461 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1462 Fusion); 1463 1464 // Copy load operand to new alloca. 1465 Builder.SetInsertPoint(Copy, Copy->begin()); 1466 auto *VT = cast<FixedVectorType>(Load->getType()); 1467 // Use an array type for the alloca, to avoid potentially huge alignment 1468 // requirements for large vector types. 1469 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1470 AllocaInst *Alloca = 1471 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1472 Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo()); 1473 1474 Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(), 1475 Load->getAlign(), LoadLoc.Size.getValue()); 1476 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1477 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1478 PHI->addIncoming(Load->getPointerOperand(), Check0); 1479 PHI->addIncoming(Load->getPointerOperand(), Check1); 1480 PHI->addIncoming(BC, Copy); 1481 1482 // Adjust DT. 1483 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1484 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1485 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1486 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1487 DT->applyUpdates(DTUpdates); 1488 return PHI; 1489 } 1490 1491 bool isFusionProfitable(CallInst *MatMul) { 1492 if (ForceFusion) 1493 return true; 1494 1495 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1496 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1497 1498 const unsigned R = LShape.NumRows; 1499 const unsigned C = RShape.NumColumns; 1500 const unsigned M = LShape.NumColumns; 1501 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1502 1503 const unsigned VF = std::max<unsigned>( 1504 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1505 .getFixedValue() / 1506 EltType->getPrimitiveSizeInBits().getFixedValue(), 1507 1U); 1508 1509 // Cost model for tiling 1510 // 1511 // For tiling to be beneficial, we need reuse either along the R or 1512 // the C axis. We vectorize along the R axis so that means at least 1513 // 3 elements. 1514 // TODO: Also consider cost of copying if operands alias. 1515 if (R <= VF && C == 1) 1516 return false; 1517 // Then we need enough elements to exceed the number of vector 1518 // registers we have. Note that this is an oversimplification since 1519 // fusing also takes some extra loads which may exceed the number of 1520 // reloads necessary. 1521 unsigned Op0Regs = (R + VF - 1) / VF * M; 1522 unsigned Op1Regs = (M + VF - 1) / VF * C; 1523 return Op0Regs + Op1Regs > 1524 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1525 } 1526 1527 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1528 MatrixTy Res; 1529 auto *ColumType = FixedVectorType::get(EltType, R); 1530 for (unsigned I = 0; I < C; ++I) 1531 Res.addVector(ConstantAggregateZero::get(ColumType)); 1532 return Res; 1533 } 1534 1535 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1536 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1537 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1538 1539 // Create the main tiling loop nest. 1540 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1541 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1542 Instruction *InsertI = cast<Instruction>(MatMul); 1543 BasicBlock *Start = InsertI->getParent(); 1544 BasicBlock *End = 1545 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1546 IRBuilder<> Builder(MatMul); 1547 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1548 1549 Type *TileVecTy = 1550 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1551 MatrixTy TileResult; 1552 // Insert in the inner loop header. 1553 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1554 // Create PHI nodes for the result columns to accumulate across iterations. 1555 SmallVector<PHINode *, 4> ColumnPhis; 1556 for (unsigned I = 0; I < TileSize; I++) { 1557 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1558 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1559 TI.RowLoop.Header->getSingleSuccessor()); 1560 TileResult.addVector(Phi); 1561 ColumnPhis.push_back(Phi); 1562 } 1563 1564 // Insert in the inner loop body, which computes 1565 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1566 Builder.SetInsertPoint(InnerBody->getTerminator()); 1567 // Load tiles of the operands. 1568 MatrixTy A = 1569 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1570 {TileSize, TileSize}, EltType, Builder); 1571 MatrixTy B = 1572 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1573 {TileSize, TileSize}, EltType, Builder); 1574 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1575 getFastMathFlags(MatMul)); 1576 // Store result after the inner loop is done. 1577 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1578 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1579 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1580 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1581 1582 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1583 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1584 1585 // Force unrolling of a few iterations of the inner loop, to make sure there 1586 // is enough work per iteration. 1587 // FIXME: The unroller should make this decision directly instead, but 1588 // currently the cost-model is not up to the task. 1589 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1590 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1591 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1592 } 1593 1594 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1595 StoreInst *Store, 1596 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1597 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1598 "Tiling only supported for column-major matrixes at the moment!"); 1599 if (!isFusionProfitable(MatMul)) 1600 return; 1601 1602 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1603 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1604 1605 const unsigned R = LShape.NumRows; 1606 const unsigned C = RShape.NumColumns; 1607 const unsigned M = LShape.NumColumns; 1608 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1609 1610 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1611 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1612 Value *CPtr = Store->getPointerOperand(); 1613 1614 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1615 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1616 else { 1617 IRBuilder<> Builder(Store); 1618 for (unsigned J = 0; J < C; J += TileSize) 1619 for (unsigned I = 0; I < R; I += TileSize) { 1620 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1621 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1622 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1623 1624 for (unsigned K = 0; K < M; K += TileSize) { 1625 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1626 MatrixTy A = 1627 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1628 LShape, Builder.getInt64(I), Builder.getInt64(K), 1629 {TileR, TileM}, EltType, Builder); 1630 MatrixTy B = 1631 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1632 RShape, Builder.getInt64(K), Builder.getInt64(J), 1633 {TileM, TileC}, EltType, Builder); 1634 emitMatrixMultiply(Res, A, B, Builder, true, false, 1635 getFastMathFlags(MatMul)); 1636 } 1637 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1638 Builder.getInt64(I), Builder.getInt64(J), EltType, 1639 Builder); 1640 } 1641 } 1642 1643 // Mark eliminated instructions as fused and remove them. 1644 FusedInsts.insert(Store); 1645 FusedInsts.insert(MatMul); 1646 Store->eraseFromParent(); 1647 MatMul->eraseFromParent(); 1648 if (LoadOp0->hasNUses(0)) { 1649 FusedInsts.insert(LoadOp0); 1650 LoadOp0->eraseFromParent(); 1651 } 1652 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 1653 FusedInsts.insert(LoadOp1); 1654 LoadOp1->eraseFromParent(); 1655 } 1656 } 1657 1658 /// Try to lower matrix multiply chains by fusing operations. 1659 /// 1660 /// Call finalizeLowering on lowered instructions. Instructions that are 1661 /// completely eliminated by fusion are added to \p FusedInsts. 1662 void LowerMatrixMultiplyFused(CallInst *MatMul, 1663 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1664 if (!FuseMatrix || !DT) 1665 return; 1666 1667 assert(AA && LI && "Analyses should be available"); 1668 1669 Value *A = MatMul->getArgOperand(0); 1670 Value *B = MatMul->getArgOperand(1); 1671 1672 // We can fold the transpose into the operand that is used to fetch scalars. 1673 Value *T; 1674 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1675 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1676 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1677 IRBuilder<> Builder(MatMul); 1678 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1679 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1680 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1681 const unsigned R = LShape.NumRows; 1682 const unsigned M = LShape.NumColumns; 1683 const unsigned C = RShape.NumColumns; 1684 1685 MatrixTy MA; 1686 MatrixTy MB; 1687 1688 Value *Transpose; 1689 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1690 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1691 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1692 Transpose = B; 1693 } else { 1694 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1695 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1696 Transpose = A; 1697 } 1698 1699 // Initialize the output 1700 MatrixTy Result(R, C, EltType); 1701 1702 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1703 getFastMathFlags(MatMul)); 1704 1705 FusedInsts.insert(MatMul); 1706 if (Transpose->hasOneUse()) { 1707 FusedInsts.insert(cast<Instruction>(Transpose)); 1708 ToRemove.push_back(cast<Instruction>(Transpose)); 1709 // TODO: add a fake entry for the folded instruction so that this is 1710 // included in the expression in the remark. 1711 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1712 } 1713 finalizeLowering(MatMul, Result, Builder); 1714 return; 1715 } 1716 1717 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1718 return; 1719 1720 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1721 // since the single store user will be lowered as part of this. 1722 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1723 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1724 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1725 if (LoadOp0 && LoadOp1 && Store) { 1726 // The store address must dominate the MatMul instruction, otherwise 1727 // we create invalid IR. 1728 SetVector<Value *> WorkList; 1729 WorkList.insert(Store->getOperand(1)); 1730 SmallVector<Instruction *> ToHoist; 1731 for (unsigned I = 0; I != WorkList.size(); ++I) { 1732 Value *Current = WorkList[I]; 1733 auto *CurrI = dyn_cast<Instruction>(Current); 1734 if (!CurrI) 1735 continue; 1736 if (isa<PHINode>(CurrI)) 1737 return; 1738 if (DT->dominates(CurrI, MatMul)) 1739 continue; 1740 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1741 return; 1742 ToHoist.push_back(CurrI); 1743 WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1744 } 1745 1746 sort(ToHoist, [this](Instruction *A, Instruction *B) { 1747 return DT->dominates(A, B); 1748 }); 1749 for (Instruction *I : ToHoist) 1750 I->moveBefore(MatMul); 1751 1752 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1753 return; 1754 } 1755 } 1756 1757 /// Lowers llvm.matrix.multiply. 1758 void LowerMultiply(CallInst *MatMul) { 1759 IRBuilder<> Builder(MatMul); 1760 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1761 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1762 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1763 1764 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1765 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1766 assert(Lhs.getElementType() == Rhs.getElementType() && 1767 "Matrix multiply argument element types do not match."); 1768 1769 const unsigned R = LShape.NumRows; 1770 const unsigned C = RShape.NumColumns; 1771 assert(LShape.NumColumns == RShape.NumRows); 1772 1773 // Initialize the output 1774 MatrixTy Result(R, C, EltType); 1775 assert(Lhs.getElementType() == Result.getElementType() && 1776 "Matrix multiply result element type does not match arguments."); 1777 1778 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1779 getFastMathFlags(MatMul)); 1780 finalizeLowering(MatMul, Result, Builder); 1781 } 1782 1783 /// Lowers llvm.matrix.transpose. 1784 void LowerTranspose(CallInst *Inst) { 1785 MatrixTy Result; 1786 IRBuilder<> Builder(Inst); 1787 Value *InputVal = Inst->getArgOperand(0); 1788 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1789 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1790 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1791 1792 const unsigned NewNumVecs = 1793 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1794 const unsigned NewNumElts = 1795 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1796 1797 for (unsigned I = 0; I < NewNumVecs; ++I) { 1798 // Build a single result vector. First initialize it. 1799 Value *ResultVector = PoisonValue::get( 1800 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1801 // Go through the old elements and insert it into the resulting vector. 1802 for (auto J : enumerate(InputMatrix.vectors())) { 1803 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1804 // Row and column indices are transposed. 1805 ResultVector = 1806 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1807 } 1808 Result.addVector(ResultVector); 1809 } 1810 1811 // TODO: Improve estimate of operations needed for transposes. Currently we 1812 // just count the insertelement/extractelement instructions, but do not 1813 // account for later simplifications/combines. 1814 finalizeLowering( 1815 Inst, 1816 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1817 .addNumExposedTransposes(1), 1818 Builder); 1819 } 1820 1821 /// Lower load instructions, if shape information is available. 1822 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1823 auto I = ShapeMap.find(Inst); 1824 if (I == ShapeMap.end()) 1825 return false; 1826 1827 LowerLoad(Inst, Ptr, Inst->getAlign(), 1828 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1829 I->second); 1830 return true; 1831 } 1832 1833 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1834 IRBuilder<> &Builder) { 1835 auto I = ShapeMap.find(StoredVal); 1836 if (I == ShapeMap.end()) 1837 return false; 1838 1839 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1840 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1841 I->second); 1842 return true; 1843 } 1844 1845 /// Lower binary operators, if shape information is available. 1846 bool VisitBinaryOperator(BinaryOperator *Inst) { 1847 auto I = ShapeMap.find(Inst); 1848 if (I == ShapeMap.end()) 1849 return false; 1850 1851 Value *Lhs = Inst->getOperand(0); 1852 Value *Rhs = Inst->getOperand(1); 1853 1854 IRBuilder<> Builder(Inst); 1855 ShapeInfo &Shape = I->second; 1856 1857 MatrixTy Result; 1858 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1859 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1860 assert(A.isColumnMajor() == B.isColumnMajor() && 1861 Result.isColumnMajor() == A.isColumnMajor() && 1862 "operands must agree on matrix layout"); 1863 1864 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1865 1866 // Helper to perform binary op on vectors. 1867 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1868 switch (Inst->getOpcode()) { 1869 case Instruction::Add: 1870 return Builder.CreateAdd(LHS, RHS); 1871 case Instruction::Mul: 1872 return Builder.CreateMul(LHS, RHS); 1873 case Instruction::Sub: 1874 return Builder.CreateSub(LHS, RHS); 1875 case Instruction::FAdd: 1876 return Builder.CreateFAdd(LHS, RHS); 1877 case Instruction::FMul: 1878 return Builder.CreateFMul(LHS, RHS); 1879 case Instruction::FSub: 1880 return Builder.CreateFSub(LHS, RHS); 1881 default: 1882 llvm_unreachable("Unsupported binary operator for matrix"); 1883 } 1884 }; 1885 1886 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1887 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1888 1889 finalizeLowering(Inst, 1890 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1891 Result.getNumVectors()), 1892 Builder); 1893 return true; 1894 } 1895 1896 /// Lower unary operators, if shape information is available. 1897 bool VisitUnaryOperator(UnaryOperator *Inst) { 1898 auto I = ShapeMap.find(Inst); 1899 if (I == ShapeMap.end()) 1900 return false; 1901 1902 Value *Op = Inst->getOperand(0); 1903 1904 IRBuilder<> Builder(Inst); 1905 ShapeInfo &Shape = I->second; 1906 1907 MatrixTy Result; 1908 MatrixTy M = getMatrix(Op, Shape, Builder); 1909 1910 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1911 1912 // Helper to perform unary op on vectors. 1913 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1914 switch (Inst->getOpcode()) { 1915 case Instruction::FNeg: 1916 return Builder.CreateFNeg(Op); 1917 default: 1918 llvm_unreachable("Unsupported unary operator for matrix"); 1919 } 1920 }; 1921 1922 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1923 Result.addVector(BuildVectorOp(M.getVector(I))); 1924 1925 finalizeLowering(Inst, 1926 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1927 Result.getNumVectors()), 1928 Builder); 1929 return true; 1930 } 1931 1932 /// Helper to linearize a matrix expression tree into a string. Currently 1933 /// matrix expressions are linarized by starting at an expression leaf and 1934 /// linearizing bottom up. 1935 struct ExprLinearizer { 1936 unsigned LengthToBreak = 100; 1937 std::string Str; 1938 raw_string_ostream Stream; 1939 unsigned LineLength = 0; 1940 const DataLayout &DL; 1941 1942 /// Mapping from instructions to matrixes. It is used to identify 1943 /// matrix instructions. 1944 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1945 1946 /// Mapping from values to the leaves of all expressions that the value is 1947 /// part of. 1948 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1949 1950 /// Set of matrix expressions in the scope of a given DISubprogram. 1951 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1952 1953 /// Leaf node of the expression to linearize. 1954 Value *Leaf; 1955 1956 /// Used to keep track of sub-expressions that get reused while linearizing 1957 /// the expression. Re-used sub-expressions are marked as (reused). 1958 SmallPtrSet<Value *, 8> ReusedExprs; 1959 1960 ExprLinearizer(const DataLayout &DL, 1961 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1962 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1963 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1964 Value *Leaf) 1965 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1966 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1967 1968 void indent(unsigned N) { 1969 LineLength += N; 1970 for (unsigned i = 0; i < N; i++) 1971 Stream << " "; 1972 } 1973 1974 void lineBreak() { 1975 Stream << "\n"; 1976 LineLength = 0; 1977 } 1978 1979 void maybeIndent(unsigned Indent) { 1980 if (LineLength >= LengthToBreak) 1981 lineBreak(); 1982 1983 if (LineLength == 0) 1984 indent(Indent); 1985 } 1986 1987 void write(StringRef S) { 1988 LineLength += S.size(); 1989 Stream << S; 1990 } 1991 1992 Value *getUnderlyingObjectThroughLoads(Value *V) { 1993 if (Value *Ptr = getPointerOperand(V)) 1994 return getUnderlyingObjectThroughLoads(Ptr); 1995 else if (V->getType()->isPointerTy()) 1996 return getUnderlyingObject(V); 1997 return V; 1998 } 1999 2000 /// Returns true if \p V is a matrix value in the given subprogram. 2001 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 2002 2003 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 2004 /// \p SS. 2005 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 2006 auto M = Inst2Matrix.find(V); 2007 if (M == Inst2Matrix.end()) 2008 SS << "unknown"; 2009 else { 2010 SS << M->second.getNumRows(); 2011 SS << "x"; 2012 SS << M->second.getNumColumns(); 2013 } 2014 } 2015 2016 /// Write the called function name. Handles calls to llvm.matrix.* 2017 /// specially: we write the name, followed by the dimensions of the input 2018 /// matrixes, followed by the scalar type name. 2019 void writeFnName(CallInst *CI) { 2020 if (!CI->getCalledFunction()) 2021 write("<no called fn>"); 2022 else { 2023 StringRef Name = CI->getCalledFunction()->getName(); 2024 if (!Name.startswith("llvm.matrix")) { 2025 write(Name); 2026 return; 2027 } 2028 auto *II = cast<IntrinsicInst>(CI); 2029 write(Intrinsic::getBaseName(II->getIntrinsicID()) 2030 .drop_front(StringRef("llvm.matrix.").size())); 2031 write("."); 2032 std::string Tmp; 2033 raw_string_ostream SS(Tmp); 2034 2035 switch (II->getIntrinsicID()) { 2036 case Intrinsic::matrix_multiply: 2037 prettyPrintMatrixType(II->getOperand(0), SS); 2038 SS << "."; 2039 prettyPrintMatrixType(II->getOperand(1), SS); 2040 SS << "." << *II->getType()->getScalarType(); 2041 break; 2042 case Intrinsic::matrix_transpose: 2043 prettyPrintMatrixType(II->getOperand(0), SS); 2044 SS << "." << *II->getType()->getScalarType(); 2045 break; 2046 case Intrinsic::matrix_column_major_load: 2047 prettyPrintMatrixType(II, SS); 2048 SS << "." << *II->getType()->getScalarType(); 2049 break; 2050 case Intrinsic::matrix_column_major_store: 2051 prettyPrintMatrixType(II->getOperand(0), SS); 2052 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 2053 break; 2054 default: 2055 llvm_unreachable("Unhandled case"); 2056 } 2057 SS.flush(); 2058 write(Tmp); 2059 } 2060 } 2061 2062 unsigned getNumShapeArgs(CallInst *CI) const { 2063 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 2064 switch (II->getIntrinsicID()) { 2065 case Intrinsic::matrix_multiply: 2066 return 3; 2067 case Intrinsic::matrix_transpose: 2068 return 2; 2069 case Intrinsic::matrix_column_major_load: 2070 case Intrinsic::matrix_column_major_store: 2071 return 3; 2072 default: 2073 return 0; 2074 } 2075 } 2076 return 0; 2077 } 2078 2079 /// Special printing for values: for pointers, we print if they refer to an 2080 /// (function) external address or a stack address, for other values we 2081 /// either print the constant or "scalar"/"matrix" for other values. 2082 void write(Value *V) { 2083 V = getUnderlyingObjectThroughLoads(V); 2084 if (V->getType()->isPointerTy()) { 2085 if (isa<AllocaInst>(V)) { 2086 Stream << "stack addr"; 2087 LineLength += StringRef("stack addr").size(); 2088 } else { 2089 Stream << "addr"; 2090 LineLength += StringRef("addr").size(); 2091 } 2092 if (!V->getName().empty()) { 2093 Stream << " %" << V->getName() << ""; 2094 LineLength += V->getName().size() + 2; 2095 } 2096 return; 2097 } 2098 2099 std::string Tmp; 2100 raw_string_ostream TmpStream(Tmp); 2101 2102 if (auto *CI = dyn_cast<ConstantInt>(V)) 2103 TmpStream << CI->getValue(); 2104 else if (isa<Constant>(V)) 2105 TmpStream << "constant"; 2106 else { 2107 if (isMatrix(V)) 2108 TmpStream << "matrix"; 2109 else 2110 TmpStream << "scalar"; 2111 } 2112 TmpStream.flush(); 2113 Tmp = std::string(StringRef(Tmp).trim()); 2114 LineLength += Tmp.size(); 2115 Stream << Tmp; 2116 } 2117 2118 /// Linearize expression \p Expr starting at an indentation of \p Indent. 2119 /// Expressions that are re-used multiple times are prefixed with (reused) 2120 /// at the re-used root instruction. 2121 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 2122 bool ParentShared) { 2123 auto *I = cast<Instruction>(Expr); 2124 maybeIndent(Indent); 2125 SmallVector<Value *, 8> Ops; 2126 2127 // Is Expr shared with other expression leaves? 2128 bool ExprShared = false; 2129 2130 // Deal with shared subtrees. Mark them as shared, if required. 2131 if (!ParentShared) { 2132 auto SI = Shared.find(Expr); 2133 assert(SI != Shared.end() && SI->second.count(Leaf)); 2134 2135 for (Value *S : SI->second) { 2136 if (S == Leaf) 2137 continue; 2138 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2139 write("shared with remark at line " + std::to_string(DL.getLine()) + 2140 " column " + std::to_string(DL.getCol()) + " ("); 2141 } 2142 ExprShared = SI->second.size() > 1; 2143 } 2144 2145 bool Reused = !ReusedExprs.insert(Expr).second; 2146 if (Reused && !ParentReused) 2147 write("(reused) "); 2148 2149 if (auto *CI = dyn_cast<CallInst>(I)) { 2150 writeFnName(CI); 2151 2152 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2153 } else if (isa<BitCastInst>(Expr)) { 2154 // Special case bitcasts, which are used to materialize matrixes from 2155 // non-matrix ops. 2156 write("matrix"); 2157 return; 2158 } else { 2159 Ops.append(I->value_op_begin(), I->value_op_end()); 2160 write(std::string(I->getOpcodeName())); 2161 } 2162 2163 write(std::string("(")); 2164 2165 unsigned NumOpsToBreak = 1; 2166 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2167 NumOpsToBreak = 2; 2168 2169 for (Value *Op : Ops) { 2170 if (Ops.size() > NumOpsToBreak) 2171 lineBreak(); 2172 2173 maybeIndent(Indent + 1); 2174 if (isMatrix(Op)) 2175 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2176 else 2177 write(Op); 2178 if (Op != Ops.back()) 2179 write(", "); 2180 } 2181 2182 write(")"); 2183 } 2184 2185 const std::string &getResult() { 2186 Stream.flush(); 2187 return Str; 2188 } 2189 }; 2190 2191 /// Generate remarks for matrix operations in a function. To generate remarks 2192 /// for matrix expressions, the following approach is used: 2193 /// 1. Use the inlined-at debug information to group matrix operations to the 2194 /// DISubprograms they are contained in. 2195 /// 2. Collect leaves of matrix expressions (done in 2196 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2197 // mapping. Leaves are lowered matrix instructions without other matrix 2198 // users (like stores) in the current subprogram. 2199 /// 3. For each leaf, create a remark containing a linearizied version of the 2200 /// matrix expression. The expression is linearized by a recursive 2201 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2202 /// that multiple leaves can share sub-expressions. Shared subexpressions 2203 /// are explicitly marked as shared(). 2204 struct RemarkGenerator { 2205 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2206 OptimizationRemarkEmitter &ORE; 2207 Function &Func; 2208 const DataLayout &DL; 2209 2210 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2211 OptimizationRemarkEmitter &ORE, Function &Func) 2212 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2213 DL(Func.getParent()->getDataLayout()) {} 2214 2215 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2216 /// instructions in Inst2Matrix returning void or without any users in 2217 /// \p ExprsInSubprogram. Currently that should only include stores. 2218 SmallVector<Value *, 4> 2219 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2220 SmallVector<Value *, 4> Leaves; 2221 for (auto *Expr : ExprsInSubprogram) 2222 if (Expr->getType()->isVoidTy() || 2223 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2224 return ExprsInSubprogram.count(U); 2225 })) 2226 Leaves.push_back(Expr); 2227 return Leaves; 2228 } 2229 2230 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2231 /// to all visited expressions in \p Shared. Limit the matrix operations to 2232 /// the ones in \p ExprsInSubprogram. 2233 void collectSharedInfo(Value *Leaf, Value *V, 2234 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2235 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2236 2237 if (!ExprsInSubprogram.count(V)) 2238 return; 2239 2240 auto I = Shared.insert({V, {}}); 2241 I.first->second.insert(Leaf); 2242 2243 for (Value *Op : cast<Instruction>(V)->operand_values()) 2244 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2245 } 2246 2247 /// Calculate the number of exclusive and shared op counts for expression 2248 /// starting at \p V. Expressions used multiple times are counted once. 2249 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2250 std::pair<OpInfoTy, OpInfoTy> 2251 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2252 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2253 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2254 if (!ExprsInSubprogram.count(Root)) 2255 return {}; 2256 2257 // Already counted this expression. Stop. 2258 if (!ReusedExprs.insert(Root).second) 2259 return {}; 2260 2261 OpInfoTy SharedCount; 2262 OpInfoTy Count; 2263 2264 auto I = Shared.find(Root); 2265 auto CM = Inst2Matrix.find(Root); 2266 if (I->second.size() == 1) 2267 Count = CM->second.getOpInfo(); 2268 else 2269 SharedCount = CM->second.getOpInfo(); 2270 2271 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2272 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2273 Count += C.first; 2274 SharedCount += C.second; 2275 } 2276 return {Count, SharedCount}; 2277 } 2278 2279 void emitRemarks() { 2280 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2281 return; 2282 2283 // Map matrix operations to their containting subprograms, by traversing 2284 // the inlinedAt chain. If the function does not have a DISubprogram, we 2285 // only map them to the containing function. 2286 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2287 for (const auto &KV : Inst2Matrix) { 2288 if (Func.getSubprogram()) { 2289 auto *I = cast<Instruction>(KV.first); 2290 DILocation *Context = I->getDebugLoc(); 2291 while (Context) { 2292 auto I = 2293 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 2294 I.first->second.push_back(KV.first); 2295 Context = DebugLoc(Context).getInlinedAt(); 2296 } 2297 } else { 2298 auto I = Subprog2Exprs.insert({nullptr, {}}); 2299 I.first->second.push_back(KV.first); 2300 } 2301 } 2302 for (auto &KV : Subprog2Exprs) { 2303 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2304 KV.second.end()); 2305 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2306 2307 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2308 for (Value *Leaf : Leaves) 2309 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2310 2311 // Generate remarks for each leaf. 2312 for (auto *L : Leaves) { 2313 2314 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2315 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2316 while (Context) { 2317 if (getSubprogram(Context->getScope()) == KV.first) { 2318 Loc = Context; 2319 break; 2320 } 2321 Context = DebugLoc(Context).getInlinedAt(); 2322 } 2323 2324 SmallPtrSet<Value *, 8> ReusedExprs; 2325 OpInfoTy Counts, SharedCounts; 2326 std::tie(Counts, SharedCounts) = 2327 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2328 2329 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2330 cast<Instruction>(L)->getParent()); 2331 2332 Rem << "Lowered with "; 2333 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2334 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2335 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2336 << " compute ops, " 2337 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2338 << " exposed transposes"; 2339 2340 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2341 SharedCounts.NumComputeOps > 0) { 2342 Rem << ",\nadditionally " 2343 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2344 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2345 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2346 << " compute ops" 2347 << " are shared with other expressions"; 2348 } 2349 2350 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2351 ORE.emit(Rem); 2352 } 2353 } 2354 } 2355 2356 std::string 2357 linearize(Value *L, 2358 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2359 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2360 const DataLayout &DL) { 2361 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2362 Lin.linearizeExpr(L, 0, false, false); 2363 return Lin.getResult(); 2364 } 2365 }; 2366 }; 2367 } // namespace 2368 2369 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2370 FunctionAnalysisManager &AM) { 2371 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2372 OptimizationRemarkEmitter *ORE = nullptr; 2373 AAResults *AA = nullptr; 2374 DominatorTree *DT = nullptr; 2375 LoopInfo *LI = nullptr; 2376 2377 if (!Minimal) { 2378 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2379 AA = &AM.getResult<AAManager>(F); 2380 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2381 LI = &AM.getResult<LoopAnalysis>(F); 2382 } 2383 2384 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2385 if (LMT.Visit()) { 2386 PreservedAnalyses PA; 2387 if (!Minimal) { 2388 PA.preserve<LoopAnalysis>(); 2389 PA.preserve<DominatorTreeAnalysis>(); 2390 } 2391 return PA; 2392 } 2393 return PreservedAnalyses::all(); 2394 } 2395 2396 void LowerMatrixIntrinsicsPass::printPipeline( 2397 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2398 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2399 OS, MapClassName2PassName); 2400 OS << "<"; 2401 if (Minimal) 2402 OS << "minimal"; 2403 OS << ">"; 2404 } 2405 2406 namespace { 2407 2408 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2409 public: 2410 static char ID; 2411 2412 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2413 initializeLowerMatrixIntrinsicsLegacyPassPass( 2414 *PassRegistry::getPassRegistry()); 2415 } 2416 2417 bool runOnFunction(Function &F) override { 2418 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2419 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 2420 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 2421 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2422 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2423 LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2424 bool C = LMT.Visit(); 2425 return C; 2426 } 2427 2428 void getAnalysisUsage(AnalysisUsage &AU) const override { 2429 AU.addRequired<TargetTransformInfoWrapperPass>(); 2430 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 2431 AU.addRequired<AAResultsWrapperPass>(); 2432 AU.addRequired<DominatorTreeWrapperPass>(); 2433 AU.addPreserved<DominatorTreeWrapperPass>(); 2434 AU.addRequired<LoopInfoWrapperPass>(); 2435 AU.addPreserved<LoopInfoWrapperPass>(); 2436 } 2437 }; 2438 } // namespace 2439 2440 static const char pass_name[] = "Lower the matrix intrinsics"; 2441 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2442 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2443 false, false) 2444 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 2445 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 2446 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 2447 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2448 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2449 false, false) 2450 2451 Pass *llvm::createLowerMatrixIntrinsicsPass() { 2452 return new LowerMatrixIntrinsicsLegacyPass(); 2453 } 2454 2455 namespace { 2456 2457 /// A lightweight version of the matrix lowering pass that only requires TTI. 2458 /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2459 /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2460 /// example with -O0. 2461 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2462 public: 2463 static char ID; 2464 2465 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2466 initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2467 *PassRegistry::getPassRegistry()); 2468 } 2469 2470 bool runOnFunction(Function &F) override { 2471 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2472 LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2473 bool C = LMT.Visit(); 2474 return C; 2475 } 2476 2477 void getAnalysisUsage(AnalysisUsage &AU) const override { 2478 AU.addRequired<TargetTransformInfoWrapperPass>(); 2479 AU.setPreservesCFG(); 2480 } 2481 }; 2482 } // namespace 2483 2484 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2485 char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2486 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2487 "lower-matrix-intrinsics-minimal", pass_name_minimal, 2488 false, false) 2489 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2490 "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2491 false) 2492 2493 Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2494 return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2495 } 2496