1e8d8bef9SDimitry Andric //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 2e8d8bef9SDimitry Andric // 3e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e8d8bef9SDimitry Andric // 7e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===// 8e8d8bef9SDimitry Andric // 9e8d8bef9SDimitry Andric /// \file Shape utility for AMX. 10e8d8bef9SDimitry Andric /// AMX hardware requires to config the shape of tile data register before use. 11e8d8bef9SDimitry Andric /// The 2D shape includes row and column. In AMX intrinsics interface the shape 12e8d8bef9SDimitry Andric /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd 13e8d8bef9SDimitry Andric /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate 14e8d8bef9SDimitry Andric /// tile config and register allocator. The row and column are machine operand 15e8d8bef9SDimitry Andric /// of AMX pseudo instructions. 16e8d8bef9SDimitry Andric // 17e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===// 18e8d8bef9SDimitry Andric 19e8d8bef9SDimitry Andric #ifndef LLVM_CODEGEN_TILESHAPEINFO_H 20e8d8bef9SDimitry Andric #define LLVM_CODEGEN_TILESHAPEINFO_H 21e8d8bef9SDimitry Andric 22e8d8bef9SDimitry Andric #include "llvm/ADT/DenseMapInfo.h" 23e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineInstr.h" 24e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineOperand.h" 25e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h" 26e8d8bef9SDimitry Andric #include "llvm/CodeGen/Register.h" 27e8d8bef9SDimitry Andric 28e8d8bef9SDimitry Andric namespace llvm { 29e8d8bef9SDimitry Andric 30e8d8bef9SDimitry Andric class ShapeT { 31e8d8bef9SDimitry Andric public: 32e8d8bef9SDimitry Andric ShapeT(MachineOperand *Row, MachineOperand *Col, 33e8d8bef9SDimitry Andric const MachineRegisterInfo *MRI = nullptr) 34e8d8bef9SDimitry Andric : Row(Row), Col(Col) { 35e8d8bef9SDimitry Andric if (MRI) 36e8d8bef9SDimitry Andric deduceImm(MRI); 37e8d8bef9SDimitry Andric } 38e8d8bef9SDimitry Andric ShapeT() 39e8d8bef9SDimitry Andric : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 40e8d8bef9SDimitry Andric ColImm(InvalidImmShape) {} 41*81ad6265SDimitry Andric bool operator==(const ShapeT &Shape) const { 42e8d8bef9SDimitry Andric MachineOperand *R = Shape.Row; 43e8d8bef9SDimitry Andric MachineOperand *C = Shape.Col; 44e8d8bef9SDimitry Andric if (!R || !C) 45e8d8bef9SDimitry Andric return false; 46e8d8bef9SDimitry Andric if (!Row || !Col) 47e8d8bef9SDimitry Andric return false; 48e8d8bef9SDimitry Andric if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) 49e8d8bef9SDimitry Andric return true; 50e8d8bef9SDimitry Andric if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) 51e8d8bef9SDimitry Andric return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); 52e8d8bef9SDimitry Andric return false; 53e8d8bef9SDimitry Andric } 54e8d8bef9SDimitry Andric 55*81ad6265SDimitry Andric bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); } 56e8d8bef9SDimitry Andric 57e8d8bef9SDimitry Andric MachineOperand *getRow() const { return Row; } 58e8d8bef9SDimitry Andric 59e8d8bef9SDimitry Andric MachineOperand *getCol() const { return Col; } 60e8d8bef9SDimitry Andric 61e8d8bef9SDimitry Andric int64_t getRowImm() const { return RowImm; } 62e8d8bef9SDimitry Andric 63e8d8bef9SDimitry Andric int64_t getColImm() const { return ColImm; } 64e8d8bef9SDimitry Andric 65e8d8bef9SDimitry Andric bool isValid() { return (Row != nullptr) && (Col != nullptr); } 66e8d8bef9SDimitry Andric 67e8d8bef9SDimitry Andric void deduceImm(const MachineRegisterInfo *MRI) { 68e8d8bef9SDimitry Andric // All def must be the same value, otherwise it is invalid MIs. 69e8d8bef9SDimitry Andric // Find the immediate. 70e8d8bef9SDimitry Andric // TODO copy propagation. 71e8d8bef9SDimitry Andric auto GetImm = [&](Register Reg) { 72e8d8bef9SDimitry Andric int64_t Imm = InvalidImmShape; 73e8d8bef9SDimitry Andric for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { 74e8d8bef9SDimitry Andric const auto *MI = DefMO.getParent(); 75e8d8bef9SDimitry Andric if (MI->isMoveImmediate()) { 76e8d8bef9SDimitry Andric Imm = MI->getOperand(1).getImm(); 77e8d8bef9SDimitry Andric break; 78e8d8bef9SDimitry Andric } 79e8d8bef9SDimitry Andric } 80e8d8bef9SDimitry Andric return Imm; 81e8d8bef9SDimitry Andric }; 82e8d8bef9SDimitry Andric RowImm = GetImm(Row->getReg()); 83e8d8bef9SDimitry Andric ColImm = GetImm(Col->getReg()); 84e8d8bef9SDimitry Andric } 85e8d8bef9SDimitry Andric 86e8d8bef9SDimitry Andric private: 87e8d8bef9SDimitry Andric static constexpr int64_t InvalidImmShape = -1; 88e8d8bef9SDimitry Andric MachineOperand *Row; 89e8d8bef9SDimitry Andric MachineOperand *Col; 90e8d8bef9SDimitry Andric int64_t RowImm; 91e8d8bef9SDimitry Andric int64_t ColImm; 92e8d8bef9SDimitry Andric }; 93e8d8bef9SDimitry Andric 94e8d8bef9SDimitry Andric } // namespace llvm 95e8d8bef9SDimitry Andric 96e8d8bef9SDimitry Andric #endif 97