xref: /freebsd/contrib/llvm-project/llvm/include/llvm/CodeGen/TileShapeInfo.h (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1e8d8bef9SDimitry Andric //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2e8d8bef9SDimitry Andric //
3e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e8d8bef9SDimitry Andric //
7e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
8e8d8bef9SDimitry Andric //
9e8d8bef9SDimitry Andric /// \file Shape utility for AMX.
10e8d8bef9SDimitry Andric /// AMX hardware requires to config the shape of tile data register before use.
11e8d8bef9SDimitry Andric /// The 2D shape includes row and column. In AMX intrinsics interface the shape
12e8d8bef9SDimitry Andric /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
13e8d8bef9SDimitry Andric /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
14e8d8bef9SDimitry Andric /// tile config and register allocator. The row and column are machine operand
15e8d8bef9SDimitry Andric /// of AMX pseudo instructions.
16e8d8bef9SDimitry Andric //
17e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
18e8d8bef9SDimitry Andric 
19e8d8bef9SDimitry Andric #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
20e8d8bef9SDimitry Andric #define LLVM_CODEGEN_TILESHAPEINFO_H
21e8d8bef9SDimitry Andric 
22e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineInstr.h"
23e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineOperand.h"
24e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h"
25e8d8bef9SDimitry Andric #include "llvm/CodeGen/Register.h"
26e8d8bef9SDimitry Andric 
27e8d8bef9SDimitry Andric namespace llvm {
28e8d8bef9SDimitry Andric 
29e8d8bef9SDimitry Andric class ShapeT {
30e8d8bef9SDimitry Andric public:
31e8d8bef9SDimitry Andric   ShapeT(MachineOperand *Row, MachineOperand *Col,
32e8d8bef9SDimitry Andric          const MachineRegisterInfo *MRI = nullptr)
33e8d8bef9SDimitry Andric       : Row(Row), Col(Col) {
34e8d8bef9SDimitry Andric     if (MRI)
35e8d8bef9SDimitry Andric       deduceImm(MRI);
36e8d8bef9SDimitry Andric   }
37e8d8bef9SDimitry Andric   ShapeT()
38e8d8bef9SDimitry Andric       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
39e8d8bef9SDimitry Andric         ColImm(InvalidImmShape) {}
4081ad6265SDimitry Andric   bool operator==(const ShapeT &Shape) const {
41e8d8bef9SDimitry Andric     MachineOperand *R = Shape.Row;
42e8d8bef9SDimitry Andric     MachineOperand *C = Shape.Col;
43e8d8bef9SDimitry Andric     if (!R || !C)
44e8d8bef9SDimitry Andric       return false;
45e8d8bef9SDimitry Andric     if (!Row || !Col)
46e8d8bef9SDimitry Andric       return false;
47e8d8bef9SDimitry Andric     if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
48e8d8bef9SDimitry Andric       return true;
49e8d8bef9SDimitry Andric     if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
50e8d8bef9SDimitry Andric       return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
51e8d8bef9SDimitry Andric     return false;
52e8d8bef9SDimitry Andric   }
53e8d8bef9SDimitry Andric 
5481ad6265SDimitry Andric   bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); }
55e8d8bef9SDimitry Andric 
56e8d8bef9SDimitry Andric   MachineOperand *getRow() const { return Row; }
57e8d8bef9SDimitry Andric 
58e8d8bef9SDimitry Andric   MachineOperand *getCol() const { return Col; }
59e8d8bef9SDimitry Andric 
60e8d8bef9SDimitry Andric   int64_t getRowImm() const { return RowImm; }
61e8d8bef9SDimitry Andric 
62e8d8bef9SDimitry Andric   int64_t getColImm() const { return ColImm; }
63e8d8bef9SDimitry Andric 
64e8d8bef9SDimitry Andric   bool isValid() { return (Row != nullptr) && (Col != nullptr); }
65e8d8bef9SDimitry Andric 
66e8d8bef9SDimitry Andric   void deduceImm(const MachineRegisterInfo *MRI) {
67e8d8bef9SDimitry Andric     // All def must be the same value, otherwise it is invalid MIs.
68e8d8bef9SDimitry Andric     // Find the immediate.
69e8d8bef9SDimitry Andric     // TODO copy propagation.
70e8d8bef9SDimitry Andric     auto GetImm = [&](Register Reg) {
71e8d8bef9SDimitry Andric       int64_t Imm = InvalidImmShape;
72e8d8bef9SDimitry Andric       for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
73e8d8bef9SDimitry Andric         const auto *MI = DefMO.getParent();
74e8d8bef9SDimitry Andric         if (MI->isMoveImmediate()) {
75e8d8bef9SDimitry Andric           Imm = MI->getOperand(1).getImm();
76e8d8bef9SDimitry Andric           break;
77e8d8bef9SDimitry Andric         }
78e8d8bef9SDimitry Andric       }
79e8d8bef9SDimitry Andric       return Imm;
80e8d8bef9SDimitry Andric     };
81e8d8bef9SDimitry Andric     RowImm = GetImm(Row->getReg());
82e8d8bef9SDimitry Andric     ColImm = GetImm(Col->getReg());
83e8d8bef9SDimitry Andric   }
84e8d8bef9SDimitry Andric 
85e8d8bef9SDimitry Andric private:
86e8d8bef9SDimitry Andric   static constexpr int64_t InvalidImmShape = -1;
87e8d8bef9SDimitry Andric   MachineOperand *Row;
88e8d8bef9SDimitry Andric   MachineOperand *Col;
89*06c3fb27SDimitry Andric   int64_t RowImm = -1;
90*06c3fb27SDimitry Andric   int64_t ColImm = -1;
91e8d8bef9SDimitry Andric };
92e8d8bef9SDimitry Andric 
93e8d8bef9SDimitry Andric } // namespace llvm
94e8d8bef9SDimitry Andric 
95e8d8bef9SDimitry Andric #endif
96