1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a simple and efficient mechanism for performing general 10 // tree-based pattern matches on the VPlan values and recipes, based on 11 // LLVM's IR pattern matchers. 12 // 13 // Currently it provides generic matchers for unary and binary VPInstructions, 14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond, 15 // m_BranchOnCount to match specific VPInstructions. 16 // TODO: Add missing matchers for additional opcodes and recipes as needed. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H 21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H 22 23 #include "VPlan.h" 24 25 namespace llvm { 26 namespace VPlanPatternMatch { 27 28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) { 29 return const_cast<Pattern &>(P).match(V); 30 } 31 32 template <typename Class> struct class_match { 33 template <typename ITy> bool match(ITy *V) { return isa<Class>(V); } 34 }; 35 36 /// Match an arbitrary VPValue and ignore it. 37 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); } 38 39 template <typename Class> struct bind_ty { 40 Class *&VR; 41 42 bind_ty(Class *&V) : VR(V) {} 43 44 template <typename ITy> bool match(ITy *V) { 45 if (auto *CV = dyn_cast<Class>(V)) { 46 VR = CV; 47 return true; 48 } 49 return false; 50 } 51 }; 52 53 /// Match a specified integer value or vector of all elements of that 54 /// value. \p BitWidth optionally specifies the bitwidth the matched constant 55 /// must have. If it is 0, the matched constant can have any bitwidth. 56 template <unsigned BitWidth = 0> struct specific_intval { 57 APInt Val; 58 59 specific_intval(APInt V) : Val(std::move(V)) {} 60 61 bool match(VPValue *VPV) { 62 if (!VPV->isLiveIn()) 63 return false; 64 Value *V = VPV->getLiveInIRValue(); 65 const auto *CI = dyn_cast<ConstantInt>(V); 66 if (!CI && V->getType()->isVectorTy()) 67 if (const auto *C = dyn_cast<Constant>(V)) 68 CI = dyn_cast_or_null<ConstantInt>( 69 C->getSplatValue(/*AllowPoison=*/false)); 70 if (!CI) 71 return false; 72 73 assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) && 74 "Trying the match constant with unexpected bitwidth."); 75 return APInt::isSameValue(CI->getValue(), Val); 76 } 77 }; 78 79 inline specific_intval<0> m_SpecificInt(uint64_t V) { 80 return specific_intval<0>(APInt(64, V)); 81 } 82 83 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); } 84 85 /// Matching combinators 86 template <typename LTy, typename RTy> struct match_combine_or { 87 LTy L; 88 RTy R; 89 90 match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} 91 92 template <typename ITy> bool match(ITy *V) { 93 if (L.match(V)) 94 return true; 95 if (R.match(V)) 96 return true; 97 return false; 98 } 99 }; 100 101 template <typename LTy, typename RTy> 102 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) { 103 return match_combine_or<LTy, RTy>(L, R); 104 } 105 106 /// Match a VPValue, capturing it if we match. 107 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; } 108 109 namespace detail { 110 111 /// A helper to match an opcode against multiple recipe types. 112 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {}; 113 114 template <unsigned Opcode, typename RecipeTy> 115 struct MatchRecipeAndOpcode<Opcode, RecipeTy> { 116 static bool match(const VPRecipeBase *R) { 117 auto *DefR = dyn_cast<RecipeTy>(R); 118 return DefR && DefR->getOpcode() == Opcode; 119 } 120 }; 121 122 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys> 123 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> { 124 static bool match(const VPRecipeBase *R) { 125 return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) || 126 MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R); 127 } 128 }; 129 } // namespace detail 130 131 template <typename Op0_t, unsigned Opcode, typename... RecipeTys> 132 struct UnaryRecipe_match { 133 Op0_t Op0; 134 135 UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {} 136 137 bool match(const VPValue *V) { 138 auto *DefR = V->getDefiningRecipe(); 139 return DefR && match(DefR); 140 } 141 142 bool match(const VPRecipeBase *R) { 143 if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R)) 144 return false; 145 assert(R->getNumOperands() == 1 && 146 "recipe with matched opcode does not have 1 operands"); 147 return Op0.match(R->getOperand(0)); 148 } 149 }; 150 151 template <typename Op0_t, unsigned Opcode> 152 using UnaryVPInstruction_match = 153 UnaryRecipe_match<Op0_t, Opcode, VPInstruction>; 154 155 template <typename Op0_t, unsigned Opcode> 156 using AllUnaryRecipe_match = 157 UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe, 158 VPWidenCastRecipe, VPInstruction>; 159 160 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative, 161 typename... RecipeTys> 162 struct BinaryRecipe_match { 163 Op0_t Op0; 164 Op1_t Op1; 165 166 BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} 167 168 bool match(const VPValue *V) { 169 auto *DefR = V->getDefiningRecipe(); 170 return DefR && match(DefR); 171 } 172 173 bool match(const VPSingleDefRecipe *R) { 174 return match(static_cast<const VPRecipeBase *>(R)); 175 } 176 177 bool match(const VPRecipeBase *R) { 178 if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R)) 179 return false; 180 assert(R->getNumOperands() == 2 && 181 "recipe with matched opcode does not have 2 operands"); 182 if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1))) 183 return true; 184 return Commutative && Op0.match(R->getOperand(1)) && 185 Op1.match(R->getOperand(0)); 186 } 187 }; 188 189 template <typename Op0_t, typename Op1_t, unsigned Opcode> 190 using BinaryVPInstruction_match = 191 BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false, 192 VPInstruction>; 193 194 template <typename Op0_t, typename Op1_t, unsigned Opcode, 195 bool Commutative = false> 196 using AllBinaryRecipe_match = 197 BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe, 198 VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>; 199 200 template <unsigned Opcode, typename Op0_t> 201 inline UnaryVPInstruction_match<Op0_t, Opcode> 202 m_VPInstruction(const Op0_t &Op0) { 203 return UnaryVPInstruction_match<Op0_t, Opcode>(Op0); 204 } 205 206 template <unsigned Opcode, typename Op0_t, typename Op1_t> 207 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode> 208 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) { 209 return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1); 210 } 211 212 template <typename Op0_t> 213 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not> 214 m_Not(const Op0_t &Op0) { 215 return m_VPInstruction<VPInstruction::Not>(Op0); 216 } 217 218 template <typename Op0_t> 219 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond> 220 m_BranchOnCond(const Op0_t &Op0) { 221 return m_VPInstruction<VPInstruction::BranchOnCond>(Op0); 222 } 223 224 template <typename Op0_t, typename Op1_t> 225 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask> 226 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) { 227 return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1); 228 } 229 230 template <typename Op0_t, typename Op1_t> 231 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount> 232 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { 233 return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1); 234 } 235 236 template <unsigned Opcode, typename Op0_t> 237 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) { 238 return AllUnaryRecipe_match<Op0_t, Opcode>(Op0); 239 } 240 241 template <typename Op0_t> 242 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc> 243 m_Trunc(const Op0_t &Op0) { 244 return m_Unary<Instruction::Trunc, Op0_t>(Op0); 245 } 246 247 template <typename Op0_t> 248 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) { 249 return m_Unary<Instruction::ZExt, Op0_t>(Op0); 250 } 251 252 template <typename Op0_t> 253 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) { 254 return m_Unary<Instruction::SExt, Op0_t>(Op0); 255 } 256 257 template <typename Op0_t> 258 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>, 259 AllUnaryRecipe_match<Op0_t, Instruction::SExt>> 260 m_ZExtOrSExt(const Op0_t &Op0) { 261 return m_CombineOr(m_ZExt(Op0), m_SExt(Op0)); 262 } 263 264 template <unsigned Opcode, typename Op0_t, typename Op1_t, 265 bool Commutative = false> 266 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative> 267 m_Binary(const Op0_t &Op0, const Op1_t &Op1) { 268 return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1); 269 } 270 271 template <typename Op0_t, typename Op1_t> 272 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul> 273 m_Mul(const Op0_t &Op0, const Op1_t &Op1) { 274 return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1); 275 } 276 277 template <typename Op0_t, typename Op1_t> 278 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul, 279 /* Commutative =*/true> 280 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) { 281 return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1); 282 } 283 284 /// Match a binary OR operation. Note that while conceptually the operands can 285 /// be matched commutatively, \p Commutative defaults to false in line with the 286 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative 287 /// version of the matcher. 288 template <typename Op0_t, typename Op1_t, bool Commutative = false> 289 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative> 290 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { 291 return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1); 292 } 293 294 template <typename Op0_t, typename Op1_t> 295 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, 296 /*Commutative*/ true> 297 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { 298 return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1); 299 } 300 301 template <typename Op0_t, typename Op1_t> 302 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd> 303 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) { 304 return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1); 305 } 306 307 struct VPCanonicalIVPHI_match { 308 bool match(const VPValue *V) { 309 auto *DefR = V->getDefiningRecipe(); 310 return DefR && match(DefR); 311 } 312 313 bool match(const VPRecipeBase *R) { return isa<VPCanonicalIVPHIRecipe>(R); } 314 }; 315 316 inline VPCanonicalIVPHI_match m_CanonicalIV() { 317 return VPCanonicalIVPHI_match(); 318 } 319 320 template <typename Op0_t, typename Op1_t> struct VPScalarIVSteps_match { 321 Op0_t Op0; 322 Op1_t Op1; 323 324 VPScalarIVSteps_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} 325 326 bool match(const VPValue *V) { 327 auto *DefR = V->getDefiningRecipe(); 328 return DefR && match(DefR); 329 } 330 331 bool match(const VPRecipeBase *R) { 332 if (!isa<VPScalarIVStepsRecipe>(R)) 333 return false; 334 assert(R->getNumOperands() == 2 && 335 "VPScalarIVSteps must have exactly 2 operands"); 336 return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)); 337 } 338 }; 339 340 template <typename Op0_t, typename Op1_t> 341 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0, 342 const Op1_t &Op1) { 343 return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1); 344 } 345 346 } // namespace VPlanPatternMatch 347 } // namespace llvm 348 349 #endif 350