1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Shape utility for AMX. 10 /// AMX hardware requires to config the shape of tile data register before use. 11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape 12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd 13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate 14 /// tile config and register allocator. The row and column are machine operand 15 /// of AMX pseudo instructions. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H 20 #define LLVM_CODEGEN_TILESHAPEINFO_H 21 22 #include "llvm/ADT/DenseMapInfo.h" 23 #include "llvm/CodeGen/MachineInstr.h" 24 #include "llvm/CodeGen/MachineOperand.h" 25 #include "llvm/CodeGen/MachineRegisterInfo.h" 26 #include "llvm/CodeGen/Register.h" 27 #include <utility> 28 29 namespace llvm { 30 31 class ShapeT { 32 public: 33 ShapeT(MachineOperand *Row, MachineOperand *Col, 34 const MachineRegisterInfo *MRI = nullptr) 35 : Row(Row), Col(Col) { 36 if (MRI) 37 deduceImm(MRI); 38 } 39 ShapeT() 40 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 41 ColImm(InvalidImmShape) {} 42 bool operator==(const ShapeT &Shape) { 43 MachineOperand *R = Shape.Row; 44 MachineOperand *C = Shape.Col; 45 if (!R || !C) 46 return false; 47 if (!Row || !Col) 48 return false; 49 if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) 50 return true; 51 if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) 52 return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); 53 return false; 54 } 55 56 bool operator!=(const ShapeT &Shape) { return !(*this == Shape); } 57 58 MachineOperand *getRow() const { return Row; } 59 60 MachineOperand *getCol() const { return Col; } 61 62 int64_t getRowImm() const { return RowImm; } 63 64 int64_t getColImm() const { return ColImm; } 65 66 bool isValid() { return (Row != nullptr) && (Col != nullptr); } 67 68 void deduceImm(const MachineRegisterInfo *MRI) { 69 // All def must be the same value, otherwise it is invalid MIs. 70 // Find the immediate. 71 // TODO copy propagation. 72 auto GetImm = [&](Register Reg) { 73 int64_t Imm = InvalidImmShape; 74 for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { 75 const auto *MI = DefMO.getParent(); 76 if (MI->isMoveImmediate()) { 77 Imm = MI->getOperand(1).getImm(); 78 break; 79 } 80 } 81 return Imm; 82 }; 83 RowImm = GetImm(Row->getReg()); 84 ColImm = GetImm(Col->getReg()); 85 } 86 87 private: 88 static constexpr int64_t InvalidImmShape = -1; 89 MachineOperand *Row; 90 MachineOperand *Col; 91 int64_t RowImm; 92 int64_t ColImm; 93 }; 94 95 } // namespace llvm 96 97 #endif 98