xref: /freebsd/contrib/llvm-project/llvm/include/llvm/CodeGen/TileShapeInfo.h (revision 9c77fb6aaa366cbabc80ee1b834bcfe4df135491)
1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Shape utility for AMX.
10 /// AMX hardware requires to config the shape of tile data register before use.
11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape
12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
14 /// tile config and register allocator. The row and column are machine operand
15 /// of AMX pseudo instructions.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
20 #define LLVM_CODEGEN_TILESHAPEINFO_H
21 
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineOperand.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/Register.h"
26 
27 namespace llvm {
28 
29 class ShapeT {
30 public:
31   ShapeT(MachineOperand *Row, MachineOperand *Col,
32          const MachineRegisterInfo *MRI = nullptr)
33       : Row(Row), Col(Col) {
34     if (MRI)
35       deduceImm(MRI);
36   }
37   // When ShapeT has multiple shapes, we only use Shapes (never use Row and Col)
38   // and ImmShapes. Due to the most case is only one shape (just simply use
39   // Shape.Row or Shape.Col), so here we don't merge Row and Col into vector
40   // Shapes to keep the speed and code simplicity.
41   // TODO: The upper solution is a temporary way to minimize current tile
42   // register allocation code changes. It can not handle both Reg shape and
43   // Imm shape for different shapes (e.g. shape 1 is reg shape while shape 2
44   // is imm shape). Refine me when we have more multi-tile shape instructions!
45   ShapeT(ArrayRef<MachineOperand *> ShapesOperands,
46          const MachineRegisterInfo *MRI = nullptr)
47       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
48         ColImm(InvalidImmShape) {
49     assert(ShapesOperands.size() % 2 == 0 && "Miss row or col!");
50 
51     llvm::append_range(Shapes, ShapesOperands);
52 
53     if (MRI)
54       deduceImm(MRI);
55   }
56   ShapeT()
57       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
58         ColImm(InvalidImmShape) {}
59   // TODO: We need to extern cmp operator for multi-shapes if
60   // we have requirement in the future.
61   bool operator==(const ShapeT &Shape) const {
62     MachineOperand *R = Shape.Row;
63     MachineOperand *C = Shape.Col;
64     if (!R || !C)
65       return false;
66     if (!Row || !Col)
67       return false;
68     if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
69       return true;
70     if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
71       return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
72     return false;
73   }
74 
75   bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); }
76 
77   MachineOperand *getRow(unsigned I = 0) const {
78     if (Shapes.empty())
79       return Row;
80     assert(Shapes.size() / 2 >= I && "Get invalid row from id!");
81     return Shapes[I * 2];
82   }
83 
84   MachineOperand *getCol(unsigned I = 0) const {
85     if (Shapes.empty())
86       return Col;
87     assert(Shapes.size() / 2 >= I && "Get invalid col from id!");
88     return Shapes[I * 2 + 1];
89   }
90 
91   int64_t getRowImm(unsigned I = 0) const {
92     if (ImmShapes.empty())
93       return RowImm;
94     assert(ImmShapes.size() / 2 >= I && "Get invalid imm row from id!");
95     return ImmShapes[I * 2];
96   }
97 
98   int64_t getColImm(unsigned I = 0) const {
99     if (ImmShapes.empty())
100       return ColImm;
101     assert(ImmShapes.size() / 2 >= I && "Get invalid imm col from id!");
102     return ImmShapes[I * 2 + 1];
103   }
104 
105   unsigned getShapeNum() {
106     if (Shapes.empty())
107       return isValid() ? 1 : 0;
108     else
109       return Shapes.size() / 2;
110   }
111 
112   bool isValid() { return (Row != nullptr) && (Col != nullptr); }
113 
114   void deduceImm(const MachineRegisterInfo *MRI) {
115     // All def must be the same value, otherwise it is invalid MIs.
116     // Find the immediate.
117     // TODO copy propagation.
118     auto GetImm = [&](Register Reg) {
119       int64_t Imm = InvalidImmShape;
120       for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
121         const auto *MI = DefMO.getParent();
122         if (MI->isMoveImmediate()) {
123           assert(MI->getNumOperands() == 2 &&
124                  "Unsupported number of operands in instruction for setting "
125                  "row/column.");
126           if (MI->getOperand(1).isImm()) {
127             Imm = MI->getOperand(1).getImm();
128           } else {
129             assert(MI->getOperand(1).isImplicit() &&
130                    "Operand 1 is assumed to be implicit.");
131             Imm = 0;
132           }
133           break;
134         }
135       }
136       return Imm;
137     };
138     if (Shapes.empty()) { // Single Shape
139       RowImm = GetImm(Row->getReg());
140       ColImm = GetImm(Col->getReg());
141       // The number of rows of 2nd destination buffer is assigned by the one of
142       // 1st destination buffer. If the column size is equal to zero, the row
143       // size should be reset to zero too.
144       if (ColImm == 0)
145         Row = Col;
146     } else { // Multiple Shapes
147       for (auto *Shape : Shapes) {
148         int64_t ImmShape = GetImm(Shape->getReg());
149         ImmShapes.push_back(ImmShape);
150       }
151     }
152   }
153 
154 private:
155   static constexpr int64_t InvalidImmShape = -1;
156   MachineOperand *Row;
157   MachineOperand *Col;
158   int64_t RowImm = -1;
159   int64_t ColImm = -1;
160   // Multiple Shapes
161   SmallVector<MachineOperand *, 0> Shapes;
162   SmallVector<int64_t, 0> ImmShapes;
163 };
164 
165 } // namespace llvm
166 
167 #endif
168