xref: /freebsd/contrib/llvm-project/llvm/include/llvm/CodeGen/TileShapeInfo.h (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
1e8d8bef9SDimitry Andric //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2e8d8bef9SDimitry Andric //
3e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e8d8bef9SDimitry Andric //
7e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
8e8d8bef9SDimitry Andric //
9e8d8bef9SDimitry Andric /// \file Shape utility for AMX.
10e8d8bef9SDimitry Andric /// AMX hardware requires to config the shape of tile data register before use.
11e8d8bef9SDimitry Andric /// The 2D shape includes row and column. In AMX intrinsics interface the shape
12e8d8bef9SDimitry Andric /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
13e8d8bef9SDimitry Andric /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
14e8d8bef9SDimitry Andric /// tile config and register allocator. The row and column are machine operand
15e8d8bef9SDimitry Andric /// of AMX pseudo instructions.
16e8d8bef9SDimitry Andric //
17e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
18e8d8bef9SDimitry Andric 
19e8d8bef9SDimitry Andric #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
20e8d8bef9SDimitry Andric #define LLVM_CODEGEN_TILESHAPEINFO_H
21e8d8bef9SDimitry Andric 
22e8d8bef9SDimitry Andric #include "llvm/ADT/DenseMapInfo.h"
23e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineInstr.h"
24e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineOperand.h"
25e8d8bef9SDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h"
26e8d8bef9SDimitry Andric #include "llvm/CodeGen/Register.h"
27e8d8bef9SDimitry Andric 
28e8d8bef9SDimitry Andric namespace llvm {
29e8d8bef9SDimitry Andric 
30e8d8bef9SDimitry Andric class ShapeT {
31e8d8bef9SDimitry Andric public:
32e8d8bef9SDimitry Andric   ShapeT(MachineOperand *Row, MachineOperand *Col,
33e8d8bef9SDimitry Andric          const MachineRegisterInfo *MRI = nullptr)
34e8d8bef9SDimitry Andric       : Row(Row), Col(Col) {
35e8d8bef9SDimitry Andric     if (MRI)
36e8d8bef9SDimitry Andric       deduceImm(MRI);
37e8d8bef9SDimitry Andric   }
38e8d8bef9SDimitry Andric   ShapeT()
39e8d8bef9SDimitry Andric       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
40e8d8bef9SDimitry Andric         ColImm(InvalidImmShape) {}
41*81ad6265SDimitry Andric   bool operator==(const ShapeT &Shape) const {
42e8d8bef9SDimitry Andric     MachineOperand *R = Shape.Row;
43e8d8bef9SDimitry Andric     MachineOperand *C = Shape.Col;
44e8d8bef9SDimitry Andric     if (!R || !C)
45e8d8bef9SDimitry Andric       return false;
46e8d8bef9SDimitry Andric     if (!Row || !Col)
47e8d8bef9SDimitry Andric       return false;
48e8d8bef9SDimitry Andric     if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
49e8d8bef9SDimitry Andric       return true;
50e8d8bef9SDimitry Andric     if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
51e8d8bef9SDimitry Andric       return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
52e8d8bef9SDimitry Andric     return false;
53e8d8bef9SDimitry Andric   }
54e8d8bef9SDimitry Andric 
55*81ad6265SDimitry Andric   bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); }
56e8d8bef9SDimitry Andric 
57e8d8bef9SDimitry Andric   MachineOperand *getRow() const { return Row; }
58e8d8bef9SDimitry Andric 
59e8d8bef9SDimitry Andric   MachineOperand *getCol() const { return Col; }
60e8d8bef9SDimitry Andric 
61e8d8bef9SDimitry Andric   int64_t getRowImm() const { return RowImm; }
62e8d8bef9SDimitry Andric 
63e8d8bef9SDimitry Andric   int64_t getColImm() const { return ColImm; }
64e8d8bef9SDimitry Andric 
65e8d8bef9SDimitry Andric   bool isValid() { return (Row != nullptr) && (Col != nullptr); }
66e8d8bef9SDimitry Andric 
67e8d8bef9SDimitry Andric   void deduceImm(const MachineRegisterInfo *MRI) {
68e8d8bef9SDimitry Andric     // All def must be the same value, otherwise it is invalid MIs.
69e8d8bef9SDimitry Andric     // Find the immediate.
70e8d8bef9SDimitry Andric     // TODO copy propagation.
71e8d8bef9SDimitry Andric     auto GetImm = [&](Register Reg) {
72e8d8bef9SDimitry Andric       int64_t Imm = InvalidImmShape;
73e8d8bef9SDimitry Andric       for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
74e8d8bef9SDimitry Andric         const auto *MI = DefMO.getParent();
75e8d8bef9SDimitry Andric         if (MI->isMoveImmediate()) {
76e8d8bef9SDimitry Andric           Imm = MI->getOperand(1).getImm();
77e8d8bef9SDimitry Andric           break;
78e8d8bef9SDimitry Andric         }
79e8d8bef9SDimitry Andric       }
80e8d8bef9SDimitry Andric       return Imm;
81e8d8bef9SDimitry Andric     };
82e8d8bef9SDimitry Andric     RowImm = GetImm(Row->getReg());
83e8d8bef9SDimitry Andric     ColImm = GetImm(Col->getReg());
84e8d8bef9SDimitry Andric   }
85e8d8bef9SDimitry Andric 
86e8d8bef9SDimitry Andric private:
87e8d8bef9SDimitry Andric   static constexpr int64_t InvalidImmShape = -1;
88e8d8bef9SDimitry Andric   MachineOperand *Row;
89e8d8bef9SDimitry Andric   MachineOperand *Col;
90e8d8bef9SDimitry Andric   int64_t RowImm;
91e8d8bef9SDimitry Andric   int64_t ColImm;
92e8d8bef9SDimitry Andric };
93e8d8bef9SDimitry Andric 
94e8d8bef9SDimitry Andric } // namespace llvm
95e8d8bef9SDimitry Andric 
96e8d8bef9SDimitry Andric #endif
97