xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matches on the VPlan values and recipes, based on
11 // LLVM's IR pattern matchers.
12 //
13 // Currently it provides generic matchers for unary and binary VPInstructions,
14 // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond,
15 // m_BranchOnCount to match specific VPInstructions.
16 // TODO: Add missing matchers for additional opcodes and recipes as needed.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
21 #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
22 
23 #include "VPlan.h"
24 
25 namespace llvm {
26 namespace VPlanPatternMatch {
27 
match(Val * V,const Pattern & P)28 template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
29   return const_cast<Pattern &>(P).match(V);
30 }
31 
32 template <typename Class> struct class_match {
matchclass_match33   template <typename ITy> bool match(ITy *V) { return isa<Class>(V); }
34 };
35 
36 /// Match an arbitrary VPValue and ignore it.
m_VPValue()37 inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
38 
39 template <typename Class> struct bind_ty {
40   Class *&VR;
41 
bind_tybind_ty42   bind_ty(Class *&V) : VR(V) {}
43 
matchbind_ty44   template <typename ITy> bool match(ITy *V) {
45     if (auto *CV = dyn_cast<Class>(V)) {
46       VR = CV;
47       return true;
48     }
49     return false;
50   }
51 };
52 
53 /// Match a specified integer value or vector of all elements of that
54 /// value. \p BitWidth optionally specifies the bitwidth the matched constant
55 /// must have. If it is 0, the matched constant can have any bitwidth.
56 template <unsigned BitWidth = 0> struct specific_intval {
57   APInt Val;
58 
specific_intvalspecific_intval59   specific_intval(APInt V) : Val(std::move(V)) {}
60 
matchspecific_intval61   bool match(VPValue *VPV) {
62     if (!VPV->isLiveIn())
63       return false;
64     Value *V = VPV->getLiveInIRValue();
65     const auto *CI = dyn_cast<ConstantInt>(V);
66     if (!CI && V->getType()->isVectorTy())
67       if (const auto *C = dyn_cast<Constant>(V))
68         CI = dyn_cast_or_null<ConstantInt>(
69             C->getSplatValue(/*AllowPoison=*/false));
70     if (!CI)
71       return false;
72 
73     assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
74            "Trying the match constant with unexpected bitwidth.");
75     return APInt::isSameValue(CI->getValue(), Val);
76   }
77 };
78 
m_SpecificInt(uint64_t V)79 inline specific_intval<0> m_SpecificInt(uint64_t V) {
80   return specific_intval<0>(APInt(64, V));
81 }
82 
m_False()83 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
84 
85 /// Matching combinators
86 template <typename LTy, typename RTy> struct match_combine_or {
87   LTy L;
88   RTy R;
89 
match_combine_ormatch_combine_or90   match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
91 
matchmatch_combine_or92   template <typename ITy> bool match(ITy *V) {
93     if (L.match(V))
94       return true;
95     if (R.match(V))
96       return true;
97     return false;
98   }
99 };
100 
101 template <typename LTy, typename RTy>
m_CombineOr(const LTy & L,const RTy & R)102 inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
103   return match_combine_or<LTy, RTy>(L, R);
104 }
105 
106 /// Match a VPValue, capturing it if we match.
m_VPValue(VPValue * & V)107 inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
108 
109 namespace detail {
110 
111 /// A helper to match an opcode against multiple recipe types.
112 template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {};
113 
114 template <unsigned Opcode, typename RecipeTy>
115 struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
116   static bool match(const VPRecipeBase *R) {
117     auto *DefR = dyn_cast<RecipeTy>(R);
118     return DefR && DefR->getOpcode() == Opcode;
119   }
120 };
121 
122 template <unsigned Opcode, typename RecipeTy, typename... RecipeTys>
123 struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
124   static bool match(const VPRecipeBase *R) {
125     return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) ||
126            MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
127   }
128 };
129 } // namespace detail
130 
131 template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
132 struct UnaryRecipe_match {
133   Op0_t Op0;
134 
135   UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {}
136 
137   bool match(const VPValue *V) {
138     auto *DefR = V->getDefiningRecipe();
139     return DefR && match(DefR);
140   }
141 
142   bool match(const VPRecipeBase *R) {
143     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
144       return false;
145     assert(R->getNumOperands() == 1 &&
146            "recipe with matched opcode does not have 1 operands");
147     return Op0.match(R->getOperand(0));
148   }
149 };
150 
151 template <typename Op0_t, unsigned Opcode>
152 using UnaryVPInstruction_match =
153     UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
154 
155 template <typename Op0_t, unsigned Opcode>
156 using AllUnaryRecipe_match =
157     UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
158                       VPWidenCastRecipe, VPInstruction>;
159 
160 template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
161           typename... RecipeTys>
162 struct BinaryRecipe_match {
163   Op0_t Op0;
164   Op1_t Op1;
165 
166   BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
167 
168   bool match(const VPValue *V) {
169     auto *DefR = V->getDefiningRecipe();
170     return DefR && match(DefR);
171   }
172 
173   bool match(const VPSingleDefRecipe *R) {
174     return match(static_cast<const VPRecipeBase *>(R));
175   }
176 
177   bool match(const VPRecipeBase *R) {
178     if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
179       return false;
180     assert(R->getNumOperands() == 2 &&
181            "recipe with matched opcode does not have 2 operands");
182     if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)))
183       return true;
184     return Commutative && Op0.match(R->getOperand(1)) &&
185            Op1.match(R->getOperand(0));
186   }
187 };
188 
189 template <typename Op0_t, typename Op1_t, unsigned Opcode>
190 using BinaryVPInstruction_match =
191     BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
192                        VPInstruction>;
193 
194 template <typename Op0_t, typename Op1_t, unsigned Opcode,
195           bool Commutative = false>
196 using AllBinaryRecipe_match =
197     BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
198                        VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;
199 
200 template <unsigned Opcode, typename Op0_t>
201 inline UnaryVPInstruction_match<Op0_t, Opcode>
202 m_VPInstruction(const Op0_t &Op0) {
203   return UnaryVPInstruction_match<Op0_t, Opcode>(Op0);
204 }
205 
206 template <unsigned Opcode, typename Op0_t, typename Op1_t>
207 inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>
208 m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
209   return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
210 }
211 
212 template <typename Op0_t>
213 inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
214 m_Not(const Op0_t &Op0) {
215   return m_VPInstruction<VPInstruction::Not>(Op0);
216 }
217 
218 template <typename Op0_t>
219 inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond>
220 m_BranchOnCond(const Op0_t &Op0) {
221   return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
222 }
223 
224 template <typename Op0_t, typename Op1_t>
225 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask>
226 m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) {
227   return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1);
228 }
229 
230 template <typename Op0_t, typename Op1_t>
231 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
232 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
233   return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
234 }
235 
236 template <unsigned Opcode, typename Op0_t>
237 inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) {
238   return AllUnaryRecipe_match<Op0_t, Opcode>(Op0);
239 }
240 
241 template <typename Op0_t>
242 inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc>
243 m_Trunc(const Op0_t &Op0) {
244   return m_Unary<Instruction::Trunc, Op0_t>(Op0);
245 }
246 
247 template <typename Op0_t>
248 inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) {
249   return m_Unary<Instruction::ZExt, Op0_t>(Op0);
250 }
251 
252 template <typename Op0_t>
253 inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) {
254   return m_Unary<Instruction::SExt, Op0_t>(Op0);
255 }
256 
257 template <typename Op0_t>
258 inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>,
259                         AllUnaryRecipe_match<Op0_t, Instruction::SExt>>
260 m_ZExtOrSExt(const Op0_t &Op0) {
261   return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
262 }
263 
264 template <unsigned Opcode, typename Op0_t, typename Op1_t,
265           bool Commutative = false>
266 inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>
267 m_Binary(const Op0_t &Op0, const Op1_t &Op1) {
268   return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1);
269 }
270 
271 template <typename Op0_t, typename Op1_t>
272 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul>
273 m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
274   return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
275 }
276 
277 template <typename Op0_t, typename Op1_t>
278 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul,
279                              /* Commutative =*/true>
280 m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
281   return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1);
282 }
283 
284 /// Match a binary OR operation. Note that while conceptually the operands can
285 /// be matched commutatively, \p Commutative defaults to false in line with the
286 /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
287 /// version of the matcher.
288 template <typename Op0_t, typename Op1_t, bool Commutative = false>
289 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative>
290 m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
291   return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1);
292 }
293 
294 template <typename Op0_t, typename Op1_t>
295 inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or,
296                              /*Commutative*/ true>
297 m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
298   return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
299 }
300 
301 template <typename Op0_t, typename Op1_t>
302 inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>
303 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
304   return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
305 }
306 
307 struct VPCanonicalIVPHI_match {
308   bool match(const VPValue *V) {
309     auto *DefR = V->getDefiningRecipe();
310     return DefR && match(DefR);
311   }
312 
313   bool match(const VPRecipeBase *R) { return isa<VPCanonicalIVPHIRecipe>(R); }
314 };
315 
316 inline VPCanonicalIVPHI_match m_CanonicalIV() {
317   return VPCanonicalIVPHI_match();
318 }
319 
320 template <typename Op0_t, typename Op1_t> struct VPScalarIVSteps_match {
321   Op0_t Op0;
322   Op1_t Op1;
323 
324   VPScalarIVSteps_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
325 
326   bool match(const VPValue *V) {
327     auto *DefR = V->getDefiningRecipe();
328     return DefR && match(DefR);
329   }
330 
331   bool match(const VPRecipeBase *R) {
332     if (!isa<VPScalarIVStepsRecipe>(R))
333       return false;
334     assert(R->getNumOperands() == 2 &&
335            "VPScalarIVSteps must have exactly 2 operands");
336     return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1));
337   }
338 };
339 
340 template <typename Op0_t, typename Op1_t>
341 inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
342                                                            const Op1_t &Op1) {
343   return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
344 }
345 
346 } // namespace VPlanPatternMatch
347 } // namespace llvm
348 
349 #endif
350