xref: /freebsd/contrib/llvm-project/llvm/include/llvm/CodeGen/SDPatternMatch.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Contains matchers for matching SelectionDAG nodes and values.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
14 #define LLVM_CODEGEN_SDPATTERNMATCH_H
15 
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallBitVector.h"
20 #include "llvm/CodeGen/SelectionDAG.h"
21 #include "llvm/CodeGen/SelectionDAGNodes.h"
22 #include "llvm/CodeGen/TargetLowering.h"
23 
24 namespace llvm {
25 namespace SDPatternMatch {
26 
27 /// MatchContext can repurpose existing patterns to behave differently under
28 /// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
29 /// in normal circumstances, but matches VP_ADD nodes under a custom
30 /// VPMatchContext. This design is meant to facilitate code / pattern reusing.
31 class BasicMatchContext {
32   const SelectionDAG *DAG;
33   const TargetLowering *TLI;
34 
35 public:
BasicMatchContext(const SelectionDAG * DAG)36   explicit BasicMatchContext(const SelectionDAG *DAG)
37       : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
38 
BasicMatchContext(const TargetLowering * TLI)39   explicit BasicMatchContext(const TargetLowering *TLI)
40       : DAG(nullptr), TLI(TLI) {}
41 
42   // A valid MatchContext has to implement the following functions.
43 
getDAG()44   const SelectionDAG *getDAG() const { return DAG; }
45 
getTLI()46   const TargetLowering *getTLI() const { return TLI; }
47 
48   /// Return true if N effectively has opcode Opcode.
match(SDValue N,unsigned Opcode)49   bool match(SDValue N, unsigned Opcode) const {
50     return N->getOpcode() == Opcode;
51   }
52 
getNumOperands(SDValue N)53   unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
54 };
55 
56 template <typename Pattern, typename MatchContext>
sd_context_match(SDValue N,const MatchContext & Ctx,Pattern && P)57 [[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
58                                     Pattern &&P) {
59   return P.match(Ctx, N);
60 }
61 
62 template <typename Pattern, typename MatchContext>
sd_context_match(SDNode * N,const MatchContext & Ctx,Pattern && P)63 [[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
64                                     Pattern &&P) {
65   return sd_context_match(SDValue(N, 0), Ctx, P);
66 }
67 
68 template <typename Pattern>
sd_match(SDNode * N,const SelectionDAG * DAG,Pattern && P)69 [[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
70   return sd_context_match(N, BasicMatchContext(DAG), P);
71 }
72 
73 template <typename Pattern>
sd_match(SDValue N,const SelectionDAG * DAG,Pattern && P)74 [[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
75   return sd_context_match(N, BasicMatchContext(DAG), P);
76 }
77 
78 template <typename Pattern>
sd_match(SDNode * N,Pattern && P)79 [[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
80   return sd_match(N, nullptr, P);
81 }
82 
83 template <typename Pattern>
sd_match(SDValue N,Pattern && P)84 [[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
85   return sd_match(N, nullptr, P);
86 }
87 
88 // === Utilities ===
89 struct Value_match {
90   SDValue MatchVal;
91 
92   Value_match() = default;
93 
Value_matchValue_match94   explicit Value_match(SDValue Match) : MatchVal(Match) {}
95 
matchValue_match96   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
97     if (MatchVal)
98       return MatchVal == N;
99     return N.getNode();
100   }
101 };
102 
103 /// Match any valid SDValue.
m_Value()104 inline Value_match m_Value() { return Value_match(); }
105 
m_Specific(SDValue N)106 inline Value_match m_Specific(SDValue N) {
107   assert(N);
108   return Value_match(N);
109 }
110 
111 template <unsigned ResNo, typename Pattern> struct Result_match {
112   Pattern P;
113 
Result_matchResult_match114   explicit Result_match(const Pattern &P) : P(P) {}
115 
116   template <typename MatchContext>
matchResult_match117   bool match(const MatchContext &Ctx, SDValue N) {
118     return N.getResNo() == ResNo && P.match(Ctx, N);
119   }
120 };
121 
122 /// Match only if the SDValue is a certain result at ResNo.
123 template <unsigned ResNo, typename Pattern>
m_Result(const Pattern & P)124 inline Result_match<ResNo, Pattern> m_Result(const Pattern &P) {
125   return Result_match<ResNo, Pattern>(P);
126 }
127 
128 struct DeferredValue_match {
129   SDValue &MatchVal;
130 
DeferredValue_matchDeferredValue_match131   explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
132 
matchDeferredValue_match133   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
134     return N == MatchVal;
135   }
136 };
137 
138 /// Similar to m_Specific, but the specific value to match is determined by
139 /// another sub-pattern in the same sd_match() expression. For instance,
140 /// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
141 /// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
142 /// we should use `m_Add(m_Value(X), m_Deferred(X))`.
m_Deferred(SDValue & V)143 inline DeferredValue_match m_Deferred(SDValue &V) {
144   return DeferredValue_match(V);
145 }
146 
147 struct Opcode_match {
148   unsigned Opcode;
149 
Opcode_matchOpcode_match150   explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
151 
152   template <typename MatchContext>
matchOpcode_match153   bool match(const MatchContext &Ctx, SDValue N) {
154     return Ctx.match(N, Opcode);
155   }
156 };
157 
m_Opc(unsigned Opcode)158 inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
159 
m_Undef()160 inline Opcode_match m_Undef() { return Opcode_match(ISD::UNDEF); }
161 
m_Poison()162 inline Opcode_match m_Poison() { return Opcode_match(ISD::POISON); }
163 
164 template <unsigned NumUses, typename Pattern> struct NUses_match {
165   Pattern P;
166 
NUses_matchNUses_match167   explicit NUses_match(const Pattern &P) : P(P) {}
168 
169   template <typename MatchContext>
matchNUses_match170   bool match(const MatchContext &Ctx, SDValue N) {
171     // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
172     // multiple results, hence we check the subsequent pattern here before
173     // checking the number of value users.
174     return P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
175   }
176 };
177 
178 template <typename Pattern>
m_OneUse(const Pattern & P)179 inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
180   return NUses_match<1, Pattern>(P);
181 }
182 template <unsigned N, typename Pattern>
m_NUses(const Pattern & P)183 inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
184   return NUses_match<N, Pattern>(P);
185 }
186 
m_OneUse()187 inline NUses_match<1, Value_match> m_OneUse() {
188   return NUses_match<1, Value_match>(m_Value());
189 }
m_NUses()190 template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
191   return NUses_match<N, Value_match>(m_Value());
192 }
193 
194 struct Value_bind {
195   SDValue &BindVal;
196 
Value_bindValue_bind197   explicit Value_bind(SDValue &N) : BindVal(N) {}
198 
matchValue_bind199   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
200     BindVal = N;
201     return true;
202   }
203 };
204 
m_Value(SDValue & N)205 inline Value_bind m_Value(SDValue &N) { return Value_bind(N); }
206 
207 template <typename Pattern, typename PredFuncT> struct TLI_pred_match {
208   Pattern P;
209   PredFuncT PredFunc;
210 
TLI_pred_matchTLI_pred_match211   TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
212       : P(P), PredFunc(Pred) {}
213 
214   template <typename MatchContext>
matchTLI_pred_match215   bool match(const MatchContext &Ctx, SDValue N) {
216     assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
217     return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
218   }
219 };
220 
221 // Explicit deduction guide.
222 template <typename PredFuncT, typename Pattern>
223 TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
224     -> TLI_pred_match<Pattern, PredFuncT>;
225 
226 /// Match legal SDNodes based on the information provided by TargetLowering.
m_LegalOp(const Pattern & P)227 template <typename Pattern> inline auto m_LegalOp(const Pattern &P) {
228   return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
229                           return TLI.isOperationLegal(N->getOpcode(),
230                                                       N.getValueType());
231                         },
232                         P};
233 }
234 
235 /// Switch to a different MatchContext for subsequent patterns.
236 template <typename NewMatchContext, typename Pattern> struct SwitchContext {
237   const NewMatchContext &Ctx;
238   Pattern P;
239 
240   template <typename OrigMatchContext>
matchSwitchContext241   bool match(const OrigMatchContext &, SDValue N) {
242     return P.match(Ctx, N);
243   }
244 };
245 
246 template <typename MatchContext, typename Pattern>
m_Context(const MatchContext & Ctx,Pattern && P)247 inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
248                                                       Pattern &&P) {
249   return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
250 }
251 
252 // === Value type ===
253 struct ValueType_bind {
254   EVT &BindVT;
255 
ValueType_bindValueType_bind256   explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
257 
matchValueType_bind258   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
259     BindVT = N.getValueType();
260     return true;
261   }
262 };
263 
264 /// Retreive the ValueType of the current SDValue.
m_VT(EVT & VT)265 inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
266 
267 template <typename Pattern, typename PredFuncT> struct ValueType_match {
268   PredFuncT PredFunc;
269   Pattern P;
270 
ValueType_matchValueType_match271   ValueType_match(const PredFuncT &Pred, const Pattern &P)
272       : PredFunc(Pred), P(P) {}
273 
274   template <typename MatchContext>
matchValueType_match275   bool match(const MatchContext &Ctx, SDValue N) {
276     return PredFunc(N.getValueType()) && P.match(Ctx, N);
277   }
278 };
279 
280 // Explicit deduction guide.
281 template <typename PredFuncT, typename Pattern>
282 ValueType_match(const PredFuncT &Pred, const Pattern &P)
283     -> ValueType_match<Pattern, PredFuncT>;
284 
285 /// Match a specific ValueType.
286 template <typename Pattern>
m_SpecificVT(EVT RefVT,const Pattern & P)287 inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
288   return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
289 }
m_SpecificVT(EVT RefVT)290 inline auto m_SpecificVT(EVT RefVT) {
291   return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
292 }
293 
m_Glue()294 inline auto m_Glue() { return m_SpecificVT(MVT::Glue); }
m_OtherVT()295 inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); }
296 
297 /// Match a scalar ValueType.
298 template <typename Pattern>
m_SpecificScalarVT(EVT RefVT,const Pattern & P)299 inline auto m_SpecificScalarVT(EVT RefVT, const Pattern &P) {
300   return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
301                          P};
302 }
m_SpecificScalarVT(EVT RefVT)303 inline auto m_SpecificScalarVT(EVT RefVT) {
304   return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
305                          m_Value()};
306 }
307 
308 /// Match a vector ValueType.
309 template <typename Pattern>
m_SpecificVectorElementVT(EVT RefVT,const Pattern & P)310 inline auto m_SpecificVectorElementVT(EVT RefVT, const Pattern &P) {
311   return ValueType_match{[=](EVT VT) {
312                            return VT.isVector() &&
313                                   VT.getVectorElementType() == RefVT;
314                          },
315                          P};
316 }
m_SpecificVectorElementVT(EVT RefVT)317 inline auto m_SpecificVectorElementVT(EVT RefVT) {
318   return ValueType_match{[=](EVT VT) {
319                            return VT.isVector() &&
320                                   VT.getVectorElementType() == RefVT;
321                          },
322                          m_Value()};
323 }
324 
325 /// Match any integer ValueTypes.
m_IntegerVT(const Pattern & P)326 template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
327   return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
328 }
m_IntegerVT()329 inline auto m_IntegerVT() {
330   return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
331 }
332 
333 /// Match any floating point ValueTypes.
m_FloatingPointVT(const Pattern & P)334 template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
335   return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
336 }
m_FloatingPointVT()337 inline auto m_FloatingPointVT() {
338   return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
339                          m_Value()};
340 }
341 
342 /// Match any vector ValueTypes.
m_VectorVT(const Pattern & P)343 template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
344   return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
345 }
m_VectorVT()346 inline auto m_VectorVT() {
347   return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
348 }
349 
350 /// Match fixed-length vector ValueTypes.
m_FixedVectorVT(const Pattern & P)351 template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
352   return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
353 }
m_FixedVectorVT()354 inline auto m_FixedVectorVT() {
355   return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
356                          m_Value()};
357 }
358 
359 /// Match scalable vector ValueTypes.
m_ScalableVectorVT(const Pattern & P)360 template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
361   return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
362 }
m_ScalableVectorVT()363 inline auto m_ScalableVectorVT() {
364   return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
365                          m_Value()};
366 }
367 
368 /// Match legal ValueTypes based on the information provided by TargetLowering.
m_LegalType(const Pattern & P)369 template <typename Pattern> inline auto m_LegalType(const Pattern &P) {
370   return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
371                           return TLI.isTypeLegal(N.getValueType());
372                         },
373                         P};
374 }
375 
376 // === Patterns combinators ===
377 template <typename... Preds> struct And {
matchAnd378   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
379     return true;
380   }
381 };
382 
383 template <typename Pred, typename... Preds>
384 struct And<Pred, Preds...> : And<Preds...> {
385   Pred P;
386   And(const Pred &p, const Preds &...preds) : And<Preds...>(preds...), P(p) {}
387 
388   template <typename MatchContext>
389   bool match(const MatchContext &Ctx, SDValue N) {
390     return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
391   }
392 };
393 
394 template <typename... Preds> struct Or {
395   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
396     return false;
397   }
398 };
399 
400 template <typename Pred, typename... Preds>
401 struct Or<Pred, Preds...> : Or<Preds...> {
402   Pred P;
403   Or(const Pred &p, const Preds &...preds) : Or<Preds...>(preds...), P(p) {}
404 
405   template <typename MatchContext>
406   bool match(const MatchContext &Ctx, SDValue N) {
407     return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
408   }
409 };
410 
411 template <typename Pred> struct Not {
412   Pred P;
413 
414   explicit Not(const Pred &P) : P(P) {}
415 
416   template <typename MatchContext>
417   bool match(const MatchContext &Ctx, SDValue N) {
418     return !P.match(Ctx, N);
419   }
420 };
421 // Explicit deduction guide.
422 template <typename Pred> Not(const Pred &P) -> Not<Pred>;
423 
424 /// Match if the inner pattern does NOT match.
425 template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
426   return Not{P};
427 }
428 
429 template <typename... Preds> And<Preds...> m_AllOf(const Preds &...preds) {
430   return And<Preds...>(preds...);
431 }
432 
433 template <typename... Preds> Or<Preds...> m_AnyOf(const Preds &...preds) {
434   return Or<Preds...>(preds...);
435 }
436 
437 template <typename... Preds> auto m_NoneOf(const Preds &...preds) {
438   return m_Unless(m_AnyOf(preds...));
439 }
440 
441 // === Generic node matching ===
442 template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
443   template <typename MatchContext>
444   bool match(const MatchContext &Ctx, SDValue N) {
445     // Returns false if there are more operands than predicates;
446     // Ignores the last two operands if both the Context and the Node are VP
447     return Ctx.getNumOperands(N) == OpIdx;
448   }
449 };
450 
451 template <unsigned OpIdx, typename OpndPred, typename... OpndPreds>
452 struct Operands_match<OpIdx, OpndPred, OpndPreds...>
453     : Operands_match<OpIdx + 1, OpndPreds...> {
454   OpndPred P;
455 
456   Operands_match(const OpndPred &p, const OpndPreds &...preds)
457       : Operands_match<OpIdx + 1, OpndPreds...>(preds...), P(p) {}
458 
459   template <typename MatchContext>
460   bool match(const MatchContext &Ctx, SDValue N) {
461     if (OpIdx < N->getNumOperands())
462       return P.match(Ctx, N->getOperand(OpIdx)) &&
463              Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N);
464 
465     // This is the case where there are more predicates than operands.
466     return false;
467   }
468 };
469 
470 template <typename... OpndPreds>
471 auto m_Node(unsigned Opcode, const OpndPreds &...preds) {
472   return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(preds...));
473 }
474 
475 /// Provide number of operands that are not chain or glue, as well as the first
476 /// index of such operand.
477 template <bool ExcludeChain> struct EffectiveOperands {
478   unsigned Size = 0;
479   unsigned FirstIndex = 0;
480 
481   template <typename MatchContext>
482   explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
483     const unsigned TotalNumOps = Ctx.getNumOperands(N);
484     FirstIndex = TotalNumOps;
485     for (unsigned I = 0; I < TotalNumOps; ++I) {
486       // Count the number of non-chain and non-glue nodes (we ignore chain
487       // and glue by default) and retreive the operand index offset.
488       EVT VT = N->getOperand(I).getValueType();
489       if (VT != MVT::Glue && VT != MVT::Other) {
490         ++Size;
491         if (FirstIndex == TotalNumOps)
492           FirstIndex = I;
493       }
494     }
495   }
496 };
497 
498 template <> struct EffectiveOperands<false> {
499   unsigned Size = 0;
500   unsigned FirstIndex = 0;
501 
502   template <typename MatchContext>
503   explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
504       : Size(Ctx.getNumOperands(N)) {}
505 };
506 
507 // === Ternary operations ===
508 template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
509           bool ExcludeChain = false>
510 struct TernaryOpc_match {
511   unsigned Opcode;
512   T0_P Op0;
513   T1_P Op1;
514   T2_P Op2;
515 
516   TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
517                    const T2_P &Op2)
518       : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
519 
520   template <typename MatchContext>
521   bool match(const MatchContext &Ctx, SDValue N) {
522     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
523       EffectiveOperands<ExcludeChain> EO(N, Ctx);
524       assert(EO.Size == 3);
525       return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
526                Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
527               (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
528                Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
529              Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
530     }
531 
532     return false;
533   }
534 };
535 
536 template <typename T0_P, typename T1_P, typename T2_P>
537 inline TernaryOpc_match<T0_P, T1_P, T2_P>
538 m_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
539   return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SETCC, LHS, RHS, CC);
540 }
541 
542 template <typename T0_P, typename T1_P, typename T2_P>
543 inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
544 m_c_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
545   return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, LHS, RHS,
546                                                          CC);
547 }
548 
549 template <typename T0_P, typename T1_P, typename T2_P>
550 inline TernaryOpc_match<T0_P, T1_P, T2_P>
551 m_Select(const T0_P &Cond, const T1_P &T, const T2_P &F) {
552   return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SELECT, Cond, T, F);
553 }
554 
555 template <typename T0_P, typename T1_P, typename T2_P>
556 inline TernaryOpc_match<T0_P, T1_P, T2_P>
557 m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
558   return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
559 }
560 
561 template <typename T0_P, typename T1_P, typename T2_P>
562 inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
563 m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {
564   return m_Result<0>(
565       TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::LOAD, Ch, Ptr, Offset));
566 }
567 
568 template <typename T0_P, typename T1_P, typename T2_P>
569 inline TernaryOpc_match<T0_P, T1_P, T2_P>
570 m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) {
571   return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::INSERT_VECTOR_ELT, Vec, Val,
572                                             Idx);
573 }
574 
575 template <typename LHS, typename RHS, typename IDX>
576 inline TernaryOpc_match<LHS, RHS, IDX>
577 m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) {
578   return TernaryOpc_match<LHS, RHS, IDX>(ISD::INSERT_SUBVECTOR, Base, Sub, Idx);
579 }
580 
581 // === Binary operations ===
582 template <typename LHS_P, typename RHS_P, bool Commutable = false,
583           bool ExcludeChain = false>
584 struct BinaryOpc_match {
585   unsigned Opcode;
586   LHS_P LHS;
587   RHS_P RHS;
588   std::optional<SDNodeFlags> Flags;
589   BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
590                   std::optional<SDNodeFlags> Flgs = std::nullopt)
591       : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
592 
593   template <typename MatchContext>
594   bool match(const MatchContext &Ctx, SDValue N) {
595     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
596       EffectiveOperands<ExcludeChain> EO(N, Ctx);
597       assert(EO.Size == 2);
598       if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
599              RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
600             (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
601              RHS.match(Ctx, N->getOperand(EO.FirstIndex)))))
602         return false;
603 
604       if (!Flags.has_value())
605         return true;
606 
607       return (*Flags & N->getFlags()) == *Flags;
608     }
609 
610     return false;
611   }
612 };
613 
614 /// Matching while capturing mask
615 template <typename T0, typename T1, typename T2> struct SDShuffle_match {
616   T0 Op1;
617   T1 Op2;
618   T2 Mask;
619 
620   SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
621       : Op1(Op1), Op2(Op2), Mask(Mask) {}
622 
623   template <typename MatchContext>
624   bool match(const MatchContext &Ctx, SDValue N) {
625     if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
626       return Op1.match(Ctx, I->getOperand(0)) &&
627              Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask());
628     }
629     return false;
630   }
631 };
632 struct m_Mask {
633   ArrayRef<int> &MaskRef;
634   m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
635   bool match(ArrayRef<int> Mask) {
636     MaskRef = Mask;
637     return true;
638   }
639 };
640 
641 struct m_SpecificMask {
642   ArrayRef<int> MaskRef;
643   m_SpecificMask(ArrayRef<int> MaskRef) : MaskRef(MaskRef) {}
644   bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
645 };
646 
647 template <typename LHS_P, typename RHS_P, typename Pred_t,
648           bool Commutable = false, bool ExcludeChain = false>
649 struct MaxMin_match {
650   using PredType = Pred_t;
651   LHS_P LHS;
652   RHS_P RHS;
653 
654   MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
655 
656   template <typename MatchContext>
657   bool match(const MatchContext &Ctx, SDValue N) {
658     auto MatchMinMax = [&](SDValue L, SDValue R, SDValue TrueValue,
659                            SDValue FalseValue, ISD::CondCode CC) {
660       if ((TrueValue != L || FalseValue != R) &&
661           (TrueValue != R || FalseValue != L))
662         return false;
663 
664       ISD::CondCode Cond =
665           TrueValue == L ? CC : getSetCCInverse(CC, L.getValueType());
666       if (!Pred_t::match(Cond))
667         return false;
668 
669       return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
670              (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
671     };
672 
673     if (sd_context_match(N, Ctx, m_Opc(ISD::SELECT)) ||
674         sd_context_match(N, Ctx, m_Opc(ISD::VSELECT))) {
675       EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
676       assert(EO_SELECT.Size == 3);
677       SDValue Cond = N->getOperand(EO_SELECT.FirstIndex);
678       SDValue TrueValue = N->getOperand(EO_SELECT.FirstIndex + 1);
679       SDValue FalseValue = N->getOperand(EO_SELECT.FirstIndex + 2);
680 
681       if (sd_context_match(Cond, Ctx, m_Opc(ISD::SETCC))) {
682         EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
683         assert(EO_SETCC.Size == 3);
684         SDValue L = Cond->getOperand(EO_SETCC.FirstIndex);
685         SDValue R = Cond->getOperand(EO_SETCC.FirstIndex + 1);
686         auto *CondNode =
687             cast<CondCodeSDNode>(Cond->getOperand(EO_SETCC.FirstIndex + 2));
688         return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
689       }
690     }
691 
692     if (sd_context_match(N, Ctx, m_Opc(ISD::SELECT_CC))) {
693       EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
694       assert(EO_SELECT.Size == 5);
695       SDValue L = N->getOperand(EO_SELECT.FirstIndex);
696       SDValue R = N->getOperand(EO_SELECT.FirstIndex + 1);
697       SDValue TrueValue = N->getOperand(EO_SELECT.FirstIndex + 2);
698       SDValue FalseValue = N->getOperand(EO_SELECT.FirstIndex + 3);
699       auto *CondNode =
700           cast<CondCodeSDNode>(N->getOperand(EO_SELECT.FirstIndex + 4));
701       return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
702     }
703 
704     return false;
705   }
706 };
707 
708 // Helper class for identifying signed max predicates.
709 struct smax_pred_ty {
710   static bool match(ISD::CondCode Cond) {
711     return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
712   }
713 };
714 
715 // Helper class for identifying unsigned max predicates.
716 struct umax_pred_ty {
717   static bool match(ISD::CondCode Cond) {
718     return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
719   }
720 };
721 
722 // Helper class for identifying signed min predicates.
723 struct smin_pred_ty {
724   static bool match(ISD::CondCode Cond) {
725     return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
726   }
727 };
728 
729 // Helper class for identifying unsigned min predicates.
730 struct umin_pred_ty {
731   static bool match(ISD::CondCode Cond) {
732     return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
733   }
734 };
735 
736 template <typename LHS, typename RHS>
737 inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
738                                          const RHS &R) {
739   return BinaryOpc_match<LHS, RHS>(Opc, L, R);
740 }
741 template <typename LHS, typename RHS>
742 inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L,
743                                                  const RHS &R) {
744   return BinaryOpc_match<LHS, RHS, true>(Opc, L, R);
745 }
746 
747 template <typename LHS, typename RHS>
748 inline BinaryOpc_match<LHS, RHS, false, true>
749 m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
750   return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R);
751 }
752 template <typename LHS, typename RHS>
753 inline BinaryOpc_match<LHS, RHS, true, true>
754 m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
755   return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R);
756 }
757 
758 // Common binary operations
759 template <typename LHS, typename RHS>
760 inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
761   return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R);
762 }
763 
764 template <typename LHS, typename RHS>
765 inline BinaryOpc_match<LHS, RHS> m_Sub(const LHS &L, const RHS &R) {
766   return BinaryOpc_match<LHS, RHS>(ISD::SUB, L, R);
767 }
768 
769 template <typename LHS, typename RHS>
770 inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) {
771   return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R);
772 }
773 
774 template <typename LHS, typename RHS>
775 inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) {
776   return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R);
777 }
778 
779 template <typename LHS, typename RHS>
780 inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
781   return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
782 }
783 
784 template <typename LHS, typename RHS>
785 inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
786                                                     const RHS &R) {
787   return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint);
788 }
789 
790 template <typename LHS, typename RHS>
791 inline auto m_AddLike(const LHS &L, const RHS &R) {
792   return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
793 }
794 
795 template <typename LHS, typename RHS>
796 inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
797   return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
798 }
799 
800 template <typename LHS, typename RHS>
801 inline auto m_BitwiseLogic(const LHS &L, const RHS &R) {
802   return m_AnyOf(m_And(L, R), m_Or(L, R), m_Xor(L, R));
803 }
804 
805 template <typename LHS, typename RHS>
806 inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
807   return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
808 }
809 
810 template <typename LHS, typename RHS>
811 inline auto m_SMinLike(const LHS &L, const RHS &R) {
812   return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R),
813                  MaxMin_match<LHS, RHS, smin_pred_ty, true>(L, R));
814 }
815 
816 template <typename LHS, typename RHS>
817 inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
818   return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
819 }
820 
821 template <typename LHS, typename RHS>
822 inline auto m_SMaxLike(const LHS &L, const RHS &R) {
823   return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R),
824                  MaxMin_match<LHS, RHS, smax_pred_ty, true>(L, R));
825 }
826 
827 template <typename LHS, typename RHS>
828 inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
829   return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
830 }
831 
832 template <typename LHS, typename RHS>
833 inline auto m_UMinLike(const LHS &L, const RHS &R) {
834   return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R),
835                  MaxMin_match<LHS, RHS, umin_pred_ty, true>(L, R));
836 }
837 
838 template <typename LHS, typename RHS>
839 inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
840   return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
841 }
842 
843 template <typename LHS, typename RHS>
844 inline auto m_UMaxLike(const LHS &L, const RHS &R) {
845   return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R),
846                  MaxMin_match<LHS, RHS, umax_pred_ty, true>(L, R));
847 }
848 
849 template <typename LHS, typename RHS>
850 inline BinaryOpc_match<LHS, RHS> m_UDiv(const LHS &L, const RHS &R) {
851   return BinaryOpc_match<LHS, RHS>(ISD::UDIV, L, R);
852 }
853 template <typename LHS, typename RHS>
854 inline BinaryOpc_match<LHS, RHS> m_SDiv(const LHS &L, const RHS &R) {
855   return BinaryOpc_match<LHS, RHS>(ISD::SDIV, L, R);
856 }
857 
858 template <typename LHS, typename RHS>
859 inline BinaryOpc_match<LHS, RHS> m_URem(const LHS &L, const RHS &R) {
860   return BinaryOpc_match<LHS, RHS>(ISD::UREM, L, R);
861 }
862 template <typename LHS, typename RHS>
863 inline BinaryOpc_match<LHS, RHS> m_SRem(const LHS &L, const RHS &R) {
864   return BinaryOpc_match<LHS, RHS>(ISD::SREM, L, R);
865 }
866 
867 template <typename LHS, typename RHS>
868 inline BinaryOpc_match<LHS, RHS> m_Shl(const LHS &L, const RHS &R) {
869   return BinaryOpc_match<LHS, RHS>(ISD::SHL, L, R);
870 }
871 
872 template <typename LHS, typename RHS>
873 inline BinaryOpc_match<LHS, RHS> m_Sra(const LHS &L, const RHS &R) {
874   return BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R);
875 }
876 template <typename LHS, typename RHS>
877 inline BinaryOpc_match<LHS, RHS> m_Srl(const LHS &L, const RHS &R) {
878   return BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R);
879 }
880 
881 template <typename LHS, typename RHS>
882 inline BinaryOpc_match<LHS, RHS> m_Rotl(const LHS &L, const RHS &R) {
883   return BinaryOpc_match<LHS, RHS>(ISD::ROTL, L, R);
884 }
885 
886 template <typename LHS, typename RHS>
887 inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
888   return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
889 }
890 
891 template <typename LHS, typename RHS>
892 inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
893   return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);
894 }
895 
896 template <typename LHS, typename RHS>
897 inline BinaryOpc_match<LHS, RHS> m_FSub(const LHS &L, const RHS &R) {
898   return BinaryOpc_match<LHS, RHS>(ISD::FSUB, L, R);
899 }
900 
901 template <typename LHS, typename RHS>
902 inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) {
903   return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R);
904 }
905 
906 template <typename LHS, typename RHS>
907 inline BinaryOpc_match<LHS, RHS> m_FDiv(const LHS &L, const RHS &R) {
908   return BinaryOpc_match<LHS, RHS>(ISD::FDIV, L, R);
909 }
910 
911 template <typename LHS, typename RHS>
912 inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
913   return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
914 }
915 
916 template <typename V1_t, typename V2_t>
917 inline BinaryOpc_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2) {
918   return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
919 }
920 
921 template <typename V1_t, typename V2_t, typename Mask_t>
922 inline SDShuffle_match<V1_t, V2_t, Mask_t>
923 m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
924   return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
925 }
926 
927 template <typename LHS, typename RHS>
928 inline BinaryOpc_match<LHS, RHS> m_ExtractElt(const LHS &Vec, const RHS &Idx) {
929   return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx);
930 }
931 
932 template <typename LHS, typename RHS>
933 inline BinaryOpc_match<LHS, RHS> m_ExtractSubvector(const LHS &Vec,
934                                                     const RHS &Idx) {
935   return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_SUBVECTOR, Vec, Idx);
936 }
937 
938 // === Unary operations ===
939 template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
940   unsigned Opcode;
941   Opnd_P Opnd;
942   std::optional<SDNodeFlags> Flags;
943   UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
944                  std::optional<SDNodeFlags> Flgs = std::nullopt)
945       : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
946 
947   template <typename MatchContext>
948   bool match(const MatchContext &Ctx, SDValue N) {
949     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
950       EffectiveOperands<ExcludeChain> EO(N, Ctx);
951       assert(EO.Size == 1);
952       if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex)))
953         return false;
954       if (!Flags.has_value())
955         return true;
956 
957       return (*Flags & N->getFlags()) == *Flags;
958     }
959 
960     return false;
961   }
962 };
963 
964 template <typename Opnd>
965 inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
966   return UnaryOpc_match<Opnd>(Opc, Op);
967 }
968 template <typename Opnd>
969 inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
970                                                    const Opnd &Op) {
971   return UnaryOpc_match<Opnd, true>(Opc, Op);
972 }
973 
974 template <typename Opnd> inline UnaryOpc_match<Opnd> m_BitCast(const Opnd &Op) {
975   return UnaryOpc_match<Opnd>(ISD::BITCAST, Op);
976 }
977 
978 template <typename Opnd>
979 inline UnaryOpc_match<Opnd> m_BSwap(const Opnd &Op) {
980   return UnaryOpc_match<Opnd>(ISD::BSWAP, Op);
981 }
982 
983 template <typename Opnd>
984 inline UnaryOpc_match<Opnd> m_BitReverse(const Opnd &Op) {
985   return UnaryOpc_match<Opnd>(ISD::BITREVERSE, Op);
986 }
987 
988 template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
989   return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
990 }
991 
992 template <typename Opnd>
993 inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
994   return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg);
995 }
996 
997 template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
998   return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
999 }
1000 
1001 template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
1002   return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op);
1003 }
1004 
1005 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
1006   return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
1007 }
1008 
1009 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Abs(const Opnd &Op) {
1010   return UnaryOpc_match<Opnd>(ISD::ABS, Op);
1011 }
1012 
1013 /// Match a zext or identity
1014 /// Allows to peek through optional extensions
1015 template <typename Opnd> inline auto m_ZExtOrSelf(const Opnd &Op) {
1016   return m_AnyOf(m_ZExt(Op), Op);
1017 }
1018 
1019 /// Match a sext or identity
1020 /// Allows to peek through optional extensions
1021 template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
1022   return m_AnyOf(m_SExt(Op), Op);
1023 }
1024 
1025 template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
1026   return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
1027 }
1028 
1029 /// Match a aext or identity
1030 /// Allows to peek through optional extensions
1031 template <typename Opnd>
1032 inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(const Opnd &Op) {
1033   return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(Op), Op);
1034 }
1035 
1036 /// Match a trunc or identity
1037 /// Allows to peek through optional truncations
1038 template <typename Opnd>
1039 inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(const Opnd &Op) {
1040   return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(Op), Op);
1041 }
1042 
1043 template <typename Opnd> inline UnaryOpc_match<Opnd> m_VScale(const Opnd &Op) {
1044   return UnaryOpc_match<Opnd>(ISD::VSCALE, Op);
1045 }
1046 
1047 template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToUI(const Opnd &Op) {
1048   return UnaryOpc_match<Opnd>(ISD::FP_TO_UINT, Op);
1049 }
1050 
1051 template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToSI(const Opnd &Op) {
1052   return UnaryOpc_match<Opnd>(ISD::FP_TO_SINT, Op);
1053 }
1054 
1055 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctpop(const Opnd &Op) {
1056   return UnaryOpc_match<Opnd>(ISD::CTPOP, Op);
1057 }
1058 
1059 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctlz(const Opnd &Op) {
1060   return UnaryOpc_match<Opnd>(ISD::CTLZ, Op);
1061 }
1062 
1063 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Cttz(const Opnd &Op) {
1064   return UnaryOpc_match<Opnd>(ISD::CTTZ, Op);
1065 }
1066 
1067 // === Constants ===
1068 struct ConstantInt_match {
1069   APInt *BindVal;
1070 
1071   explicit ConstantInt_match(APInt *V) : BindVal(V) {}
1072 
1073   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1074     // The logics here are similar to that in
1075     // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
1076     // treats GlobalAddressSDNode as a constant, which is difficult to turn into
1077     // APInt.
1078     if (auto *C = dyn_cast_or_null<ConstantSDNode>(N.getNode())) {
1079       if (BindVal)
1080         *BindVal = C->getAPIntValue();
1081       return true;
1082     }
1083 
1084     APInt Discard;
1085     return ISD::isConstantSplatVector(N.getNode(),
1086                                       BindVal ? *BindVal : Discard);
1087   }
1088 };
1089 /// Match any interger constants or splat of an integer constant.
1090 inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); }
1091 /// Match any interger constants or splat of an integer constant; return the
1092 /// specific constant or constant splat value.
1093 inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); }
1094 
1095 struct SpecificInt_match {
1096   APInt IntVal;
1097 
1098   explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
1099 
1100   template <typename MatchContext>
1101   bool match(const MatchContext &Ctx, SDValue N) {
1102     APInt ConstInt;
1103     if (sd_context_match(N, Ctx, m_ConstInt(ConstInt)))
1104       return APInt::isSameValue(IntVal, ConstInt);
1105     return false;
1106   }
1107 };
1108 
1109 /// Match a specific integer constant or constant splat value.
1110 inline SpecificInt_match m_SpecificInt(APInt V) {
1111   return SpecificInt_match(std::move(V));
1112 }
1113 inline SpecificInt_match m_SpecificInt(uint64_t V) {
1114   return SpecificInt_match(APInt(64, V));
1115 }
1116 
1117 struct Zero_match {
1118   bool AllowUndefs;
1119 
1120   explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1121 
1122   template <typename MatchContext>
1123   bool match(const MatchContext &, SDValue N) const {
1124     return isZeroOrZeroSplat(N, AllowUndefs);
1125   }
1126 };
1127 
1128 struct Ones_match {
1129   bool AllowUndefs;
1130 
1131   Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1132 
1133   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1134     return isOnesOrOnesSplat(N, AllowUndefs);
1135   }
1136 };
1137 
1138 struct AllOnes_match {
1139   bool AllowUndefs;
1140 
1141   AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1142 
1143   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1144     return isAllOnesOrAllOnesSplat(N, AllowUndefs);
1145   }
1146 };
1147 
1148 inline Ones_match m_One(bool AllowUndefs = false) {
1149   return Ones_match(AllowUndefs);
1150 }
1151 inline Zero_match m_Zero(bool AllowUndefs = false) {
1152   return Zero_match(AllowUndefs);
1153 }
1154 inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
1155   return AllOnes_match(AllowUndefs);
1156 }
1157 
1158 /// Match true boolean value based on the information provided by
1159 /// TargetLowering.
1160 inline auto m_True() {
1161   return TLI_pred_match{
1162       [](const TargetLowering &TLI, SDValue N) {
1163         APInt ConstVal;
1164         if (sd_match(N, m_ConstInt(ConstVal)))
1165           switch (TLI.getBooleanContents(N.getValueType())) {
1166           case TargetLowering::ZeroOrOneBooleanContent:
1167             return ConstVal.isOne();
1168           case TargetLowering::ZeroOrNegativeOneBooleanContent:
1169             return ConstVal.isAllOnes();
1170           case TargetLowering::UndefinedBooleanContent:
1171             return (ConstVal & 0x01) == 1;
1172           }
1173 
1174         return false;
1175       },
1176       m_Value()};
1177 }
1178 /// Match false boolean value based on the information provided by
1179 /// TargetLowering.
1180 inline auto m_False() {
1181   return TLI_pred_match{
1182       [](const TargetLowering &TLI, SDValue N) {
1183         APInt ConstVal;
1184         if (sd_match(N, m_ConstInt(ConstVal)))
1185           switch (TLI.getBooleanContents(N.getValueType())) {
1186           case TargetLowering::ZeroOrOneBooleanContent:
1187           case TargetLowering::ZeroOrNegativeOneBooleanContent:
1188             return ConstVal.isZero();
1189           case TargetLowering::UndefinedBooleanContent:
1190             return (ConstVal & 0x01) == 0;
1191           }
1192 
1193         return false;
1194       },
1195       m_Value()};
1196 }
1197 
1198 struct CondCode_match {
1199   std::optional<ISD::CondCode> CCToMatch;
1200   ISD::CondCode *BindCC = nullptr;
1201 
1202   explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
1203 
1204   explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
1205 
1206   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1207     if (auto *CC = dyn_cast<CondCodeSDNode>(N.getNode())) {
1208       if (CCToMatch && *CCToMatch != CC->get())
1209         return false;
1210 
1211       if (BindCC)
1212         *BindCC = CC->get();
1213       return true;
1214     }
1215 
1216     return false;
1217   }
1218 };
1219 
1220 /// Match any conditional code SDNode.
1221 inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
1222 /// Match any conditional code SDNode and return its ISD::CondCode value.
1223 inline CondCode_match m_CondCode(ISD::CondCode &CC) {
1224   return CondCode_match(&CC);
1225 }
1226 /// Match a conditional code SDNode with a specific ISD::CondCode.
1227 inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
1228   return CondCode_match(CC);
1229 }
1230 
1231 /// Match a negate as a sub(0, v)
1232 template <typename ValTy>
1233 inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
1234   return m_Sub(m_Zero(), V);
1235 }
1236 
1237 /// Match a Not as a xor(v, -1) or xor(-1, v)
1238 template <typename ValTy>
1239 inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
1240   return m_Xor(V, m_AllOnes());
1241 }
1242 
1243 template <typename... PatternTs> struct ReassociatableOpc_match {
1244   unsigned Opcode;
1245   std::tuple<PatternTs...> Patterns;
1246 
1247   ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
1248       : Opcode(Opcode), Patterns(Patterns...) {}
1249 
1250   template <typename MatchContext>
1251   bool match(const MatchContext &Ctx, SDValue N) {
1252     constexpr size_t NumPatterns = std::tuple_size_v<std::tuple<PatternTs...>>;
1253 
1254     SmallVector<SDValue> Leaves;
1255     collectLeaves(N, Leaves);
1256     if (Leaves.size() != NumPatterns)
1257       return false;
1258 
1259     // Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
1260     // std::get<J>(Patterns)) == true
1261     std::array<SmallBitVector, NumPatterns> Matches;
1262     for (size_t I = 0; I != NumPatterns; I++) {
1263       std::apply(
1264           [&](auto &...P) {
1265             (Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
1266           },
1267           Patterns);
1268     }
1269 
1270     SmallBitVector Used(NumPatterns);
1271     return reassociatableMatchHelper(Matches, Used);
1272   }
1273 
1274   void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
1275     if (V->getOpcode() == Opcode) {
1276       for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
1277         collectLeaves(V->getOperand(I), Leaves);
1278     } else {
1279       Leaves.emplace_back(V);
1280     }
1281   }
1282 
1283   [[nodiscard]] inline bool
1284   reassociatableMatchHelper(const ArrayRef<SmallBitVector> Matches,
1285                             SmallBitVector &Used, size_t Curr = 0) {
1286     if (Curr == Matches.size())
1287       return true;
1288     for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
1289       if (!Matches[Curr][Match] || Used[Match])
1290         continue;
1291       Used[Match] = true;
1292       if (reassociatableMatchHelper(Matches, Used, Curr + 1))
1293         return true;
1294       Used[Match] = false;
1295     }
1296     return false;
1297   }
1298 };
1299 
1300 template <typename... PatternTs>
1301 inline ReassociatableOpc_match<PatternTs...>
1302 m_ReassociatableAdd(const PatternTs &...Patterns) {
1303   return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
1304 }
1305 
1306 template <typename... PatternTs>
1307 inline ReassociatableOpc_match<PatternTs...>
1308 m_ReassociatableOr(const PatternTs &...Patterns) {
1309   return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
1310 }
1311 
1312 template <typename... PatternTs>
1313 inline ReassociatableOpc_match<PatternTs...>
1314 m_ReassociatableAnd(const PatternTs &...Patterns) {
1315   return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
1316 }
1317 
1318 template <typename... PatternTs>
1319 inline ReassociatableOpc_match<PatternTs...>
1320 m_ReassociatableMul(const PatternTs &...Patterns) {
1321   return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
1322 }
1323 
1324 } // namespace SDPatternMatch
1325 } // namespace llvm
1326 #endif
1327