xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp (revision 5e801ac66d24704442eba426ed13c3effb8a34e7)
1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
18 //
19 #include "X86.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/Analysis/DomTreeUpdater.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27 #include "llvm/CodeGen/ValueTypes.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/IntrinsicsX86.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Target/TargetMachine.h"
39 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
40 #include "llvm/Transforms/Utils/LoopUtils.h"
41 
42 using namespace llvm;
43 using namespace PatternMatch;
44 
45 #define DEBUG_TYPE "lower-amx-intrinsics"
46 
47 #ifndef NDEBUG
48 static bool isV256I32Ty(Type *Ty) {
49   if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
50     return FVT->getNumElements() == 256 &&
51            FVT->getElementType()->isIntegerTy(32);
52   return false;
53 }
54 #endif
55 
56 static cl::opt<bool>
57     X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
58                     cl::desc("X86: enable AMX scalarizition."));
59 
60 namespace {
61 class X86LowerAMXIntrinsics {
62   Function &Func;
63 
64 public:
65   X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
66       : Func(F), DTU(DomTU), LI(LoopI) {}
67   bool visit();
68 
69 private:
70   DomTreeUpdater &DTU;
71   LoopInfo *LI;
72   BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
73                          Value *Step, StringRef Name, IRBuilderBase &B,
74                          Loop *L);
75   template <bool IsTileLoad>
76   Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
77                                   IRBuilderBase &B, Value *Row, Value *Col,
78                                   Value *Ptr, Value *Stride, Value *Tile);
79   template <Intrinsic::ID IntrID>
80   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
81                               IntrID == Intrinsic::x86_tdpbsud_internal ||
82                               IntrID == Intrinsic::x86_tdpbusd_internal ||
83                               IntrID == Intrinsic::x86_tdpbuud_internal ||
84                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
85                           Value *>::type
86   createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
87                     Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
88                     Value *RHS);
89   template <bool IsTileLoad>
90   bool lowerTileLoadStore(Instruction *TileLoadStore);
91   template <Intrinsic::ID IntrID>
92   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
93                               IntrID == Intrinsic::x86_tdpbsud_internal ||
94                               IntrID == Intrinsic::x86_tdpbusd_internal ||
95                               IntrID == Intrinsic::x86_tdpbuud_internal ||
96                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
97                           bool>::type
98   lowerTileDP(Instruction *TileDP);
99   bool lowerTileZero(Instruction *TileZero);
100 };
101 } // anonymous namespace
102 
103 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
104                                               BasicBlock *Exit, Value *Bound,
105                                               Value *Step, StringRef Name,
106                                               IRBuilderBase &B, Loop *L) {
107   LLVMContext &Ctx = Preheader->getContext();
108   BasicBlock *Header =
109       BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
110   BasicBlock *Body =
111       BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
112   BasicBlock *Latch =
113       BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
114 
115   Type *I16Ty = Type::getInt16Ty(Ctx);
116   BranchInst::Create(Body, Header);
117   BranchInst::Create(Latch, Body);
118   PHINode *IV =
119       PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
120   IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
121 
122   B.SetInsertPoint(Latch);
123   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
124   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
125   BranchInst::Create(Header, Exit, Cond, Latch);
126   IV->addIncoming(Inc, Latch);
127 
128   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
129   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
130   PreheaderBr->setSuccessor(0, Header);
131   DTU.applyUpdatesPermissive({
132       {DominatorTree::Delete, Preheader, Tmp},
133       {DominatorTree::Insert, Header, Body},
134       {DominatorTree::Insert, Body, Latch},
135       {DominatorTree::Insert, Latch, Header},
136       {DominatorTree::Insert, Latch, Exit},
137       {DominatorTree::Insert, Preheader, Header},
138   });
139   if (LI) {
140     L->addBasicBlockToLoop(Header, *LI);
141     L->addBasicBlockToLoop(Body, *LI);
142     L->addBasicBlockToLoop(Latch, *LI);
143   }
144   return Body;
145 }
146 
147 template <bool IsTileLoad>
148 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
149     BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
150     Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
151   std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
152   Loop *RowLoop = nullptr;
153   Loop *ColLoop = nullptr;
154   if (LI) {
155     RowLoop = LI->AllocateLoop();
156     ColLoop = LI->AllocateLoop();
157     RowLoop->addChildLoop(ColLoop);
158     if (Loop *ParentL = LI->getLoopFor(Start))
159       ParentL->addChildLoop(RowLoop);
160     else
161       LI->addTopLevelLoop(RowLoop);
162   }
163 
164   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
165                                    IntrinName + ".scalarize.rows", B, RowLoop);
166   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
167 
168   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
169                                    IntrinName + ".scalarize.cols", B, ColLoop);
170 
171   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
172   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
173   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
174   Value *CurrentRow = &*RowLoopHeader->begin();
175   Value *CurrentCol = &*ColLoopHeader->begin();
176   Type *EltTy = B.getInt32Ty();
177   FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
178 
179   // Common part for tileload and tilestore
180   // *.scalarize.cols.body:
181   // Calculate %idxmem and %idxvec
182   B.SetInsertPoint(ColBody->getTerminator());
183   Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
184   Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
185   Value *Offset =
186       B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
187   unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
188   Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
189   Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
190   Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
191   if (IsTileLoad) {
192     // tileload.scalarize.rows.header:
193     // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
194     // %tileload.scalarize.rows.latch ]
195     B.SetInsertPoint(RowLoopHeader->getTerminator());
196     Value *VecZero = Constant::getNullValue(V256I32Ty);
197     PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
198     VecCPhiRowLoop->addIncoming(VecZero, Start);
199 
200     // tileload.scalarize.cols.header:
201     // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
202     // ], [ %ResVec, %tileload.scalarize.cols.latch ]
203     B.SetInsertPoint(ColLoopHeader->getTerminator());
204     PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
205     VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
206 
207     // tileload.scalarize.cols.body:
208     // Calculate %idxmem and %idxvec
209     // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
210     // %elt = load i32, i32* %ptr
211     // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
212     B.SetInsertPoint(ColBody->getTerminator());
213     Value *Elt = B.CreateLoad(EltTy, EltPtr);
214     Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
215     VecPhi->addIncoming(ResVec, ColLoopLatch);
216     VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
217 
218     return ResVec;
219   } else {
220     auto *BitCast = cast<BitCastInst>(Tile);
221     Value *Vec = BitCast->getOperand(0);
222     assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
223     // tilestore.scalarize.cols.body:
224     // %mul = mul i16 %row.iv, i16 16
225     // %idx = add i16 %mul, i16 %col.iv
226     // %vec = extractelement <16 x i32> %vec, i16 %idx
227     // store i32 %vec, i32* %ptr
228     B.SetInsertPoint(ColBody->getTerminator());
229     Value *Elt = B.CreateExtractElement(Vec, Idx);
230 
231     B.CreateStore(Elt, EltPtr);
232     return nullptr;
233   }
234 }
235 
236 template <Intrinsic::ID IntrID>
237 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
238                             IntrID == Intrinsic::x86_tdpbsud_internal ||
239                             IntrID == Intrinsic::x86_tdpbusd_internal ||
240                             IntrID == Intrinsic::x86_tdpbuud_internal ||
241                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
242                         Value *>::type
243 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
244                                          IRBuilderBase &B, Value *Row,
245                                          Value *Col, Value *K, Value *Acc,
246                                          Value *LHS, Value *RHS) {
247   std::string IntrinName;
248   switch (IntrID) {
249   case Intrinsic::x86_tdpbssd_internal:
250     IntrinName = "tiledpbssd";
251     break;
252   case Intrinsic::x86_tdpbsud_internal:
253     IntrinName = "tiledpbsud";
254     break;
255   case Intrinsic::x86_tdpbusd_internal:
256     IntrinName = "tiledpbusd";
257     break;
258   case Intrinsic::x86_tdpbuud_internal:
259     IntrinName = "tiledpbuud";
260     break;
261   case Intrinsic::x86_tdpbf16ps_internal:
262     IntrinName = "tiledpbf16ps";
263     break;
264   }
265   Loop *RowLoop = nullptr;
266   Loop *ColLoop = nullptr;
267   Loop *InnerLoop = nullptr;
268   if (LI) {
269     RowLoop = LI->AllocateLoop();
270     ColLoop = LI->AllocateLoop();
271     InnerLoop = LI->AllocateLoop();
272     ColLoop->addChildLoop(InnerLoop);
273     RowLoop->addChildLoop(ColLoop);
274     if (Loop *ParentL = LI->getLoopFor(Start))
275       ParentL->addChildLoop(RowLoop);
276     else
277       LI->addTopLevelLoop(RowLoop);
278   }
279 
280   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
281                                    IntrinName + ".scalarize.rows", B, RowLoop);
282   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
283 
284   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
285                                    IntrinName + ".scalarize.cols", B, ColLoop);
286 
287   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
288 
289   B.SetInsertPoint(ColBody->getTerminator());
290   BasicBlock *InnerBody =
291       createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
292                  IntrinName + ".scalarize.inner", B, InnerLoop);
293 
294   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
295   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
296   BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
297   BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
298   Value *CurrentRow = &*RowLoopHeader->begin();
299   Value *CurrentCol = &*ColLoopHeader->begin();
300   Value *CurrentInner = &*InnerLoopHeader->begin();
301 
302   FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
303   auto *BitCastAcc = cast<BitCastInst>(Acc);
304   Value *VecC = BitCastAcc->getOperand(0);
305   assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
306   // TODO else create BitCast from x86amx to v256i32.
307   // Store x86amx to memory, and reload from memory
308   // to vector. However with -O0, it doesn't happen.
309   auto *BitCastLHS = cast<BitCastInst>(LHS);
310   Value *VecA = BitCastLHS->getOperand(0);
311   assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
312   auto *BitCastRHS = cast<BitCastInst>(RHS);
313   Value *VecB = BitCastRHS->getOperand(0);
314   assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
315 
316   // tiledpbssd.scalarize.rows.header:
317   // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
318   // %tiledpbssd.scalarize.rows.latch ]
319 
320   // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
321   // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
322   B.SetInsertPoint(RowLoopHeader->getTerminator());
323   PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
324   VecCPhiRowLoop->addIncoming(VecC, Start);
325   Value *VecZero = Constant::getNullValue(V256I32Ty);
326   PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
327   VecDPhiRowLoop->addIncoming(VecZero, Start);
328 
329   // tiledpbssd.scalarize.cols.header:
330   // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
331   // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
332   // %tiledpbssd.scalarize.cols.latch ]
333 
334   // %vec.d.phi.col = phi <256 x i32> [
335   // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
336   // %tiledpbssd.scalarize.cols.latch ]
337 
338   // calculate idxc.
339   B.SetInsertPoint(ColLoopHeader->getTerminator());
340   PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
341   VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
342   PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
343   VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
344   Value *IdxC =
345       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
346 
347   // tiledpbssd.scalarize.inner.header:
348   // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
349   // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
350   // %tiledpbssd.scalarize.inner.latch ]
351 
352   B.SetInsertPoint(InnerLoopHeader->getTerminator());
353   PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
354   VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
355 
356   B.SetInsertPoint(InnerBody->getTerminator());
357   Value *IdxA =
358       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
359   Value *IdxB =
360       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
361   Value *NewVecC = nullptr;
362 
363   if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
364     // tiledpbssd.scalarize.inner.body:
365     // calculate idxa, idxb
366     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
367     // %elta = extractelement <256 x i32> %veca, i16 %idxa
368     // %eltav4i8 = bitcast i32 %elta to <4 x i8>
369     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
370     // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
371     // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
372     // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
373     // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
374     // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
375     // %neweltc = add i32 %elt, %acc
376     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
377     // i16 %idxc
378     FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
379     FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
380     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
381     Value *EltA = B.CreateExtractElement(VecA, IdxA);
382     Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
383     Value *EltB = B.CreateExtractElement(VecB, IdxB);
384     Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
385     Value *SEXTSubVecB = nullptr;
386     Value *SEXTSubVecA = nullptr;
387     switch (IntrID) {
388     case Intrinsic::x86_tdpbssd_internal:
389       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
390       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
391       break;
392     case Intrinsic::x86_tdpbsud_internal:
393       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
394       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
395       break;
396     case Intrinsic::x86_tdpbusd_internal:
397       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
398       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
399       break;
400     case Intrinsic::x86_tdpbuud_internal:
401       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
402       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
403       break;
404     default:
405       llvm_unreachable("Invalid intrinsic ID!");
406     }
407     Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
408     Value *ResElt = B.CreateAdd(EltC, SubVecR);
409     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
410   } else {
411     // tiledpbf16ps.scalarize.inner.body:
412     // calculate idxa, idxb, idxc
413     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
414     // %eltcf32 = bitcast i32 %eltc to float
415     // %elta = extractelement <256 x i32> %veca, i16 %idxa
416     // %eltav2i16 = bitcast i32 %elta to <2 x i16>
417     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
418     // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
419     // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
420     // x i32> <i32 2, i32 0, i32 3, i32 1>
421     // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
422     // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
423     // i32> <i32 2, i32 0, i32 3, i32 1>
424     // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
425     // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
426     // %acc = call float
427     // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
428     // %neweltc = bitcast float %acc to i32
429     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
430     // i16 %idxc
431     // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
432     // i16 %idxc
433     FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
434     FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
435     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
436     Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
437     Value *EltA = B.CreateExtractElement(VecA, IdxA);
438     Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
439     Value *EltB = B.CreateExtractElement(VecB, IdxB);
440     Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
441     Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
442     int ShuffleMask[4] = {2, 0, 3, 1};
443     auto ShuffleArray = makeArrayRef(ShuffleMask);
444     Value *AV2F32 = B.CreateBitCast(
445         B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
446     Value *BV2F32 = B.CreateBitCast(
447         B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
448     Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
449     Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
450     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
451   }
452 
453   // tiledpbssd.scalarize.cols.latch:
454   // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
455   // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
456   // i16 %idxc
457   B.SetInsertPoint(ColLoopLatch->getTerminator());
458   Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
459   Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
460 
461   VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
462   VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
463   VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
464   VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
465   VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
466 
467   return NewVecD;
468 }
469 
470 template <Intrinsic::ID IntrID>
471 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
472                             IntrID == Intrinsic::x86_tdpbsud_internal ||
473                             IntrID == Intrinsic::x86_tdpbusd_internal ||
474                             IntrID == Intrinsic::x86_tdpbuud_internal ||
475                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
476                         bool>::type
477 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
478   Value *M, *N, *K, *C, *A, *B;
479   match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
480                                     m_Value(C), m_Value(A), m_Value(B)));
481   Instruction *InsertI = TileDP;
482   IRBuilder<> PreBuilder(TileDP);
483   PreBuilder.SetInsertPoint(TileDP);
484   // We visit the loop with (m, n/4, k/4):
485   // %n_dword = lshr i16 %n, 2
486   // %k_dword = lshr i16 %k, 2
487   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
488   Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
489   BasicBlock *Start = InsertI->getParent();
490   BasicBlock *End =
491       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
492   IRBuilder<> Builder(TileDP);
493   Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
494                                             KDWord, C, A, B);
495   // we cannot assume there always be bitcast after tiledpbssd. So we need to
496   // insert one bitcast as required
497   Builder.SetInsertPoint(End->getFirstNonPHI());
498   Value *ResAMX =
499       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
500   // Delete TileDP intrinsic and do some clean-up.
501   for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
502     Instruction *I = cast<Instruction>(U.getUser());
503     Value *Vec;
504     if (match(I, m_BitCast(m_Value(Vec)))) {
505       I->replaceAllUsesWith(ResVec);
506       I->eraseFromParent();
507     }
508   }
509   TileDP->replaceAllUsesWith(ResAMX);
510   TileDP->eraseFromParent();
511   return true;
512 }
513 
514 template <bool IsTileLoad>
515 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
516   Value *M, *N, *Ptr, *Stride, *Tile;
517   if (IsTileLoad)
518     match(TileLoadStore,
519           m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
520               m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
521   else
522     match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
523                              m_Value(M), m_Value(N), m_Value(Ptr),
524                              m_Value(Stride), m_Value(Tile)));
525 
526   Instruction *InsertI = TileLoadStore;
527   IRBuilder<> PreBuilder(TileLoadStore);
528   PreBuilder.SetInsertPoint(TileLoadStore);
529   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
530   Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
531   BasicBlock *Start = InsertI->getParent();
532   BasicBlock *End =
533       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
534   IRBuilder<> Builder(TileLoadStore);
535   Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
536       Start, End, Builder, M, NDWord, Ptr, StrideDWord,
537       IsTileLoad ? nullptr : Tile);
538   if (IsTileLoad) {
539     // we cannot assume there always be bitcast after tileload. So we need to
540     // insert one bitcast as required
541     Builder.SetInsertPoint(End->getFirstNonPHI());
542     Value *ResAMX =
543         Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
544     // Delete tileloadd6 intrinsic and do some clean-up
545     for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
546       Instruction *I = cast<Instruction>(U.getUser());
547       Value *Vec;
548       if (match(I, m_BitCast(m_Value(Vec)))) {
549         I->replaceAllUsesWith(ResVec);
550         I->eraseFromParent();
551       }
552     }
553     TileLoadStore->replaceAllUsesWith(ResAMX);
554   }
555   TileLoadStore->eraseFromParent();
556   return true;
557 }
558 
559 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
560   IRBuilder<> Builder(TileZero);
561   FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
562   Value *VecZero = Constant::getNullValue(V256I32Ty);
563   for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
564     Instruction *I = cast<Instruction>(U.getUser());
565     Value *Vec;
566     if (match(I, m_BitCast(m_Value(Vec)))) {
567       I->replaceAllUsesWith(VecZero);
568       I->eraseFromParent();
569     }
570   }
571   TileZero->eraseFromParent();
572   return true;
573 }
574 
575 bool X86LowerAMXIntrinsics::visit() {
576   bool C = false;
577   SmallVector<IntrinsicInst *, 8> WorkList;
578   for (BasicBlock *BB : depth_first(&Func)) {
579     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
580       if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
581         switch (Inst->getIntrinsicID()) {
582         case Intrinsic::x86_tdpbssd_internal:
583         case Intrinsic::x86_tdpbsud_internal:
584         case Intrinsic::x86_tdpbusd_internal:
585         case Intrinsic::x86_tdpbuud_internal:
586         case Intrinsic::x86_tileloadd64_internal:
587         case Intrinsic::x86_tilestored64_internal:
588         case Intrinsic::x86_tilezero_internal:
589         case Intrinsic::x86_tdpbf16ps_internal:
590           WorkList.push_back(Inst);
591           break;
592         default:
593           break;
594         }
595       }
596     }
597   }
598 
599   for (auto *Inst : WorkList) {
600     switch (Inst->getIntrinsicID()) {
601     case Intrinsic::x86_tdpbssd_internal:
602       C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
603       break;
604     case Intrinsic::x86_tdpbsud_internal:
605       C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
606       break;
607     case Intrinsic::x86_tdpbusd_internal:
608       C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
609       break;
610     case Intrinsic::x86_tdpbuud_internal:
611       C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
612       break;
613     case Intrinsic::x86_tdpbf16ps_internal:
614       C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
615       break;
616     case Intrinsic::x86_tileloadd64_internal:
617       C = lowerTileLoadStore<true>(Inst) || C;
618       break;
619     case Intrinsic::x86_tilestored64_internal:
620       C = lowerTileLoadStore<false>(Inst) || C;
621       break;
622     case Intrinsic::x86_tilezero_internal:
623       C = lowerTileZero(Inst) || C;
624       break;
625     default:
626       llvm_unreachable("invalid amx intrinsics!");
627     }
628   }
629 
630   return C;
631 }
632 
633 namespace {
634 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
635 public:
636   static char ID;
637 
638   X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
639     initializeX86LowerAMXIntrinsicsLegacyPassPass(
640         *PassRegistry::getPassRegistry());
641   }
642 
643   bool runOnFunction(Function &F) override {
644     if (!X86ScalarizeAMX)
645       return false;
646     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
647     if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
648         TM->getOptLevel() != CodeGenOpt::None)
649       return false;
650 
651     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
652     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
653     auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
654     auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
655     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
656 
657     X86LowerAMXIntrinsics LAT(F, DTU, LI);
658     return LAT.visit();
659   }
660   StringRef getPassName() const override { return "Lower AMX intrinsics"; }
661 
662   void getAnalysisUsage(AnalysisUsage &AU) const override {
663     AU.addPreserved<DominatorTreeWrapperPass>();
664     AU.addPreserved<LoopInfoWrapperPass>();
665     AU.addRequired<TargetPassConfig>();
666   }
667 };
668 } // namespace
669 
670 static const char PassName[] = "Lower AMX intrinsics";
671 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
672 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
673                       false, false)
674 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
675 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
676                     false, false)
677 
678 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
679   return new X86LowerAMXIntrinsicsLegacyPass();
680 }
681