1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Shape utility for AMX. 10 /// AMX hardware requires to config the shape of tile data register before use. 11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape 12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd 13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate 14 /// tile config and register allocator. The row and column are machine operand 15 /// of AMX pseudo instructions. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H 20 #define LLVM_CODEGEN_TILESHAPEINFO_H 21 22 #include "llvm/ADT/DenseMapInfo.h" 23 #include "llvm/CodeGen/MachineInstr.h" 24 #include "llvm/CodeGen/MachineOperand.h" 25 #include "llvm/CodeGen/MachineRegisterInfo.h" 26 #include "llvm/CodeGen/Register.h" 27 28 namespace llvm { 29 30 class ShapeT { 31 public: 32 ShapeT(MachineOperand *Row, MachineOperand *Col, 33 const MachineRegisterInfo *MRI = nullptr) 34 : Row(Row), Col(Col) { 35 if (MRI) 36 deduceImm(MRI); 37 } 38 ShapeT() 39 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 40 ColImm(InvalidImmShape) {} 41 bool operator==(const ShapeT &Shape) { 42 MachineOperand *R = Shape.Row; 43 MachineOperand *C = Shape.Col; 44 if (!R || !C) 45 return false; 46 if (!Row || !Col) 47 return false; 48 if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) 49 return true; 50 if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) 51 return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); 52 return false; 53 } 54 55 bool operator!=(const ShapeT &Shape) { return !(*this == Shape); } 56 57 MachineOperand *getRow() const { return Row; } 58 59 MachineOperand *getCol() const { return Col; } 60 61 int64_t getRowImm() const { return RowImm; } 62 63 int64_t getColImm() const { return ColImm; } 64 65 bool isValid() { return (Row != nullptr) && (Col != nullptr); } 66 67 void deduceImm(const MachineRegisterInfo *MRI) { 68 // All def must be the same value, otherwise it is invalid MIs. 69 // Find the immediate. 70 // TODO copy propagation. 71 auto GetImm = [&](Register Reg) { 72 int64_t Imm = InvalidImmShape; 73 for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { 74 const auto *MI = DefMO.getParent(); 75 if (MI->isMoveImmediate()) { 76 Imm = MI->getOperand(1).getImm(); 77 break; 78 } 79 } 80 return Imm; 81 }; 82 RowImm = GetImm(Row->getReg()); 83 ColImm = GetImm(Col->getReg()); 84 } 85 86 private: 87 static constexpr int64_t InvalidImmShape = -1; 88 MachineOperand *Row; 89 MachineOperand *Col; 90 int64_t RowImm; 91 int64_t ColImm; 92 }; 93 94 } // namespace llvm 95 96 #endif 97