1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Shape utility for AMX. 10 /// AMX hardware requires to config the shape of tile data register before use. 11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape 12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd 13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate 14 /// tile config and register allocator. The row and column are machine operand 15 /// of AMX pseudo instructions. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H 20 #define LLVM_CODEGEN_TILESHAPEINFO_H 21 22 #include "llvm/CodeGen/MachineInstr.h" 23 #include "llvm/CodeGen/MachineOperand.h" 24 #include "llvm/CodeGen/MachineRegisterInfo.h" 25 #include "llvm/CodeGen/Register.h" 26 27 namespace llvm { 28 29 class ShapeT { 30 public: 31 ShapeT(MachineOperand *Row, MachineOperand *Col, 32 const MachineRegisterInfo *MRI = nullptr) Row(Row)33 : Row(Row), Col(Col) { 34 if (MRI) 35 deduceImm(MRI); 36 } 37 // When ShapeT has multiple shapes, we only use Shapes (never use Row and Col) 38 // and ImmShapes. Due to the most case is only one shape (just simply use 39 // Shape.Row or Shape.Col), so here we don't merge Row and Col into vector 40 // Shapes to keep the speed and code simplicity. 41 // TODO: The upper solution is a temporary way to minimize current tile 42 // register allocation code changes. It can not handle both Reg shape and 43 // Imm shape for different shapes (e.g. shape 1 is reg shape while shape 2 44 // is imm shape). Refine me when we have more multi-tile shape instructions! 45 ShapeT(ArrayRef<MachineOperand *> ShapesOperands, 46 const MachineRegisterInfo *MRI = nullptr) Row(nullptr)47 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 48 ColImm(InvalidImmShape) { 49 assert(ShapesOperands.size() % 2 == 0 && "Miss row or col!"); 50 51 llvm::append_range(Shapes, ShapesOperands); 52 53 if (MRI) 54 deduceImm(MRI); 55 } ShapeT()56 ShapeT() 57 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 58 ColImm(InvalidImmShape) {} 59 // TODO: We need to extern cmp operator for multi-shapes if 60 // we have requirement in the future. 61 bool operator==(const ShapeT &Shape) const { 62 MachineOperand *R = Shape.Row; 63 MachineOperand *C = Shape.Col; 64 if (!R || !C) 65 return false; 66 if (!Row || !Col) 67 return false; 68 if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) 69 return true; 70 if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) 71 return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); 72 return false; 73 } 74 75 bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); } 76 77 MachineOperand *getRow(unsigned I = 0) const { 78 if (Shapes.empty()) 79 return Row; 80 assert(Shapes.size() / 2 >= I && "Get invalid row from id!"); 81 return Shapes[I * 2]; 82 } 83 84 MachineOperand *getCol(unsigned I = 0) const { 85 if (Shapes.empty()) 86 return Col; 87 assert(Shapes.size() / 2 >= I && "Get invalid col from id!"); 88 return Shapes[I * 2 + 1]; 89 } 90 91 int64_t getRowImm(unsigned I = 0) const { 92 if (ImmShapes.empty()) 93 return RowImm; 94 assert(ImmShapes.size() / 2 >= I && "Get invalid imm row from id!"); 95 return ImmShapes[I * 2]; 96 } 97 98 int64_t getColImm(unsigned I = 0) const { 99 if (ImmShapes.empty()) 100 return ColImm; 101 assert(ImmShapes.size() / 2 >= I && "Get invalid imm col from id!"); 102 return ImmShapes[I * 2 + 1]; 103 } 104 getShapeNum()105 unsigned getShapeNum() { 106 if (Shapes.empty()) 107 return isValid() ? 1 : 0; 108 else 109 return Shapes.size() / 2; 110 } 111 isValid()112 bool isValid() { return (Row != nullptr) && (Col != nullptr); } 113 deduceImm(const MachineRegisterInfo * MRI)114 void deduceImm(const MachineRegisterInfo *MRI) { 115 // All def must be the same value, otherwise it is invalid MIs. 116 // Find the immediate. 117 // TODO copy propagation. 118 auto GetImm = [&](Register Reg) { 119 int64_t Imm = InvalidImmShape; 120 for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { 121 const auto *MI = DefMO.getParent(); 122 if (MI->isMoveImmediate()) { 123 assert(MI->getNumOperands() == 2 && 124 "Unsupported number of operands in instruction for setting " 125 "row/column."); 126 if (MI->getOperand(1).isImm()) { 127 Imm = MI->getOperand(1).getImm(); 128 } else { 129 assert(MI->getOperand(1).isImplicit() && 130 "Operand 1 is assumed to be implicit."); 131 Imm = 0; 132 } 133 break; 134 } 135 } 136 return Imm; 137 }; 138 if (Shapes.empty()) { // Single Shape 139 RowImm = GetImm(Row->getReg()); 140 ColImm = GetImm(Col->getReg()); 141 // The number of rows of 2nd destination buffer is assigned by the one of 142 // 1st destination buffer. If the column size is equal to zero, the row 143 // size should be reset to zero too. 144 if (ColImm == 0) 145 Row = Col; 146 } else { // Multiple Shapes 147 for (auto *Shape : Shapes) { 148 int64_t ImmShape = GetImm(Shape->getReg()); 149 ImmShapes.push_back(ImmShape); 150 } 151 } 152 } 153 154 private: 155 static constexpr int64_t InvalidImmShape = -1; 156 MachineOperand *Row; 157 MachineOperand *Col; 158 int64_t RowImm = -1; 159 int64_t ColImm = -1; 160 // Multiple Shapes 161 SmallVector<MachineOperand *, 0> Shapes; 162 SmallVector<int64_t, 0> ImmShapes; 163 }; 164 165 } // namespace llvm 166 167 #endif 168