xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ...                                          |
27 /// | ...                                                    |
28 /// | "use %td"                                              |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ...                                          |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34 /// | ...                                                    |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2"                                             |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetLibraryInfo.h"
47 #include "llvm/Analysis/TargetTransformInfo.h"
48 #include "llvm/CodeGen/Passes.h"
49 #include "llvm/CodeGen/TargetPassConfig.h"
50 #include "llvm/CodeGen/ValueTypes.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/Function.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/IR/IntrinsicsX86.h"
57 #include "llvm/IR/PatternMatch.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Pass.h"
60 #include "llvm/Target/TargetMachine.h"
61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
62 #include "llvm/Transforms/Utils/Local.h"
63 
64 #include <map>
65 
66 using namespace llvm;
67 using namespace PatternMatch;
68 
69 #define DEBUG_TYPE "lower-amx-type"
70 
isAMXCast(Instruction * II)71 static bool isAMXCast(Instruction *II) {
72   return match(II,
73                m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
74          match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
75 }
76 
isAMXIntrinsic(Value * I)77 static bool isAMXIntrinsic(Value *I) {
78   auto *II = dyn_cast<IntrinsicInst>(I);
79   if (!II)
80     return false;
81   if (isAMXCast(II))
82     return false;
83   // Check if return type or parameter is x86_amx. If it is x86_amx
84   // the intrinsic must be x86 amx intrinsics.
85   if (II->getType()->isX86_AMXTy())
86     return true;
87   for (Value *V : II->args()) {
88     if (V->getType()->isX86_AMXTy())
89       return true;
90   }
91 
92   return false;
93 }
94 
containsAMXCode(Function & F)95 static bool containsAMXCode(Function &F) {
96   for (BasicBlock &BB : F)
97     for (Instruction &I : BB)
98       if (I.getType()->isX86_AMXTy())
99         return true;
100   return false;
101 }
102 
createAllocaInstAtEntry(IRBuilder<> & Builder,BasicBlock * BB,Type * Ty)103 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
104                                            Type *Ty) {
105   Function &F = *BB->getParent();
106   const DataLayout &DL = F.getDataLayout();
107 
108   LLVMContext &Ctx = Builder.getContext();
109   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
110   unsigned AllocaAS = DL.getAllocaAddrSpace();
111   AllocaInst *AllocaRes =
112       new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
113   AllocaRes->setAlignment(AllocaAlignment);
114   return AllocaRes;
115 }
116 
getFirstNonAllocaInTheEntryBlock(Function & F)117 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
118   for (Instruction &I : F.getEntryBlock())
119     if (!isa<AllocaInst>(&I))
120       return &I;
121   llvm_unreachable("No terminator in the entry block!");
122 }
123 
getShape(IntrinsicInst * II,unsigned OpNo)124 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
125   IRBuilder<> Builder(II);
126   Value *Row = nullptr, *Col = nullptr;
127   switch (II->getIntrinsicID()) {
128   default:
129     llvm_unreachable("Expect amx intrinsics");
130   case Intrinsic::x86_tileloadd64_internal:
131   case Intrinsic::x86_tileloaddt164_internal:
132   case Intrinsic::x86_tilestored64_internal: {
133     Row = II->getArgOperand(0);
134     Col = II->getArgOperand(1);
135     break;
136   }
137   // a * b + c
138   // The shape depends on which operand.
139   case Intrinsic::x86_tcmmimfp16ps_internal:
140   case Intrinsic::x86_tcmmrlfp16ps_internal:
141   case Intrinsic::x86_tdpbssd_internal:
142   case Intrinsic::x86_tdpbsud_internal:
143   case Intrinsic::x86_tdpbusd_internal:
144   case Intrinsic::x86_tdpbuud_internal:
145   case Intrinsic::x86_tdpbf16ps_internal:
146   case Intrinsic::x86_tdpfp16ps_internal: {
147     switch (OpNo) {
148     case 3:
149       Row = II->getArgOperand(0);
150       Col = II->getArgOperand(1);
151       break;
152     case 4:
153       Row = II->getArgOperand(0);
154       Col = II->getArgOperand(2);
155       break;
156     case 5:
157       if (isa<ConstantInt>(II->getArgOperand(2)))
158         Row = Builder.getInt16(
159             (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
160       else if (isa<Instruction>(II->getArgOperand(2))) {
161         // When it is not a const value and it is not a function argument, we
162         // create Row after the definition of II->getOperand(2) instead of
163         // before II. For example, II is %118, we try to getshape for %117:
164         //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
165         //   i32> %115).
166         //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
167         //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
168         //   %117).
169         // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
170         // definition is after its user(new tileload for %117).
171         // So, the best choice is to create %row right after the definition of
172         // %106.
173         Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
174         Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
175         cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
176       } else {
177         // When it is not a const value and it is a function argument, we create
178         // Row at the entry bb.
179         IRBuilder<> NewBuilder(
180             getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
181         Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
182       }
183       Col = II->getArgOperand(1);
184       break;
185     }
186     break;
187   }
188   }
189 
190   return std::make_pair(Row, Col);
191 }
192 
getShape(PHINode * Phi)193 static std::pair<Value *, Value *> getShape(PHINode *Phi) {
194   Use &U = *(Phi->use_begin());
195   unsigned OpNo = U.getOperandNo();
196   User *V = U.getUser();
197   // TODO We don't traverse all users. To make the algorithm simple, here we
198   // just traverse the first user. If we can find shape, then return the shape,
199   // otherwise just return nullptr and the optimization for undef/zero will be
200   // abandoned.
201   while (V) {
202     if (isAMXCast(dyn_cast<Instruction>(V))) {
203       if (V->use_empty())
204         break;
205       Use &U = *(V->use_begin());
206       OpNo = U.getOperandNo();
207       V = U.getUser();
208     } else if (isAMXIntrinsic(V)) {
209       return getShape(cast<IntrinsicInst>(V), OpNo);
210     } else if (isa<PHINode>(V)) {
211       if (V->use_empty())
212         break;
213       Use &U = *(V->use_begin());
214       V = U.getUser();
215     } else {
216       break;
217     }
218   }
219 
220   return std::make_pair(nullptr, nullptr);
221 }
222 
223 namespace {
224 class X86LowerAMXType {
225   Function &Func;
226 
227   // In AMX intrinsics we let Shape = {Row, Col}, but the
228   // RealCol = Col / ElementSize. We may use the RealCol
229   // as a new Row for other new created AMX intrinsics.
230   std::map<Value *, Value *> Col2Row;
231 
232 public:
X86LowerAMXType(Function & F)233   X86LowerAMXType(Function &F) : Func(F) {}
234   bool visit();
235   void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
236   void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
237   bool transformBitcast(BitCastInst *Bitcast);
238 };
239 
240 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
241 // %2 = bitcast <256 x i32> %src to x86_amx
242 // -->
243 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
244 // i8* %addr, i64 %stride64)
combineLoadBitcast(LoadInst * LD,BitCastInst * Bitcast)245 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
246   Value *Row = nullptr, *Col = nullptr;
247   Use &U = *(Bitcast->use_begin());
248   unsigned OpNo = U.getOperandNo();
249   auto *II = cast<IntrinsicInst>(U.getUser());
250   std::tie(Row, Col) = getShape(II, OpNo);
251   IRBuilder<> Builder(Bitcast);
252   // Use the maximun column as stride.
253   Value *Stride = Builder.getInt64(64);
254   Value *I8Ptr = LD->getOperand(0);
255   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
256 
257   Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
258                                            std::nullopt, Args);
259   Bitcast->replaceAllUsesWith(NewInst);
260 }
261 
262 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
263 //                                                    %stride);
264 // %13 = bitcast x86_amx %src to <256 x i32>
265 // store <256 x i32> %13, <256 x i32>* %addr, align 64
266 // -->
267 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
268 //                                           %stride64, %13)
combineBitcastStore(BitCastInst * Bitcast,StoreInst * ST)269 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
270 
271   Value *Tile = Bitcast->getOperand(0);
272   auto *II = cast<IntrinsicInst>(Tile);
273   // Tile is output from AMX intrinsic. The first operand of the
274   // intrinsic is row, the second operand of the intrinsic is column.
275   Value *Row = II->getOperand(0);
276   Value *Col = II->getOperand(1);
277   IRBuilder<> Builder(ST);
278   // Use the maximum column as stride. It must be the same with load
279   // stride.
280   Value *Stride = Builder.getInt64(64);
281   Value *I8Ptr = ST->getOperand(1);
282   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
283   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
284                           Args);
285   if (Bitcast->hasOneUse())
286     return;
287   // %13 = bitcast x86_amx %src to <256 x i32>
288   // store <256 x i32> %13, <256 x i32>* %addr, align 64
289   // %add = <256 x i32> %13, <256 x i32> %src2
290   // -->
291   // %13 = bitcast x86_amx %src to <256 x i32>
292   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
293   //                                           %stride64, %13)
294   // %14 = load <256 x i32>, %addr
295   // %add = <256 x i32> %14, <256 x i32> %src2
296   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
297   Bitcast->replaceAllUsesWith(Vec);
298 }
299 
300 // transform bitcast to <store, load> instructions.
transformBitcast(BitCastInst * Bitcast)301 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
302   IRBuilder<> Builder(Bitcast);
303   AllocaInst *AllocaAddr;
304   Value *I8Ptr, *Stride;
305   auto *Src = Bitcast->getOperand(0);
306 
307   auto Prepare = [&](Type *MemTy) {
308     AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
309     I8Ptr = AllocaAddr;
310     Stride = Builder.getInt64(64);
311   };
312 
313   if (Bitcast->getType()->isX86_AMXTy()) {
314     // %2 = bitcast <256 x i32> %src to x86_amx
315     // -->
316     // %addr = alloca <256 x i32>, align 64
317     // store <256 x i32> %src, <256 x i32>* %addr, align 64
318     // %addr2 = bitcast <256 x i32>* to i8*
319     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
320     //                                                  i8* %addr2,
321     //                                                  i64 64)
322     Use &U = *(Bitcast->use_begin());
323     unsigned OpNo = U.getOperandNo();
324     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
325     if (!II)
326       return false; // May be bitcast from x86amx to <256 x i32>.
327     Prepare(Bitcast->getOperand(0)->getType());
328     Builder.CreateStore(Src, AllocaAddr);
329     // TODO we can pick an constant operand for the shape.
330     Value *Row = nullptr, *Col = nullptr;
331     std::tie(Row, Col) = getShape(II, OpNo);
332     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
333     Value *NewInst = Builder.CreateIntrinsic(
334         Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
335     Bitcast->replaceAllUsesWith(NewInst);
336   } else {
337     // %2 = bitcast x86_amx %src to <256 x i32>
338     // -->
339     // %addr = alloca <256 x i32>, align 64
340     // %addr2 = bitcast <256 x i32>* to i8*
341     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
342     //                                           i8* %addr2, i64 %stride)
343     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
344     auto *II = dyn_cast<IntrinsicInst>(Src);
345     if (!II)
346       return false; // May be bitcast from <256 x i32> to x86amx.
347     Prepare(Bitcast->getType());
348     Value *Row = II->getOperand(0);
349     Value *Col = II->getOperand(1);
350     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
351     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
352                             Args);
353     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
354     Bitcast->replaceAllUsesWith(NewInst);
355   }
356 
357   return true;
358 }
359 
visit()360 bool X86LowerAMXType::visit() {
361   SmallVector<Instruction *, 8> DeadInsts;
362   Col2Row.clear();
363 
364   for (BasicBlock *BB : post_order(&Func)) {
365     for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
366       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
367       if (!Bitcast)
368         continue;
369 
370       Value *Src = Bitcast->getOperand(0);
371       if (Bitcast->getType()->isX86_AMXTy()) {
372         if (Bitcast->user_empty()) {
373           DeadInsts.push_back(Bitcast);
374           continue;
375         }
376         LoadInst *LD = dyn_cast<LoadInst>(Src);
377         if (!LD) {
378           if (transformBitcast(Bitcast))
379             DeadInsts.push_back(Bitcast);
380           continue;
381         }
382         // If load has mutli-user, duplicate a vector load.
383         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
384         // %2 = bitcast <256 x i32> %src to x86_amx
385         // %add = add <256 x i32> %src, <256 x i32> %src2
386         // -->
387         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
388         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
389         //                                            i8* %addr, i64 %stride64)
390         // %add = add <256 x i32> %src, <256 x i32> %src2
391 
392         // If load has one user, the load will be eliminated in DAG ISel.
393         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
394         // %2 = bitcast <256 x i32> %src to x86_amx
395         // -->
396         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
397         //                                            i8* %addr, i64 %stride64)
398         combineLoadBitcast(LD, Bitcast);
399         DeadInsts.push_back(Bitcast);
400         if (LD->hasOneUse())
401           DeadInsts.push_back(LD);
402       } else if (Src->getType()->isX86_AMXTy()) {
403         if (Bitcast->user_empty()) {
404           DeadInsts.push_back(Bitcast);
405           continue;
406         }
407         StoreInst *ST = nullptr;
408         for (Use &U : Bitcast->uses()) {
409           ST = dyn_cast<StoreInst>(U.getUser());
410           if (ST)
411             break;
412         }
413         if (!ST) {
414           if (transformBitcast(Bitcast))
415             DeadInsts.push_back(Bitcast);
416           continue;
417         }
418         // If bitcast (%13) has one use, combine bitcast and store to amx store.
419         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
420         //                                                    %stride);
421         // %13 = bitcast x86_amx %src to <256 x i32>
422         // store <256 x i32> %13, <256 x i32>* %addr, align 64
423         // -->
424         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
425         //                                           %stride64, %13)
426         //
427         // If bitcast (%13) has multi-use, transform as below.
428         // %13 = bitcast x86_amx %src to <256 x i32>
429         // store <256 x i32> %13, <256 x i32>* %addr, align 64
430         // %add = <256 x i32> %13, <256 x i32> %src2
431         // -->
432         // %13 = bitcast x86_amx %src to <256 x i32>
433         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
434         //                                           %stride64, %13)
435         // %14 = load <256 x i32>, %addr
436         // %add = <256 x i32> %14, <256 x i32> %src2
437         //
438         combineBitcastStore(Bitcast, ST);
439         // Delete user first.
440         DeadInsts.push_back(ST);
441         DeadInsts.push_back(Bitcast);
442       }
443     }
444   }
445 
446   bool C = !DeadInsts.empty();
447 
448   for (auto *Inst : DeadInsts)
449     Inst->eraseFromParent();
450 
451   return C;
452 }
453 } // anonymous namespace
454 
getAllocaPos(BasicBlock * BB)455 static Value *getAllocaPos(BasicBlock *BB) {
456   Function *F = BB->getParent();
457   IRBuilder<> Builder(&F->getEntryBlock().front());
458   const DataLayout &DL = F->getDataLayout();
459   unsigned AllocaAS = DL.getAllocaAddrSpace();
460   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
461   AllocaInst *AllocaRes =
462       new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
463   BasicBlock::iterator Iter = AllocaRes->getIterator();
464   ++Iter;
465   Builder.SetInsertPoint(&*Iter);
466   Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
467   return I8Ptr;
468 }
469 
createTileStore(Instruction * TileDef,Value * Ptr)470 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
471   assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
472   auto *II = cast<IntrinsicInst>(TileDef);
473   assert(II && "Not tile intrinsic!");
474   Value *Row = II->getOperand(0);
475   Value *Col = II->getOperand(1);
476 
477   BasicBlock *BB = TileDef->getParent();
478   BasicBlock::iterator Iter = TileDef->getIterator();
479   IRBuilder<> Builder(BB, ++Iter);
480   Value *Stride = Builder.getInt64(64);
481   std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
482 
483   Instruction *TileStore = Builder.CreateIntrinsic(
484       Intrinsic::x86_tilestored64_internal, std::nullopt, Args);
485   return TileStore;
486 }
487 
replaceWithTileLoad(Use & U,Value * Ptr,bool IsPHI=false)488 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
489   Value *V = U.get();
490   assert(V->getType()->isX86_AMXTy() && "Not define tile!");
491 
492   // Get tile shape.
493   IntrinsicInst *II = nullptr;
494   if (IsPHI) {
495     Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
496     II = cast<IntrinsicInst>(PhiOp);
497   } else {
498     II = cast<IntrinsicInst>(V);
499   }
500   Value *Row = II->getOperand(0);
501   Value *Col = II->getOperand(1);
502 
503   Instruction *UserI = cast<Instruction>(U.getUser());
504   IRBuilder<> Builder(UserI);
505   Value *Stride = Builder.getInt64(64);
506   std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
507 
508   Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
509                                             std::nullopt, Args);
510   UserI->replaceUsesOfWith(V, TileLoad);
511 }
512 
isIncomingOfPHI(Instruction * I)513 static bool isIncomingOfPHI(Instruction *I) {
514   for (Use &U : I->uses()) {
515     User *V = U.getUser();
516     if (isa<PHINode>(V))
517       return true;
518   }
519   return false;
520 }
521 
522 // Let all AMX tile data become volatile data, shorten the life range
523 // of each tile register before fast register allocation.
524 namespace {
525 class X86VolatileTileData {
526   Function &F;
527 
528 public:
X86VolatileTileData(Function & Func)529   X86VolatileTileData(Function &Func) : F(Func) {}
530   Value *updatePhiIncomings(BasicBlock *BB,
531                             SmallVector<Instruction *, 2> &Incomings);
532   void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
533   bool volatileTileData();
534   void volatileTilePHI(PHINode *PHI);
535   void volatileTileNonPHI(Instruction *I);
536 };
537 
updatePhiIncomings(BasicBlock * BB,SmallVector<Instruction *,2> & Incomings)538 Value *X86VolatileTileData::updatePhiIncomings(
539     BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
540   Value *I8Ptr = getAllocaPos(BB);
541 
542   for (auto *I : Incomings) {
543     User *Store = createTileStore(I, I8Ptr);
544 
545     // All its uses (except phi) should load from stored mem.
546     for (Use &U : I->uses()) {
547       User *V = U.getUser();
548       if (isa<PHINode>(V) || V == Store)
549         continue;
550       replaceWithTileLoad(U, I8Ptr);
551     }
552   }
553   return I8Ptr;
554 }
555 
replacePhiDefWithLoad(Instruction * PHI,Value * StorePtr)556 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
557                                                 Value *StorePtr) {
558   for (Use &U : PHI->uses())
559     replaceWithTileLoad(U, StorePtr, true);
560   PHI->eraseFromParent();
561 }
562 
563 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
564 // and their related AMX intrinsics.
565 // 1) PHI Def should change to tileload.
566 // 2) PHI Incoming Values should tilestored in just after their def.
567 // 3) The mem of these tileload and tilestores should be same.
568 // e.g.
569 // ------------------------------------------------------
570 // bb_dom:
571 //   ...
572 //   br i1 %bool.cond, label %if.else, label %if.then
573 //
574 // if.then:
575 //   def %t0 = ...
576 //   ...
577 //   use %t0
578 //   ...
579 //   br label %if.end
580 //
581 // if.else:
582 //   def %t1 = ...
583 //   br label %if.end
584 //
585 // if.end:
586 //   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
587 //   ...
588 //   use %td
589 // ------------------------------------------------------
590 // -->
591 // ------------------------------------------------------
592 // bb_entry:
593 //   %mem = alloca <256 x i32>, align 1024                  *
594 //   ...
595 // bb_dom:
596 //   ...
597 //   br i1 %bool.cond, label %if.else, label %if.then
598 //
599 // if.then:
600 //   def %t0 = ...
601 //   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
602 //   ...
603 //   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
604 //   use %t0`                                               *
605 //   ...
606 //   br label %if.end
607 //
608 // if.else:
609 //   def %t1 = ...
610 //   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
611 //   br label %if.end
612 //
613 // if.end:
614 //   ...
615 //   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
616 //   use %td
617 // ------------------------------------------------------
volatileTilePHI(PHINode * PHI)618 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
619   BasicBlock *BB = PHI->getParent();
620   SmallVector<Instruction *, 2> Incomings;
621 
622   for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
623     Value *Op = PHI->getIncomingValue(I);
624     Instruction *Inst = dyn_cast<Instruction>(Op);
625     assert(Inst && "We shouldn't fold AMX instrution!");
626     Incomings.push_back(Inst);
627   }
628 
629   Value *StorePtr = updatePhiIncomings(BB, Incomings);
630   replacePhiDefWithLoad(PHI, StorePtr);
631 }
632 
633 // Store the defined tile and load it before use.
634 // All its users are not PHI.
635 // e.g.
636 // ------------------------------------------------------
637 // def %td = ...
638 // ...
639 // "use %td"
640 // ------------------------------------------------------
641 // -->
642 // ------------------------------------------------------
643 // def %td = ...
644 // call void @llvm.x86.tilestored64.internal(mem, %td)
645 // ...
646 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
647 // "use %td2"
648 // ------------------------------------------------------
volatileTileNonPHI(Instruction * I)649 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
650   BasicBlock *BB = I->getParent();
651   Value *I8Ptr = getAllocaPos(BB);
652   User *Store = createTileStore(I, I8Ptr);
653 
654   // All its uses should load from stored mem.
655   for (Use &U : I->uses()) {
656     User *V = U.getUser();
657     assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
658     if (V != Store)
659       replaceWithTileLoad(U, I8Ptr);
660   }
661 }
662 
663 // Volatile Tile Model:
664 // 1) All the uses of tile data comes from tileload in time.
665 // 2) All the defs of tile data tilestore into mem immediately.
666 // For example:
667 // --------------------------------------------------------------------------
668 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
669 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
670 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
671 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
672 // call void @llvm.x86.tilestored64.internal(... td)                     area
673 // --------------------------------------------------------------------------
674 // 3) No terminator, call or other amx instructions in the key amx area.
volatileTileData()675 bool X86VolatileTileData::volatileTileData() {
676   bool Changed = false;
677   for (BasicBlock &BB : F) {
678     SmallVector<Instruction *, 2> PHIInsts;
679     SmallVector<Instruction *, 8> AMXDefInsts;
680 
681     for (Instruction &I : BB) {
682       if (!I.getType()->isX86_AMXTy())
683         continue;
684       if (isa<PHINode>(&I))
685         PHIInsts.push_back(&I);
686       else
687         AMXDefInsts.push_back(&I);
688     }
689 
690     // First we "volatile" the non-phi related amx intrinsics.
691     for (Instruction *I : AMXDefInsts) {
692       if (isIncomingOfPHI(I))
693         continue;
694       volatileTileNonPHI(I);
695       Changed = true;
696     }
697 
698     for (Instruction *I : PHIInsts) {
699       volatileTilePHI(dyn_cast<PHINode>(I));
700       Changed = true;
701     }
702   }
703   return Changed;
704 }
705 
706 } // anonymous namespace
707 
708 namespace {
709 
710 class X86LowerAMXCast {
711   Function &Func;
712   std::unique_ptr<DominatorTree> DT;
713 
714 public:
X86LowerAMXCast(Function & F)715   X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
716   bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
717   bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
718   bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
719   bool combineAMXcast(TargetLibraryInfo *TLI);
720   bool transformAMXCast(IntrinsicInst *AMXCast);
721   bool transformAllAMXCast();
722   bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
723                               SmallSetVector<Instruction *, 16> &DeadInst);
724 };
725 
DCEInstruction(Instruction * I,SmallSetVector<Instruction *,16> & WorkList,const TargetLibraryInfo * TLI)726 static bool DCEInstruction(Instruction *I,
727                            SmallSetVector<Instruction *, 16> &WorkList,
728                            const TargetLibraryInfo *TLI) {
729   if (isInstructionTriviallyDead(I, TLI)) {
730     salvageDebugInfo(*I);
731     salvageKnowledge(I);
732 
733     // Null out all of the instruction's operands to see if any operand becomes
734     // dead as we go.
735     for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
736       Value *OpV = I->getOperand(i);
737       I->setOperand(i, nullptr);
738 
739       if (!OpV->use_empty() || I == OpV)
740         continue;
741 
742       // If the operand is an instruction that became dead as we nulled out the
743       // operand, and if it is 'trivially' dead, delete it in a future loop
744       // iteration.
745       if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
746         if (isInstructionTriviallyDead(OpI, TLI)) {
747           WorkList.insert(OpI);
748         }
749       }
750     }
751     I->eraseFromParent();
752     return true;
753   }
754   return false;
755 }
756 
757 /// This function handles following case
758 ///
759 ///     A  ->  B    amxcast
760 ///     PHI
761 ///     B  ->  A    amxcast
762 ///
763 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
764 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
optimizeAMXCastFromPhi(IntrinsicInst * CI,PHINode * PN,SmallSetVector<Instruction *,16> & DeadInst)765 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
766     IntrinsicInst *CI, PHINode *PN,
767     SmallSetVector<Instruction *, 16> &DeadInst) {
768   IRBuilder<> Builder(CI);
769   Value *Src = CI->getOperand(0);
770   Type *SrcTy = Src->getType(); // Type B
771   Type *DestTy = CI->getType(); // Type A
772 
773   SmallVector<PHINode *, 4> PhiWorklist;
774   SmallSetVector<PHINode *, 4> OldPhiNodes;
775 
776   // Find all of the A->B casts and PHI nodes.
777   // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
778   // OldPhiNodes is used to track all known PHI nodes, before adding a new
779   // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
780   PhiWorklist.push_back(PN);
781   OldPhiNodes.insert(PN);
782   while (!PhiWorklist.empty()) {
783     auto *OldPN = PhiWorklist.pop_back_val();
784     for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
785       Value *IncValue = OldPN->getIncomingValue(I);
786       // TODO: currently, We ignore cases where it is a const. In the future, we
787       // might support const.
788       if (isa<Constant>(IncValue)) {
789         auto *IncConst = dyn_cast<Constant>(IncValue);
790         if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
791           return false;
792         Value *Row = nullptr, *Col = nullptr;
793         std::tie(Row, Col) = getShape(OldPN);
794         // TODO: If it is not constant the Row and Col must domoniate tilezero
795         // that we are going to create.
796         if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
797           return false;
798         // Create tilezero at the end of incoming block.
799         auto *Block = OldPN->getIncomingBlock(I);
800         BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
801         Instruction *NewInst = Builder.CreateIntrinsic(
802             Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col});
803         NewInst->moveBefore(&*Iter);
804         NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
805                                           {IncValue->getType()}, {NewInst});
806         NewInst->moveBefore(&*Iter);
807         // Replace InValue with new Value.
808         OldPN->setIncomingValue(I, NewInst);
809         IncValue = NewInst;
810       }
811 
812       if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
813         if (OldPhiNodes.insert(PNode))
814           PhiWorklist.push_back(PNode);
815         continue;
816       }
817       Instruction *ACI = dyn_cast<Instruction>(IncValue);
818       if (ACI && isAMXCast(ACI)) {
819         // Verify it's a A->B cast.
820         Type *TyA = ACI->getOperand(0)->getType();
821         Type *TyB = ACI->getType();
822         if (TyA != DestTy || TyB != SrcTy)
823           return false;
824         continue;
825       }
826       return false;
827     }
828   }
829 
830   // Check that each user of each old PHI node is something that we can
831   // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
832   for (auto *OldPN : OldPhiNodes) {
833     for (User *V : OldPN->users()) {
834       Instruction *ACI = dyn_cast<Instruction>(V);
835       if (ACI && isAMXCast(ACI)) {
836         // Verify it's a B->A cast.
837         Type *TyB = ACI->getOperand(0)->getType();
838         Type *TyA = ACI->getType();
839         if (TyA != DestTy || TyB != SrcTy)
840           return false;
841       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
842         // As long as the user is another old PHI node, then even if we don't
843         // rewrite it, the PHI web we're considering won't have any users
844         // outside itself, so it'll be dead.
845         // example:
846         //   bb.0:
847         //      %0 = amxcast ...
848         //   bb.1:
849         //      %1 = amxcast ...
850         //   bb.2:
851         //      %goodphi = phi %0, %1
852         //      %3 = amxcast %goodphi
853         //   bb.3:
854         //      %goodphi2 = phi %0, %goodphi
855         //      %4 = amxcast %goodphi2
856         // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
857         // outside the phi-web, so the combination stop When
858         // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
859         // will be done.
860         if (OldPhiNodes.count(PHI) == 0)
861           return false;
862       } else
863         return false;
864     }
865   }
866 
867   // For each old PHI node, create a corresponding new PHI node with a type A.
868   SmallDenseMap<PHINode *, PHINode *> NewPNodes;
869   for (auto *OldPN : OldPhiNodes) {
870     Builder.SetInsertPoint(OldPN);
871     PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
872     NewPNodes[OldPN] = NewPN;
873   }
874 
875   // Fill in the operands of new PHI nodes.
876   for (auto *OldPN : OldPhiNodes) {
877     PHINode *NewPN = NewPNodes[OldPN];
878     for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
879       Value *V = OldPN->getOperand(j);
880       Value *NewV = nullptr;
881       Instruction *ACI = dyn_cast<Instruction>(V);
882       // There should not be a AMXcast from a const.
883       if (ACI && isAMXCast(ACI))
884         NewV = ACI->getOperand(0);
885       else if (auto *PrevPN = dyn_cast<PHINode>(V))
886         NewV = NewPNodes[PrevPN];
887       assert(NewV);
888       NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
889     }
890   }
891 
892   // Traverse all accumulated PHI nodes and process its users,
893   // which are Stores and BitcCasts. Without this processing
894   // NewPHI nodes could be replicated and could lead to extra
895   // moves generated after DeSSA.
896   // If there is a store with type B, change it to type A.
897 
898   // Replace users of BitCast B->A with NewPHI. These will help
899   // later to get rid of a closure formed by OldPHI nodes.
900   for (auto *OldPN : OldPhiNodes) {
901     PHINode *NewPN = NewPNodes[OldPN];
902     for (User *V : make_early_inc_range(OldPN->users())) {
903       Instruction *ACI = dyn_cast<Instruction>(V);
904       if (ACI && isAMXCast(ACI)) {
905         Type *TyB = ACI->getOperand(0)->getType();
906         Type *TyA = ACI->getType();
907         assert(TyA == DestTy && TyB == SrcTy);
908         (void)TyA;
909         (void)TyB;
910         ACI->replaceAllUsesWith(NewPN);
911         DeadInst.insert(ACI);
912       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
913         // We don't need to push PHINode into DeadInst since they are operands
914         // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
915         assert(OldPhiNodes.contains(PHI));
916         (void)PHI;
917       } else
918         llvm_unreachable("all uses should be handled");
919     }
920   }
921   return true;
922 }
923 
924 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
925 // store <256 x i32> %43, <256 x i32>* %p, align 64
926 // -->
927 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
928 //                                           i64 64, x86_amx %42)
combineCastStore(IntrinsicInst * Cast,StoreInst * ST)929 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
930   Value *Tile = Cast->getOperand(0);
931   // TODO: If it is cast intrinsic or phi node, we can propagate the
932   // shape information through def-use chain.
933   if (!isAMXIntrinsic(Tile))
934     return false;
935   auto *II = cast<IntrinsicInst>(Tile);
936   // Tile is output from AMX intrinsic. The first operand of the
937   // intrinsic is row, the second operand of the intrinsic is column.
938   Value *Row = II->getOperand(0);
939   Value *Col = II->getOperand(1);
940   IRBuilder<> Builder(ST);
941   // Stride should be equal to col(measured by bytes)
942   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
943   Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
944   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
945   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
946                           Args);
947   return true;
948 }
949 
950 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
951 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
952 // -->
953 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
954 //                                                   i8* %p, i64 64)
combineLoadCast(IntrinsicInst * Cast,LoadInst * LD)955 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
956   bool EraseLoad = true;
957   Value *Row = nullptr, *Col = nullptr;
958   Use &U = *(Cast->use_begin());
959   unsigned OpNo = U.getOperandNo();
960   auto *II = cast<IntrinsicInst>(U.getUser());
961   // TODO: If it is cast intrinsic or phi node, we can propagate the
962   // shape information through def-use chain.
963   if (!isAMXIntrinsic(II))
964     return false;
965   std::tie(Row, Col) = getShape(II, OpNo);
966   IRBuilder<> Builder(LD);
967   // Stride should be equal to col(measured by bytes)
968   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
969   Value *I8Ptr;
970 
971   // To save compiling time, we create doninator tree when it is really
972   // needed.
973   if (!DT)
974     DT.reset(new DominatorTree(Func));
975   if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
976     // store the value to stack and reload it from stack before cast.
977     auto *AllocaAddr =
978         createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
979     Builder.SetInsertPoint(&*std::next(LD->getIterator()));
980     Builder.CreateStore(LD, AllocaAddr);
981 
982     Builder.SetInsertPoint(Cast);
983     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
984     EraseLoad = false;
985   } else {
986     I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
987   }
988   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
989 
990   Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
991                                            std::nullopt, Args);
992   Cast->replaceAllUsesWith(NewInst);
993 
994   return EraseLoad;
995 }
996 
combineLdSt(SmallVectorImpl<Instruction * > & Casts)997 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
998   bool Change = false;
999   for (auto *Cast : Casts) {
1000     auto *II = cast<IntrinsicInst>(Cast);
1001     // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1002     // store <256 x i32> %43, <256 x i32>* %p, align 64
1003     // -->
1004     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1005     //                                           i64 64, x86_amx %42)
1006     if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1007       SmallVector<Instruction *, 2> DeadStores;
1008       for (User *U : Cast->users()) {
1009         StoreInst *Store = dyn_cast<StoreInst>(U);
1010         if (!Store)
1011           continue;
1012         if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1013           DeadStores.push_back(Store);
1014           Change = true;
1015         }
1016       }
1017       for (auto *Store : DeadStores)
1018         Store->eraseFromParent();
1019     } else { // x86_cast_vector_to_tile
1020       SmallVector<Instruction *, 2> DeadLoads;
1021       auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1022       if (!Load || !Load->hasOneUse())
1023         continue;
1024       // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1025       // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1026       // -->
1027       // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1028       //                                                   i8* %p, i64 64)
1029       if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1030         // Set the operand is null so that load instruction can be erased.
1031         Cast->setOperand(0, nullptr);
1032         Load->eraseFromParent();
1033       }
1034     }
1035   }
1036   return Change;
1037 }
1038 
combineAMXcast(TargetLibraryInfo * TLI)1039 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1040   bool Change = false;
1041   // Collect tile cast instruction.
1042   SmallVector<Instruction *, 8> Vec2TileInsts;
1043   SmallVector<Instruction *, 8> Tile2VecInsts;
1044   SmallVector<Instruction *, 8> PhiCastWorkList;
1045   SmallSetVector<Instruction *, 16> DeadInst;
1046   for (BasicBlock &BB : Func) {
1047     for (Instruction &I : BB) {
1048       Value *Vec;
1049       if (match(&I,
1050                 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1051         Vec2TileInsts.push_back(&I);
1052       else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1053                              m_Value(Vec))))
1054         Tile2VecInsts.push_back(&I);
1055     }
1056   }
1057 
1058   auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1059     for (auto *Inst : Insts) {
1060       for (User *U : Inst->users()) {
1061         IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1062         if (!II || II->getIntrinsicID() != IID)
1063           continue;
1064         // T1 = vec2tile V0
1065         // V2 = tile2vec T1
1066         // V3 = OP V2
1067         // -->
1068         // T1 = vec2tile V0
1069         // V2 = tile2vec T1
1070         // V3 = OP V0
1071         II->replaceAllUsesWith(Inst->getOperand(0));
1072         Change = true;
1073       }
1074     }
1075   };
1076 
1077   Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1078   Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1079 
1080   SmallVector<Instruction *, 8> LiveCasts;
1081   auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1082     for (auto *Inst : Insts) {
1083       if (Inst->use_empty()) {
1084         Inst->eraseFromParent();
1085         Change = true;
1086       } else {
1087         LiveCasts.push_back(Inst);
1088       }
1089     }
1090   };
1091 
1092   EraseInst(Vec2TileInsts);
1093   EraseInst(Tile2VecInsts);
1094   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1095                        "Vec2Tile and Tile2Vec:\n";
1096              Func.dump());
1097   Change |= combineLdSt(LiveCasts);
1098   EraseInst(LiveCasts);
1099   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1100                        "AMXCast and load/store:\n";
1101              Func.dump());
1102 
1103   // Handle the A->B->A cast, and there is an intervening PHI node.
1104   for (BasicBlock &BB : Func) {
1105     for (Instruction &I : BB) {
1106       if (isAMXCast(&I)) {
1107         if (isa<PHINode>(I.getOperand(0)))
1108           PhiCastWorkList.push_back(&I);
1109       }
1110     }
1111   }
1112   for (auto *I : PhiCastWorkList) {
1113     // We skip the dead Amxcast.
1114     if (DeadInst.contains(I))
1115       continue;
1116     PHINode *PN = cast<PHINode>(I->getOperand(0));
1117     if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1118       DeadInst.insert(PN);
1119       Change = true;
1120     }
1121   }
1122 
1123   // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1124   // have no uses. We do some DeadCodeElimination for them.
1125   while (!DeadInst.empty()) {
1126     Instruction *I = DeadInst.pop_back_val();
1127     Change |= DCEInstruction(I, DeadInst, TLI);
1128   }
1129   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1130                        "optimizeAMXCastFromPhi:\n";
1131              Func.dump());
1132   return Change;
1133 }
1134 
1135 // There might be remaining AMXcast after combineAMXcast and they should be
1136 // handled elegantly.
transformAMXCast(IntrinsicInst * AMXCast)1137 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1138   IRBuilder<> Builder(AMXCast);
1139   AllocaInst *AllocaAddr;
1140   Value *I8Ptr, *Stride;
1141   auto *Src = AMXCast->getOperand(0);
1142 
1143   auto Prepare = [&](Type *MemTy) {
1144     AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1145     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1146     Stride = Builder.getInt64(64);
1147   };
1148 
1149   if (AMXCast->getType()->isX86_AMXTy()) {
1150     // %2 = amxcast <225 x i32> %src to x86_amx
1151     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1152     //                                           i8* %addr3, i64 60, x86_amx %2)
1153     // -->
1154     // %addr = alloca <225 x i32>, align 64
1155     // store <225 x i32> %src, <225 x i32>* %addr, align 64
1156     // %addr2 = bitcast <225 x i32>* %addr to i8*
1157     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1158     //                                                  i8* %addr2,
1159     //                                                  i64 60)
1160     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1161     //                                           i8* %addr3, i64 60, x86_amx %2)
1162     if (AMXCast->use_empty()) {
1163       AMXCast->eraseFromParent();
1164       return true;
1165     }
1166     Use &U = *(AMXCast->use_begin());
1167     unsigned OpNo = U.getOperandNo();
1168     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1169     if (!II)
1170       return false; // May be bitcast from x86amx to <256 x i32>.
1171     Prepare(AMXCast->getOperand(0)->getType());
1172     Builder.CreateStore(Src, AllocaAddr);
1173     // TODO we can pick an constant operand for the shape.
1174     Value *Row = nullptr, *Col = nullptr;
1175     std::tie(Row, Col) = getShape(II, OpNo);
1176     std::array<Value *, 4> Args = {
1177         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1178     Value *NewInst = Builder.CreateIntrinsic(
1179         Intrinsic::x86_tileloadd64_internal, std::nullopt, Args);
1180     AMXCast->replaceAllUsesWith(NewInst);
1181     AMXCast->eraseFromParent();
1182   } else {
1183     // %2 = amxcast x86_amx %src to <225 x i32>
1184     // -->
1185     // %addr = alloca <225 x i32>, align 64
1186     // %addr2 = bitcast <225 x i32>* to i8*
1187     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1188     //                                           i8* %addr2, i64 %stride)
1189     // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1190     auto *II = dyn_cast<IntrinsicInst>(Src);
1191     if (!II)
1192       return false; // May be bitcast from <256 x i32> to x86amx.
1193     Prepare(AMXCast->getType());
1194     Value *Row = II->getOperand(0);
1195     Value *Col = II->getOperand(1);
1196     std::array<Value *, 5> Args = {
1197         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1198     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt,
1199                             Args);
1200     Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1201     AMXCast->replaceAllUsesWith(NewInst);
1202     AMXCast->eraseFromParent();
1203   }
1204 
1205   return true;
1206 }
1207 
transformAllAMXCast()1208 bool X86LowerAMXCast::transformAllAMXCast() {
1209   bool Change = false;
1210   // Collect tile cast instruction.
1211   SmallVector<Instruction *, 8> WorkLists;
1212   for (BasicBlock &BB : Func) {
1213     for (Instruction &I : BB) {
1214       if (isAMXCast(&I))
1215         WorkLists.push_back(&I);
1216     }
1217   }
1218 
1219   for (auto *Inst : WorkLists) {
1220     Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1221   }
1222 
1223   return Change;
1224 }
1225 
1226 } // anonymous namespace
1227 
1228 namespace {
1229 
1230 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1231 public:
1232   static char ID;
1233 
X86LowerAMXTypeLegacyPass()1234   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
1235     initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
1236   }
1237 
runOnFunction(Function & F)1238   bool runOnFunction(Function &F) override {
1239     // Performance optimization: most code doesn't use AMX, so return early if
1240     // there are no instructions that produce AMX values. This is sufficient, as
1241     // AMX arguments and constants are not allowed -- so any producer of an AMX
1242     // value must be an instruction.
1243     // TODO: find a cheaper way for this, without looking at all instructions.
1244     if (!containsAMXCode(F))
1245       return false;
1246 
1247     bool C = false;
1248     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1249     TargetLibraryInfo *TLI =
1250         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1251 
1252     X86LowerAMXCast LAC(F);
1253     C |= LAC.combineAMXcast(TLI);
1254     // There might be remaining AMXcast after combineAMXcast and they should be
1255     // handled elegantly.
1256     C |= LAC.transformAllAMXCast();
1257 
1258     X86LowerAMXType LAT(F);
1259     C |= LAT.visit();
1260 
1261     // Prepare for fast register allocation at O0.
1262     // Todo: May better check the volatile model of AMX code, not just
1263     // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1264     if (TM->getOptLevel() == CodeGenOptLevel::None) {
1265       // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1266       // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1267       // sure the amx data is volatile, that is nessary for AMX fast
1268       // register allocation.
1269       if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1270         X86VolatileTileData VTD(F);
1271         C = VTD.volatileTileData() || C;
1272       }
1273     }
1274 
1275     return C;
1276   }
1277 
getAnalysisUsage(AnalysisUsage & AU) const1278   void getAnalysisUsage(AnalysisUsage &AU) const override {
1279     AU.setPreservesCFG();
1280     AU.addRequired<TargetPassConfig>();
1281     AU.addRequired<TargetLibraryInfoWrapperPass>();
1282   }
1283 };
1284 
1285 } // anonymous namespace
1286 
1287 static const char PassName[] = "Lower AMX type for load/store";
1288 char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass,DEBUG_TYPE,PassName,false,false)1289 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1290                       false)
1291 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1292 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1293 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1294                     false)
1295 
1296 FunctionPass *llvm::createX86LowerAMXTypePass() {
1297   return new X86LowerAMXTypeLegacyPass();
1298 }
1299