1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Shape utility for AMX. 10 /// AMX hardware requires to config the shape of tile data register before use. 11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape 12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd 13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate 14 /// tile config and register allocator. The row and column are machine operand 15 /// of AMX pseudo instructions. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H 20 #define LLVM_CODEGEN_TILESHAPEINFO_H 21 22 #include "llvm/CodeGen/MachineInstr.h" 23 #include "llvm/CodeGen/MachineOperand.h" 24 #include "llvm/CodeGen/MachineRegisterInfo.h" 25 #include "llvm/CodeGen/Register.h" 26 27 namespace llvm { 28 29 class ShapeT { 30 public: 31 ShapeT(MachineOperand *Row, MachineOperand *Col, 32 const MachineRegisterInfo *MRI = nullptr) 33 : Row(Row), Col(Col) { 34 if (MRI) 35 deduceImm(MRI); 36 } 37 ShapeT() 38 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 39 ColImm(InvalidImmShape) {} 40 bool operator==(const ShapeT &Shape) const { 41 MachineOperand *R = Shape.Row; 42 MachineOperand *C = Shape.Col; 43 if (!R || !C) 44 return false; 45 if (!Row || !Col) 46 return false; 47 if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) 48 return true; 49 if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) 50 return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); 51 return false; 52 } 53 54 bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); } 55 56 MachineOperand *getRow() const { return Row; } 57 58 MachineOperand *getCol() const { return Col; } 59 60 int64_t getRowImm() const { return RowImm; } 61 62 int64_t getColImm() const { return ColImm; } 63 64 bool isValid() { return (Row != nullptr) && (Col != nullptr); } 65 66 void deduceImm(const MachineRegisterInfo *MRI) { 67 // All def must be the same value, otherwise it is invalid MIs. 68 // Find the immediate. 69 // TODO copy propagation. 70 auto GetImm = [&](Register Reg) { 71 int64_t Imm = InvalidImmShape; 72 for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { 73 const auto *MI = DefMO.getParent(); 74 if (MI->isMoveImmediate()) { 75 Imm = MI->getOperand(1).getImm(); 76 break; 77 } 78 } 79 return Imm; 80 }; 81 RowImm = GetImm(Row->getReg()); 82 ColImm = GetImm(Col->getReg()); 83 } 84 85 private: 86 static constexpr int64_t InvalidImmShape = -1; 87 MachineOperand *Row; 88 MachineOperand *Col; 89 int64_t RowImm = -1; 90 int64_t ColImm = -1; 91 }; 92 93 } // namespace llvm 94 95 #endif 96