1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Lower matrix intrinsics to vector operations. 10 // 11 // TODO: 12 // * Improve fusion: 13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results 14 // transposed. 15 // * Improve cost-modeling, e.g. choose different number of rows/columns 16 // columns for tiles, consider cost of copies on alias. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" 21 #include "llvm/ADT/GraphTraits.h" 22 #include "llvm/ADT/PostOrderIterator.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "llvm/Analysis/AliasAnalysis.h" 25 #include "llvm/Analysis/DomTreeUpdater.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/TargetTransformInfo.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Analysis/VectorUtils.h" 30 #include "llvm/IR/CFG.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugInfoMetadata.h" 33 #include "llvm/IR/Function.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instructions.h" 36 #include "llvm/IR/IntrinsicInst.h" 37 #include "llvm/IR/PatternMatch.h" 38 #include "llvm/InitializePasses.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Alignment.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/Debug.h" 43 #include "llvm/Transforms/Scalar.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 46 using namespace llvm; 47 using namespace PatternMatch; 48 49 #define DEBUG_TYPE "lower-matrix-intrinsics" 50 51 static cl::opt<bool> EnableShapePropagation( 52 "matrix-propagate-shape", cl::init(true), cl::Hidden, 53 cl::desc("Enable/disable shape propagation from matrix intrinsics to other " 54 "instructions.")); 55 56 static cl::opt<bool> 57 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, 58 cl::desc("Enable/disable fusing matrix instructions.")); 59 // TODO: Allow and use non-square tiles. 60 static cl::opt<unsigned> TileSize( 61 "fuse-matrix-tile-size", cl::init(4), cl::Hidden, 62 cl::desc( 63 "Tile size for matrix instruction fusion using square-shaped tiles.")); 64 static cl::opt<bool> ForceFusion( 65 "force-fuse-matrix", cl::init(false), cl::Hidden, 66 cl::desc("Force matrix instruction fusion even if not profitable.")); 67 static cl::opt<bool> AllowContractEnabled( 68 "matrix-allow-contract", cl::init(false), cl::Hidden, 69 cl::desc("Allow the use of FMAs if available and profitable. This may " 70 "result in different results, due to less rounding error.")); 71 72 enum class MatrixLayoutTy { ColumnMajor, RowMajor }; 73 74 static cl::opt<MatrixLayoutTy> MatrixLayout( 75 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), 76 cl::desc("Sets the default matrix layout"), 77 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", 78 "Use column-major layout"), 79 clEnumValN(MatrixLayoutTy::RowMajor, "row-major", 80 "Use row-major layout"))); 81 82 /// Helper function to either return Scope, if it is a subprogram or the 83 /// attached subprogram for a local scope. 84 static DISubprogram *getSubprogram(DIScope *Scope) { 85 if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) 86 return Subprogram; 87 return cast<DILocalScope>(Scope)->getSubprogram(); 88 } 89 90 namespace { 91 92 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute 93 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) 94 // assuming \p Stride elements between start two consecutive vectors. 95 // \p Stride must be >= \p NumElements. 96 // For column-major matrixes, the function computes the address of a column 97 // vectors and \p NumElements must be set to the number of elements in a column 98 // (= number of rows of the matrix). For row-major matrixes, the function 99 // computes the address of a row vector and \p NumElements must be set to the 100 // number of elements in a column (= number of columns of the matrix). 101 // 102 // Consider a 4x4 matrix in column-mjaor layout like below 103 // 104 // 0 1 2 3 105 // 0 v_0_0 v_0_1 v_0_2 v_0_3 106 // 1 v_1_0 v_1_1 v_1_2 v_1_3 107 // 2 v_2_0 v_2_1 v_2_2 v_2_3 108 // 3 v_3_0 v_3_1 v_3_2 v_3_3 109 110 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, 111 // we need a pointer to the first element of the submatrix as base pointer. 112 // Then we can use computeVectorAddr to compute the addresses for the columns 113 // of the sub-matrix. 114 // 115 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) 116 // -> just returns Base 117 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) 118 // -> returns Base + (1 * 4) 119 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) 120 // -> returns Base + (2 * 4) 121 // 122 // The graphic below illustrates the number of elements in a column (marked 123 // with |) and the number of skipped elements (marked with }). 124 // 125 // v_0_0 v_0_1 {v_0_2 {v_0_3 126 // Base Col 1 Col 2 127 // | | | 128 // v_1_0 |v_1_1 |v_1_2 |v_1_3 129 // v_2_0 |v_2_1 |v_2_2 |v_2_3 130 // v_3_0 {v_3_1 {v_3_2 v_3_3 131 // 132 Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, 133 unsigned NumElements, Type *EltType, 134 IRBuilder<> &Builder) { 135 136 assert((!isa<ConstantInt>(Stride) || 137 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && 138 "Stride must be >= the number of elements in the result vector."); 139 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 140 141 // Compute the start of the vector with index VecIdx as VecIdx * Stride. 142 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); 143 144 // Get pointer to the start of the selected vector. Skip GEP creation, 145 // if we select vector 0. 146 if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) 147 VecStart = BasePtr; 148 else 149 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); 150 151 // Cast elementwise vector start pointer to a pointer to a vector 152 // (EltType x NumElements)*. 153 auto *VecType = FixedVectorType::get(EltType, NumElements); 154 Type *VecPtrType = PointerType::get(VecType, AS); 155 return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); 156 } 157 158 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. 159 /// 160 /// Currently, the lowering for each matrix intrinsic is done as follows: 161 /// 1. Propagate the shape information from intrinsics to connected 162 /// instructions. 163 /// 2. Lower instructions with shape information (assuming column-major layout). 164 /// The lowering works similarly using row-major layout. 165 /// 2.1. Get column vectors for each argument. If we already lowered the 166 /// definition of an argument, use the produced column vectors directly. 167 /// If not, split the operand vector containing an embedded matrix into 168 /// a set of column vectors, 169 /// 2.2. Lower the instruction in terms of column major operations, which 170 /// yields a set of column vectors containing result matrix. Note that we 171 /// lower all instructions that have shape information. Besides the 172 /// intrinsics, this includes stores for example. 173 /// 2.3. Update uses of the lowered instruction. If we have shape information 174 /// for a user, there is nothing to do, as we will look up the result 175 /// column matrix when lowering the user. For other uses, we embed the 176 /// result matrix in a flat vector and update the use. 177 /// 2.4. Cache the result column matrix for the instruction we lowered 178 /// 3. After we lowered all instructions in a function, remove the now 179 /// obsolete instructions. 180 /// 181 class LowerMatrixIntrinsics { 182 Function &Func; 183 const DataLayout &DL; 184 const TargetTransformInfo &TTI; 185 AliasAnalysis &AA; 186 DominatorTree &DT; 187 LoopInfo &LI; 188 OptimizationRemarkEmitter &ORE; 189 190 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. 191 struct OpInfoTy { 192 /// Number of stores emitted to generate this matrix. 193 unsigned NumStores = 0; 194 /// Number of loads emitted to generate this matrix. 195 unsigned NumLoads = 0; 196 /// Number of compute operations emitted to generate this matrix. 197 unsigned NumComputeOps = 0; 198 199 OpInfoTy &operator+=(const OpInfoTy &RHS) { 200 NumStores += RHS.NumStores; 201 NumLoads += RHS.NumLoads; 202 NumComputeOps += RHS.NumComputeOps; 203 return *this; 204 } 205 }; 206 207 /// Wrapper class representing a matrix as a set of vectors, either in row or 208 /// column major layout. All vectors must have the same vector type. 209 class MatrixTy { 210 SmallVector<Value *, 16> Vectors; 211 212 OpInfoTy OpInfo; 213 214 bool IsColumnMajor = true; 215 216 public: 217 MatrixTy() 218 : Vectors(), 219 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 220 MatrixTy(ArrayRef<Value *> Vectors) 221 : Vectors(Vectors.begin(), Vectors.end()), 222 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 223 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) 224 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { 225 226 unsigned D = isColumnMajor() ? NumColumns : NumRows; 227 for (unsigned J = 0; J < D; ++J) 228 addVector(UndefValue::get(FixedVectorType::get( 229 EltTy, isColumnMajor() ? NumRows : NumColumns))); 230 } 231 232 Value *getVector(unsigned i) const { return Vectors[i]; } 233 Value *getColumn(unsigned i) const { 234 assert(isColumnMajor() && "only supported for column-major matrixes"); 235 return Vectors[i]; 236 } 237 Value *getRow(unsigned i) const { 238 assert(!isColumnMajor() && "only supported for row-major matrixes"); 239 return Vectors[i]; 240 } 241 242 void setVector(unsigned i, Value *V) { Vectors[i] = V; } 243 244 Type *getElementType() { return getVectorTy()->getElementType(); } 245 246 unsigned getNumVectors() const { 247 if (isColumnMajor()) 248 return getNumColumns(); 249 return getNumRows(); 250 } 251 252 unsigned getNumColumns() const { 253 if (isColumnMajor()) 254 return Vectors.size(); 255 else { 256 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 257 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 258 } 259 } 260 unsigned getNumRows() const { 261 if (isColumnMajor()) { 262 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); 263 return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); 264 } else 265 return Vectors.size(); 266 } 267 268 void addVector(Value *V) { Vectors.push_back(V); } 269 VectorType *getColumnTy() { 270 assert(isColumnMajor() && "only supported for column-major matrixes"); 271 return getVectorTy(); 272 } 273 274 VectorType *getVectorTy() { 275 return cast<VectorType>(Vectors[0]->getType()); 276 } 277 278 iterator_range<SmallVector<Value *, 8>::iterator> columns() { 279 assert(isColumnMajor() && 280 "columns() only supported for column-major matrixes"); 281 return make_range(Vectors.begin(), Vectors.end()); 282 } 283 284 iterator_range<SmallVector<Value *, 8>::iterator> vectors() { 285 return make_range(Vectors.begin(), Vectors.end()); 286 } 287 288 /// Embed the vectors of the matrix into a flat vector by concatenating 289 /// them. 290 Value *embedInVector(IRBuilder<> &Builder) const { 291 return Vectors.size() == 1 ? Vectors[0] 292 : concatenateVectors(Builder, Vectors); 293 } 294 295 MatrixTy &addNumLoads(unsigned N) { 296 OpInfo.NumLoads += N; 297 return *this; 298 } 299 300 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } 301 302 MatrixTy &addNumStores(unsigned N) { 303 OpInfo.NumStores += N; 304 return *this; 305 } 306 307 MatrixTy &addNumComputeOps(unsigned N) { 308 OpInfo.NumComputeOps += N; 309 return *this; 310 } 311 312 unsigned getNumStores() const { return OpInfo.NumStores; } 313 unsigned getNumLoads() const { return OpInfo.NumLoads; } 314 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } 315 316 const OpInfoTy &getOpInfo() const { return OpInfo; } 317 318 bool isColumnMajor() const { return IsColumnMajor; } 319 320 unsigned getStride() const { 321 if (isColumnMajor()) 322 return getNumRows(); 323 return getNumColumns(); 324 } 325 326 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the 327 /// matrix is column-major, the result vector is extracted from a column 328 /// vector, otherwise from a row vector. 329 Value *extractVector(unsigned I, unsigned J, unsigned NumElts, 330 IRBuilder<> &Builder) const { 331 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); 332 Value *Undef = UndefValue::get(Vec->getType()); 333 return Builder.CreateShuffleVector( 334 Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), 335 "block"); 336 } 337 }; 338 339 struct ShapeInfo { 340 unsigned NumRows; 341 unsigned NumColumns; 342 343 bool IsColumnMajor; 344 345 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) 346 : NumRows(NumRows), NumColumns(NumColumns), 347 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} 348 349 ShapeInfo(Value *NumRows, Value *NumColumns) 350 : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), 351 cast<ConstantInt>(NumColumns)->getZExtValue()) {} 352 353 bool operator==(const ShapeInfo &other) { 354 return NumRows == other.NumRows && NumColumns == other.NumColumns; 355 } 356 bool operator!=(const ShapeInfo &other) { return !(*this == other); } 357 358 /// Returns true if shape-information is defined, meaning both dimensions 359 /// are != 0. 360 operator bool() const { 361 assert(NumRows == 0 || NumColumns != 0); 362 return NumRows != 0; 363 } 364 365 unsigned getStride() const { 366 if (IsColumnMajor) 367 return NumRows; 368 return NumColumns; 369 } 370 371 unsigned getNumVectors() const { 372 if (IsColumnMajor) 373 return NumColumns; 374 return NumRows; 375 } 376 }; 377 378 /// Maps instructions to their shape information. The shape information 379 /// describes the shape to be used while lowering. This matches the shape of 380 /// the result value of the instruction, with the only exceptions being store 381 /// instructions and the matrix_column_major_store intrinsics. For those, the 382 /// shape information indicates that those instructions should be lowered 383 /// using shape information as well. 384 DenseMap<Value *, ShapeInfo> ShapeMap; 385 386 /// List of instructions to remove. While lowering, we are not replacing all 387 /// users of a lowered instruction, if shape information is available and 388 /// those need to be removed after we finished lowering. 389 SmallVector<Instruction *, 16> ToRemove; 390 391 /// Map from instructions to their produced column matrix. 392 MapVector<Value *, MatrixTy> Inst2ColumnMatrix; 393 394 public: 395 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, 396 AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, 397 OptimizationRemarkEmitter &ORE) 398 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), 399 LI(LI), ORE(ORE) {} 400 401 unsigned getNumOps(Type *VT) { 402 assert(isa<VectorType>(VT) && "Expected vector type"); 403 return getNumOps(VT->getScalarType(), 404 cast<FixedVectorType>(VT)->getNumElements()); 405 } 406 407 // 408 /// Return the estimated number of vector ops required for an operation on 409 /// \p VT * N. 410 unsigned getNumOps(Type *ST, unsigned N) { 411 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / 412 double(TTI.getRegisterBitWidth(true))); 413 } 414 415 /// Return the set of vectors that a matrix value is lowered to. 416 /// 417 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise 418 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI 419 /// into vectors. 420 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, 421 IRBuilder<> &Builder) { 422 VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); 423 assert(VType && "MatrixVal must be a vector type"); 424 assert(cast<FixedVectorType>(VType)->getNumElements() == 425 SI.NumRows * SI.NumColumns && 426 "The vector size must match the number of matrix elements"); 427 428 // Check if we lowered MatrixVal using shape information. In that case, 429 // return the existing matrix, if it matches the requested shape 430 // information. If there is a mis-match, embed the result in a flat 431 // vector and split it later. 432 auto Found = Inst2ColumnMatrix.find(MatrixVal); 433 if (Found != Inst2ColumnMatrix.end()) { 434 MatrixTy &M = Found->second; 435 // Return the found matrix, if its shape matches the requested shape 436 // information 437 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) 438 return M; 439 440 MatrixVal = M.embedInVector(Builder); 441 } 442 443 // Otherwise split MatrixVal. 444 SmallVector<Value *, 16> SplitVecs; 445 Value *Undef = UndefValue::get(VType); 446 for (unsigned MaskStart = 0; 447 MaskStart < cast<FixedVectorType>(VType)->getNumElements(); 448 MaskStart += SI.getStride()) { 449 Value *V = Builder.CreateShuffleVector( 450 MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), 451 "split"); 452 SplitVecs.push_back(V); 453 } 454 455 return {SplitVecs}; 456 } 457 458 /// If \p V already has a known shape return false. Otherwise set the shape 459 /// for instructions that support it. 460 bool setShapeInfo(Value *V, ShapeInfo Shape) { 461 assert(Shape && "Shape not set"); 462 if (isa<UndefValue>(V) || !supportsShapeInfo(V)) 463 return false; 464 465 auto SIter = ShapeMap.find(V); 466 if (SIter != ShapeMap.end()) { 467 LLVM_DEBUG(dbgs() << " not overriding existing shape: " 468 << SIter->second.NumRows << " " 469 << SIter->second.NumColumns << " for " << *V << "\n"); 470 return false; 471 } 472 473 ShapeMap.insert({V, Shape}); 474 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns 475 << " for " << *V << "\n"); 476 return true; 477 } 478 479 bool isUniformShape(Value *V) { 480 Instruction *I = dyn_cast<Instruction>(V); 481 if (!I) 482 return true; 483 484 switch (I->getOpcode()) { 485 case Instruction::FAdd: 486 case Instruction::FSub: 487 case Instruction::FMul: // Scalar multiply. 488 case Instruction::Add: 489 case Instruction::Mul: 490 case Instruction::Sub: 491 return true; 492 default: 493 return false; 494 } 495 } 496 497 /// Returns true if shape information can be used for \p V. The supported 498 /// instructions must match the instructions that can be lowered by this pass. 499 bool supportsShapeInfo(Value *V) { 500 Instruction *Inst = dyn_cast<Instruction>(V); 501 if (!Inst) 502 return false; 503 504 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); 505 if (II) 506 switch (II->getIntrinsicID()) { 507 case Intrinsic::matrix_multiply: 508 case Intrinsic::matrix_transpose: 509 case Intrinsic::matrix_column_major_load: 510 case Intrinsic::matrix_column_major_store: 511 return true; 512 default: 513 return false; 514 } 515 return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); 516 } 517 518 /// Propagate the shape information of instructions to their users. 519 /// The work list contains instructions for which we can compute the shape, 520 /// either based on the information provided by matrix intrinsics or known 521 /// shapes of operands. 522 SmallVector<Instruction *, 32> 523 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) { 524 SmallVector<Instruction *, 32> NewWorkList; 525 // Pop an element for which we guaranteed to have at least one of the 526 // operand shapes. Add the shape for this and then add users to the work 527 // list. 528 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); 529 while (!WorkList.empty()) { 530 Instruction *Inst = WorkList.back(); 531 WorkList.pop_back(); 532 533 // New entry, set the value and insert operands 534 bool Propagate = false; 535 536 Value *MatrixA; 537 Value *MatrixB; 538 Value *M; 539 Value *N; 540 Value *K; 541 if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( 542 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 543 m_Value(N), m_Value(K)))) { 544 Propagate = setShapeInfo(Inst, {M, K}); 545 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( 546 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 547 // Flip dimensions. 548 Propagate = setShapeInfo(Inst, {N, M}); 549 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( 550 m_Value(MatrixA), m_Value(), m_Value(), 551 m_Value(), m_Value(M), m_Value(N)))) { 552 Propagate = setShapeInfo(Inst, {N, M}); 553 } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( 554 m_Value(), m_Value(), m_Value(), m_Value(M), 555 m_Value(N)))) { 556 Propagate = setShapeInfo(Inst, {M, N}); 557 } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { 558 auto OpShape = ShapeMap.find(MatrixA); 559 if (OpShape != ShapeMap.end()) 560 setShapeInfo(Inst, OpShape->second); 561 continue; 562 } else if (isUniformShape(Inst)) { 563 // Find the first operand that has a known shape and use that. 564 for (auto &Op : Inst->operands()) { 565 auto OpShape = ShapeMap.find(Op.get()); 566 if (OpShape != ShapeMap.end()) { 567 Propagate |= setShapeInfo(Inst, OpShape->second); 568 break; 569 } 570 } 571 } 572 573 if (Propagate) { 574 NewWorkList.push_back(Inst); 575 for (auto *User : Inst->users()) 576 if (ShapeMap.count(User) == 0) 577 WorkList.push_back(cast<Instruction>(User)); 578 } 579 } 580 581 return NewWorkList; 582 } 583 584 /// Propagate the shape to operands of instructions with shape information. 585 /// \p Worklist contains the instruction for which we already know the shape. 586 SmallVector<Instruction *, 32> 587 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) { 588 SmallVector<Instruction *, 32> NewWorkList; 589 590 auto pushInstruction = [](Value *V, 591 SmallVectorImpl<Instruction *> &WorkList) { 592 Instruction *I = dyn_cast<Instruction>(V); 593 if (I) 594 WorkList.push_back(I); 595 }; 596 // Pop an element with known shape. Traverse the operands, if their shape 597 // derives from the result shape and is unknown, add it and add them to the 598 // worklist. 599 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); 600 while (!WorkList.empty()) { 601 Value *V = WorkList.back(); 602 WorkList.pop_back(); 603 604 size_t BeforeProcessingV = WorkList.size(); 605 if (!isa<Instruction>(V)) 606 continue; 607 608 Value *MatrixA; 609 Value *MatrixB; 610 Value *M; 611 Value *N; 612 Value *K; 613 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( 614 m_Value(MatrixA), m_Value(MatrixB), m_Value(M), 615 m_Value(N), m_Value(K)))) { 616 if (setShapeInfo(MatrixA, {M, N})) 617 pushInstruction(MatrixA, WorkList); 618 619 if (setShapeInfo(MatrixB, {N, K})) 620 pushInstruction(MatrixB, WorkList); 621 622 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( 623 m_Value(MatrixA), m_Value(M), m_Value(N)))) { 624 // Flip dimensions. 625 if (setShapeInfo(MatrixA, {M, N})) 626 pushInstruction(MatrixA, WorkList); 627 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( 628 m_Value(MatrixA), m_Value(), m_Value(), m_Value(), 629 m_Value(M), m_Value(N)))) { 630 if (setShapeInfo(MatrixA, {M, N})) { 631 pushInstruction(MatrixA, WorkList); 632 } 633 } else if (isa<LoadInst>(V) || 634 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { 635 // Nothing to do, no matrix input. 636 } else if (isa<StoreInst>(V)) { 637 // Nothing to do. We forward-propagated to this so we would just 638 // backward propagate to an instruction with an already known shape. 639 } else if (isUniformShape(V)) { 640 // Propagate to all operands. 641 ShapeInfo Shape = ShapeMap[V]; 642 for (Use &U : cast<Instruction>(V)->operands()) { 643 if (setShapeInfo(U.get(), Shape)) 644 pushInstruction(U.get(), WorkList); 645 } 646 } 647 // After we discovered new shape info for new instructions in the 648 // worklist, we use their users as seeds for the next round of forward 649 // propagation. 650 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) 651 for (User *U : WorkList[I]->users()) 652 if (isa<Instruction>(U) && V != U) 653 NewWorkList.push_back(cast<Instruction>(U)); 654 } 655 return NewWorkList; 656 } 657 658 bool Visit() { 659 if (EnableShapePropagation) { 660 SmallVector<Instruction *, 32> WorkList; 661 662 // Initially only the shape of matrix intrinsics is known. 663 // Initialize the work list with ops carrying shape information. 664 for (BasicBlock &BB : Func) 665 for (Instruction &Inst : BB) { 666 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); 667 if (!II) 668 continue; 669 670 switch (II->getIntrinsicID()) { 671 case Intrinsic::matrix_multiply: 672 case Intrinsic::matrix_transpose: 673 case Intrinsic::matrix_column_major_load: 674 case Intrinsic::matrix_column_major_store: 675 WorkList.push_back(&Inst); 676 break; 677 default: 678 break; 679 } 680 } 681 // Propagate shapes until nothing changes any longer. 682 while (!WorkList.empty()) { 683 WorkList = propagateShapeForward(WorkList); 684 WorkList = propagateShapeBackward(WorkList); 685 } 686 } 687 688 bool Changed = false; 689 SmallVector<CallInst *, 16> MaybeFusableInsts; 690 SmallVector<Instruction *, 16> MatrixInsts; 691 692 // First, collect all instructions with shape information and candidates for 693 // fusion (currently only matrix multiplies). 694 ReversePostOrderTraversal<Function *> RPOT(&Func); 695 for (auto *BB : RPOT) 696 for (Instruction &I : *BB) { 697 if (ShapeMap.find(&I) == ShapeMap.end()) 698 continue; 699 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) 700 MaybeFusableInsts.push_back(cast<CallInst>(&I)); 701 MatrixInsts.push_back(&I); 702 } 703 704 // Second, try to fuse candidates. 705 SmallPtrSet<Instruction *, 16> FusedInsts; 706 for (CallInst *CI : MaybeFusableInsts) 707 LowerMatrixMultiplyFused(CI, FusedInsts); 708 Changed = !FusedInsts.empty(); 709 710 // Third, lower remaining instructions with shape information. 711 for (Instruction *Inst : MatrixInsts) { 712 if (FusedInsts.count(Inst)) 713 continue; 714 715 IRBuilder<> Builder(Inst); 716 717 if (CallInst *CInst = dyn_cast<CallInst>(Inst)) 718 Changed |= VisitCallInst(CInst); 719 720 Value *Op1; 721 Value *Op2; 722 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) 723 Changed |= VisitBinaryOperator(BinOp); 724 if (match(Inst, m_Load(m_Value(Op1)))) 725 Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); 726 else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) 727 Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); 728 } 729 730 RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); 731 RemarkGen.emitRemarks(); 732 733 for (Instruction *Inst : reverse(ToRemove)) 734 Inst->eraseFromParent(); 735 736 return Changed; 737 } 738 739 /// Turns \p BasePtr into an elementwise pointer to \p EltType. 740 Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { 741 unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); 742 Type *EltPtrType = PointerType::get(EltType, AS); 743 return Builder.CreatePointerCast(BasePtr, EltPtrType); 744 } 745 746 /// Replace intrinsic calls 747 bool VisitCallInst(CallInst *Inst) { 748 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) 749 return false; 750 751 switch (Inst->getCalledFunction()->getIntrinsicID()) { 752 case Intrinsic::matrix_multiply: 753 LowerMultiply(Inst); 754 break; 755 case Intrinsic::matrix_transpose: 756 LowerTranspose(Inst); 757 break; 758 case Intrinsic::matrix_column_major_load: 759 LowerColumnMajorLoad(Inst); 760 break; 761 case Intrinsic::matrix_column_major_store: 762 LowerColumnMajorStore(Inst); 763 break; 764 default: 765 return false; 766 } 767 return true; 768 } 769 770 /// Compute the alignment for a column/row \p Idx with \p Stride between them. 771 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a 772 /// ConstantInt, reduce the initial alignment based on the byte offset. For 773 /// non-ConstantInt strides, return the common alignment of the initial 774 /// alignment and the element size in bytes. 775 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, 776 MaybeAlign A) const { 777 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); 778 if (Idx == 0) 779 return InitialAlign; 780 781 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); 782 if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { 783 uint64_t StrideInBytes = 784 ConstStride->getZExtValue() * ElementSizeInBits / 8; 785 return commonAlignment(InitialAlign, Idx * StrideInBytes); 786 } 787 return commonAlignment(InitialAlign, ElementSizeInBits / 8); 788 } 789 790 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between 791 /// vectors. 792 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, 793 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { 794 auto VType = cast<VectorType>(Ty); 795 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 796 MatrixTy Result; 797 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { 798 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, 799 Shape.getStride(), VType->getElementType(), 800 Builder); 801 Value *Vector = Builder.CreateAlignedLoad( 802 GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), 803 IsVolatile, "col.load"); 804 805 Result.addVector(Vector); 806 } 807 return Result.addNumLoads(getNumOps(Result.getVectorTy()) * 808 Result.getNumVectors()); 809 } 810 811 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, 812 /// starting at \p MatrixPtr[I][J]. 813 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, 814 ShapeInfo MatrixShape, Value *I, Value *J, 815 ShapeInfo ResultShape, Type *EltTy, 816 IRBuilder<> &Builder) { 817 818 Value *Offset = Builder.CreateAdd( 819 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 820 821 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 822 Value *EltPtr = 823 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 824 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 825 auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * 826 ResultShape.NumColumns); 827 Type *TilePtrTy = PointerType::get(TileTy, AS); 828 Value *TilePtr = 829 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 830 831 return loadMatrix(TileTy, TilePtr, Align, 832 Builder.getInt64(MatrixShape.getStride()), IsVolatile, 833 ResultShape, Builder); 834 } 835 836 /// Lower a load instruction with shape information. 837 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, 838 bool IsVolatile, ShapeInfo Shape) { 839 IRBuilder<> Builder(Inst); 840 finalizeLowering(Inst, 841 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, 842 Shape, Builder), 843 Builder); 844 } 845 846 /// Lowers llvm.matrix.column.major.load. 847 /// 848 /// The intrinsic loads a matrix from memory using a stride between columns. 849 void LowerColumnMajorLoad(CallInst *Inst) { 850 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 851 "Intrinsic only supports column-major layout!"); 852 Value *Ptr = Inst->getArgOperand(0); 853 Value *Stride = Inst->getArgOperand(1); 854 LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, 855 cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), 856 {Inst->getArgOperand(3), Inst->getArgOperand(4)}); 857 } 858 859 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p 860 /// MatrixPtr[I][J]. 861 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, 862 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, 863 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { 864 Value *Offset = Builder.CreateAdd( 865 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); 866 867 unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); 868 Value *EltPtr = 869 Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); 870 Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); 871 auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * 872 StoreVal.getNumColumns()); 873 Type *TilePtrTy = PointerType::get(TileTy, AS); 874 Value *TilePtr = 875 Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); 876 877 storeMatrix(TileTy, StoreVal, TilePtr, MAlign, 878 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); 879 } 880 881 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between 882 /// vectors. 883 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, 884 MaybeAlign MAlign, Value *Stride, bool IsVolatile, 885 IRBuilder<> &Builder) { 886 auto VType = cast<VectorType>(Ty); 887 Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); 888 for (auto Vec : enumerate(StoreVal.vectors())) { 889 Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), 890 Stride, StoreVal.getStride(), 891 VType->getElementType(), Builder); 892 Builder.CreateAlignedStore(Vec.value(), GEP, 893 getAlignForIndex(Vec.index(), Stride, 894 VType->getElementType(), 895 MAlign), 896 IsVolatile); 897 } 898 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * 899 StoreVal.getNumVectors()); 900 } 901 902 /// Lower a store instruction with shape information. 903 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, 904 Value *Stride, bool IsVolatile, ShapeInfo Shape) { 905 IRBuilder<> Builder(Inst); 906 auto StoreVal = getMatrix(Matrix, Shape, Builder); 907 finalizeLowering(Inst, 908 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, 909 IsVolatile, Builder), 910 Builder); 911 } 912 913 /// Lowers llvm.matrix.column.major.store. 914 /// 915 /// The intrinsic store a matrix back memory using a stride between columns. 916 void LowerColumnMajorStore(CallInst *Inst) { 917 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 918 "Intrinsic only supports column-major layout!"); 919 Value *Matrix = Inst->getArgOperand(0); 920 Value *Ptr = Inst->getArgOperand(1); 921 Value *Stride = Inst->getArgOperand(2); 922 LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, 923 cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), 924 {Inst->getArgOperand(4), Inst->getArgOperand(5)}); 925 } 926 927 // Set elements I..I+NumElts-1 to Block 928 Value *insertVector(Value *Col, unsigned I, Value *Block, 929 IRBuilder<> &Builder) { 930 931 // First, bring Block to the same size as Col 932 unsigned BlockNumElts = 933 cast<FixedVectorType>(Block->getType())->getNumElements(); 934 unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); 935 assert(NumElts >= BlockNumElts && "Too few elements for current block"); 936 937 Value *Undef = UndefValue::get(Block->getType()); 938 Block = Builder.CreateShuffleVector( 939 Block, Undef, 940 createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); 941 942 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, 943 // 8, 4, 5, 6 944 SmallVector<int, 16> Mask; 945 unsigned i; 946 for (i = 0; i < I; i++) 947 Mask.push_back(i); 948 949 unsigned VecNumElts = 950 cast<FixedVectorType>(Col->getType())->getNumElements(); 951 for (; i < I + BlockNumElts; i++) 952 Mask.push_back(i - I + VecNumElts); 953 954 for (; i < VecNumElts; i++) 955 Mask.push_back(i); 956 957 return Builder.CreateShuffleVector(Col, Block, Mask); 958 } 959 960 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, 961 IRBuilder<> &Builder, bool AllowContraction, 962 unsigned &NumComputeOps) { 963 NumComputeOps += getNumOps(A->getType()); 964 if (!Sum) 965 return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); 966 967 if (UseFPOp) { 968 if (AllowContraction) { 969 // Use fmuladd for floating point operations and let the backend decide 970 // if that's profitable. 971 Function *FMulAdd = Intrinsic::getDeclaration( 972 Func.getParent(), Intrinsic::fmuladd, A->getType()); 973 return Builder.CreateCall(FMulAdd, {A, B, Sum}); 974 } 975 NumComputeOps += getNumOps(A->getType()); 976 Value *Mul = Builder.CreateFMul(A, B); 977 return Builder.CreateFAdd(Sum, Mul); 978 } 979 980 NumComputeOps += getNumOps(A->getType()); 981 Value *Mul = Builder.CreateMul(A, B); 982 return Builder.CreateAdd(Sum, Mul); 983 } 984 985 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For 986 /// users with shape information, there's nothing to do: the will use the 987 /// cached value when they are lowered. For other users, \p Matrix is 988 /// flattened and the uses are updated to use it. Also marks \p Inst for 989 /// deletion. 990 void finalizeLowering(Instruction *Inst, MatrixTy Matrix, 991 IRBuilder<> &Builder) { 992 Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); 993 994 ToRemove.push_back(Inst); 995 Value *Flattened = nullptr; 996 for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { 997 Use &U = *I++; 998 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) { 999 if (!Flattened) 1000 Flattened = Matrix.embedInVector(Builder); 1001 U.set(Flattened); 1002 } 1003 } 1004 } 1005 1006 /// Compute \p Result += \p A * \p B for input matrices with left-associating 1007 /// addition. 1008 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, 1009 const MatrixTy &B, bool AllowContraction, 1010 IRBuilder<> &Builder, bool isTiled) { 1011 const unsigned VF = std::max<unsigned>( 1012 TTI.getRegisterBitWidth(true) / 1013 Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), 1014 1U); 1015 unsigned R = Result.getNumRows(); 1016 unsigned C = Result.getNumColumns(); 1017 unsigned M = A.getNumColumns(); 1018 1019 bool IsFP = Result.getElementType()->isFloatingPointTy(); 1020 assert(A.isColumnMajor() == B.isColumnMajor() && 1021 Result.isColumnMajor() == A.isColumnMajor() && 1022 "operands must agree on matrix layout"); 1023 unsigned NumComputeOps = 0; 1024 if (A.isColumnMajor()) { 1025 // Multiply columns from the first operand with scalars from the second 1026 // operand. Then move along the K axes and accumulate the columns. With 1027 // this the adds can be vectorized without reassociation. 1028 for (unsigned J = 0; J < C; ++J) { 1029 unsigned BlockSize = VF; 1030 // If Result is zero, we don't need to accumulate in the K==0 iteration. 1031 bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); 1032 1033 for (unsigned I = 0; I < R; I += BlockSize) { 1034 // Gradually lower the vectorization factor to cover the remainder. 1035 while (I + BlockSize > R) 1036 BlockSize /= 2; 1037 1038 Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) 1039 : nullptr; 1040 for (unsigned K = 0; K < M; ++K) { 1041 Value *L = A.extractVector(I, K, BlockSize, Builder); 1042 Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); 1043 Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); 1044 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, 1045 Result.getElementType()->isFloatingPointTy(), 1046 Builder, AllowContraction, NumComputeOps); 1047 } 1048 Result.setVector(J, 1049 insertVector(Result.getVector(J), I, Sum, Builder)); 1050 } 1051 } 1052 } else { 1053 // Multiply rows from the second operand with scalars from the first 1054 // operand. Then move along the K axes and accumulate the rows. With this 1055 // the adds can be vectorized without reassociation. 1056 for (unsigned I = 0; I < R; ++I) { 1057 unsigned BlockSize = VF; 1058 bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); 1059 for (unsigned J = 0; J < C; J += BlockSize) { 1060 // Gradually lower the vectorization factor to cover the remainder. 1061 while (J + BlockSize > C) 1062 BlockSize /= 2; 1063 1064 Value *Sum = nullptr; 1065 for (unsigned K = 0; K < M; ++K) { 1066 Value *R = B.extractVector(K, J, BlockSize, Builder); 1067 Value *LH = Builder.CreateExtractElement(A.getVector(I), K); 1068 Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); 1069 Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, 1070 IsFP, Builder, AllowContraction, NumComputeOps); 1071 } 1072 Result.setVector(I, 1073 insertVector(Result.getVector(I), J, Sum, Builder)); 1074 } 1075 } 1076 } 1077 Result.addNumComputeOps(NumComputeOps); 1078 } 1079 1080 /// Ensure that the memory in \p Load does not alias \p Store by potentially 1081 /// copying it to a new location. This new or otherwise the original location 1082 /// is returned. 1083 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, 1084 CallInst *MatMul) { 1085 MemoryLocation StoreLoc = MemoryLocation::get(Store); 1086 MemoryLocation LoadLoc = MemoryLocation::get(Load); 1087 1088 AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); 1089 1090 // If we can statically determine noalias we're good. 1091 if (!LdAliased) 1092 return Load->getPointerOperand(); 1093 1094 // Create code to check if the memory locations of the Load and Store 1095 // overlap and if they do, copy Load's operand to a new buffer. 1096 1097 // First, create new blocks for 2n part of the check and the copy. 1098 BasicBlock *Check0 = MatMul->getParent(); 1099 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a 1100 // DT. Manually collect dominator tree updates, to avoid unnecessary work, 1101 // as we adjust Check0 and Check1's branches. 1102 SmallVector<DominatorTree::UpdateType, 4> DTUpdates; 1103 for (BasicBlock *Succ : successors(Check0)) 1104 DTUpdates.push_back({DT.Delete, Check0, Succ}); 1105 1106 BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1107 nullptr, "alias_cont"); 1108 BasicBlock *Copy = 1109 SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); 1110 BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, 1111 nullptr, "no_alias"); 1112 1113 // Check if the loaded memory location begins before the end of the store 1114 // location. If the condition holds, they might overlap, otherwise they are 1115 // guaranteed to not overlap. 1116 IRBuilder<> Builder(MatMul); 1117 Check0->getTerminator()->eraseFromParent(); 1118 Builder.SetInsertPoint(Check0); 1119 Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); 1120 Value *StoreBegin = Builder.CreatePtrToInt( 1121 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); 1122 Value *StoreEnd = Builder.CreateAdd( 1123 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), 1124 "store.end", true, true); 1125 Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), 1126 IntPtrTy, "load.begin"); 1127 Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, 1128 Fusion); 1129 1130 // Check if the store begins before the end of the load location. If the 1131 // condition holds, they alias, otherwise they are guaranteed to not 1132 // overlap. 1133 Check1->getTerminator()->eraseFromParent(); 1134 Builder.SetInsertPoint(Check1, Check1->begin()); 1135 Value *LoadEnd = Builder.CreateAdd( 1136 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), 1137 "load.end", true, true); 1138 Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, 1139 Fusion); 1140 1141 // Copy load operand to new alloca. 1142 Builder.SetInsertPoint(Copy, Copy->begin()); 1143 AllocaInst *NewLd = 1144 Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); 1145 Builder.CreateMemCpy(NewLd, NewLd->getAlign(), 1146 Load->getPointerOperand(), Load->getAlign(), 1147 LoadLoc.Size.getValue()); 1148 Builder.SetInsertPoint(Fusion, Fusion->begin()); 1149 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); 1150 PHI->addIncoming(Load->getPointerOperand(), Check0); 1151 PHI->addIncoming(Load->getPointerOperand(), Check1); 1152 PHI->addIncoming(NewLd, Copy); 1153 1154 // Adjust DT. 1155 DTUpdates.push_back({DT.Insert, Check0, Check1}); 1156 DTUpdates.push_back({DT.Insert, Check0, Fusion}); 1157 DTUpdates.push_back({DT.Insert, Check1, Copy}); 1158 DTUpdates.push_back({DT.Insert, Check1, Fusion}); 1159 DT.applyUpdates(DTUpdates); 1160 return PHI; 1161 } 1162 1163 bool isFusionProfitable(CallInst *MatMul) { 1164 if (ForceFusion) 1165 return true; 1166 1167 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1168 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1169 1170 const unsigned R = LShape.NumRows; 1171 const unsigned C = RShape.NumColumns; 1172 const unsigned M = LShape.NumColumns; 1173 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1174 1175 const unsigned VF = 1176 std::max<unsigned>(TTI.getRegisterBitWidth(true) / 1177 EltType->getPrimitiveSizeInBits().getFixedSize(), 1178 1U); 1179 1180 // Cost model for tiling 1181 // 1182 // For tiling to be beneficial, we need reuse either along the R or 1183 // the C axis. We vectorize along the R axis so that means at least 1184 // 3 elements. 1185 // TODO: Also consider cost of copying if operands alias. 1186 if (R <= VF && C == 1) 1187 return false; 1188 // Then we need enough elements to exceed the number of vector 1189 // registers we have. Note that this is an oversimplification since 1190 // fusing also takes some extra loads which may exceed the number of 1191 // reloads necessary. 1192 unsigned Op0Regs = (R + VF - 1) / VF * M; 1193 unsigned Op1Regs = (M + VF - 1) / VF * C; 1194 return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); 1195 } 1196 1197 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { 1198 MatrixTy Res; 1199 auto *ColumType = FixedVectorType::get(EltType, R); 1200 for (unsigned I = 0; I < C; ++I) 1201 Res.addVector(ConstantAggregateZero::get(ColumType)); 1202 return Res; 1203 } 1204 1205 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, 1206 StoreInst *Store, 1207 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1208 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && 1209 "Tiling only supported for column-major matrixes at the moment!"); 1210 if (!isFusionProfitable(MatMul)) 1211 return; 1212 1213 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1214 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1215 1216 const unsigned R = LShape.NumRows; 1217 const unsigned C = RShape.NumColumns; 1218 const unsigned M = LShape.NumColumns; 1219 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1220 1221 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); 1222 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); 1223 Value *CPtr = Store->getPointerOperand(); 1224 1225 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1226 MatMul->hasAllowContract()); 1227 IRBuilder<> Builder(Store); 1228 for (unsigned J = 0; J < C; J += TileSize) 1229 for (unsigned I = 0; I < R; I += TileSize) { 1230 const unsigned TileR = std::min(R - I, unsigned(TileSize)); 1231 const unsigned TileC = std::min(C - J, unsigned(TileSize)); 1232 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); 1233 1234 for (unsigned K = 0; K < M; K += TileSize) { 1235 const unsigned TileM = std::min(M - K, unsigned(TileSize)); 1236 MatrixTy A = 1237 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), 1238 LShape, Builder.getInt64(I), Builder.getInt64(K), 1239 {TileR, TileM}, EltType, Builder); 1240 MatrixTy B = 1241 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), 1242 RShape, Builder.getInt64(K), Builder.getInt64(J), 1243 {TileM, TileC}, EltType, Builder); 1244 emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); 1245 } 1246 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, 1247 Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); 1248 } 1249 1250 // Mark eliminated instructions as fused and remove them. 1251 FusedInsts.insert(Store); 1252 FusedInsts.insert(MatMul); 1253 Store->eraseFromParent(); 1254 MatMul->eraseFromParent(); 1255 if (LoadOp0->hasNUses(0)) { 1256 FusedInsts.insert(LoadOp0); 1257 LoadOp0->eraseFromParent(); 1258 } 1259 if (LoadOp1->hasNUses(0)) { 1260 FusedInsts.insert(LoadOp1); 1261 LoadOp1->eraseFromParent(); 1262 } 1263 } 1264 1265 /// Try to lower matrix multiply chains by fusing operations. 1266 /// 1267 /// Currently we only lower {ld, ld} -> matmul -> st chains. 1268 // 1269 /// No need to return a MatrixTy object for the result of the operation, since 1270 /// the single store user will be lowered as part of this. Instructions that 1271 /// are completely eliminated by fusion are added to \p FusedInsts. 1272 void LowerMatrixMultiplyFused(CallInst *MatMul, 1273 SmallPtrSetImpl<Instruction *> &FusedInsts) { 1274 if (!FuseMatrix || !MatMul->hasOneUse() || 1275 MatrixLayout != MatrixLayoutTy::ColumnMajor) 1276 return; 1277 1278 auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); 1279 auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); 1280 auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); 1281 if (LoadOp0 && LoadOp1 && Store) { 1282 // The store address must dominate the MatMul instruction, otherwise 1283 // we create invalid IR. 1284 // FIXME: See if we can hoist the store address computation. 1285 auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); 1286 if (AddrI && (!DT.dominates(AddrI, MatMul))) 1287 return; 1288 1289 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); 1290 return; 1291 } 1292 } 1293 1294 /// Lowers llvm.matrix.multiply. 1295 void LowerMultiply(CallInst *MatMul) { 1296 IRBuilder<> Builder(MatMul); 1297 auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); 1298 ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); 1299 ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); 1300 1301 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); 1302 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); 1303 1304 const unsigned R = LShape.NumRows; 1305 const unsigned C = RShape.NumColumns; 1306 assert(LShape.NumColumns == RShape.NumRows); 1307 1308 // Initialize the output 1309 MatrixTy Result(R, C, EltType); 1310 1311 bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && 1312 MatMul->hasAllowContract()); 1313 emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); 1314 finalizeLowering(MatMul, Result, Builder); 1315 } 1316 1317 /// Lowers llvm.matrix.transpose. 1318 void LowerTranspose(CallInst *Inst) { 1319 MatrixTy Result; 1320 IRBuilder<> Builder(Inst); 1321 Value *InputVal = Inst->getArgOperand(0); 1322 VectorType *VectorTy = cast<VectorType>(InputVal->getType()); 1323 ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); 1324 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); 1325 1326 const unsigned NewNumVecs = 1327 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; 1328 const unsigned NewNumElts = 1329 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; 1330 1331 for (unsigned I = 0; I < NewNumVecs; ++I) { 1332 // Build a single result vector. First initialize it. 1333 Value *ResultVector = UndefValue::get( 1334 FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); 1335 // Go through the old elements and insert it into the resulting vector. 1336 for (auto J : enumerate(InputMatrix.vectors())) { 1337 Value *Elt = Builder.CreateExtractElement(J.value(), I); 1338 // Row and column indices are transposed. 1339 ResultVector = 1340 Builder.CreateInsertElement(ResultVector, Elt, J.index()); 1341 } 1342 Result.addVector(ResultVector); 1343 } 1344 1345 // TODO: Improve estimate of operations needed for transposes. Currently we 1346 // just count the insertelement/extractelement instructions, but do not 1347 // account for later simplifications/combines. 1348 finalizeLowering( 1349 Inst, 1350 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), 1351 Builder); 1352 } 1353 1354 /// Lower load instructions, if shape information is available. 1355 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { 1356 auto I = ShapeMap.find(Inst); 1357 if (I == ShapeMap.end()) 1358 return false; 1359 1360 LowerLoad(Inst, Ptr, Inst->getAlign(), 1361 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1362 I->second); 1363 return true; 1364 } 1365 1366 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, 1367 IRBuilder<> &Builder) { 1368 auto I = ShapeMap.find(StoredVal); 1369 if (I == ShapeMap.end()) 1370 return false; 1371 1372 LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), 1373 Builder.getInt64(I->second.getStride()), Inst->isVolatile(), 1374 I->second); 1375 return true; 1376 } 1377 1378 /// Lower binary operators, if shape information is available. 1379 bool VisitBinaryOperator(BinaryOperator *Inst) { 1380 auto I = ShapeMap.find(Inst); 1381 if (I == ShapeMap.end()) 1382 return false; 1383 1384 Value *Lhs = Inst->getOperand(0); 1385 Value *Rhs = Inst->getOperand(1); 1386 1387 IRBuilder<> Builder(Inst); 1388 ShapeInfo &Shape = I->second; 1389 1390 MatrixTy Result; 1391 MatrixTy A = getMatrix(Lhs, Shape, Builder); 1392 MatrixTy B = getMatrix(Rhs, Shape, Builder); 1393 assert(A.isColumnMajor() == B.isColumnMajor() && 1394 Result.isColumnMajor() == A.isColumnMajor() && 1395 "operands must agree on matrix layout"); 1396 1397 // Helper to perform binary op on vectors. 1398 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { 1399 switch (Inst->getOpcode()) { 1400 case Instruction::Add: 1401 return Builder.CreateAdd(LHS, RHS); 1402 case Instruction::Mul: 1403 return Builder.CreateMul(LHS, RHS); 1404 case Instruction::Sub: 1405 return Builder.CreateSub(LHS, RHS); 1406 case Instruction::FAdd: 1407 return Builder.CreateFAdd(LHS, RHS); 1408 case Instruction::FMul: 1409 return Builder.CreateFMul(LHS, RHS); 1410 case Instruction::FSub: 1411 return Builder.CreateFSub(LHS, RHS); 1412 default: 1413 llvm_unreachable("Unsupported binary operator for matrix"); 1414 } 1415 }; 1416 1417 for (unsigned I = 0; I < Shape.getNumVectors(); ++I) 1418 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); 1419 1420 finalizeLowering(Inst, 1421 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * 1422 Result.getNumVectors()), 1423 Builder); 1424 return true; 1425 } 1426 1427 /// Helper to linearize a matrix expression tree into a string. Currently 1428 /// matrix expressions are linarized by starting at an expression leaf and 1429 /// linearizing bottom up. 1430 struct ExprLinearizer { 1431 unsigned LengthToBreak = 100; 1432 std::string Str; 1433 raw_string_ostream Stream; 1434 unsigned LineLength = 0; 1435 const DataLayout &DL; 1436 1437 /// Mapping from instructions to matrixes. It is used to identify 1438 /// matrix instructions. 1439 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1440 1441 /// Mapping from values to the leaves of all expressions that the value is 1442 /// part of. 1443 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; 1444 1445 /// Set of matrix expressions in the scope of a given DISubprogram. 1446 const SmallSetVector<Value *, 32> &ExprsInSubprogram; 1447 1448 /// Leaf node of the expression to linearize. 1449 Value *Leaf; 1450 1451 /// Used to keep track of sub-expressions that get reused while linearizing 1452 /// the expression. Re-used sub-expressions are marked as (reused). 1453 SmallPtrSet<Value *, 8> ReusedExprs; 1454 1455 ExprLinearizer(const DataLayout &DL, 1456 const MapVector<Value *, MatrixTy> &Inst2Matrix, 1457 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1458 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1459 Value *Leaf) 1460 : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), 1461 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} 1462 1463 void indent(unsigned N) { 1464 LineLength += N; 1465 for (unsigned i = 0; i < N; i++) 1466 Stream << " "; 1467 } 1468 1469 void lineBreak() { 1470 Stream << "\n"; 1471 LineLength = 0; 1472 } 1473 1474 void maybeIndent(unsigned Indent) { 1475 if (LineLength >= LengthToBreak) 1476 lineBreak(); 1477 1478 if (LineLength == 0) 1479 indent(Indent); 1480 } 1481 1482 void write(StringRef S) { 1483 LineLength += S.size(); 1484 Stream << S; 1485 } 1486 1487 Value *getUnderlyingObjectThroughLoads(Value *V) { 1488 if (Value *Ptr = getPointerOperand(V)) 1489 return getUnderlyingObjectThroughLoads(Ptr); 1490 else if (V->getType()->isPointerTy()) 1491 return GetUnderlyingObject(V, DL); 1492 return V; 1493 } 1494 1495 /// Returns true if \p V is a matrix value in the given subprogram. 1496 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } 1497 1498 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to 1499 /// \p SS. 1500 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { 1501 auto M = Inst2Matrix.find(V); 1502 if (M == Inst2Matrix.end()) 1503 SS << "unknown"; 1504 else { 1505 SS << M->second.getNumRows(); 1506 SS << "x"; 1507 SS << M->second.getNumColumns(); 1508 } 1509 } 1510 1511 /// Write the called function name. Handles calls to llvm.matrix.* 1512 /// specially: we write the name, followed by the dimensions of the input 1513 /// matrixes, followed by the scalar type name. 1514 void writeFnName(CallInst *CI) { 1515 if (!CI->getCalledFunction()) 1516 write("<no called fn>"); 1517 else { 1518 StringRef Name = CI->getCalledFunction()->getName(); 1519 if (!Name.startswith("llvm.matrix")) { 1520 write(Name); 1521 return; 1522 } 1523 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1524 write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) 1525 .drop_front(StringRef("llvm.matrix.").size())); 1526 write("."); 1527 std::string Tmp = ""; 1528 raw_string_ostream SS(Tmp); 1529 1530 switch (II->getIntrinsicID()) { 1531 case Intrinsic::matrix_multiply: 1532 prettyPrintMatrixType(II->getOperand(0), SS); 1533 SS << "."; 1534 prettyPrintMatrixType(II->getOperand(1), SS); 1535 SS << "." << *II->getType()->getScalarType(); 1536 break; 1537 case Intrinsic::matrix_transpose: 1538 prettyPrintMatrixType(II->getOperand(0), SS); 1539 SS << "." << *II->getType()->getScalarType(); 1540 break; 1541 case Intrinsic::matrix_column_major_load: 1542 prettyPrintMatrixType(II, SS); 1543 SS << "." << *II->getType()->getScalarType(); 1544 break; 1545 case Intrinsic::matrix_column_major_store: 1546 prettyPrintMatrixType(II->getOperand(0), SS); 1547 SS << "." << *II->getOperand(0)->getType()->getScalarType(); 1548 break; 1549 default: 1550 llvm_unreachable("Unhandled case"); 1551 } 1552 SS.flush(); 1553 write(Tmp); 1554 } 1555 } 1556 1557 unsigned getNumShapeArgs(CallInst *CI) const { 1558 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 1559 switch (II->getIntrinsicID()) { 1560 case Intrinsic::matrix_multiply: 1561 return 3; 1562 case Intrinsic::matrix_transpose: 1563 return 2; 1564 case Intrinsic::matrix_column_major_load: 1565 case Intrinsic::matrix_column_major_store: 1566 return 3; 1567 default: 1568 return 0; 1569 } 1570 } 1571 return 0; 1572 } 1573 1574 /// Special printing for values: for pointers, we print if they refer to an 1575 /// (function) external address or a stack address, for other values we 1576 /// either print the constant or "scalar"/"matrix" for other values. 1577 void write(Value *V) { 1578 V = getUnderlyingObjectThroughLoads(V); 1579 if (V->getType()->isPointerTy()) { 1580 if (isa<AllocaInst>(V)) { 1581 Stream << "stack addr"; 1582 LineLength += StringRef("stack addr").size(); 1583 } else { 1584 Stream << "addr"; 1585 LineLength += StringRef("addr").size(); 1586 } 1587 if (!V->getName().empty()) { 1588 Stream << " %" << V->getName() << ""; 1589 LineLength += V->getName().size() + 2; 1590 } 1591 return; 1592 } 1593 1594 std::string Tmp; 1595 raw_string_ostream TmpStream(Tmp); 1596 1597 if (auto *CI = dyn_cast<ConstantInt>(V)) 1598 TmpStream << CI->getValue(); 1599 else if (isa<Constant>(V)) 1600 TmpStream << "constant"; 1601 else { 1602 if (isMatrix(V)) 1603 TmpStream << "matrix"; 1604 else 1605 TmpStream << "scalar"; 1606 } 1607 TmpStream.flush(); 1608 Tmp = std::string(StringRef(Tmp).trim()); 1609 LineLength += Tmp.size(); 1610 Stream << Tmp; 1611 } 1612 1613 /// Linearize expression \p Expr starting at an indentation of \p Indent. 1614 /// Expressions that are re-used multiple times are prefixed with (reused) 1615 /// at the re-used root instruction. 1616 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, 1617 bool ParentShared) { 1618 auto *I = cast<Instruction>(Expr); 1619 maybeIndent(Indent); 1620 SmallVector<Value *, 8> Ops; 1621 1622 // Is Expr shared with other expression leaves? 1623 bool ExprShared = false; 1624 1625 // Deal with shared subtrees. Mark them as shared, if required. 1626 if (!ParentShared) { 1627 auto SI = Shared.find(Expr); 1628 assert(SI != Shared.end() && SI->second.count(Leaf)); 1629 1630 for (Value *S : SI->second) { 1631 if (S == Leaf) 1632 continue; 1633 DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); 1634 write("shared with remark at line " + std::to_string(DL.getLine()) + 1635 " column " + std::to_string(DL.getCol()) + " ("); 1636 } 1637 ExprShared = SI->second.size() > 1; 1638 } 1639 1640 bool Reused = !ReusedExprs.insert(Expr).second; 1641 if (Reused && !ParentReused) 1642 write("(reused) "); 1643 1644 if (auto *CI = dyn_cast<CallInst>(I)) { 1645 writeFnName(CI); 1646 1647 Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); 1648 } else if (isa<BitCastInst>(Expr)) { 1649 // Special case bitcasts, which are used to materialize matrixes from 1650 // non-matrix ops. 1651 write("matrix"); 1652 return; 1653 } else { 1654 Ops.append(I->value_op_begin(), I->value_op_end()); 1655 write(std::string(I->getOpcodeName())); 1656 } 1657 1658 write(std::string("(")); 1659 1660 unsigned NumOpsToBreak = 1; 1661 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) 1662 NumOpsToBreak = 2; 1663 1664 for (Value *Op : Ops) { 1665 if (Ops.size() > NumOpsToBreak) 1666 lineBreak(); 1667 1668 maybeIndent(Indent + 1); 1669 if (isMatrix(Op)) 1670 linearizeExpr(Op, Indent + 1, Reused, ExprShared); 1671 else 1672 write(Op); 1673 if (Op != Ops.back()) 1674 write(", "); 1675 } 1676 1677 write(")"); 1678 } 1679 1680 const std::string &getResult() { 1681 Stream.flush(); 1682 return Str; 1683 } 1684 }; 1685 1686 /// Generate remarks for matrix operations in a function. To generate remarks 1687 /// for matrix expressions, the following approach is used: 1688 /// 1. Use the inlined-at debug information to group matrix operations to the 1689 /// DISubprograms they are contained in. 1690 /// 2. Collect leaves of matrix expressions (done in 1691 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression 1692 // mapping. Leaves are lowered matrix instructions without other matrix 1693 // users (like stores) in the current subprogram. 1694 /// 3. For each leaf, create a remark containing a linearizied version of the 1695 /// matrix expression. The expression is linearized by a recursive 1696 /// bottom-up traversal of the matrix operands, starting at a leaf. Note 1697 /// that multiple leaves can share sub-expressions. Shared subexpressions 1698 /// are explicitly marked as shared(). 1699 struct RemarkGenerator { 1700 const MapVector<Value *, MatrixTy> &Inst2Matrix; 1701 OptimizationRemarkEmitter &ORE; 1702 Function &Func; 1703 const DataLayout &DL; 1704 1705 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, 1706 OptimizationRemarkEmitter &ORE, Function &Func) 1707 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), 1708 DL(Func.getParent()->getDataLayout()) {} 1709 1710 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are 1711 /// instructions in Inst2Matrix returning void or without any users in 1712 /// \p ExprsInSubprogram. Currently that should only include stores. 1713 SmallVector<Value *, 4> 1714 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { 1715 SmallVector<Value *, 4> Leaves; 1716 for (auto *Expr : ExprsInSubprogram) 1717 if (Expr->getType()->isVoidTy() || 1718 !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { 1719 return ExprsInSubprogram.count(U); 1720 })) 1721 Leaves.push_back(Expr); 1722 return Leaves; 1723 } 1724 1725 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf 1726 /// to all visited expressions in \p Shared. Limit the matrix operations to 1727 /// the ones in \p ExprsInSubprogram. 1728 void collectSharedInfo(Value *Leaf, Value *V, 1729 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1730 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { 1731 1732 if (!ExprsInSubprogram.count(V)) 1733 return; 1734 1735 auto I = Shared.insert({V, {}}); 1736 I.first->second.insert(Leaf); 1737 1738 for (Value *Op : cast<Instruction>(V)->operand_values()) 1739 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); 1740 return; 1741 } 1742 1743 /// Calculate the number of exclusive and shared op counts for expression 1744 /// starting at \p V. Expressions used multiple times are counted once. 1745 /// Limit the matrix operations to the ones in \p ExprsInSubprogram. 1746 std::pair<OpInfoTy, OpInfoTy> 1747 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, 1748 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1749 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { 1750 if (!ExprsInSubprogram.count(Root)) 1751 return {}; 1752 1753 // Already counted this expression. Stop. 1754 if (!ReusedExprs.insert(Root).second) 1755 return {}; 1756 1757 OpInfoTy SharedCount; 1758 OpInfoTy Count; 1759 1760 auto I = Shared.find(Root); 1761 auto CM = Inst2Matrix.find(Root); 1762 if (I->second.size() == 1) 1763 Count = CM->second.getOpInfo(); 1764 else 1765 SharedCount = CM->second.getOpInfo(); 1766 1767 for (Value *Op : cast<Instruction>(Root)->operand_values()) { 1768 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); 1769 Count += C.first; 1770 SharedCount += C.second; 1771 } 1772 return {Count, SharedCount}; 1773 } 1774 1775 void emitRemarks() { 1776 if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) 1777 return; 1778 1779 // Map matrix operations to their containting subprograms, by traversing 1780 // the inlinedAt chain. If the function does not have a DISubprogram, we 1781 // only map them to the containing function. 1782 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; 1783 for (auto &KV : Inst2Matrix) { 1784 if (Func.getSubprogram()) { 1785 auto *I = cast<Instruction>(KV.first); 1786 DILocation *Context = I->getDebugLoc(); 1787 while (Context) { 1788 auto I = 1789 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); 1790 I.first->second.push_back(KV.first); 1791 Context = DebugLoc(Context).getInlinedAt(); 1792 } 1793 } else { 1794 auto I = Subprog2Exprs.insert({nullptr, {}}); 1795 I.first->second.push_back(KV.first); 1796 } 1797 } 1798 for (auto &KV : Subprog2Exprs) { 1799 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), 1800 KV.second.end()); 1801 auto Leaves = getExpressionLeaves(ExprsInSubprogram); 1802 1803 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; 1804 for (Value *Leaf : Leaves) 1805 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); 1806 1807 // Generate remarks for each leaf. 1808 for (auto *L : Leaves) { 1809 1810 DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); 1811 DILocation *Context = cast<Instruction>(L)->getDebugLoc(); 1812 while (Context) { 1813 if (getSubprogram(Context->getScope()) == KV.first) { 1814 Loc = Context; 1815 break; 1816 } 1817 Context = DebugLoc(Context).getInlinedAt(); 1818 } 1819 1820 SmallPtrSet<Value *, 8> ReusedExprs; 1821 OpInfoTy Counts, SharedCounts; 1822 std::tie(Counts, SharedCounts) = 1823 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); 1824 1825 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, 1826 cast<Instruction>(L)->getParent()); 1827 1828 Rem << "Lowered with "; 1829 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " 1830 << ore::NV("NumLoads", Counts.NumLoads) << " loads, " 1831 << ore::NV("NumComputeOps", Counts.NumComputeOps) 1832 << " compute ops"; 1833 1834 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || 1835 SharedCounts.NumComputeOps > 0) { 1836 Rem << ",\nadditionally " 1837 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " 1838 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " 1839 << ore::NV("NumFPOps", SharedCounts.NumComputeOps) 1840 << " compute ops" 1841 << " are shared with other expressions"; 1842 } 1843 1844 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); 1845 ORE.emit(Rem); 1846 } 1847 } 1848 } 1849 1850 std::string 1851 linearize(Value *L, 1852 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, 1853 const SmallSetVector<Value *, 32> &ExprsInSubprogram, 1854 const DataLayout &DL) { 1855 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); 1856 Lin.linearizeExpr(L, 0, false, false); 1857 return Lin.getResult(); 1858 } 1859 }; 1860 }; 1861 } // namespace 1862 1863 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, 1864 FunctionAnalysisManager &AM) { 1865 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1866 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 1867 auto &AA = AM.getResult<AAManager>(F); 1868 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 1869 auto &LI = AM.getResult<LoopAnalysis>(F); 1870 1871 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1872 if (LMT.Visit()) { 1873 PreservedAnalyses PA; 1874 PA.preserveSet<CFGAnalyses>(); 1875 return PA; 1876 } 1877 return PreservedAnalyses::all(); 1878 } 1879 1880 namespace { 1881 1882 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { 1883 public: 1884 static char ID; 1885 1886 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { 1887 initializeLowerMatrixIntrinsicsLegacyPassPass( 1888 *PassRegistry::getPassRegistry()); 1889 } 1890 1891 bool runOnFunction(Function &F) override { 1892 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1893 auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); 1894 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 1895 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1896 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1897 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); 1898 bool C = LMT.Visit(); 1899 return C; 1900 } 1901 1902 void getAnalysisUsage(AnalysisUsage &AU) const override { 1903 AU.addRequired<TargetTransformInfoWrapperPass>(); 1904 AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); 1905 AU.addRequired<AAResultsWrapperPass>(); 1906 AU.addRequired<DominatorTreeWrapperPass>(); 1907 AU.addPreserved<DominatorTreeWrapperPass>(); 1908 AU.addRequired<LoopInfoWrapperPass>(); 1909 AU.addPreserved<LoopInfoWrapperPass>(); 1910 } 1911 }; 1912 } // namespace 1913 1914 static const char pass_name[] = "Lower the matrix intrinsics"; 1915 char LowerMatrixIntrinsicsLegacyPass::ID = 0; 1916 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1917 false, false) 1918 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) 1919 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 1920 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 1921 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 1922 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, 1923 false, false) 1924 1925 Pass *llvm::createLowerMatrixIntrinsicsPass() { 1926 return new LowerMatrixIntrinsicsLegacyPass(); 1927 } 1928