xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/MatchContext.h (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===---------------- llvm/CodeGen/MatchContext.h  --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the EmptyMatchContext class and VPMatchContext class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
14 #define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
15 
16 #include "llvm/CodeGen/SelectionDAG.h"
17 #include "llvm/CodeGen/TargetLowering.h"
18 
19 using namespace llvm;
20 
21 namespace {
22 class EmptyMatchContext {
23   SelectionDAG &DAG;
24   const TargetLowering &TLI;
25   SDNode *Root;
26 
27 public:
28   EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
29       : DAG(DAG), TLI(TLI), Root(Root) {}
30 
31   unsigned getRootBaseOpcode() { return Root->getOpcode(); }
32   bool match(SDValue OpN, unsigned Opcode) const {
33     return Opcode == OpN->getOpcode();
34   }
35 
36   // Same as SelectionDAG::getNode().
37   template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
38     return DAG.getNode(std::forward<ArgT>(Args)...);
39   }
40 
41   bool isOperationLegal(unsigned Op, EVT VT) const {
42     return TLI.isOperationLegal(Op, VT);
43   }
44 
45   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
46                                 bool LegalOnly = false) const {
47     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
48   }
49 };
50 
51 class VPMatchContext {
52   SelectionDAG &DAG;
53   const TargetLowering &TLI;
54   SDValue RootMaskOp;
55   SDValue RootVectorLenOp;
56   SDNode *Root;
57 
58 public:
59   VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root)
60       : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
61     Root = _Root;
62     assert(Root->isVPOpcode());
63     if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
64       RootMaskOp = Root->getOperand(*RootMaskPos);
65     else if (Root->getOpcode() == ISD::VP_SELECT)
66       RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
67                                           Root->getOperand(0).getValueType());
68 
69     if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
70       RootVectorLenOp = Root->getOperand(*RootVLenPos);
71   }
72 
73   unsigned getRootBaseOpcode() {
74     std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP(
75         Root->getOpcode(), !Root->getFlags().hasNoFPExcept());
76     assert(Opcode.has_value());
77     return *Opcode;
78   }
79 
80   /// whether \p OpVal is a node that is functionally compatible with the
81   /// NodeType \p Opc
82   bool match(SDValue OpVal, unsigned Opc) const {
83     if (!OpVal->isVPOpcode())
84       return OpVal->getOpcode() == Opc;
85 
86     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
87                                            !OpVal->getFlags().hasNoFPExcept());
88     if (BaseOpc != Opc)
89       return false;
90 
91     // Make sure the mask of OpVal is true mask or is same as Root's.
92     unsigned VPOpcode = OpVal->getOpcode();
93     if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
94       SDValue MaskOp = OpVal.getOperand(*MaskPos);
95       if (RootMaskOp != MaskOp &&
96           !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
97         return false;
98     }
99 
100     // Make sure the EVL of OpVal is same as Root's.
101     if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
102       if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
103         return false;
104     return true;
105   }
106 
107   // Specialize based on number of operands.
108   // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
109   // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
110   // DAG.getNode(Opcode, DL, VT); }
111   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
112     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
113     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
114            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
115     return DAG.getNode(VPOpcode, DL, VT,
116                        {Operand, RootMaskOp, RootVectorLenOp});
117   }
118 
119   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
120                   SDValue N2) {
121     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
122     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
123            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
124     return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp});
125   }
126 
127   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
128                   SDValue N2, SDValue N3) {
129     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
130     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
131            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
132     return DAG.getNode(VPOpcode, DL, VT,
133                        {N1, N2, N3, RootMaskOp, RootVectorLenOp});
134   }
135 
136   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
137                   SDNodeFlags Flags) {
138     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
139     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
140            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
141     return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
142                        Flags);
143   }
144 
145   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
146                   SDValue N2, SDNodeFlags Flags) {
147     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
148     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
149            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
150     return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
151                        Flags);
152   }
153 
154   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
155                   SDValue N2, SDValue N3, SDNodeFlags Flags) {
156     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
157     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
158            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
159     return DAG.getNode(VPOpcode, DL, VT,
160                        {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
161   }
162 
163   bool isOperationLegal(unsigned Op, EVT VT) const {
164     unsigned VPOp = ISD::getVPForBaseOpcode(Op);
165     return TLI.isOperationLegal(VPOp, VT);
166   }
167 
168   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
169                                 bool LegalOnly = false) const {
170     unsigned VPOp = ISD::getVPForBaseOpcode(Op);
171     return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
172   }
173 };
174 } // end anonymous namespace
175 #endif
176