1 //===---------------- llvm/CodeGen/MatchContext.h --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the EmptyMatchContext class and VPMatchContext class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H 14 #define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H 15 16 #include "llvm/CodeGen/SelectionDAG.h" 17 #include "llvm/CodeGen/TargetLowering.h" 18 19 using namespace llvm; 20 21 namespace { 22 class EmptyMatchContext { 23 SelectionDAG &DAG; 24 const TargetLowering &TLI; 25 SDNode *Root; 26 27 public: 28 EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) 29 : DAG(DAG), TLI(TLI), Root(Root) {} 30 31 unsigned getRootBaseOpcode() { return Root->getOpcode(); } 32 bool match(SDValue OpN, unsigned Opcode) const { 33 return Opcode == OpN->getOpcode(); 34 } 35 36 // Same as SelectionDAG::getNode(). 37 template <typename... ArgT> SDValue getNode(ArgT &&...Args) { 38 return DAG.getNode(std::forward<ArgT>(Args)...); 39 } 40 41 bool isOperationLegal(unsigned Op, EVT VT) const { 42 return TLI.isOperationLegal(Op, VT); 43 } 44 45 bool isOperationLegalOrCustom(unsigned Op, EVT VT, 46 bool LegalOnly = false) const { 47 return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly); 48 } 49 }; 50 51 class VPMatchContext { 52 SelectionDAG &DAG; 53 const TargetLowering &TLI; 54 SDValue RootMaskOp; 55 SDValue RootVectorLenOp; 56 SDNode *Root; 57 58 public: 59 VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root) 60 : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() { 61 Root = _Root; 62 assert(Root->isVPOpcode()); 63 if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode())) 64 RootMaskOp = Root->getOperand(*RootMaskPos); 65 else if (Root->getOpcode() == ISD::VP_SELECT) 66 RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root), 67 Root->getOperand(0).getValueType()); 68 69 if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Root->getOpcode())) 70 RootVectorLenOp = Root->getOperand(*RootVLenPos); 71 } 72 73 unsigned getRootBaseOpcode() { 74 std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP( 75 Root->getOpcode(), !Root->getFlags().hasNoFPExcept()); 76 assert(Opcode.has_value()); 77 return *Opcode; 78 } 79 80 /// whether \p OpVal is a node that is functionally compatible with the 81 /// NodeType \p Opc 82 bool match(SDValue OpVal, unsigned Opc) const { 83 if (!OpVal->isVPOpcode()) 84 return OpVal->getOpcode() == Opc; 85 86 auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), 87 !OpVal->getFlags().hasNoFPExcept()); 88 if (BaseOpc != Opc) 89 return false; 90 91 // Make sure the mask of OpVal is true mask or is same as Root's. 92 unsigned VPOpcode = OpVal->getOpcode(); 93 if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) { 94 SDValue MaskOp = OpVal.getOperand(*MaskPos); 95 if (RootMaskOp != MaskOp && 96 !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode())) 97 return false; 98 } 99 100 // Make sure the EVL of OpVal is same as Root's. 101 if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode)) 102 if (RootVectorLenOp != OpVal.getOperand(*VLenPos)) 103 return false; 104 return true; 105 } 106 107 // Specialize based on number of operands. 108 // TODO emit VP intrinsics where MaskOp/VectorLenOp != null 109 // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return 110 // DAG.getNode(Opcode, DL, VT); } 111 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) { 112 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); 113 assert(ISD::getVPMaskIdx(VPOpcode) == 1 && 114 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); 115 return DAG.getNode(VPOpcode, DL, VT, 116 {Operand, RootMaskOp, RootVectorLenOp}); 117 } 118 119 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 120 SDValue N2) { 121 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); 122 assert(ISD::getVPMaskIdx(VPOpcode) == 2 && 123 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); 124 return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}); 125 } 126 127 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 128 SDValue N2, SDValue N3) { 129 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); 130 assert(ISD::getVPMaskIdx(VPOpcode) == 3 && 131 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); 132 return DAG.getNode(VPOpcode, DL, VT, 133 {N1, N2, N3, RootMaskOp, RootVectorLenOp}); 134 } 135 136 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, 137 SDNodeFlags Flags) { 138 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); 139 assert(ISD::getVPMaskIdx(VPOpcode) == 1 && 140 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); 141 return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, 142 Flags); 143 } 144 145 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 146 SDValue N2, SDNodeFlags Flags) { 147 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); 148 assert(ISD::getVPMaskIdx(VPOpcode) == 2 && 149 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); 150 return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, 151 Flags); 152 } 153 154 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, 155 SDValue N2, SDValue N3, SDNodeFlags Flags) { 156 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); 157 assert(ISD::getVPMaskIdx(VPOpcode) == 3 && 158 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); 159 return DAG.getNode(VPOpcode, DL, VT, 160 {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); 161 } 162 163 bool isOperationLegal(unsigned Op, EVT VT) const { 164 unsigned VPOp = ISD::getVPForBaseOpcode(Op); 165 return TLI.isOperationLegal(VPOp, VT); 166 } 167 168 bool isOperationLegalOrCustom(unsigned Op, EVT VT, 169 bool LegalOnly = false) const { 170 unsigned VPOp = ISD::getVPForBaseOpcode(Op); 171 return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly); 172 } 173 }; 174 } // end anonymous namespace 175 #endif 176