xref: /freebsd/contrib/llvm-project/llvm/include/llvm/CodeGen/TileShapeInfo.h (revision e1e636193db45630c7881246d25902e57c43d24e)
1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Shape utility for AMX.
10 /// AMX hardware requires to config the shape of tile data register before use.
11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape
12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
14 /// tile config and register allocator. The row and column are machine operand
15 /// of AMX pseudo instructions.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
20 #define LLVM_CODEGEN_TILESHAPEINFO_H
21 
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineOperand.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/Register.h"
26 
27 namespace llvm {
28 
29 class ShapeT {
30 public:
31   ShapeT(MachineOperand *Row, MachineOperand *Col,
32          const MachineRegisterInfo *MRI = nullptr)
33       : Row(Row), Col(Col) {
34     if (MRI)
35       deduceImm(MRI);
36   }
37   ShapeT()
38       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
39         ColImm(InvalidImmShape) {}
40   bool operator==(const ShapeT &Shape) const {
41     MachineOperand *R = Shape.Row;
42     MachineOperand *C = Shape.Col;
43     if (!R || !C)
44       return false;
45     if (!Row || !Col)
46       return false;
47     if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
48       return true;
49     if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
50       return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
51     return false;
52   }
53 
54   bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); }
55 
56   MachineOperand *getRow() const { return Row; }
57 
58   MachineOperand *getCol() const { return Col; }
59 
60   int64_t getRowImm() const { return RowImm; }
61 
62   int64_t getColImm() const { return ColImm; }
63 
64   bool isValid() { return (Row != nullptr) && (Col != nullptr); }
65 
66   void deduceImm(const MachineRegisterInfo *MRI) {
67     // All def must be the same value, otherwise it is invalid MIs.
68     // Find the immediate.
69     // TODO copy propagation.
70     auto GetImm = [&](Register Reg) {
71       int64_t Imm = InvalidImmShape;
72       for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
73         const auto *MI = DefMO.getParent();
74         if (MI->isMoveImmediate()) {
75           Imm = MI->getOperand(1).getImm();
76           break;
77         }
78       }
79       return Imm;
80     };
81     RowImm = GetImm(Row->getReg());
82     ColImm = GetImm(Col->getReg());
83   }
84 
85 private:
86   static constexpr int64_t InvalidImmShape = -1;
87   MachineOperand *Row;
88   MachineOperand *Col;
89   int64_t RowImm = -1;
90   int64_t ColImm = -1;
91 };
92 
93 } // namespace llvm
94 
95 #endif
96