1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "llvm/Analysis/AliasAnalysis.h" 24 #include "llvm/Analysis/DomTreeUpdater.h" 25 #include "llvm/Analysis/LoopInfo.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/MatrixBuilder.h" 38 #include "llvm/IR/PatternMatch.h" 39 #include "llvm/InitializePasses.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Alignment.h" 42 #include "llvm/Support/CommandLine.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Scalar.h" 45 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46 #include "llvm/Transforms/Utils/LoopUtils.h" 47 #include "llvm/Transforms/Utils/MatrixUtils.h" 48 49 using namespace llvm; 50 using namespace PatternMatch; 51 52 #define DEBUG_TYPE "lower-matrix-intrinsics" 53 54 static cl::opt<bool> 55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 56 cl::desc("Enable/disable fusing matrix instructions.")); 57 // TODO: Allow and use non-square tiles. 58 static cl::opt<unsigned> TileSize( 59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 60 cl::desc( 61 "Tile size for matrix instruction fusion using square-shaped tiles.")); 62 static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), 63 cl::Hidden, 64 cl::desc("Generate loop nest for tiling.")); 65 static cl::opt<bool> ForceFusion( 66 "force-fuse-matrix", cl::init(false), cl::Hidden, 67 cl::desc("Force matrix instruction fusion even if not profitable.")); 68 static cl::opt<bool> AllowContractEnabled( 69 "matrix-allow-contract", cl::init(false), cl::Hidden, 70 cl::desc("Allow the use of FMAs if available and profitable. This may " 71 "result in different results, due to less rounding error.")); 72 73 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 74 75 static cl::opt<MatrixLayoutTy> MatrixLayout( 76 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 77 cl::desc("Sets the default matrix layout"), 78 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 79 "Use column-major layout"), 80 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 81 "Use row-major layout"))); 82 83 /// Helper function to either return Scope, if it is a subprogram or the 84 /// attached subprogram for a local scope. 85 static DISubprogram *getSubprogram(DIScope *Scope) { 86 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 87 return Subprogram; 88 return cast<DILocalScope>(Scope)->getSubprogram(); 89 } 90 91 namespace { 92 93 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 94 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 95 // assuming \p Stride elements between start two consecutive vectors. 96 // \p Stride must be >= \p NumElements. 97 // For column-major matrixes, the function computes the address of a column 98 // vectors and \p NumElements must be set to the number of elements in a column 99 // (= number of rows of the matrix). For row-major matrixes, the function 100 // computes the address of a row vector and \p NumElements must be set to the 101 // number of elements in a column (= number of columns of the matrix). 102 // 103 // Consider a 4x4 matrix in column-mjaor layout like below 104 // 105 // 0 1 2 3 106 // 0 v_0_0 v_0_1 v_0_2 v_0_3 107 // 1 v_1_0 v_1_1 v_1_2 v_1_3 108 // 2 v_2_0 v_2_1 v_2_2 v_2_3 109 // 3 v_3_0 v_3_1 v_3_2 v_3_3 110 111 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 112 // we need a pointer to the first element of the submatrix as base pointer. 113 // Then we can use computeVectorAddr to compute the addresses for the columns 114 // of the sub-matrix. 115 // 116 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 117 // -> just returns Base 118 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 119 // -> returns Base + (1 * 4) 120 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 121 // -> returns Base + (2 * 4) 122 // 123 // The graphic below illustrates the number of elements in a column (marked 124 // with |) and the number of skipped elements (marked with }). 125 // 126 // v_0_0 v_0_1 {v_0_2 {v_0_3 127 // Base Col 1 Col 2 128 // | | | 129 // v_1_0 |v_1_1 |v_1_2 |v_1_3 130 // v_2_0 |v_2_1 |v_2_2 |v_2_3 131 // v_3_0 {v_3_1 {v_3_2 v_3_3 132 // 133 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 134 unsigned NumElements, Type *EltType, 135 IRBuilder<> &Builder) { 136 137 assert((!isa<ConstantInt>(Stride) || 138 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 139 "Stride must be >= the number of elements in the result vector."); 140 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 141 142 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 143 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 144 145 // Get pointer to the start of the selected vector. Skip GEP creation, 146 // if we select vector 0. 147 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 148 VecStart = BasePtr; 149 else 150 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 151 152 // Cast elementwise vector start pointer to a pointer to a vector 153 // (EltType x NumElements)*. 154 auto *VecType = FixedVectorType::get(EltType, NumElements); 155 Type *VecPtrType = PointerType::get(VecType, AS); 156 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 157 } 158 159 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 160 /// 161 /// Currently, the lowering for each matrix intrinsic is done as follows: 162 /// 1. Propagate the shape information from intrinsics to connected 163 /// instructions. 164 /// 2. Lower instructions with shape information (assuming column-major layout). 165 /// The lowering works similarly using row-major layout. 166 /// 2.1. Get column vectors for each argument. If we already lowered the 167 /// definition of an argument, use the produced column vectors directly. 168 /// If not, split the operand vector containing an embedded matrix into 169 /// a set of column vectors, 170 /// 2.2. Lower the instruction in terms of column major operations, which 171 /// yields a set of column vectors containing result matrix. Note that we 172 /// lower all instructions that have shape information. Besides the 173 /// intrinsics, this includes stores for example. 174 /// 2.3. Update uses of the lowered instruction. If we have shape information 175 /// for a user, there is nothing to do, as we will look up the result 176 /// column matrix when lowering the user. For other uses, we embed the 177 /// result matrix in a flat vector and update the use. 178 /// 2.4. Cache the result column matrix for the instruction we lowered 179 /// 3. After we lowered all instructions in a function, remove the now 180 /// obsolete instructions. 181 /// 182 class LowerMatrixIntrinsics { 183 Function &Func; 184 const DataLayout &DL; 185 const TargetTransformInfo &TTI; 186 AliasAnalysis *AA; 187 DominatorTree *DT; 188 LoopInfo *LI; 189 OptimizationRemarkEmitter *ORE; 190 191 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 192 struct OpInfoTy { 193 /// Number of stores emitted to generate this matrix. 194 unsigned NumStores = 0; 195 /// Number of loads emitted to generate this matrix. 196 unsigned NumLoads = 0; 197 /// Number of compute operations emitted to generate this matrix. 198 unsigned NumComputeOps = 0; 199 /// Most of the time transposes can be fused with matrix multiplies or can 200 /// be folded away via algebraic simplifications. This is the number of 201 /// transposes that we failed to make "free" via such optimizations. 202 unsigned NumExposedTransposes = 0; 203 204 OpInfoTy &operator+=(const OpInfoTy &RHS) { 205 NumStores += RHS.NumStores; 206 NumLoads += RHS.NumLoads; 207 NumComputeOps += RHS.NumComputeOps; 208 NumExposedTransposes += RHS.NumExposedTransposes; 209 return *this; 210 } 211 }; 212 213 /// Wrapper class representing a matrix as a set of vectors, either in row or 214 /// column major layout. All vectors must have the same vector type. 215 class MatrixTy { 216 SmallVector<Value *, 16> Vectors; 217 218 OpInfoTy OpInfo; 219 220 bool IsColumnMajor = true; 221 222 public: 223 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 224 MatrixTy(ArrayRef<Value *> Vectors) 225 : Vectors(Vectors.begin(), Vectors.end()), 226 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 227 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 228 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 229 230 unsigned D = isColumnMajor() ? NumColumns : NumRows; 231 for (unsigned J = 0; J < D; ++J) 232 addVector(UndefValue::get(FixedVectorType::get( 233 EltTy, isColumnMajor() ? NumRows : NumColumns))); 234 } 235 236 Value *getVector(unsigned i) const { return Vectors[i]; } 237 Value *getColumn(unsigned i) const { 238 assert(isColumnMajor() && "only supported for column-major matrixes"); 239 return Vectors[i]; 240 } 241 Value *getRow(unsigned i) const { 242 assert(!isColumnMajor() && "only supported for row-major matrixes"); 243 return Vectors[i]; 244 } 245 246 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 247 248 Type *getElementType() const { return getVectorTy()->getElementType(); } 249 250 unsigned getNumVectors() const { 251 if (isColumnMajor()) 252 return getNumColumns(); 253 return getNumRows(); 254 } 255 256 unsigned getNumColumns() const { 257 if (isColumnMajor()) 258 return Vectors.size(); 259 else { 260 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 261 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 262 } 263 } 264 unsigned getNumRows() const { 265 if (isColumnMajor()) { 266 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 267 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 268 } else 269 return Vectors.size(); 270 } 271 272 void addVector(Value *V) { Vectors.push_back(V); } 273 VectorType *getColumnTy() { 274 assert(isColumnMajor() && "only supported for column-major matrixes"); 275 return getVectorTy(); 276 } 277 278 VectorType *getVectorTy() const { 279 return cast<VectorType>(Vectors[0]->getType()); 280 } 281 282 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 283 assert(isColumnMajor() && 284 "columns() only supported for column-major matrixes"); 285 return make_range(Vectors.begin(), Vectors.end()); 286 } 287 288 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 289 return make_range(Vectors.begin(), Vectors.end()); 290 } 291 292 /// Embed the vectors of the matrix into a flat vector by concatenating 293 /// them. 294 Value *embedInVector(IRBuilder<> &Builder) const { 295 return Vectors.size() == 1 ? Vectors[0] 296 : concatenateVectors(Builder, Vectors); 297 } 298 299 MatrixTy &addNumLoads(unsigned N) { 300 OpInfo.NumLoads += N; 301 return *this; 302 } 303 304 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 305 306 MatrixTy &addNumStores(unsigned N) { 307 OpInfo.NumStores += N; 308 return *this; 309 } 310 311 MatrixTy &addNumExposedTransposes(unsigned N) { 312 OpInfo.NumExposedTransposes += N; 313 return *this; 314 } 315 316 MatrixTy &addNumComputeOps(unsigned N) { 317 OpInfo.NumComputeOps += N; 318 return *this; 319 } 320 321 unsigned getNumStores() const { return OpInfo.NumStores; } 322 unsigned getNumLoads() const { return OpInfo.NumLoads; } 323 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 324 325 const OpInfoTy &getOpInfo() const { return OpInfo; } 326 327 bool isColumnMajor() const { return IsColumnMajor; } 328 329 unsigned getStride() const { 330 if (isColumnMajor()) 331 return getNumRows(); 332 return getNumColumns(); 333 } 334 335 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 336 /// matrix is column-major, the result vector is extracted from a column 337 /// vector, otherwise from a row vector. 338 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 339 IRBuilder<> &Builder) const { 340 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 341 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >= 342 NumElts && 343 "Extracted vector will contain poison values"); 344 return Builder.CreateShuffleVector( 345 Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 346 "block"); 347 } 348 }; 349 350 struct ShapeInfo { 351 unsigned NumRows; 352 unsigned NumColumns; 353 354 bool IsColumnMajor; 355 356 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 357 : NumRows(NumRows), NumColumns(NumColumns), 358 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 359 360 ShapeInfo(Value *NumRows, Value *NumColumns) 361 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 362 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 363 364 bool operator==(const ShapeInfo &other) { 365 return NumRows == other.NumRows && NumColumns == other.NumColumns; 366 } 367 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 368 369 /// Returns true if shape-information is defined, meaning both dimensions 370 /// are != 0. 371 operator bool() const { 372 assert(NumRows == 0 || NumColumns != 0); 373 return NumRows != 0; 374 } 375 376 unsigned getStride() const { 377 if (IsColumnMajor) 378 return NumRows; 379 return NumColumns; 380 } 381 382 unsigned getNumVectors() const { 383 if (IsColumnMajor) 384 return NumColumns; 385 return NumRows; 386 } 387 }; 388 389 /// Maps instructions to their shape information. The shape information 390 /// describes the shape to be used while lowering. This matches the shape of 391 /// the result value of the instruction, with the only exceptions being store 392 /// instructions and the matrix_column_major_store intrinsics. For those, the 393 /// shape information indicates that those instructions should be lowered 394 /// using shape information as well. A ValueMap is used so that when 395 /// sub-passes like optimizeTransposes performs RAUW the map stays 396 /// up-to-date. 397 ValueMap<Value *, ShapeInfo> ShapeMap; 398 399 /// List of instructions to remove. While lowering, we are not replacing all 400 /// users of a lowered instruction, if shape information is available and 401 /// those need to be removed after we finished lowering. 402 SmallVector<Instruction *, 16> ToRemove; 403 404 /// Map from instructions to their produced column matrix. 405 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 406 407 private: 408 static FastMathFlags getFastMathFlags(Instruction *Inst) { 409 FastMathFlags FMF; 410 411 if (isa<FPMathOperator>(*Inst)) 412 FMF = Inst->getFastMathFlags(); 413 414 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract()); 415 416 return FMF; 417 } 418 419 public: 420 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 421 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, 422 OptimizationRemarkEmitter *ORE) 423 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 424 LI(LI), ORE(ORE) {} 425 426 unsigned getNumOps(Type *VT) { 427 assert(isa<VectorType>(VT) && "Expected vector type"); 428 return getNumOps(VT->getScalarType(), 429 cast<FixedVectorType>(VT)->getNumElements()); 430 } 431 432 /// Is this the minimal version executed in the backend pipelines. 433 bool isMinimal() const { 434 return !DT; 435 } 436 437 /// Return the estimated number of vector ops required for an operation on 438 /// \p VT * N. 439 unsigned getNumOps(Type *ST, unsigned N) { 440 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 441 double(TTI.getRegisterBitWidth( 442 TargetTransformInfo::RGK_FixedWidthVector) 443 .getFixedSize())); 444 } 445 446 /// Return the set of vectors that a matrix value is lowered to. 447 /// 448 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 449 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 450 /// into vectors. 451 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 452 IRBuilder<> &Builder) { 453 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 454 assert(VType && "MatrixVal must be a vector type"); 455 assert(cast<FixedVectorType>(VType)->getNumElements() == 456 SI.NumRows * SI.NumColumns && 457 "The vector size must match the number of matrix elements"); 458 459 // Check if we lowered MatrixVal using shape information. In that case, 460 // return the existing matrix, if it matches the requested shape 461 // information. If there is a mis-match, embed the result in a flat 462 // vector and split it later. 463 auto Found = Inst2ColumnMatrix.find(MatrixVal); 464 if (Found != Inst2ColumnMatrix.end()) { 465 MatrixTy &M = Found->second; 466 // Return the found matrix, if its shape matches the requested shape 467 // information 468 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 469 return M; 470 471 MatrixVal = M.embedInVector(Builder); 472 } 473 474 // Otherwise split MatrixVal. 475 SmallVector<Value *, 16> SplitVecs; 476 for (unsigned MaskStart = 0; 477 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 478 MaskStart += SI.getStride()) { 479 Value *V = Builder.CreateShuffleVector( 480 MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), 481 "split"); 482 SplitVecs.push_back(V); 483 } 484 485 return {SplitVecs}; 486 } 487 488 /// If \p V already has a known shape return false. Otherwise set the shape 489 /// for instructions that support it. 490 bool setShapeInfo(Value *V, ShapeInfo Shape) { 491 assert(Shape && "Shape not set"); 492 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 493 return false; 494 495 auto SIter = ShapeMap.find(V); 496 if (SIter != ShapeMap.end()) { 497 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 498 << SIter->second.NumRows << " " 499 << SIter->second.NumColumns << " for " << *V << "\n"); 500 return false; 501 } 502 503 ShapeMap.insert({V, Shape}); 504 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 505 << " for " << *V << "\n"); 506 return true; 507 } 508 509 bool isUniformShape(Value *V) { 510 Instruction *I = dyn_cast<Instruction>(V); 511 if (!I) 512 return true; 513 514 switch (I->getOpcode()) { 515 case Instruction::FAdd: 516 case Instruction::FSub: 517 case Instruction::FMul: // Scalar multiply. 518 case Instruction::FNeg: 519 case Instruction::Add: 520 case Instruction::Mul: 521 case Instruction::Sub: 522 return true; 523 default: 524 return false; 525 } 526 } 527 528 /// Returns true if shape information can be used for \p V. The supported 529 /// instructions must match the instructions that can be lowered by this pass. 530 bool supportsShapeInfo(Value *V) { 531 Instruction *Inst = dyn_cast<Instruction>(V); 532 if (!Inst) 533 return false; 534 535 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 536 if (II) 537 switch (II->getIntrinsicID()) { 538 case Intrinsic::matrix_multiply: 539 case Intrinsic::matrix_transpose: 540 case Intrinsic::matrix_column_major_load: 541 case Intrinsic::matrix_column_major_store: 542 return true; 543 default: 544 return false; 545 } 546 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 547 } 548 549 /// Propagate the shape information of instructions to their users. 550 /// The work list contains instructions for which we can compute the shape, 551 /// either based on the information provided by matrix intrinsics or known 552 /// shapes of operands. 553 SmallVector<Instruction *, 32> 554 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 555 SmallVector<Instruction *, 32> NewWorkList; 556 // Pop an element for which we guaranteed to have at least one of the 557 // operand shapes. Add the shape for this and then add users to the work 558 // list. 559 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 560 while (!WorkList.empty()) { 561 Instruction *Inst = WorkList.pop_back_val(); 562 563 // New entry, set the value and insert operands 564 bool Propagate = false; 565 566 Value *MatrixA; 567 Value *MatrixB; 568 Value *M; 569 Value *N; 570 Value *K; 571 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 572 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 573 m_Value(N), m_Value(K)))) { 574 Propagate = setShapeInfo(Inst, {M, K}); 575 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 576 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 577 // Flip dimensions. 578 Propagate = setShapeInfo(Inst, {N, M}); 579 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 580 m_Value(MatrixA), m_Value(), m_Value(), 581 m_Value(), m_Value(M), m_Value(N)))) { 582 Propagate = setShapeInfo(Inst, {N, M}); 583 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 584 m_Value(), m_Value(), m_Value(), m_Value(M), 585 m_Value(N)))) { 586 Propagate = setShapeInfo(Inst, {M, N}); 587 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 588 auto OpShape = ShapeMap.find(MatrixA); 589 if (OpShape != ShapeMap.end()) 590 setShapeInfo(Inst, OpShape->second); 591 continue; 592 } else if (isUniformShape(Inst)) { 593 // Find the first operand that has a known shape and use that. 594 for (auto &Op : Inst->operands()) { 595 auto OpShape = ShapeMap.find(Op.get()); 596 if (OpShape != ShapeMap.end()) { 597 Propagate |= setShapeInfo(Inst, OpShape->second); 598 break; 599 } 600 } 601 } 602 603 if (Propagate) { 604 NewWorkList.push_back(Inst); 605 for (auto *User : Inst->users()) 606 if (ShapeMap.count(User) == 0) 607 WorkList.push_back(cast<Instruction>(User)); 608 } 609 } 610 611 return NewWorkList; 612 } 613 614 /// Propagate the shape to operands of instructions with shape information. 615 /// \p Worklist contains the instruction for which we already know the shape. 616 SmallVector<Instruction *, 32> 617 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 618 SmallVector<Instruction *, 32> NewWorkList; 619 620 auto pushInstruction = [](Value *V, 621 SmallVectorImpl<Instruction *> &WorkList) { 622 Instruction *I = dyn_cast<Instruction>(V); 623 if (I) 624 WorkList.push_back(I); 625 }; 626 // Pop an element with known shape. Traverse the operands, if their shape 627 // derives from the result shape and is unknown, add it and add them to the 628 // worklist. 629 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 630 while (!WorkList.empty()) { 631 Value *V = WorkList.pop_back_val(); 632 633 size_t BeforeProcessingV = WorkList.size(); 634 if (!isa<Instruction>(V)) 635 continue; 636 637 Value *MatrixA; 638 Value *MatrixB; 639 Value *M; 640 Value *N; 641 Value *K; 642 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 643 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 644 m_Value(N), m_Value(K)))) { 645 if (setShapeInfo(MatrixA, {M, N})) 646 pushInstruction(MatrixA, WorkList); 647 648 if (setShapeInfo(MatrixB, {N, K})) 649 pushInstruction(MatrixB, WorkList); 650 651 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 652 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 653 // Flip dimensions. 654 if (setShapeInfo(MatrixA, {M, N})) 655 pushInstruction(MatrixA, WorkList); 656 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 657 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 658 m_Value(M), m_Value(N)))) { 659 if (setShapeInfo(MatrixA, {M, N})) { 660 pushInstruction(MatrixA, WorkList); 661 } 662 } else if (isa<LoadInst>(V) || 663 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 664 // Nothing to do, no matrix input. 665 } else if (isa<StoreInst>(V)) { 666 // Nothing to do. We forward-propagated to this so we would just 667 // backward propagate to an instruction with an already known shape. 668 } else if (isUniformShape(V)) { 669 // Propagate to all operands. 670 ShapeInfo Shape = ShapeMap[V]; 671 for (Use &U : cast<Instruction>(V)->operands()) { 672 if (setShapeInfo(U.get(), Shape)) 673 pushInstruction(U.get(), WorkList); 674 } 675 } 676 // After we discovered new shape info for new instructions in the 677 // worklist, we use their users as seeds for the next round of forward 678 // propagation. 679 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 680 for (User *U : WorkList[I]->users()) 681 if (isa<Instruction>(U) && V != U) 682 NewWorkList.push_back(cast<Instruction>(U)); 683 } 684 return NewWorkList; 685 } 686 687 /// Try moving transposes in order to fold them away or into multiplies. 688 void optimizeTransposes() { 689 auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { 690 // We need to remove Old from the ShapeMap otherwise RAUW will replace it 691 // with New. We should only add New it it supportsShapeInfo so we insert 692 // it conditionally instead. 693 auto S = ShapeMap.find(&Old); 694 if (S != ShapeMap.end()) { 695 ShapeMap.erase(S); 696 if (supportsShapeInfo(New)) 697 ShapeMap.insert({New, S->second}); 698 } 699 Old.replaceAllUsesWith(New); 700 }; 701 702 // First sink all transposes inside matmuls, hoping that we end up with NN, 703 // NT or TN variants. 704 for (BasicBlock &BB : reverse(Func)) { 705 for (auto II = BB.rbegin(); II != BB.rend();) { 706 Instruction &I = *II; 707 // We may remove II. By default continue on the next/prev instruction. 708 ++II; 709 // If we were to erase II, move again. 710 auto EraseFromParent = [&II, &BB](Value *V) { 711 auto *Inst = cast<Instruction>(V); 712 if (Inst->use_empty()) { 713 if (II != BB.rend() && Inst == &*II) { 714 ++II; 715 } 716 Inst->eraseFromParent(); 717 } 718 }; 719 720 // If we're creating a new instruction, continue from there. 721 Instruction *NewInst = nullptr; 722 723 IRBuilder<> IB(&I); 724 MatrixBuilder Builder(IB); 725 726 Value *TA, *TAMA, *TAMB; 727 ConstantInt *R, *K, *C; 728 if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) { 729 730 // Transpose of a transpose is a nop 731 Value *TATA; 732 if (match(TA, 733 m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { 734 ReplaceAllUsesWith(I, TATA); 735 EraseFromParent(&I); 736 EraseFromParent(TA); 737 } 738 739 // (A * B)^t -> B^t * A^t 740 // RxK KxC CxK KxR 741 else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( 742 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), 743 m_ConstantInt(K), m_ConstantInt(C)))) { 744 Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), 745 C->getZExtValue(), 746 TAMB->getName() + "_t"); 747 // We are being run after shape prop, add shape for newly created 748 // instructions so that we lower them later. 749 setShapeInfo(T0, {C, K}); 750 Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), 751 K->getZExtValue(), 752 TAMA->getName() + "_t"); 753 setShapeInfo(T1, {K, R}); 754 NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), 755 K->getZExtValue(), 756 R->getZExtValue(), "mmul"); 757 ReplaceAllUsesWith(I, NewInst); 758 EraseFromParent(&I); 759 EraseFromParent(TA); 760 } 761 } 762 763 // If we replaced I with a new instruction, continue from there. 764 if (NewInst) 765 II = std::next(BasicBlock::reverse_iterator(NewInst)); 766 } 767 } 768 769 // If we have a TT matmul, lift the transpose. We may be able to fold into 770 // consuming multiply. 771 for (BasicBlock &BB : Func) { 772 for (Instruction &I : llvm::make_early_inc_range(BB)) { 773 Value *A, *B, *AT, *BT; 774 ConstantInt *R, *K, *C; 775 // A^t * B ^t -> (B * A)^t 776 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( 777 m_Value(A), m_Value(B), m_ConstantInt(R), 778 m_ConstantInt(K), m_ConstantInt(C))) && 779 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && 780 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { 781 IRBuilder<> IB(&I); 782 MatrixBuilder Builder(IB); 783 Value *M = Builder.CreateMatrixMultiply( 784 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); 785 setShapeInfo(M, {C, R}); 786 Instruction *NewInst = Builder.CreateMatrixTranspose( 787 M, C->getZExtValue(), R->getZExtValue()); 788 ReplaceAllUsesWith(I, NewInst); 789 if (I.use_empty()) 790 I.eraseFromParent(); 791 if (A->use_empty()) 792 cast<Instruction>(A)->eraseFromParent(); 793 if (A != B && B->use_empty()) 794 cast<Instruction>(B)->eraseFromParent(); 795 } 796 } 797 } 798 } 799 800 bool Visit() { 801 SmallVector<Instruction *, 32> WorkList; 802 803 // Initially only the shape of matrix intrinsics is known. 804 // Initialize the work list with ops carrying shape information. 805 for (BasicBlock &BB : Func) 806 for (Instruction &Inst : BB) { 807 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 808 if (!II) 809 continue; 810 811 switch (II->getIntrinsicID()) { 812 case Intrinsic::matrix_multiply: 813 case Intrinsic::matrix_transpose: 814 case Intrinsic::matrix_column_major_load: 815 case Intrinsic::matrix_column_major_store: 816 WorkList.push_back(&Inst); 817 break; 818 default: 819 break; 820 } 821 } 822 823 // Avoid unnecessary work if there are no matrix intrinsics in the function. 824 if (WorkList.empty()) 825 return false; 826 827 // Propagate shapes until nothing changes any longer. 828 while (!WorkList.empty()) { 829 WorkList = propagateShapeForward(WorkList); 830 WorkList = propagateShapeBackward(WorkList); 831 } 832 833 if (!isMinimal()) { 834 optimizeTransposes(); 835 LLVM_DEBUG({ 836 dbgs() << "Dump after matrix transpose optimization:\n"; 837 Func.dump(); 838 }); 839 } 840 841 bool Changed = false; 842 SmallVector<CallInst *, 16> MaybeFusableInsts; 843 SmallVector<Instruction *, 16> MatrixInsts; 844 845 // First, collect all instructions with shape information and candidates for 846 // fusion (currently only matrix multiplies). 847 ReversePostOrderTraversal<Function *> RPOT(&Func); 848 for (auto *BB : RPOT) 849 for (Instruction &I : *BB) { 850 if (ShapeMap.find(&I) == ShapeMap.end()) 851 continue; 852 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 853 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 854 MatrixInsts.push_back(&I); 855 } 856 857 // Second, try to fuse candidates. 858 SmallPtrSet<Instruction *, 16> FusedInsts; 859 for (CallInst *CI : MaybeFusableInsts) 860 LowerMatrixMultiplyFused(CI, FusedInsts); 861 Changed = !FusedInsts.empty(); 862 863 // Third, lower remaining instructions with shape information. 864 for (Instruction *Inst : MatrixInsts) { 865 if (FusedInsts.count(Inst)) 866 continue; 867 868 IRBuilder<> Builder(Inst); 869 870 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 871 Changed |= VisitCallInst(CInst); 872 873 Value *Op1; 874 Value *Op2; 875 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 876 Changed |= VisitBinaryOperator(BinOp); 877 if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) 878 Changed |= VisitUnaryOperator(UnOp); 879 if (match(Inst, m_Load(m_Value(Op1)))) 880 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 881 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 882 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 883 } 884 885 if (ORE) { 886 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); 887 RemarkGen.emitRemarks(); 888 } 889 890 // Delete the instructions backwards, as it has a reduced likelihood of 891 // having to update as many def-use and use-def chains. 892 // 893 // Because we add to ToRemove during fusion we can't guarantee that defs 894 // are before uses. Change uses to poison temporarily as these should get 895 // removed as well. 896 // 897 // For verification, we keep track of where we changed uses to poison in 898 // PoisonedInsts and then check that we in fact remove them. 899 SmallSet<Instruction *, 16> PoisonedInsts; 900 for (auto *Inst : reverse(ToRemove)) { 901 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 902 if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) 903 PoisonedInsts.insert(Poisoned); 904 U.set(PoisonValue::get(Inst->getType())); 905 } 906 Inst->eraseFromParent(); 907 PoisonedInsts.erase(Inst); 908 } 909 if (!PoisonedInsts.empty()) { 910 // If we didn't remove all poisoned instructions, it's a hard error. 911 dbgs() << "Poisoned but present instructions:\n"; 912 for (auto *I : PoisonedInsts) 913 dbgs() << *I << "\n"; 914 llvm_unreachable("Poisoned but instruction not removed"); 915 } 916 917 return Changed; 918 } 919 920 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 921 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 922 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 923 Type *EltPtrType = PointerType::get(EltType, AS); 924 return Builder.CreatePointerCast(BasePtr, EltPtrType); 925 } 926 927 /// Replace intrinsic calls 928 bool VisitCallInst(CallInst *Inst) { 929 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 930 return false; 931 932 switch (Inst->getCalledFunction()->getIntrinsicID()) { 933 case Intrinsic::matrix_multiply: 934 LowerMultiply(Inst); 935 break; 936 case Intrinsic::matrix_transpose: 937 LowerTranspose(Inst); 938 break; 939 case Intrinsic::matrix_column_major_load: 940 LowerColumnMajorLoad(Inst); 941 break; 942 case Intrinsic::matrix_column_major_store: 943 LowerColumnMajorStore(Inst); 944 break; 945 default: 946 return false; 947 } 948 return true; 949 } 950 951 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 952 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 953 /// ConstantInt, reduce the initial alignment based on the byte offset. For 954 /// non-ConstantInt strides, return the common alignment of the initial 955 /// alignment and the element size in bytes. 956 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 957 MaybeAlign A) const { 958 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 959 if (Idx == 0) 960 return InitialAlign; 961 962 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 963 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 964 uint64_t StrideInBytes = 965 ConstStride->getZExtValue() * ElementSizeInBits / 8; 966 return commonAlignment(InitialAlign, Idx * StrideInBytes); 967 } 968 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 969 } 970 971 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 972 /// vectors. 973 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 974 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 975 auto *VType = cast<VectorType>(Ty); 976 Type *EltTy = VType->getElementType(); 977 Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); 978 Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); 979 MatrixTy Result; 980 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 981 Value *GEP = computeVectorAddr( 982 EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), 983 Stride, Shape.getStride(), EltTy, Builder); 984 Value *Vector = Builder.CreateAlignedLoad( 985 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), 986 IsVolatile, "col.load"); 987 988 Result.addVector(Vector); 989 } 990 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 991 Result.getNumVectors()); 992 } 993 994 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 995 /// starting at \p MatrixPtr[I][J]. 996 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 997 ShapeInfo MatrixShape, Value *I, Value *J, 998 ShapeInfo ResultShape, Type *EltTy, 999 IRBuilder<> &Builder) { 1000 1001 Value *Offset = Builder.CreateAdd( 1002 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1003 1004 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1005 Value *EltPtr = 1006 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1007 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1008 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 1009 ResultShape.NumColumns); 1010 Type *TilePtrTy = PointerType::get(TileTy, AS); 1011 Value *TilePtr = 1012 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1013 1014 return loadMatrix(TileTy, TilePtr, Align, 1015 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 1016 ResultShape, Builder); 1017 } 1018 1019 /// Lower a load instruction with shape information. 1020 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 1021 bool IsVolatile, ShapeInfo Shape) { 1022 IRBuilder<> Builder(Inst); 1023 finalizeLowering(Inst, 1024 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 1025 Shape, Builder), 1026 Builder); 1027 } 1028 1029 /// Lowers llvm.matrix.column.major.load. 1030 /// 1031 /// The intrinsic loads a matrix from memory using a stride between columns. 1032 void LowerColumnMajorLoad(CallInst *Inst) { 1033 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1034 "Intrinsic only supports column-major layout!"); 1035 Value *Ptr = Inst->getArgOperand(0); 1036 Value *Stride = Inst->getArgOperand(1); 1037 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 1038 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 1039 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 1040 } 1041 1042 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 1043 /// MatrixPtr[I][J]. 1044 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 1045 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 1046 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 1047 Value *Offset = Builder.CreateAdd( 1048 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 1049 1050 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 1051 Value *EltPtr = 1052 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 1053 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 1054 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 1055 StoreVal.getNumColumns()); 1056 Type *TilePtrTy = PointerType::get(TileTy, AS); 1057 Value *TilePtr = 1058 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 1059 1060 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 1061 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 1062 } 1063 1064 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 1065 /// vectors. 1066 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 1067 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 1068 IRBuilder<> &Builder) { 1069 auto VType = cast<VectorType>(Ty); 1070 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 1071 for (auto Vec : enumerate(StoreVal.vectors())) { 1072 Value *GEP = computeVectorAddr( 1073 EltPtr, 1074 Builder.getIntN(Stride->getType()->getScalarSizeInBits(), 1075 Vec.index()), 1076 Stride, StoreVal.getStride(), VType->getElementType(), Builder); 1077 Builder.CreateAlignedStore(Vec.value(), GEP, 1078 getAlignForIndex(Vec.index(), Stride, 1079 VType->getElementType(), 1080 MAlign), 1081 IsVolatile); 1082 } 1083 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 1084 StoreVal.getNumVectors()); 1085 } 1086 1087 /// Lower a store instruction with shape information. 1088 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 1089 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 1090 IRBuilder<> Builder(Inst); 1091 auto StoreVal = getMatrix(Matrix, Shape, Builder); 1092 finalizeLowering(Inst, 1093 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 1094 IsVolatile, Builder), 1095 Builder); 1096 } 1097 1098 /// Lowers llvm.matrix.column.major.store. 1099 /// 1100 /// The intrinsic store a matrix back memory using a stride between columns. 1101 void LowerColumnMajorStore(CallInst *Inst) { 1102 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1103 "Intrinsic only supports column-major layout!"); 1104 Value *Matrix = Inst->getArgOperand(0); 1105 Value *Ptr = Inst->getArgOperand(1); 1106 Value *Stride = Inst->getArgOperand(2); 1107 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 1108 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 1109 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 1110 } 1111 1112 // Set elements I..I+NumElts-1 to Block 1113 Value *insertVector(Value *Col, unsigned I, Value *Block, 1114 IRBuilder<> &Builder) { 1115 1116 // First, bring Block to the same size as Col 1117 unsigned BlockNumElts = 1118 cast<FixedVectorType>(Block->getType())->getNumElements(); 1119 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 1120 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 1121 1122 Block = Builder.CreateShuffleVector( 1123 Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 1124 1125 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 1126 // 8, 4, 5, 6 1127 SmallVector<int, 16> Mask; 1128 unsigned i; 1129 for (i = 0; i < I; i++) 1130 Mask.push_back(i); 1131 1132 unsigned VecNumElts = 1133 cast<FixedVectorType>(Col->getType())->getNumElements(); 1134 for (; i < I + BlockNumElts; i++) 1135 Mask.push_back(i - I + VecNumElts); 1136 1137 for (; i < VecNumElts; i++) 1138 Mask.push_back(i); 1139 1140 return Builder.CreateShuffleVector(Col, Block, Mask); 1141 } 1142 1143 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 1144 IRBuilder<> &Builder, bool AllowContraction, 1145 unsigned &NumComputeOps) { 1146 NumComputeOps += getNumOps(A->getType()); 1147 if (!Sum) 1148 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 1149 1150 if (UseFPOp) { 1151 if (AllowContraction) { 1152 // Use fmuladd for floating point operations and let the backend decide 1153 // if that's profitable. 1154 Function *FMulAdd = Intrinsic::getDeclaration( 1155 Func.getParent(), Intrinsic::fmuladd, A->getType()); 1156 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 1157 } 1158 NumComputeOps += getNumOps(A->getType()); 1159 Value *Mul = Builder.CreateFMul(A, B); 1160 return Builder.CreateFAdd(Sum, Mul); 1161 } 1162 1163 NumComputeOps += getNumOps(A->getType()); 1164 Value *Mul = Builder.CreateMul(A, B); 1165 return Builder.CreateAdd(Sum, Mul); 1166 } 1167 1168 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 1169 /// users with shape information, there's nothing to do: they will use the 1170 /// cached value when they are lowered. For other users, \p Matrix is 1171 /// flattened and the uses are updated to use it. Also marks \p Inst for 1172 /// deletion. 1173 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 1174 IRBuilder<> &Builder) { 1175 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 1176 (void)inserted; 1177 assert(inserted.second && "multiple matrix lowering mapping"); 1178 1179 ToRemove.push_back(Inst); 1180 Value *Flattened = nullptr; 1181 for (Use &U : llvm::make_early_inc_range(Inst->uses())) { 1182 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 1183 if (!Flattened) 1184 Flattened = Matrix.embedInVector(Builder); 1185 U.set(Flattened); 1186 } 1187 } 1188 } 1189 1190 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1191 /// addition. 1192 /// 1193 /// We can fold a transpose into the operand that is used to extract scalars. 1194 /// This is the first operands with row-major and the second with 1195 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate 1196 /// operand is transposed. 1197 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1198 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, 1199 bool IsScalarMatrixTransposed, FastMathFlags FMF) { 1200 const unsigned VF = std::max<unsigned>( 1201 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1202 .getFixedSize() / 1203 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1204 1U); 1205 unsigned R = Result.getNumRows(); 1206 unsigned C = Result.getNumColumns(); 1207 unsigned M = A.getNumColumns(); 1208 1209 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1210 assert(A.isColumnMajor() == B.isColumnMajor() && 1211 Result.isColumnMajor() == A.isColumnMajor() && 1212 "operands must agree on matrix layout"); 1213 unsigned NumComputeOps = 0; 1214 1215 Builder.setFastMathFlags(FMF); 1216 1217 if (A.isColumnMajor()) { 1218 // Multiply columns from the first operand with scalars from the second 1219 // operand. Then move along the K axes and accumulate the columns. With 1220 // this the adds can be vectorized without reassociation. 1221 for (unsigned J = 0; J < C; ++J) { 1222 unsigned BlockSize = VF; 1223 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1224 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1225 1226 for (unsigned I = 0; I < R; I += BlockSize) { 1227 // Gradually lower the vectorization factor to cover the remainder. 1228 while (I + BlockSize > R) 1229 BlockSize /= 2; 1230 1231 Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) 1232 : nullptr; 1233 for (unsigned K = 0; K < M; ++K) { 1234 Value *L = A.extractVector(I, K, BlockSize, Builder); 1235 Value *RH = Builder.CreateExtractElement( 1236 B.getColumn(IsScalarMatrixTransposed ? K : J), 1237 IsScalarMatrixTransposed ? J : K); 1238 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1239 Sum = 1240 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1241 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1242 } 1243 Result.setVector(J, 1244 insertVector(Result.getVector(J), I, Sum, Builder)); 1245 } 1246 } 1247 } else { 1248 // Multiply rows from the second operand with scalars from the first 1249 // operand. Then move along the K axes and accumulate the rows. With this 1250 // the adds can be vectorized without reassociation. 1251 for (unsigned I = 0; I < R; ++I) { 1252 unsigned BlockSize = VF; 1253 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1254 for (unsigned J = 0; J < C; J += BlockSize) { 1255 // Gradually lower the vectorization factor to cover the remainder. 1256 while (J + BlockSize > C) 1257 BlockSize /= 2; 1258 1259 Value *Sum = nullptr; 1260 for (unsigned K = 0; K < M; ++K) { 1261 Value *R = B.extractVector(K, J, BlockSize, Builder); 1262 Value *LH = Builder.CreateExtractElement( 1263 A.getVector(IsScalarMatrixTransposed ? K : I), 1264 IsScalarMatrixTransposed ? I : K); 1265 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1266 Sum = 1267 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1268 IsFP, Builder, FMF.allowContract(), NumComputeOps); 1269 } 1270 Result.setVector(I, 1271 insertVector(Result.getVector(I), J, Sum, Builder)); 1272 } 1273 } 1274 } 1275 Result.addNumComputeOps(NumComputeOps); 1276 } 1277 1278 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1279 /// copying it to a new location. This new or otherwise the original location 1280 /// is returned. 1281 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1282 CallInst *MatMul) { 1283 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1284 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1285 1286 // If we can statically determine noalias we're good. 1287 if (AA->isNoAlias(LoadLoc, StoreLoc)) 1288 return Load->getPointerOperand(); 1289 1290 // Create code to check if the memory locations of the Load and Store 1291 // overlap and if they do, copy Load's operand to a new buffer. 1292 1293 // First, create new blocks for 2n part of the check and the copy. 1294 BasicBlock *Check0 = MatMul->getParent(); 1295 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1296 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1297 // as we adjust Check0 and Check1's branches. 1298 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1299 for (BasicBlock *Succ : successors(Check0)) 1300 DTUpdates.push_back({DT->Delete, Check0, Succ}); 1301 1302 BasicBlock *Check1 = 1303 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1304 nullptr, "alias_cont"); 1305 BasicBlock *Copy = 1306 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1307 nullptr, "copy"); 1308 BasicBlock *Fusion = 1309 SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, 1310 nullptr, "no_alias"); 1311 1312 // Check if the loaded memory location begins before the end of the store 1313 // location. If the condition holds, they might overlap, otherwise they are 1314 // guaranteed to not overlap. 1315 IRBuilder<> Builder(MatMul); 1316 Check0->getTerminator()->eraseFromParent(); 1317 Builder.SetInsertPoint(Check0); 1318 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1319 Value *StoreBegin = Builder.CreatePtrToInt( 1320 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1321 Value *StoreEnd = Builder.CreateAdd( 1322 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1323 "store.end", true, true); 1324 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1325 IntPtrTy, "load.begin"); 1326 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1327 Fusion); 1328 1329 // Check if the store begins before the end of the load location. If the 1330 // condition holds, they alias, otherwise they are guaranteed to not 1331 // overlap. 1332 Check1->getTerminator()->eraseFromParent(); 1333 Builder.SetInsertPoint(Check1, Check1->begin()); 1334 Value *LoadEnd = Builder.CreateAdd( 1335 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1336 "load.end", true, true); 1337 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1338 Fusion); 1339 1340 // Copy load operand to new alloca. 1341 Builder.SetInsertPoint(Copy, Copy->begin()); 1342 auto *VT = cast<FixedVectorType>(Load->getType()); 1343 // Use an array type for the alloca, to avoid potentially huge alignment 1344 // requirements for large vector types. 1345 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); 1346 AllocaInst *Alloca = 1347 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); 1348 Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo()); 1349 1350 Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(), 1351 Load->getAlign(), LoadLoc.Size.getValue()); 1352 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1353 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1354 PHI->addIncoming(Load->getPointerOperand(), Check0); 1355 PHI->addIncoming(Load->getPointerOperand(), Check1); 1356 PHI->addIncoming(BC, Copy); 1357 1358 // Adjust DT. 1359 DTUpdates.push_back({DT->Insert, Check0, Check1}); 1360 DTUpdates.push_back({DT->Insert, Check0, Fusion}); 1361 DTUpdates.push_back({DT->Insert, Check1, Copy}); 1362 DTUpdates.push_back({DT->Insert, Check1, Fusion}); 1363 DT->applyUpdates(DTUpdates); 1364 return PHI; 1365 } 1366 1367 bool isFusionProfitable(CallInst *MatMul) { 1368 if (ForceFusion) 1369 return true; 1370 1371 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1372 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1373 1374 const unsigned R = LShape.NumRows; 1375 const unsigned C = RShape.NumColumns; 1376 const unsigned M = LShape.NumColumns; 1377 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1378 1379 const unsigned VF = std::max<unsigned>( 1380 TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 1381 .getFixedSize() / 1382 EltType->getPrimitiveSizeInBits().getFixedSize(), 1383 1U); 1384 1385 // Cost model for tiling 1386 // 1387 // For tiling to be beneficial, we need reuse either along the R or 1388 // the C axis. We vectorize along the R axis so that means at least 1389 // 3 elements. 1390 // TODO: Also consider cost of copying if operands alias. 1391 if (R <= VF && C == 1) 1392 return false; 1393 // Then we need enough elements to exceed the number of vector 1394 // registers we have. Note that this is an oversimplification since 1395 // fusing also takes some extra loads which may exceed the number of 1396 // reloads necessary. 1397 unsigned Op0Regs = (R + VF - 1) / VF * M; 1398 unsigned Op1Regs = (M + VF - 1) / VF * C; 1399 return Op0Regs + Op1Regs > 1400 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)); 1401 } 1402 1403 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1404 MatrixTy Res; 1405 auto *ColumType = FixedVectorType::get(EltType, R); 1406 for (unsigned I = 0; I < C; ++I) 1407 Res.addVector(ConstantAggregateZero::get(ColumType)); 1408 return Res; 1409 } 1410 1411 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, 1412 Value *RPtr, ShapeInfo RShape, StoreInst *Store) { 1413 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1414 1415 // Create the main tiling loop nest. 1416 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); 1417 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1418 Instruction *InsertI = cast<Instruction>(MatMul); 1419 BasicBlock *Start = InsertI->getParent(); 1420 BasicBlock *End = 1421 SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); 1422 IRBuilder<> Builder(MatMul); 1423 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); 1424 1425 Type *TileVecTy = 1426 FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); 1427 MatrixTy TileResult; 1428 // Insert in the inner loop header. 1429 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); 1430 // Create PHI nodes for the result columns to accumulate across iterations. 1431 SmallVector<PHINode *, 4> ColumnPhis; 1432 for (unsigned I = 0; I < TileSize; I++) { 1433 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); 1434 Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), 1435 TI.RowLoop.Header->getSingleSuccessor()); 1436 TileResult.addVector(Phi); 1437 ColumnPhis.push_back(Phi); 1438 } 1439 1440 // Insert in the inner loop body, which computes 1441 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) 1442 Builder.SetInsertPoint(InnerBody->getTerminator()); 1443 // Load tiles of the operands. 1444 MatrixTy A = 1445 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, 1446 {TileSize, TileSize}, EltType, Builder); 1447 MatrixTy B = 1448 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, 1449 {TileSize, TileSize}, EltType, Builder); 1450 emitMatrixMultiply(TileResult, A, B, Builder, true, false, 1451 getFastMathFlags(MatMul)); 1452 // Store result after the inner loop is done. 1453 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); 1454 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), 1455 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, 1456 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); 1457 1458 for (unsigned I = 0; I < TileResult.getNumVectors(); I++) 1459 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); 1460 1461 // Force unrolling of a few iterations of the inner loop, to make sure there 1462 // is enough work per iteration. 1463 // FIXME: The unroller should make this decision directly instead, but 1464 // currently the cost-model is not up to the task. 1465 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); 1466 addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), 1467 "llvm.loop.unroll.count", InnerLoopUnrollCount); 1468 } 1469 1470 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1471 StoreInst *Store, 1472 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1473 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1474 "Tiling only supported for column-major matrixes at the moment!"); 1475 if (!isFusionProfitable(MatMul)) 1476 return; 1477 1478 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1479 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1480 1481 const unsigned R = LShape.NumRows; 1482 const unsigned C = RShape.NumColumns; 1483 const unsigned M = LShape.NumColumns; 1484 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1485 1486 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1487 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1488 Value *CPtr = Store->getPointerOperand(); 1489 1490 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) 1491 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); 1492 else { 1493 IRBuilder<> Builder(Store); 1494 for (unsigned J = 0; J < C; J += TileSize) 1495 for (unsigned I = 0; I < R; I += TileSize) { 1496 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1497 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1498 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1499 1500 for (unsigned K = 0; K < M; K += TileSize) { 1501 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1502 MatrixTy A = 1503 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1504 LShape, Builder.getInt64(I), Builder.getInt64(K), 1505 {TileR, TileM}, EltType, Builder); 1506 MatrixTy B = 1507 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1508 RShape, Builder.getInt64(K), Builder.getInt64(J), 1509 {TileM, TileC}, EltType, Builder); 1510 emitMatrixMultiply(Res, A, B, Builder, true, false, 1511 getFastMathFlags(MatMul)); 1512 } 1513 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1514 Builder.getInt64(I), Builder.getInt64(J), EltType, 1515 Builder); 1516 } 1517 } 1518 1519 // Mark eliminated instructions as fused and remove them. 1520 FusedInsts.insert(Store); 1521 FusedInsts.insert(MatMul); 1522 Store->eraseFromParent(); 1523 MatMul->eraseFromParent(); 1524 if (LoadOp0->hasNUses(0)) { 1525 FusedInsts.insert(LoadOp0); 1526 LoadOp0->eraseFromParent(); 1527 } 1528 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { 1529 FusedInsts.insert(LoadOp1); 1530 LoadOp1->eraseFromParent(); 1531 } 1532 } 1533 1534 /// Try to lower matrix multiply chains by fusing operations. 1535 /// 1536 /// Call finalizeLowering on lowered instructions. Instructions that are 1537 /// completely eliminated by fusion are added to \p FusedInsts. 1538 void LowerMatrixMultiplyFused(CallInst *MatMul, 1539 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1540 if (!FuseMatrix || !DT) 1541 return; 1542 1543 assert(AA && LI && "Analyses should be available"); 1544 1545 Value *A = MatMul->getArgOperand(0); 1546 Value *B = MatMul->getArgOperand(1); 1547 1548 // We can fold the transpose into the operand that is used to fetch scalars. 1549 Value *T; 1550 if (MatrixLayout == MatrixLayoutTy::ColumnMajor 1551 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T))) 1552 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) { 1553 IRBuilder<> Builder(MatMul); 1554 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1555 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1556 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1557 const unsigned R = LShape.NumRows; 1558 const unsigned M = LShape.NumColumns; 1559 const unsigned C = RShape.NumColumns; 1560 1561 MatrixTy MA; 1562 MatrixTy MB; 1563 1564 Value *Transpose; 1565 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) { 1566 MA = getMatrix(A, ShapeInfo(R, M), Builder); 1567 MB = getMatrix(T, ShapeInfo(C, M), Builder); 1568 Transpose = B; 1569 } else { 1570 MA = getMatrix(T, ShapeInfo(R, M), Builder); 1571 MB = getMatrix(B, ShapeInfo(C, M), Builder); 1572 Transpose = A; 1573 } 1574 1575 // Initialize the output 1576 MatrixTy Result(R, C, EltType); 1577 1578 emitMatrixMultiply(Result, MA, MB, Builder, false, true, 1579 getFastMathFlags(MatMul)); 1580 1581 FusedInsts.insert(MatMul); 1582 if (Transpose->hasOneUse()) { 1583 FusedInsts.insert(cast<Instruction>(Transpose)); 1584 ToRemove.push_back(cast<Instruction>(Transpose)); 1585 // TODO: add a fake entry for the folded instruction so that this is 1586 // included in the expression in the remark. 1587 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); 1588 } 1589 finalizeLowering(MatMul, Result, Builder); 1590 return; 1591 } 1592 1593 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor) 1594 return; 1595 1596 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering 1597 // since the single store user will be lowered as part of this. 1598 auto *LoadOp0 = dyn_cast<LoadInst>(A); 1599 auto *LoadOp1 = dyn_cast<LoadInst>(B); 1600 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1601 if (LoadOp0 && LoadOp1 && Store) { 1602 // The store address must dominate the MatMul instruction, otherwise 1603 // we create invalid IR. 1604 SetVector<Value *> WorkList; 1605 WorkList.insert(Store->getOperand(1)); 1606 SmallVector<Instruction *> ToHoist; 1607 for (unsigned I = 0; I != WorkList.size(); ++I) { 1608 Value *Current = WorkList[I]; 1609 auto *CurrI = dyn_cast<Instruction>(Current); 1610 if (!CurrI) 1611 continue; 1612 if (isa<PHINode>(CurrI)) 1613 return; 1614 if (DT->dominates(CurrI, MatMul)) 1615 continue; 1616 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory()) 1617 return; 1618 ToHoist.push_back(CurrI); 1619 WorkList.insert(CurrI->op_begin(), CurrI->op_end()); 1620 } 1621 1622 sort(ToHoist, [this](Instruction *A, Instruction *B) { 1623 return DT->dominates(A, B); 1624 }); 1625 for (Instruction *I : ToHoist) 1626 I->moveBefore(MatMul); 1627 1628 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1629 return; 1630 } 1631 } 1632 1633 /// Lowers llvm.matrix.multiply. 1634 void LowerMultiply(CallInst *MatMul) { 1635 IRBuilder<> Builder(MatMul); 1636 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1637 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1638 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1639 1640 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1641 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1642 assert(Lhs.getElementType() == Rhs.getElementType() && 1643 "Matrix multiply argument element types do not match."); 1644 1645 const unsigned R = LShape.NumRows; 1646 const unsigned C = RShape.NumColumns; 1647 assert(LShape.NumColumns == RShape.NumRows); 1648 1649 // Initialize the output 1650 MatrixTy Result(R, C, EltType); 1651 assert(Lhs.getElementType() == Result.getElementType() && 1652 "Matrix multiply result element type does not match arguments."); 1653 1654 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, 1655 getFastMathFlags(MatMul)); 1656 finalizeLowering(MatMul, Result, Builder); 1657 } 1658 1659 /// Lowers llvm.matrix.transpose. 1660 void LowerTranspose(CallInst *Inst) { 1661 MatrixTy Result; 1662 IRBuilder<> Builder(Inst); 1663 Value *InputVal = Inst->getArgOperand(0); 1664 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1665 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1666 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1667 1668 const unsigned NewNumVecs = 1669 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1670 const unsigned NewNumElts = 1671 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1672 1673 for (unsigned I = 0; I < NewNumVecs; ++I) { 1674 // Build a single result vector. First initialize it. 1675 Value *ResultVector = PoisonValue::get( 1676 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1677 // Go through the old elements and insert it into the resulting vector. 1678 for (auto J : enumerate(InputMatrix.vectors())) { 1679 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1680 // Row and column indices are transposed. 1681 ResultVector = 1682 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1683 } 1684 Result.addVector(ResultVector); 1685 } 1686 1687 // TODO: Improve estimate of operations needed for transposes. Currently we 1688 // just count the insertelement/extractelement instructions, but do not 1689 // account for later simplifications/combines. 1690 finalizeLowering( 1691 Inst, 1692 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) 1693 .addNumExposedTransposes(1), 1694 Builder); 1695 } 1696 1697 /// Lower load instructions, if shape information is available. 1698 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1699 auto I = ShapeMap.find(Inst); 1700 if (I == ShapeMap.end()) 1701 return false; 1702 1703 LowerLoad(Inst, Ptr, Inst->getAlign(), 1704 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1705 I->second); 1706 return true; 1707 } 1708 1709 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1710 IRBuilder<> &Builder) { 1711 auto I = ShapeMap.find(StoredVal); 1712 if (I == ShapeMap.end()) 1713 return false; 1714 1715 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1716 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1717 I->second); 1718 return true; 1719 } 1720 1721 /// Lower binary operators, if shape information is available. 1722 bool VisitBinaryOperator(BinaryOperator *Inst) { 1723 auto I = ShapeMap.find(Inst); 1724 if (I == ShapeMap.end()) 1725 return false; 1726 1727 Value *Lhs = Inst->getOperand(0); 1728 Value *Rhs = Inst->getOperand(1); 1729 1730 IRBuilder<> Builder(Inst); 1731 ShapeInfo &Shape = I->second; 1732 1733 MatrixTy Result; 1734 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1735 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1736 assert(A.isColumnMajor() == B.isColumnMajor() && 1737 Result.isColumnMajor() == A.isColumnMajor() && 1738 "operands must agree on matrix layout"); 1739 1740 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1741 1742 // Helper to perform binary op on vectors. 1743 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1744 switch (Inst->getOpcode()) { 1745 case Instruction::Add: 1746 return Builder.CreateAdd(LHS, RHS); 1747 case Instruction::Mul: 1748 return Builder.CreateMul(LHS, RHS); 1749 case Instruction::Sub: 1750 return Builder.CreateSub(LHS, RHS); 1751 case Instruction::FAdd: 1752 return Builder.CreateFAdd(LHS, RHS); 1753 case Instruction::FMul: 1754 return Builder.CreateFMul(LHS, RHS); 1755 case Instruction::FSub: 1756 return Builder.CreateFSub(LHS, RHS); 1757 default: 1758 llvm_unreachable("Unsupported binary operator for matrix"); 1759 } 1760 }; 1761 1762 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1763 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1764 1765 finalizeLowering(Inst, 1766 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1767 Result.getNumVectors()), 1768 Builder); 1769 return true; 1770 } 1771 1772 /// Lower unary operators, if shape information is available. 1773 bool VisitUnaryOperator(UnaryOperator *Inst) { 1774 auto I = ShapeMap.find(Inst); 1775 if (I == ShapeMap.end()) 1776 return false; 1777 1778 Value *Op = Inst->getOperand(0); 1779 1780 IRBuilder<> Builder(Inst); 1781 ShapeInfo &Shape = I->second; 1782 1783 MatrixTy Result; 1784 MatrixTy M = getMatrix(Op, Shape, Builder); 1785 1786 Builder.setFastMathFlags(getFastMathFlags(Inst)); 1787 1788 // Helper to perform unary op on vectors. 1789 auto BuildVectorOp = [&Builder, Inst](Value *Op) { 1790 switch (Inst->getOpcode()) { 1791 case Instruction::FNeg: 1792 return Builder.CreateFNeg(Op); 1793 default: 1794 llvm_unreachable("Unsupported unary operator for matrix"); 1795 } 1796 }; 1797 1798 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1799 Result.addVector(BuildVectorOp(M.getVector(I))); 1800 1801 finalizeLowering(Inst, 1802 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1803 Result.getNumVectors()), 1804 Builder); 1805 return true; 1806 } 1807 1808 /// Helper to linearize a matrix expression tree into a string. Currently 1809 /// matrix expressions are linarized by starting at an expression leaf and 1810 /// linearizing bottom up. 1811 struct ExprLinearizer { 1812 unsigned LengthToBreak = 100; 1813 std::string Str; 1814 raw_string_ostream Stream; 1815 unsigned LineLength = 0; 1816 const DataLayout &DL; 1817 1818 /// Mapping from instructions to matrixes. It is used to identify 1819 /// matrix instructions. 1820 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1821 1822 /// Mapping from values to the leaves of all expressions that the value is 1823 /// part of. 1824 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1825 1826 /// Set of matrix expressions in the scope of a given DISubprogram. 1827 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1828 1829 /// Leaf node of the expression to linearize. 1830 Value *Leaf; 1831 1832 /// Used to keep track of sub-expressions that get reused while linearizing 1833 /// the expression. Re-used sub-expressions are marked as (reused). 1834 SmallPtrSet<Value *, 8> ReusedExprs; 1835 1836 ExprLinearizer(const DataLayout &DL, 1837 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1838 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1839 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1840 Value *Leaf) 1841 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1842 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1843 1844 void indent(unsigned N) { 1845 LineLength += N; 1846 for (unsigned i = 0; i < N; i++) 1847 Stream << " "; 1848 } 1849 1850 void lineBreak() { 1851 Stream << "\n"; 1852 LineLength = 0; 1853 } 1854 1855 void maybeIndent(unsigned Indent) { 1856 if (LineLength >= LengthToBreak) 1857 lineBreak(); 1858 1859 if (LineLength == 0) 1860 indent(Indent); 1861 } 1862 1863 void write(StringRef S) { 1864 LineLength += S.size(); 1865 Stream << S; 1866 } 1867 1868 Value *getUnderlyingObjectThroughLoads(Value *V) { 1869 if (Value *Ptr = getPointerOperand(V)) 1870 return getUnderlyingObjectThroughLoads(Ptr); 1871 else if (V->getType()->isPointerTy()) 1872 return getUnderlyingObject(V); 1873 return V; 1874 } 1875 1876 /// Returns true if \p V is a matrix value in the given subprogram. 1877 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1878 1879 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1880 /// \p SS. 1881 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1882 auto M = Inst2Matrix.find(V); 1883 if (M == Inst2Matrix.end()) 1884 SS << "unknown"; 1885 else { 1886 SS << M->second.getNumRows(); 1887 SS << "x"; 1888 SS << M->second.getNumColumns(); 1889 } 1890 } 1891 1892 /// Write the called function name. Handles calls to llvm.matrix.* 1893 /// specially: we write the name, followed by the dimensions of the input 1894 /// matrixes, followed by the scalar type name. 1895 void writeFnName(CallInst *CI) { 1896 if (!CI->getCalledFunction()) 1897 write("<no called fn>"); 1898 else { 1899 StringRef Name = CI->getCalledFunction()->getName(); 1900 if (!Name.startswith("llvm.matrix")) { 1901 write(Name); 1902 return; 1903 } 1904 auto *II = cast<IntrinsicInst>(CI); 1905 write(Intrinsic::getBaseName(II->getIntrinsicID()) 1906 .drop_front(StringRef("llvm.matrix.").size())); 1907 write("."); 1908 std::string Tmp; 1909 raw_string_ostream SS(Tmp); 1910 1911 switch (II->getIntrinsicID()) { 1912 case Intrinsic::matrix_multiply: 1913 prettyPrintMatrixType(II->getOperand(0), SS); 1914 SS << "."; 1915 prettyPrintMatrixType(II->getOperand(1), SS); 1916 SS << "." << *II->getType()->getScalarType(); 1917 break; 1918 case Intrinsic::matrix_transpose: 1919 prettyPrintMatrixType(II->getOperand(0), SS); 1920 SS << "." << *II->getType()->getScalarType(); 1921 break; 1922 case Intrinsic::matrix_column_major_load: 1923 prettyPrintMatrixType(II, SS); 1924 SS << "." << *II->getType()->getScalarType(); 1925 break; 1926 case Intrinsic::matrix_column_major_store: 1927 prettyPrintMatrixType(II->getOperand(0), SS); 1928 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1929 break; 1930 default: 1931 llvm_unreachable("Unhandled case"); 1932 } 1933 SS.flush(); 1934 write(Tmp); 1935 } 1936 } 1937 1938 unsigned getNumShapeArgs(CallInst *CI) const { 1939 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1940 switch (II->getIntrinsicID()) { 1941 case Intrinsic::matrix_multiply: 1942 return 3; 1943 case Intrinsic::matrix_transpose: 1944 return 2; 1945 case Intrinsic::matrix_column_major_load: 1946 case Intrinsic::matrix_column_major_store: 1947 return 3; 1948 default: 1949 return 0; 1950 } 1951 } 1952 return 0; 1953 } 1954 1955 /// Special printing for values: for pointers, we print if they refer to an 1956 /// (function) external address or a stack address, for other values we 1957 /// either print the constant or "scalar"/"matrix" for other values. 1958 void write(Value *V) { 1959 V = getUnderlyingObjectThroughLoads(V); 1960 if (V->getType()->isPointerTy()) { 1961 if (isa<AllocaInst>(V)) { 1962 Stream << "stack addr"; 1963 LineLength += StringRef("stack addr").size(); 1964 } else { 1965 Stream << "addr"; 1966 LineLength += StringRef("addr").size(); 1967 } 1968 if (!V->getName().empty()) { 1969 Stream << " %" << V->getName() << ""; 1970 LineLength += V->getName().size() + 2; 1971 } 1972 return; 1973 } 1974 1975 std::string Tmp; 1976 raw_string_ostream TmpStream(Tmp); 1977 1978 if (auto *CI = dyn_cast<ConstantInt>(V)) 1979 TmpStream << CI->getValue(); 1980 else if (isa<Constant>(V)) 1981 TmpStream << "constant"; 1982 else { 1983 if (isMatrix(V)) 1984 TmpStream << "matrix"; 1985 else 1986 TmpStream << "scalar"; 1987 } 1988 TmpStream.flush(); 1989 Tmp = std::string(StringRef(Tmp).trim()); 1990 LineLength += Tmp.size(); 1991 Stream << Tmp; 1992 } 1993 1994 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1995 /// Expressions that are re-used multiple times are prefixed with (reused) 1996 /// at the re-used root instruction. 1997 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1998 bool ParentShared) { 1999 auto *I = cast<Instruction>(Expr); 2000 maybeIndent(Indent); 2001 SmallVector<Value *, 8> Ops; 2002 2003 // Is Expr shared with other expression leaves? 2004 bool ExprShared = false; 2005 2006 // Deal with shared subtrees. Mark them as shared, if required. 2007 if (!ParentShared) { 2008 auto SI = Shared.find(Expr); 2009 assert(SI != Shared.end() && SI->second.count(Leaf)); 2010 2011 for (Value *S : SI->second) { 2012 if (S == Leaf) 2013 continue; 2014 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 2015 write("shared with remark at line " + std::to_string(DL.getLine()) + 2016 " column " + std::to_string(DL.getCol()) + " ("); 2017 } 2018 ExprShared = SI->second.size() > 1; 2019 } 2020 2021 bool Reused = !ReusedExprs.insert(Expr).second; 2022 if (Reused && !ParentReused) 2023 write("(reused) "); 2024 2025 if (auto *CI = dyn_cast<CallInst>(I)) { 2026 writeFnName(CI); 2027 2028 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 2029 } else if (isa<BitCastInst>(Expr)) { 2030 // Special case bitcasts, which are used to materialize matrixes from 2031 // non-matrix ops. 2032 write("matrix"); 2033 return; 2034 } else { 2035 Ops.append(I->value_op_begin(), I->value_op_end()); 2036 write(std::string(I->getOpcodeName())); 2037 } 2038 2039 write(std::string("(")); 2040 2041 unsigned NumOpsToBreak = 1; 2042 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 2043 NumOpsToBreak = 2; 2044 2045 for (Value *Op : Ops) { 2046 if (Ops.size() > NumOpsToBreak) 2047 lineBreak(); 2048 2049 maybeIndent(Indent + 1); 2050 if (isMatrix(Op)) 2051 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 2052 else 2053 write(Op); 2054 if (Op != Ops.back()) 2055 write(", "); 2056 } 2057 2058 write(")"); 2059 } 2060 2061 const std::string &getResult() { 2062 Stream.flush(); 2063 return Str; 2064 } 2065 }; 2066 2067 /// Generate remarks for matrix operations in a function. To generate remarks 2068 /// for matrix expressions, the following approach is used: 2069 /// 1. Use the inlined-at debug information to group matrix operations to the 2070 /// DISubprograms they are contained in. 2071 /// 2. Collect leaves of matrix expressions (done in 2072 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 2073 // mapping. Leaves are lowered matrix instructions without other matrix 2074 // users (like stores) in the current subprogram. 2075 /// 3. For each leaf, create a remark containing a linearizied version of the 2076 /// matrix expression. The expression is linearized by a recursive 2077 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 2078 /// that multiple leaves can share sub-expressions. Shared subexpressions 2079 /// are explicitly marked as shared(). 2080 struct RemarkGenerator { 2081 const MapVector<Value *, MatrixTy> &Inst2Matrix; 2082 OptimizationRemarkEmitter &ORE; 2083 Function &Func; 2084 const DataLayout &DL; 2085 2086 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 2087 OptimizationRemarkEmitter &ORE, Function &Func) 2088 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 2089 DL(Func.getParent()->getDataLayout()) {} 2090 2091 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 2092 /// instructions in Inst2Matrix returning void or without any users in 2093 /// \p ExprsInSubprogram. Currently that should only include stores. 2094 SmallVector<Value *, 4> 2095 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 2096 SmallVector<Value *, 4> Leaves; 2097 for (auto *Expr : ExprsInSubprogram) 2098 if (Expr->getType()->isVoidTy() || 2099 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 2100 return ExprsInSubprogram.count(U); 2101 })) 2102 Leaves.push_back(Expr); 2103 return Leaves; 2104 } 2105 2106 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 2107 /// to all visited expressions in \p Shared. Limit the matrix operations to 2108 /// the ones in \p ExprsInSubprogram. 2109 void collectSharedInfo(Value *Leaf, Value *V, 2110 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2111 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 2112 2113 if (!ExprsInSubprogram.count(V)) 2114 return; 2115 2116 auto I = Shared.insert({V, {}}); 2117 I.first->second.insert(Leaf); 2118 2119 for (Value *Op : cast<Instruction>(V)->operand_values()) 2120 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 2121 } 2122 2123 /// Calculate the number of exclusive and shared op counts for expression 2124 /// starting at \p V. Expressions used multiple times are counted once. 2125 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 2126 std::pair<OpInfoTy, OpInfoTy> 2127 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 2128 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2129 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 2130 if (!ExprsInSubprogram.count(Root)) 2131 return {}; 2132 2133 // Already counted this expression. Stop. 2134 if (!ReusedExprs.insert(Root).second) 2135 return {}; 2136 2137 OpInfoTy SharedCount; 2138 OpInfoTy Count; 2139 2140 auto I = Shared.find(Root); 2141 auto CM = Inst2Matrix.find(Root); 2142 if (I->second.size() == 1) 2143 Count = CM->second.getOpInfo(); 2144 else 2145 SharedCount = CM->second.getOpInfo(); 2146 2147 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 2148 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 2149 Count += C.first; 2150 SharedCount += C.second; 2151 } 2152 return {Count, SharedCount}; 2153 } 2154 2155 void emitRemarks() { 2156 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 2157 return; 2158 2159 // Map matrix operations to their containting subprograms, by traversing 2160 // the inlinedAt chain. If the function does not have a DISubprogram, we 2161 // only map them to the containing function. 2162 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 2163 for (auto &KV : Inst2Matrix) { 2164 if (Func.getSubprogram()) { 2165 auto *I = cast<Instruction>(KV.first); 2166 DILocation *Context = I->getDebugLoc(); 2167 while (Context) { 2168 auto I = 2169 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 2170 I.first->second.push_back(KV.first); 2171 Context = DebugLoc(Context).getInlinedAt(); 2172 } 2173 } else { 2174 auto I = Subprog2Exprs.insert({nullptr, {}}); 2175 I.first->second.push_back(KV.first); 2176 } 2177 } 2178 for (auto &KV : Subprog2Exprs) { 2179 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 2180 KV.second.end()); 2181 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 2182 2183 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 2184 for (Value *Leaf : Leaves) 2185 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 2186 2187 // Generate remarks for each leaf. 2188 for (auto *L : Leaves) { 2189 2190 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 2191 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 2192 while (Context) { 2193 if (getSubprogram(Context->getScope()) == KV.first) { 2194 Loc = Context; 2195 break; 2196 } 2197 Context = DebugLoc(Context).getInlinedAt(); 2198 } 2199 2200 SmallPtrSet<Value *, 8> ReusedExprs; 2201 OpInfoTy Counts, SharedCounts; 2202 std::tie(Counts, SharedCounts) = 2203 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 2204 2205 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 2206 cast<Instruction>(L)->getParent()); 2207 2208 Rem << "Lowered with "; 2209 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 2210 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 2211 << ore::NV("NumComputeOps", Counts.NumComputeOps) 2212 << " compute ops, " 2213 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) 2214 << " exposed transposes"; 2215 2216 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 2217 SharedCounts.NumComputeOps > 0) { 2218 Rem << ",\nadditionally " 2219 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 2220 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 2221 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 2222 << " compute ops" 2223 << " are shared with other expressions"; 2224 } 2225 2226 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 2227 ORE.emit(Rem); 2228 } 2229 } 2230 } 2231 2232 std::string 2233 linearize(Value *L, 2234 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 2235 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 2236 const DataLayout &DL) { 2237 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 2238 Lin.linearizeExpr(L, 0, false, false); 2239 return Lin.getResult(); 2240 } 2241 }; 2242 }; 2243 } // namespace 2244 2245 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 2246 FunctionAnalysisManager &AM) { 2247 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 2248 OptimizationRemarkEmitter *ORE = nullptr; 2249 AAResults *AA = nullptr; 2250 DominatorTree *DT = nullptr; 2251 LoopInfo *LI = nullptr; 2252 2253 if (!Minimal) { 2254 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 2255 AA = &AM.getResult<AAManager>(F); 2256 DT = &AM.getResult<DominatorTreeAnalysis>(F); 2257 LI = &AM.getResult<LoopAnalysis>(F); 2258 } 2259 2260 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 2261 if (LMT.Visit()) { 2262 PreservedAnalyses PA; 2263 if (!Minimal) { 2264 PA.preserve<LoopAnalysis>(); 2265 PA.preserve<DominatorTreeAnalysis>(); 2266 } 2267 return PA; 2268 } 2269 return PreservedAnalyses::all(); 2270 } 2271 2272 void LowerMatrixIntrinsicsPass::printPipeline( 2273 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 2274 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( 2275 OS, MapClassName2PassName); 2276 OS << "<"; 2277 if (Minimal) 2278 OS << "minimal"; 2279 OS << ">"; 2280 } 2281 2282 namespace { 2283 2284 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 2285 public: 2286 static char ID; 2287 2288 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 2289 initializeLowerMatrixIntrinsicsLegacyPassPass( 2290 *PassRegistry::getPassRegistry()); 2291 } 2292 2293 bool runOnFunction(Function &F) override { 2294 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2295 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 2296 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 2297 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2298 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2299 LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); 2300 bool C = LMT.Visit(); 2301 return C; 2302 } 2303 2304 void getAnalysisUsage(AnalysisUsage &AU) const override { 2305 AU.addRequired<TargetTransformInfoWrapperPass>(); 2306 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 2307 AU.addRequired<AAResultsWrapperPass>(); 2308 AU.addRequired<DominatorTreeWrapperPass>(); 2309 AU.addPreserved<DominatorTreeWrapperPass>(); 2310 AU.addRequired<LoopInfoWrapperPass>(); 2311 AU.addPreserved<LoopInfoWrapperPass>(); 2312 } 2313 }; 2314 } // namespace 2315 2316 static const char pass_name[] = "Lower the matrix intrinsics"; 2317 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 2318 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2319 false, false) 2320 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 2321 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 2322 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 2323 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 2324 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 2325 false, false) 2326 2327 Pass *llvm::createLowerMatrixIntrinsicsPass() { 2328 return new LowerMatrixIntrinsicsLegacyPass(); 2329 } 2330 2331 namespace { 2332 2333 /// A lightweight version of the matrix lowering pass that only requires TTI. 2334 /// Advanced features that require DT, AA or ORE like tiling are disabled. This 2335 /// is used to lower matrix intrinsics if the main lowering pass is not run, for 2336 /// example with -O0. 2337 class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { 2338 public: 2339 static char ID; 2340 2341 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { 2342 initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( 2343 *PassRegistry::getPassRegistry()); 2344 } 2345 2346 bool runOnFunction(Function &F) override { 2347 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 2348 LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); 2349 bool C = LMT.Visit(); 2350 return C; 2351 } 2352 2353 void getAnalysisUsage(AnalysisUsage &AU) const override { 2354 AU.addRequired<TargetTransformInfoWrapperPass>(); 2355 AU.setPreservesCFG(); 2356 } 2357 }; 2358 } // namespace 2359 2360 static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; 2361 char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; 2362 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, 2363 "lower-matrix-intrinsics-minimal", pass_name_minimal, 2364 false, false) 2365 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, 2366 "lower-matrix-intrinsics-minimal", pass_name_minimal, false, 2367 false) 2368 2369 Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { 2370 return new LowerMatrixIntrinsicsMinimalLegacyPass(); 2371 } 2372