1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform amx intrinsics to scalar operations. 10 /// This pass is always enabled and it skips when it is not -O0 and has no 11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx 12 /// intrinsics is near the amx intrinsics code. We are not able to find a 13 /// point which post-dominate all the shape and dominate all amx intrinsics. 14 /// To decouple the dependency of the shape, we transform amx intrinsics 15 /// to scalar operation, so that compiling doesn't fail. In long term, we 16 /// should improve fast register allocation to allocate amx register. 17 //===----------------------------------------------------------------------===// 18 // 19 #include "X86.h" 20 #include "llvm/ADT/DenseSet.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/Analysis/DomTreeUpdater.h" 23 #include "llvm/Analysis/LoopInfo.h" 24 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 25 #include "llvm/Analysis/TargetTransformInfo.h" 26 #include "llvm/CodeGen/Passes.h" 27 #include "llvm/CodeGen/TargetPassConfig.h" 28 #include "llvm/CodeGen/ValueTypes.h" 29 #include "llvm/IR/DataLayout.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/IRBuilder.h" 32 #include "llvm/IR/Instructions.h" 33 #include "llvm/IR/IntrinsicInst.h" 34 #include "llvm/IR/IntrinsicsX86.h" 35 #include "llvm/IR/PatternMatch.h" 36 #include "llvm/InitializePasses.h" 37 #include "llvm/Pass.h" 38 #include "llvm/Support/CommandLine.h" 39 #include "llvm/Target/TargetMachine.h" 40 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 41 #include "llvm/Transforms/Utils/LoopUtils.h" 42 43 using namespace llvm; 44 using namespace PatternMatch; 45 46 #define DEBUG_TYPE "lower-amx-intrinsics" 47 48 #ifndef NDEBUG 49 static bool isV256I32Ty(Type *Ty) { 50 if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) 51 return FVT->getNumElements() == 256 && 52 FVT->getElementType()->isIntegerTy(32); 53 return false; 54 } 55 #endif 56 57 static cl::opt<bool> 58 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, 59 cl::desc("X86: enable AMX scalarizition.")); 60 61 namespace { 62 class X86LowerAMXIntrinsics { 63 Function &Func; 64 65 public: 66 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) 67 : Func(F), DTU(DomTU), LI(LoopI) {} 68 bool visit(); 69 70 private: 71 DomTreeUpdater &DTU; 72 LoopInfo *LI; 73 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound, 74 Value *Step, StringRef Name, IRBuilderBase &B, 75 Loop *L); 76 template <bool IsTileLoad> 77 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, 78 IRBuilderBase &B, Value *Row, Value *Col, 79 Value *Ptr, Value *Stride, Value *Tile); 80 template <Intrinsic::ID IntrID> 81 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 82 IntrID == Intrinsic::x86_tdpbsud_internal || 83 IntrID == Intrinsic::x86_tdpbusd_internal || 84 IntrID == Intrinsic::x86_tdpbuud_internal || 85 IntrID == Intrinsic::x86_tdpbf16ps_internal, 86 Value *> 87 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, 88 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, 89 Value *RHS); 90 template <bool IsTileLoad> 91 bool lowerTileLoadStore(Instruction *TileLoadStore); 92 template <Intrinsic::ID IntrID> 93 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 94 IntrID == Intrinsic::x86_tdpbsud_internal || 95 IntrID == Intrinsic::x86_tdpbusd_internal || 96 IntrID == Intrinsic::x86_tdpbuud_internal || 97 IntrID == Intrinsic::x86_tdpbf16ps_internal, 98 bool> 99 lowerTileDP(Instruction *TileDP); 100 bool lowerTileZero(Instruction *TileZero); 101 }; 102 } // anonymous namespace 103 104 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader, 105 BasicBlock *Exit, Value *Bound, 106 Value *Step, StringRef Name, 107 IRBuilderBase &B, Loop *L) { 108 LLVMContext &Ctx = Preheader->getContext(); 109 BasicBlock *Header = 110 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit); 111 BasicBlock *Body = 112 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit); 113 BasicBlock *Latch = 114 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit); 115 116 Type *I16Ty = Type::getInt16Ty(Ctx); 117 BranchInst::Create(Body, Header); 118 BranchInst::Create(Latch, Body); 119 PHINode *IV = 120 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()); 121 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader); 122 123 B.SetInsertPoint(Latch); 124 Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 125 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 126 BranchInst::Create(Header, Exit, Cond, Latch); 127 IV->addIncoming(Inc, Latch); 128 129 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 130 BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 131 PreheaderBr->setSuccessor(0, Header); 132 DTU.applyUpdatesPermissive({ 133 {DominatorTree::Delete, Preheader, Tmp}, 134 {DominatorTree::Insert, Header, Body}, 135 {DominatorTree::Insert, Body, Latch}, 136 {DominatorTree::Insert, Latch, Header}, 137 {DominatorTree::Insert, Latch, Exit}, 138 {DominatorTree::Insert, Preheader, Header}, 139 }); 140 if (LI) { 141 L->addBasicBlockToLoop(Header, *LI); 142 L->addBasicBlockToLoop(Body, *LI); 143 L->addBasicBlockToLoop(Latch, *LI); 144 } 145 return Body; 146 } 147 148 template <bool IsTileLoad> 149 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( 150 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, 151 Value *Col, Value *Ptr, Value *Stride, Value *Tile) { 152 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore"; 153 Loop *RowLoop = nullptr; 154 Loop *ColLoop = nullptr; 155 if (LI) { 156 RowLoop = LI->AllocateLoop(); 157 ColLoop = LI->AllocateLoop(); 158 RowLoop->addChildLoop(ColLoop); 159 if (Loop *ParentL = LI->getLoopFor(Start)) 160 ParentL->addChildLoop(RowLoop); 161 else 162 LI->addTopLevelLoop(RowLoop); 163 } 164 165 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 166 IntrinName + ".scalarize.rows", B, RowLoop); 167 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 168 169 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 170 IntrinName + ".scalarize.cols", B, ColLoop); 171 172 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 173 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 174 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 175 Value *CurrentRow = &*RowLoopHeader->begin(); 176 Value *CurrentCol = &*ColLoopHeader->begin(); 177 Type *EltTy = B.getInt32Ty(); 178 FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256); 179 180 // Common part for tileload and tilestore 181 // *.scalarize.cols.body: 182 // Calculate %idxmem and %idxvec 183 B.SetInsertPoint(ColBody->getTerminator()); 184 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType()); 185 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType()); 186 Value *Offset = 187 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt); 188 unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace(); 189 Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS)); 190 Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset); 191 Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 192 if (IsTileLoad) { 193 // tileload.scalarize.rows.header: 194 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, 195 // %tileload.scalarize.rows.latch ] 196 B.SetInsertPoint(RowLoopHeader->getTerminator()); 197 Value *VecZero = Constant::getNullValue(V256I32Ty); 198 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row"); 199 VecCPhiRowLoop->addIncoming(VecZero, Start); 200 201 // tileload.scalarize.cols.header: 202 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body 203 // ], [ %ResVec, %tileload.scalarize.cols.latch ] 204 B.SetInsertPoint(ColLoopHeader->getTerminator()); 205 PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); 206 VecPhi->addIncoming(VecCPhiRowLoop, RowBody); 207 208 // tileload.scalarize.cols.body: 209 // Calculate %idxmem and %idxvec 210 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem 211 // %elt = load i32, i32* %ptr 212 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec 213 B.SetInsertPoint(ColBody->getTerminator()); 214 Value *Elt = B.CreateLoad(EltTy, EltPtr); 215 Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx); 216 VecPhi->addIncoming(ResVec, ColLoopLatch); 217 VecCPhiRowLoop->addIncoming(ResVec, RowLatch); 218 219 return ResVec; 220 } else { 221 auto *BitCast = cast<BitCastInst>(Tile); 222 Value *Vec = BitCast->getOperand(0); 223 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx"); 224 // tilestore.scalarize.cols.body: 225 // %mul = mul i16 %row.iv, i16 16 226 // %idx = add i16 %mul, i16 %col.iv 227 // %vec = extractelement <16 x i32> %vec, i16 %idx 228 // store i32 %vec, i32* %ptr 229 B.SetInsertPoint(ColBody->getTerminator()); 230 Value *Elt = B.CreateExtractElement(Vec, Idx); 231 232 B.CreateStore(Elt, EltPtr); 233 return nullptr; 234 } 235 } 236 237 template <Intrinsic::ID IntrID> 238 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 239 IntrID == Intrinsic::x86_tdpbsud_internal || 240 IntrID == Intrinsic::x86_tdpbusd_internal || 241 IntrID == Intrinsic::x86_tdpbuud_internal || 242 IntrID == Intrinsic::x86_tdpbf16ps_internal, 243 Value *> 244 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, 245 IRBuilderBase &B, Value *Row, 246 Value *Col, Value *K, Value *Acc, 247 Value *LHS, Value *RHS) { 248 std::string IntrinName; 249 switch (IntrID) { 250 case Intrinsic::x86_tdpbssd_internal: 251 IntrinName = "tiledpbssd"; 252 break; 253 case Intrinsic::x86_tdpbsud_internal: 254 IntrinName = "tiledpbsud"; 255 break; 256 case Intrinsic::x86_tdpbusd_internal: 257 IntrinName = "tiledpbusd"; 258 break; 259 case Intrinsic::x86_tdpbuud_internal: 260 IntrinName = "tiledpbuud"; 261 break; 262 case Intrinsic::x86_tdpbf16ps_internal: 263 IntrinName = "tiledpbf16ps"; 264 break; 265 } 266 Loop *RowLoop = nullptr; 267 Loop *ColLoop = nullptr; 268 Loop *InnerLoop = nullptr; 269 if (LI) { 270 RowLoop = LI->AllocateLoop(); 271 ColLoop = LI->AllocateLoop(); 272 InnerLoop = LI->AllocateLoop(); 273 ColLoop->addChildLoop(InnerLoop); 274 RowLoop->addChildLoop(ColLoop); 275 if (Loop *ParentL = LI->getLoopFor(Start)) 276 ParentL->addChildLoop(RowLoop); 277 else 278 LI->addTopLevelLoop(RowLoop); 279 } 280 281 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 282 IntrinName + ".scalarize.rows", B, RowLoop); 283 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 284 285 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 286 IntrinName + ".scalarize.cols", B, ColLoop); 287 288 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 289 290 B.SetInsertPoint(ColBody->getTerminator()); 291 BasicBlock *InnerBody = 292 createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), 293 IntrinName + ".scalarize.inner", B, InnerLoop); 294 295 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 296 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 297 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); 298 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); 299 Value *CurrentRow = &*RowLoopHeader->begin(); 300 Value *CurrentCol = &*ColLoopHeader->begin(); 301 Value *CurrentInner = &*InnerLoopHeader->begin(); 302 303 FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); 304 auto *BitCastAcc = cast<BitCastInst>(Acc); 305 Value *VecC = BitCastAcc->getOperand(0); 306 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx"); 307 // TODO else create BitCast from x86amx to v256i32. 308 // Store x86amx to memory, and reload from memory 309 // to vector. However with -O0, it doesn't happen. 310 auto *BitCastLHS = cast<BitCastInst>(LHS); 311 Value *VecA = BitCastLHS->getOperand(0); 312 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx"); 313 auto *BitCastRHS = cast<BitCastInst>(RHS); 314 Value *VecB = BitCastRHS->getOperand(0); 315 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx"); 316 317 // tiledpbssd.scalarize.rows.header: 318 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, 319 // %tiledpbssd.scalarize.rows.latch ] 320 321 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ 322 // %NewVecD, %tiledpbssd.scalarize.rows.latch ] 323 B.SetInsertPoint(RowLoopHeader->getTerminator()); 324 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row"); 325 VecCPhiRowLoop->addIncoming(VecC, Start); 326 Value *VecZero = Constant::getNullValue(V256I32Ty); 327 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row"); 328 VecDPhiRowLoop->addIncoming(VecZero, Start); 329 330 // tiledpbssd.scalarize.cols.header: 331 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, 332 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, 333 // %tiledpbssd.scalarize.cols.latch ] 334 335 // %vec.d.phi.col = phi <256 x i32> [ 336 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, 337 // %tiledpbssd.scalarize.cols.latch ] 338 339 // calculate idxc. 340 B.SetInsertPoint(ColLoopHeader->getTerminator()); 341 PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col"); 342 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody); 343 PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col"); 344 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody); 345 Value *IdxC = 346 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 347 348 // tiledpbssd.scalarize.inner.header: 349 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, 350 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, 351 // %tiledpbssd.scalarize.inner.latch ] 352 353 B.SetInsertPoint(InnerLoopHeader->getTerminator()); 354 PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi"); 355 VecCPhi->addIncoming(VecCPhiColLoop, ColBody); 356 357 B.SetInsertPoint(InnerBody->getTerminator()); 358 Value *IdxA = 359 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); 360 Value *IdxB = 361 B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); 362 Value *NewVecC = nullptr; 363 364 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { 365 // tiledpbssd.scalarize.inner.body: 366 // calculate idxa, idxb 367 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 368 // %elta = extractelement <256 x i32> %veca, i16 %idxa 369 // %eltav4i8 = bitcast i32 %elta to <4 x i8> 370 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 371 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> 372 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> 373 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> 374 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 375 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) 376 // %neweltc = add i32 %elt, %acc 377 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 378 // i16 %idxc 379 FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); 380 FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); 381 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 382 Value *EltA = B.CreateExtractElement(VecA, IdxA); 383 Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); 384 Value *EltB = B.CreateExtractElement(VecB, IdxB); 385 Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); 386 Value *SEXTSubVecB = nullptr; 387 Value *SEXTSubVecA = nullptr; 388 switch (IntrID) { 389 case Intrinsic::x86_tdpbssd_internal: 390 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 391 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 392 break; 393 case Intrinsic::x86_tdpbsud_internal: 394 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 395 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 396 break; 397 case Intrinsic::x86_tdpbusd_internal: 398 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 399 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 400 break; 401 case Intrinsic::x86_tdpbuud_internal: 402 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 403 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 404 break; 405 default: 406 llvm_unreachable("Invalid intrinsic ID!"); 407 } 408 Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); 409 Value *ResElt = B.CreateAdd(EltC, SubVecR); 410 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 411 } else { 412 // tiledpbf16ps.scalarize.inner.body: 413 // calculate idxa, idxb, idxc 414 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 415 // %eltcf32 = bitcast i32 %eltc to float 416 // %elta = extractelement <256 x i32> %veca, i16 %idxa 417 // %eltav2i16 = bitcast i32 %elta to <2 x i16> 418 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 419 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> 420 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 421 // x i32> <i32 2, i32 0, i32 3, i32 1> 422 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> 423 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x 424 // i32> <i32 2, i32 0, i32 3, i32 1> 425 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> 426 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 427 // %acc = call float 428 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) 429 // %neweltc = bitcast float %acc to i32 430 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 431 // i16 %idxc 432 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, 433 // i16 %idxc 434 FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); 435 FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); 436 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 437 Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy()); 438 Value *EltA = B.CreateExtractElement(VecA, IdxA); 439 Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); 440 Value *EltB = B.CreateExtractElement(VecB, IdxB); 441 Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); 442 Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); 443 int ShuffleMask[4] = {2, 0, 3, 1}; 444 auto ShuffleArray = ArrayRef(ShuffleMask); 445 Value *AV2F32 = B.CreateBitCast( 446 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty); 447 Value *BV2F32 = B.CreateBitCast( 448 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty); 449 Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32)); 450 Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); 451 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 452 } 453 454 // tiledpbssd.scalarize.cols.latch: 455 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc 456 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, 457 // i16 %idxc 458 B.SetInsertPoint(ColLoopLatch->getTerminator()); 459 Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC); 460 Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC); 461 462 VecCPhi->addIncoming(NewVecC, InnerLoopLatch); 463 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch); 464 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch); 465 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch); 466 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch); 467 468 return NewVecD; 469 } 470 471 template <Intrinsic::ID IntrID> 472 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 473 IntrID == Intrinsic::x86_tdpbsud_internal || 474 IntrID == Intrinsic::x86_tdpbusd_internal || 475 IntrID == Intrinsic::x86_tdpbuud_internal || 476 IntrID == Intrinsic::x86_tdpbf16ps_internal, 477 bool> 478 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { 479 Value *M, *N, *K, *C, *A, *B; 480 match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K), 481 m_Value(C), m_Value(A), m_Value(B))); 482 Instruction *InsertI = TileDP; 483 IRBuilder<> PreBuilder(TileDP); 484 PreBuilder.SetInsertPoint(TileDP); 485 // We visit the loop with (m, n/4, k/4): 486 // %n_dword = lshr i16 %n, 2 487 // %k_dword = lshr i16 %k, 2 488 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 489 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2)); 490 BasicBlock *Start = InsertI->getParent(); 491 BasicBlock *End = 492 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 493 IRBuilder<> Builder(TileDP); 494 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, 495 KDWord, C, A, B); 496 // we cannot assume there always be bitcast after tiledpbssd. So we need to 497 // insert one bitcast as required 498 Builder.SetInsertPoint(End->getFirstNonPHI()); 499 Value *ResAMX = 500 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 501 // Delete TileDP intrinsic and do some clean-up. 502 for (Use &U : llvm::make_early_inc_range(TileDP->uses())) { 503 Instruction *I = cast<Instruction>(U.getUser()); 504 Value *Vec; 505 if (match(I, m_BitCast(m_Value(Vec)))) { 506 I->replaceAllUsesWith(ResVec); 507 I->eraseFromParent(); 508 } 509 } 510 TileDP->replaceAllUsesWith(ResAMX); 511 TileDP->eraseFromParent(); 512 return true; 513 } 514 515 template <bool IsTileLoad> 516 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { 517 Value *M, *N, *Ptr, *Stride, *Tile; 518 if (IsTileLoad) 519 match(TileLoadStore, 520 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( 521 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride))); 522 else 523 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>( 524 m_Value(M), m_Value(N), m_Value(Ptr), 525 m_Value(Stride), m_Value(Tile))); 526 527 Instruction *InsertI = TileLoadStore; 528 IRBuilder<> PreBuilder(TileLoadStore); 529 PreBuilder.SetInsertPoint(TileLoadStore); 530 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 531 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2)); 532 BasicBlock *Start = InsertI->getParent(); 533 BasicBlock *End = 534 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 535 IRBuilder<> Builder(TileLoadStore); 536 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( 537 Start, End, Builder, M, NDWord, Ptr, StrideDWord, 538 IsTileLoad ? nullptr : Tile); 539 if (IsTileLoad) { 540 // we cannot assume there always be bitcast after tileload. So we need to 541 // insert one bitcast as required 542 Builder.SetInsertPoint(End->getFirstNonPHI()); 543 Value *ResAMX = 544 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 545 // Delete tileloadd6 intrinsic and do some clean-up 546 for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) { 547 Instruction *I = cast<Instruction>(U.getUser()); 548 Value *Vec; 549 if (match(I, m_BitCast(m_Value(Vec)))) { 550 I->replaceAllUsesWith(ResVec); 551 I->eraseFromParent(); 552 } 553 } 554 TileLoadStore->replaceAllUsesWith(ResAMX); 555 } 556 TileLoadStore->eraseFromParent(); 557 return true; 558 } 559 560 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { 561 IRBuilder<> Builder(TileZero); 562 FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256); 563 Value *VecZero = Constant::getNullValue(V256I32Ty); 564 for (Use &U : llvm::make_early_inc_range(TileZero->uses())) { 565 Instruction *I = cast<Instruction>(U.getUser()); 566 Value *Vec; 567 if (match(I, m_BitCast(m_Value(Vec)))) { 568 I->replaceAllUsesWith(VecZero); 569 I->eraseFromParent(); 570 } 571 } 572 TileZero->eraseFromParent(); 573 return true; 574 } 575 576 bool X86LowerAMXIntrinsics::visit() { 577 bool C = false; 578 SmallVector<IntrinsicInst *, 8> WorkList; 579 for (BasicBlock *BB : depth_first(&Func)) { 580 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { 581 if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) { 582 switch (Inst->getIntrinsicID()) { 583 case Intrinsic::x86_tdpbssd_internal: 584 case Intrinsic::x86_tdpbsud_internal: 585 case Intrinsic::x86_tdpbusd_internal: 586 case Intrinsic::x86_tdpbuud_internal: 587 case Intrinsic::x86_tileloadd64_internal: 588 case Intrinsic::x86_tilestored64_internal: 589 case Intrinsic::x86_tilezero_internal: 590 case Intrinsic::x86_tdpbf16ps_internal: 591 WorkList.push_back(Inst); 592 break; 593 default: 594 break; 595 } 596 } 597 } 598 } 599 600 for (auto *Inst : WorkList) { 601 switch (Inst->getIntrinsicID()) { 602 case Intrinsic::x86_tdpbssd_internal: 603 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C; 604 break; 605 case Intrinsic::x86_tdpbsud_internal: 606 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C; 607 break; 608 case Intrinsic::x86_tdpbusd_internal: 609 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C; 610 break; 611 case Intrinsic::x86_tdpbuud_internal: 612 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C; 613 break; 614 case Intrinsic::x86_tdpbf16ps_internal: 615 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C; 616 break; 617 case Intrinsic::x86_tileloadd64_internal: 618 C = lowerTileLoadStore<true>(Inst) || C; 619 break; 620 case Intrinsic::x86_tilestored64_internal: 621 C = lowerTileLoadStore<false>(Inst) || C; 622 break; 623 case Intrinsic::x86_tilezero_internal: 624 C = lowerTileZero(Inst) || C; 625 break; 626 default: 627 llvm_unreachable("invalid amx intrinsics!"); 628 } 629 } 630 631 return C; 632 } 633 634 namespace { 635 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { 636 public: 637 static char ID; 638 639 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { 640 initializeX86LowerAMXIntrinsicsLegacyPassPass( 641 *PassRegistry::getPassRegistry()); 642 } 643 644 bool runOnFunction(Function &F) override { 645 if (!X86ScalarizeAMX) 646 return false; 647 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 648 if (!F.hasFnAttribute(Attribute::OptimizeNone) && 649 TM->getOptLevel() != CodeGenOpt::None) 650 return false; 651 652 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); 653 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; 654 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); 655 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; 656 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 657 658 X86LowerAMXIntrinsics LAT(F, DTU, LI); 659 return LAT.visit(); 660 } 661 StringRef getPassName() const override { return "Lower AMX intrinsics"; } 662 663 void getAnalysisUsage(AnalysisUsage &AU) const override { 664 AU.addPreserved<DominatorTreeWrapperPass>(); 665 AU.addPreserved<LoopInfoWrapperPass>(); 666 AU.addRequired<TargetPassConfig>(); 667 } 668 }; 669 } // namespace 670 671 static const char PassName[] = "Lower AMX intrinsics"; 672 char X86LowerAMXIntrinsicsLegacyPass::ID = 0; 673 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 674 false, false) 675 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 676 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 677 false, false) 678 679 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { 680 return new X86LowerAMXIntrinsicsLegacyPass(); 681 } 682