xref: /freebsd/contrib/llvm-project/clang/lib/ASTMatchers/Dynamic/Marshallers.h (revision 0d8fe2373503aeac48492f28073049a8bfa4feb5)
1 //===- Marshallers.h - Generic matcher function marshallers -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Functions templates and classes to wrap matcher construct functions.
11 ///
12 /// A collection of template function and classes that provide a generic
13 /// marshalling layer on top of matcher construct functions.
14 /// These are used by the registry to export all marshaller constructors with
15 /// the same generic interface.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H
20 #define LLVM_CLANG_LIB_ASTMATCHERS_DYNAMIC_MARSHALLERS_H
21 
22 #include "clang/AST/ASTTypeTraits.h"
23 #include "clang/AST/OperationKinds.h"
24 #include "clang/ASTMatchers/ASTMatchersInternal.h"
25 #include "clang/ASTMatchers/Dynamic/Diagnostics.h"
26 #include "clang/ASTMatchers/Dynamic/VariantValue.h"
27 #include "clang/Basic/AttrKinds.h"
28 #include "clang/Basic/LLVM.h"
29 #include "clang/Basic/OpenMPKinds.h"
30 #include "clang/Basic/TypeTraits.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/None.h"
33 #include "llvm/ADT/Optional.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/StringSwitch.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Support/Regex.h"
39 #include <cassert>
40 #include <cstddef>
41 #include <iterator>
42 #include <limits>
43 #include <memory>
44 #include <string>
45 #include <utility>
46 #include <vector>
47 
48 namespace clang {
49 namespace ast_matchers {
50 namespace dynamic {
51 namespace internal {
52 
53 /// Helper template class to just from argument type to the right is/get
54 ///   functions in VariantValue.
55 /// Used to verify and extract the matcher arguments below.
56 template <class T> struct ArgTypeTraits;
57 template <class T> struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {
58 };
59 
60 template <> struct ArgTypeTraits<std::string> {
61   static bool hasCorrectType(const VariantValue &Value) {
62     return Value.isString();
63   }
64   static bool hasCorrectValue(const VariantValue &Value) { return true; }
65 
66   static const std::string &get(const VariantValue &Value) {
67     return Value.getString();
68   }
69 
70   static ArgKind getKind() {
71     return ArgKind(ArgKind::AK_String);
72   }
73 
74   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
75     return llvm::None;
76   }
77 };
78 
79 template <>
80 struct ArgTypeTraits<StringRef> : public ArgTypeTraits<std::string> {
81 };
82 
83 template <class T> struct ArgTypeTraits<ast_matchers::internal::Matcher<T>> {
84   static bool hasCorrectType(const VariantValue& Value) {
85     return Value.isMatcher();
86   }
87   static bool hasCorrectValue(const VariantValue &Value) {
88     return Value.getMatcher().hasTypedMatcher<T>();
89   }
90 
91   static ast_matchers::internal::Matcher<T> get(const VariantValue &Value) {
92     return Value.getMatcher().getTypedMatcher<T>();
93   }
94 
95   static ArgKind getKind() {
96     return ArgKind(ASTNodeKind::getFromNodeKind<T>());
97   }
98 
99   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
100     return llvm::None;
101   }
102 };
103 
104 template <> struct ArgTypeTraits<bool> {
105   static bool hasCorrectType(const VariantValue &Value) {
106     return Value.isBoolean();
107   }
108   static bool hasCorrectValue(const VariantValue &Value) { return true; }
109 
110   static bool get(const VariantValue &Value) {
111     return Value.getBoolean();
112   }
113 
114   static ArgKind getKind() {
115     return ArgKind(ArgKind::AK_Boolean);
116   }
117 
118   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
119     return llvm::None;
120   }
121 };
122 
123 template <> struct ArgTypeTraits<double> {
124   static bool hasCorrectType(const VariantValue &Value) {
125     return Value.isDouble();
126   }
127   static bool hasCorrectValue(const VariantValue &Value) { return true; }
128 
129   static double get(const VariantValue &Value) {
130     return Value.getDouble();
131   }
132 
133   static ArgKind getKind() {
134     return ArgKind(ArgKind::AK_Double);
135   }
136 
137   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
138     return llvm::None;
139   }
140 };
141 
142 template <> struct ArgTypeTraits<unsigned> {
143   static bool hasCorrectType(const VariantValue &Value) {
144     return Value.isUnsigned();
145   }
146   static bool hasCorrectValue(const VariantValue &Value) { return true; }
147 
148   static unsigned get(const VariantValue &Value) {
149     return Value.getUnsigned();
150   }
151 
152   static ArgKind getKind() {
153     return ArgKind(ArgKind::AK_Unsigned);
154   }
155 
156   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
157     return llvm::None;
158   }
159 };
160 
161 template <> struct ArgTypeTraits<attr::Kind> {
162 private:
163   static Optional<attr::Kind> getAttrKind(llvm::StringRef AttrKind) {
164     if (!AttrKind.consume_front("attr::"))
165       return llvm::None;
166     return llvm::StringSwitch<Optional<attr::Kind>>(AttrKind)
167 #define ATTR(X) .Case(#X, attr::X)
168 #include "clang/Basic/AttrList.inc"
169         .Default(llvm::None);
170   }
171 
172 public:
173   static bool hasCorrectType(const VariantValue &Value) {
174     return Value.isString();
175   }
176   static bool hasCorrectValue(const VariantValue& Value) {
177     return getAttrKind(Value.getString()).hasValue();
178   }
179 
180   static attr::Kind get(const VariantValue &Value) {
181     return *getAttrKind(Value.getString());
182   }
183 
184   static ArgKind getKind() {
185     return ArgKind(ArgKind::AK_String);
186   }
187 
188   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
189 };
190 
191 template <> struct ArgTypeTraits<CastKind> {
192 private:
193   static Optional<CastKind> getCastKind(llvm::StringRef AttrKind) {
194     if (!AttrKind.consume_front("CK_"))
195       return llvm::None;
196     return llvm::StringSwitch<Optional<CastKind>>(AttrKind)
197 #define CAST_OPERATION(Name) .Case(#Name, CK_##Name)
198 #include "clang/AST/OperationKinds.def"
199         .Default(llvm::None);
200   }
201 
202 public:
203   static bool hasCorrectType(const VariantValue &Value) {
204     return Value.isString();
205   }
206   static bool hasCorrectValue(const VariantValue& Value) {
207     return getCastKind(Value.getString()).hasValue();
208   }
209 
210   static CastKind get(const VariantValue &Value) {
211     return *getCastKind(Value.getString());
212   }
213 
214   static ArgKind getKind() {
215     return ArgKind(ArgKind::AK_String);
216   }
217 
218   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
219 };
220 
221 template <> struct ArgTypeTraits<llvm::Regex::RegexFlags> {
222 private:
223   static Optional<llvm::Regex::RegexFlags> getFlags(llvm::StringRef Flags);
224 
225 public:
226   static bool hasCorrectType(const VariantValue &Value) {
227     return Value.isString();
228   }
229   static bool hasCorrectValue(const VariantValue& Value) {
230     return getFlags(Value.getString()).hasValue();
231   }
232 
233   static llvm::Regex::RegexFlags get(const VariantValue &Value) {
234     return *getFlags(Value.getString());
235   }
236 
237   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
238 
239   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
240 };
241 
242 template <> struct ArgTypeTraits<OpenMPClauseKind> {
243 private:
244   static Optional<OpenMPClauseKind> getClauseKind(llvm::StringRef ClauseKind) {
245     return llvm::StringSwitch<Optional<OpenMPClauseKind>>(ClauseKind)
246 #define GEN_CLANG_CLAUSE_CLASS
247 #define CLAUSE_CLASS(Enum, Str, Class) .Case(#Enum, llvm::omp::Clause::Enum)
248 #include "llvm/Frontend/OpenMP/OMP.inc"
249         .Default(llvm::None);
250   }
251 
252 public:
253   static bool hasCorrectType(const VariantValue &Value) {
254     return Value.isString();
255   }
256   static bool hasCorrectValue(const VariantValue& Value) {
257     return getClauseKind(Value.getString()).hasValue();
258   }
259 
260   static OpenMPClauseKind get(const VariantValue &Value) {
261     return *getClauseKind(Value.getString());
262   }
263 
264   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
265 
266   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
267 };
268 
269 template <> struct ArgTypeTraits<UnaryExprOrTypeTrait> {
270 private:
271   static Optional<UnaryExprOrTypeTrait>
272   getUnaryOrTypeTraitKind(llvm::StringRef ClauseKind) {
273     if (!ClauseKind.consume_front("UETT_"))
274       return llvm::None;
275     return llvm::StringSwitch<Optional<UnaryExprOrTypeTrait>>(ClauseKind)
276 #define UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key) .Case(#Name, UETT_##Name)
277 #define CXX11_UNARY_EXPR_OR_TYPE_TRAIT(Spelling, Name, Key)                    \
278   .Case(#Name, UETT_##Name)
279 #include "clang/Basic/TokenKinds.def"
280         .Default(llvm::None);
281   }
282 
283 public:
284   static bool hasCorrectType(const VariantValue &Value) {
285     return Value.isString();
286   }
287   static bool hasCorrectValue(const VariantValue& Value) {
288     return getUnaryOrTypeTraitKind(Value.getString()).hasValue();
289   }
290 
291   static UnaryExprOrTypeTrait get(const VariantValue &Value) {
292     return *getUnaryOrTypeTraitKind(Value.getString());
293   }
294 
295   static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
296 
297   static llvm::Optional<std::string> getBestGuess(const VariantValue &Value);
298 };
299 
300 /// Matcher descriptor interface.
301 ///
302 /// Provides a \c create() method that constructs the matcher from the provided
303 /// arguments, and various other methods for type introspection.
304 class MatcherDescriptor {
305 public:
306   virtual ~MatcherDescriptor() = default;
307 
308   virtual VariantMatcher create(SourceRange NameRange,
309                                 ArrayRef<ParserValue> Args,
310                                 Diagnostics *Error) const = 0;
311 
312   /// Returns whether the matcher is variadic. Variadic matchers can take any
313   /// number of arguments, but they must be of the same type.
314   virtual bool isVariadic() const = 0;
315 
316   /// Returns the number of arguments accepted by the matcher if not variadic.
317   virtual unsigned getNumArgs() const = 0;
318 
319   /// Given that the matcher is being converted to type \p ThisKind, append the
320   /// set of argument types accepted for argument \p ArgNo to \p ArgKinds.
321   // FIXME: We should provide the ability to constrain the output of this
322   // function based on the types of other matcher arguments.
323   virtual void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
324                            std::vector<ArgKind> &ArgKinds) const = 0;
325 
326   /// Returns whether this matcher is convertible to the given type.  If it is
327   /// so convertible, store in *Specificity a value corresponding to the
328   /// "specificity" of the converted matcher to the given context, and in
329   /// *LeastDerivedKind the least derived matcher kind which would result in the
330   /// same matcher overload.  Zero specificity indicates that this conversion
331   /// would produce a trivial matcher that will either always or never match.
332   /// Such matchers are excluded from code completion results.
333   virtual bool
334   isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity = nullptr,
335                   ASTNodeKind *LeastDerivedKind = nullptr) const = 0;
336 
337   /// Returns whether the matcher will, given a matcher of any type T, yield a
338   /// matcher of type T.
339   virtual bool isPolymorphic() const { return false; }
340 };
341 
342 inline bool isRetKindConvertibleTo(ArrayRef<ASTNodeKind> RetKinds,
343                                    ASTNodeKind Kind, unsigned *Specificity,
344                                    ASTNodeKind *LeastDerivedKind) {
345   for (const ASTNodeKind &NodeKind : RetKinds) {
346     if (ArgKind(NodeKind).isConvertibleTo(Kind, Specificity)) {
347       if (LeastDerivedKind)
348         *LeastDerivedKind = NodeKind;
349       return true;
350     }
351   }
352   return false;
353 }
354 
355 /// Simple callback implementation. Marshaller and function are provided.
356 ///
357 /// This class wraps a function of arbitrary signature and a marshaller
358 /// function into a MatcherDescriptor.
359 /// The marshaller is in charge of taking the VariantValue arguments, checking
360 /// their types, unpacking them and calling the underlying function.
361 class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
362 public:
363   using MarshallerType = VariantMatcher (*)(void (*Func)(),
364                                             StringRef MatcherName,
365                                             SourceRange NameRange,
366                                             ArrayRef<ParserValue> Args,
367                                             Diagnostics *Error);
368 
369   /// \param Marshaller Function to unpack the arguments and call \c Func
370   /// \param Func Matcher construct function. This is the function that
371   ///   compile-time matcher expressions would use to create the matcher.
372   /// \param RetKinds The list of matcher types to which the matcher is
373   ///   convertible.
374   /// \param ArgKinds The types of the arguments this matcher takes.
375   FixedArgCountMatcherDescriptor(MarshallerType Marshaller, void (*Func)(),
376                                  StringRef MatcherName,
377                                  ArrayRef<ASTNodeKind> RetKinds,
378                                  ArrayRef<ArgKind> ArgKinds)
379       : Marshaller(Marshaller), Func(Func), MatcherName(MatcherName),
380         RetKinds(RetKinds.begin(), RetKinds.end()),
381         ArgKinds(ArgKinds.begin(), ArgKinds.end()) {}
382 
383   VariantMatcher create(SourceRange NameRange,
384                         ArrayRef<ParserValue> Args,
385                         Diagnostics *Error) const override {
386     return Marshaller(Func, MatcherName, NameRange, Args, Error);
387   }
388 
389   bool isVariadic() const override { return false; }
390   unsigned getNumArgs() const override { return ArgKinds.size(); }
391 
392   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
393                    std::vector<ArgKind> &Kinds) const override {
394     Kinds.push_back(ArgKinds[ArgNo]);
395   }
396 
397   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
398                        ASTNodeKind *LeastDerivedKind) const override {
399     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
400                                   LeastDerivedKind);
401   }
402 
403 private:
404   const MarshallerType Marshaller;
405   void (* const Func)();
406   const std::string MatcherName;
407   const std::vector<ASTNodeKind> RetKinds;
408   const std::vector<ArgKind> ArgKinds;
409 };
410 
411 /// Helper methods to extract and merge all possible typed matchers
412 /// out of the polymorphic object.
413 template <class PolyMatcher>
414 static void mergePolyMatchers(const PolyMatcher &Poly,
415                               std::vector<DynTypedMatcher> &Out,
416                               ast_matchers::internal::EmptyTypeList) {}
417 
418 template <class PolyMatcher, class TypeList>
419 static void mergePolyMatchers(const PolyMatcher &Poly,
420                               std::vector<DynTypedMatcher> &Out, TypeList) {
421   Out.push_back(ast_matchers::internal::Matcher<typename TypeList::head>(Poly));
422   mergePolyMatchers(Poly, Out, typename TypeList::tail());
423 }
424 
425 /// Convert the return values of the functions into a VariantMatcher.
426 ///
427 /// There are 2 cases right now: The return value is a Matcher<T> or is a
428 /// polymorphic matcher. For the former, we just construct the VariantMatcher.
429 /// For the latter, we instantiate all the possible Matcher<T> of the poly
430 /// matcher.
431 inline VariantMatcher outvalueToVariantMatcher(const DynTypedMatcher &Matcher) {
432   return VariantMatcher::SingleMatcher(Matcher);
433 }
434 
435 template <typename T>
436 static VariantMatcher outvalueToVariantMatcher(const T &PolyMatcher,
437                                                typename T::ReturnTypes * =
438                                                    nullptr) {
439   std::vector<DynTypedMatcher> Matchers;
440   mergePolyMatchers(PolyMatcher, Matchers, typename T::ReturnTypes());
441   VariantMatcher Out = VariantMatcher::PolymorphicMatcher(std::move(Matchers));
442   return Out;
443 }
444 
445 template <typename T>
446 inline void
447 buildReturnTypeVectorFromTypeList(std::vector<ASTNodeKind> &RetTypes) {
448   RetTypes.push_back(ASTNodeKind::getFromNodeKind<typename T::head>());
449   buildReturnTypeVectorFromTypeList<typename T::tail>(RetTypes);
450 }
451 
452 template <>
453 inline void
454 buildReturnTypeVectorFromTypeList<ast_matchers::internal::EmptyTypeList>(
455     std::vector<ASTNodeKind> &RetTypes) {}
456 
457 template <typename T>
458 struct BuildReturnTypeVector {
459   static void build(std::vector<ASTNodeKind> &RetTypes) {
460     buildReturnTypeVectorFromTypeList<typename T::ReturnTypes>(RetTypes);
461   }
462 };
463 
464 template <typename T>
465 struct BuildReturnTypeVector<ast_matchers::internal::Matcher<T>> {
466   static void build(std::vector<ASTNodeKind> &RetTypes) {
467     RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>());
468   }
469 };
470 
471 template <typename T>
472 struct BuildReturnTypeVector<ast_matchers::internal::BindableMatcher<T>> {
473   static void build(std::vector<ASTNodeKind> &RetTypes) {
474     RetTypes.push_back(ASTNodeKind::getFromNodeKind<T>());
475   }
476 };
477 
478 /// Variadic marshaller function.
479 template <typename ResultT, typename ArgT,
480           ResultT (*Func)(ArrayRef<const ArgT *>)>
481 VariantMatcher
482 variadicMatcherDescriptor(StringRef MatcherName, SourceRange NameRange,
483                           ArrayRef<ParserValue> Args, Diagnostics *Error) {
484   ArgT **InnerArgs = new ArgT *[Args.size()]();
485 
486   bool HasError = false;
487   for (size_t i = 0, e = Args.size(); i != e; ++i) {
488     using ArgTraits = ArgTypeTraits<ArgT>;
489 
490     const ParserValue &Arg = Args[i];
491     const VariantValue &Value = Arg.Value;
492     if (!ArgTraits::hasCorrectType(Value)) {
493       Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
494           << (i + 1) << ArgTraits::getKind().asString() << Value.getTypeAsString();
495       HasError = true;
496       break;
497     }
498     if (!ArgTraits::hasCorrectValue(Value)) {
499       if (llvm::Optional<std::string> BestGuess =
500               ArgTraits::getBestGuess(Value)) {
501         Error->addError(Arg.Range, Error->ET_RegistryUnknownEnumWithReplace)
502             << i + 1 << Value.getString() << *BestGuess;
503       } else if (Value.isString()) {
504         Error->addError(Arg.Range, Error->ET_RegistryValueNotFound)
505             << Value.getString();
506       } else {
507         // This isn't ideal, but it's better than reporting an empty string as
508         // the error in this case.
509         Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
510             << (i + 1) << ArgTraits::getKind().asString()
511             << Value.getTypeAsString();
512       }
513       HasError = true;
514       break;
515     }
516 
517     InnerArgs[i] = new ArgT(ArgTraits::get(Value));
518   }
519 
520   VariantMatcher Out;
521   if (!HasError) {
522     Out = outvalueToVariantMatcher(Func(llvm::makeArrayRef(InnerArgs,
523                                                            Args.size())));
524   }
525 
526   for (size_t i = 0, e = Args.size(); i != e; ++i) {
527     delete InnerArgs[i];
528   }
529   delete[] InnerArgs;
530   return Out;
531 }
532 
533 /// Matcher descriptor for variadic functions.
534 ///
535 /// This class simply wraps a VariadicFunction with the right signature to export
536 /// it as a MatcherDescriptor.
537 /// This allows us to have one implementation of the interface for as many free
538 /// functions as we want, reducing the number of symbols and size of the
539 /// object file.
540 class VariadicFuncMatcherDescriptor : public MatcherDescriptor {
541 public:
542   using RunFunc = VariantMatcher (*)(StringRef MatcherName,
543                                      SourceRange NameRange,
544                                      ArrayRef<ParserValue> Args,
545                                      Diagnostics *Error);
546 
547   template <typename ResultT, typename ArgT,
548             ResultT (*F)(ArrayRef<const ArgT *>)>
549   VariadicFuncMatcherDescriptor(
550       ast_matchers::internal::VariadicFunction<ResultT, ArgT, F> Func,
551       StringRef MatcherName)
552       : Func(&variadicMatcherDescriptor<ResultT, ArgT, F>),
553         MatcherName(MatcherName.str()),
554         ArgsKind(ArgTypeTraits<ArgT>::getKind()) {
555     BuildReturnTypeVector<ResultT>::build(RetKinds);
556   }
557 
558   VariantMatcher create(SourceRange NameRange,
559                         ArrayRef<ParserValue> Args,
560                         Diagnostics *Error) const override {
561     return Func(MatcherName, NameRange, Args, Error);
562   }
563 
564   bool isVariadic() const override { return true; }
565   unsigned getNumArgs() const override { return 0; }
566 
567   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
568                    std::vector<ArgKind> &Kinds) const override {
569     Kinds.push_back(ArgsKind);
570   }
571 
572   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
573                        ASTNodeKind *LeastDerivedKind) const override {
574     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
575                                   LeastDerivedKind);
576   }
577 
578 private:
579   const RunFunc Func;
580   const std::string MatcherName;
581   std::vector<ASTNodeKind> RetKinds;
582   const ArgKind ArgsKind;
583 };
584 
585 /// Return CK_Trivial when appropriate for VariadicDynCastAllOfMatchers.
586 class DynCastAllOfMatcherDescriptor : public VariadicFuncMatcherDescriptor {
587 public:
588   template <typename BaseT, typename DerivedT>
589   DynCastAllOfMatcherDescriptor(
590       ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT> Func,
591       StringRef MatcherName)
592       : VariadicFuncMatcherDescriptor(Func, MatcherName),
593         DerivedKind(ASTNodeKind::getFromNodeKind<DerivedT>()) {}
594 
595   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
596                        ASTNodeKind *LeastDerivedKind) const override {
597     // If Kind is not a base of DerivedKind, either DerivedKind is a base of
598     // Kind (in which case the match will always succeed) or Kind and
599     // DerivedKind are unrelated (in which case it will always fail), so set
600     // Specificity to 0.
601     if (VariadicFuncMatcherDescriptor::isConvertibleTo(Kind, Specificity,
602                                                  LeastDerivedKind)) {
603       if (Kind.isSame(DerivedKind) || !Kind.isBaseOf(DerivedKind)) {
604         if (Specificity)
605           *Specificity = 0;
606       }
607       return true;
608     } else {
609       return false;
610     }
611   }
612 
613 private:
614   const ASTNodeKind DerivedKind;
615 };
616 
617 /// Helper macros to check the arguments on all marshaller functions.
618 #define CHECK_ARG_COUNT(count)                                                 \
619   if (Args.size() != count) {                                                  \
620     Error->addError(NameRange, Error->ET_RegistryWrongArgCount)                \
621         << count << Args.size();                                               \
622     return VariantMatcher();                                                   \
623   }
624 
625 #define CHECK_ARG_TYPE(index, type)                                            \
626   if (!ArgTypeTraits<type>::hasCorrectType(Args[index].Value)) {               \
627     Error->addError(Args[index].Range, Error->ET_RegistryWrongArgType)         \
628         << (index + 1) << ArgTypeTraits<type>::getKind().asString()            \
629         << Args[index].Value.getTypeAsString();                                \
630     return VariantMatcher();                                                   \
631   }                                                                            \
632   if (!ArgTypeTraits<type>::hasCorrectValue(Args[index].Value)) {              \
633     if (llvm::Optional<std::string> BestGuess =                                \
634             ArgTypeTraits<type>::getBestGuess(Args[index].Value)) {            \
635       Error->addError(Args[index].Range,                                       \
636                       Error->ET_RegistryUnknownEnumWithReplace)                \
637           << index + 1 << Args[index].Value.getString() << *BestGuess;         \
638     } else if (Args[index].Value.isString()) {                                 \
639       Error->addError(Args[index].Range, Error->ET_RegistryValueNotFound)      \
640           << Args[index].Value.getString();                                    \
641     }                                                                          \
642     return VariantMatcher();                                                   \
643   }
644 
645 /// 0-arg marshaller function.
646 template <typename ReturnType>
647 static VariantMatcher matcherMarshall0(void (*Func)(), StringRef MatcherName,
648                                        SourceRange NameRange,
649                                        ArrayRef<ParserValue> Args,
650                                        Diagnostics *Error) {
651   using FuncType = ReturnType (*)();
652   CHECK_ARG_COUNT(0);
653   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)());
654 }
655 
656 /// 1-arg marshaller function.
657 template <typename ReturnType, typename ArgType1>
658 static VariantMatcher matcherMarshall1(void (*Func)(), StringRef MatcherName,
659                                        SourceRange NameRange,
660                                        ArrayRef<ParserValue> Args,
661                                        Diagnostics *Error) {
662   using FuncType = ReturnType (*)(ArgType1);
663   CHECK_ARG_COUNT(1);
664   CHECK_ARG_TYPE(0, ArgType1);
665   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)(
666       ArgTypeTraits<ArgType1>::get(Args[0].Value)));
667 }
668 
669 /// 2-arg marshaller function.
670 template <typename ReturnType, typename ArgType1, typename ArgType2>
671 static VariantMatcher matcherMarshall2(void (*Func)(), StringRef MatcherName,
672                                        SourceRange NameRange,
673                                        ArrayRef<ParserValue> Args,
674                                        Diagnostics *Error) {
675   using FuncType = ReturnType (*)(ArgType1, ArgType2);
676   CHECK_ARG_COUNT(2);
677   CHECK_ARG_TYPE(0, ArgType1);
678   CHECK_ARG_TYPE(1, ArgType2);
679   return outvalueToVariantMatcher(reinterpret_cast<FuncType>(Func)(
680       ArgTypeTraits<ArgType1>::get(Args[0].Value),
681       ArgTypeTraits<ArgType2>::get(Args[1].Value)));
682 }
683 
684 #undef CHECK_ARG_COUNT
685 #undef CHECK_ARG_TYPE
686 
687 /// Helper class used to collect all the possible overloads of an
688 ///   argument adaptative matcher function.
689 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
690           typename FromTypes, typename ToTypes>
691 class AdaptativeOverloadCollector {
692 public:
693   AdaptativeOverloadCollector(
694       StringRef Name, std::vector<std::unique_ptr<MatcherDescriptor>> &Out)
695       : Name(Name), Out(Out) {
696     collect(FromTypes());
697   }
698 
699 private:
700   using AdaptativeFunc = ast_matchers::internal::ArgumentAdaptingMatcherFunc<
701       ArgumentAdapterT, FromTypes, ToTypes>;
702 
703   /// End case for the recursion
704   static void collect(ast_matchers::internal::EmptyTypeList) {}
705 
706   /// Recursive case. Get the overload for the head of the list, and
707   ///   recurse to the tail.
708   template <typename FromTypeList>
709   inline void collect(FromTypeList);
710 
711   StringRef Name;
712   std::vector<std::unique_ptr<MatcherDescriptor>> &Out;
713 };
714 
715 /// MatcherDescriptor that wraps multiple "overloads" of the same
716 ///   matcher.
717 ///
718 /// It will try every overload and generate appropriate errors for when none or
719 /// more than one overloads match the arguments.
720 class OverloadedMatcherDescriptor : public MatcherDescriptor {
721 public:
722   OverloadedMatcherDescriptor(
723       MutableArrayRef<std::unique_ptr<MatcherDescriptor>> Callbacks)
724       : Overloads(std::make_move_iterator(Callbacks.begin()),
725                   std::make_move_iterator(Callbacks.end())) {}
726 
727   ~OverloadedMatcherDescriptor() override = default;
728 
729   VariantMatcher create(SourceRange NameRange,
730                         ArrayRef<ParserValue> Args,
731                         Diagnostics *Error) const override {
732     std::vector<VariantMatcher> Constructed;
733     Diagnostics::OverloadContext Ctx(Error);
734     for (const auto &O : Overloads) {
735       VariantMatcher SubMatcher = O->create(NameRange, Args, Error);
736       if (!SubMatcher.isNull()) {
737         Constructed.push_back(SubMatcher);
738       }
739     }
740 
741     if (Constructed.empty()) return VariantMatcher(); // No overload matched.
742     // We ignore the errors if any matcher succeeded.
743     Ctx.revertErrors();
744     if (Constructed.size() > 1) {
745       // More than one constructed. It is ambiguous.
746       Error->addError(NameRange, Error->ET_RegistryAmbiguousOverload);
747       return VariantMatcher();
748     }
749     return Constructed[0];
750   }
751 
752   bool isVariadic() const override {
753     bool Overload0Variadic = Overloads[0]->isVariadic();
754 #ifndef NDEBUG
755     for (const auto &O : Overloads) {
756       assert(Overload0Variadic == O->isVariadic());
757     }
758 #endif
759     return Overload0Variadic;
760   }
761 
762   unsigned getNumArgs() const override {
763     unsigned Overload0NumArgs = Overloads[0]->getNumArgs();
764 #ifndef NDEBUG
765     for (const auto &O : Overloads) {
766       assert(Overload0NumArgs == O->getNumArgs());
767     }
768 #endif
769     return Overload0NumArgs;
770   }
771 
772   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
773                    std::vector<ArgKind> &Kinds) const override {
774     for (const auto &O : Overloads) {
775       if (O->isConvertibleTo(ThisKind))
776         O->getArgKinds(ThisKind, ArgNo, Kinds);
777     }
778   }
779 
780   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
781                        ASTNodeKind *LeastDerivedKind) const override {
782     for (const auto &O : Overloads) {
783       if (O->isConvertibleTo(Kind, Specificity, LeastDerivedKind))
784         return true;
785     }
786     return false;
787   }
788 
789 private:
790   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
791 };
792 
793 template <typename ReturnType>
794 class RegexMatcherDescriptor : public MatcherDescriptor {
795 public:
796   RegexMatcherDescriptor(ReturnType (*WithFlags)(StringRef,
797                                                  llvm::Regex::RegexFlags),
798                          ReturnType (*NoFlags)(StringRef),
799                          ArrayRef<ASTNodeKind> RetKinds)
800       : WithFlags(WithFlags), NoFlags(NoFlags),
801         RetKinds(RetKinds.begin(), RetKinds.end()) {}
802   bool isVariadic() const override { return true; }
803   unsigned getNumArgs() const override { return 0; }
804 
805   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
806                    std::vector<ArgKind> &Kinds) const override {
807     assert(ArgNo < 2);
808     Kinds.push_back(ArgKind::AK_String);
809   }
810 
811   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
812                        ASTNodeKind *LeastDerivedKind) const override {
813     return isRetKindConvertibleTo(RetKinds, Kind, Specificity,
814                                   LeastDerivedKind);
815   }
816 
817   VariantMatcher create(SourceRange NameRange, ArrayRef<ParserValue> Args,
818                         Diagnostics *Error) const override {
819     if (Args.size() < 1 || Args.size() > 2) {
820       Error->addError(NameRange, Diagnostics::ET_RegistryWrongArgCount)
821           << "1 or 2" << Args.size();
822       return VariantMatcher();
823     }
824     if (!ArgTypeTraits<StringRef>::hasCorrectType(Args[0].Value)) {
825       Error->addError(Args[0].Range, Error->ET_RegistryWrongArgType)
826           << 1 << ArgTypeTraits<StringRef>::getKind().asString()
827           << Args[0].Value.getTypeAsString();
828       return VariantMatcher();
829     }
830     if (Args.size() == 1) {
831       return outvalueToVariantMatcher(
832           NoFlags(ArgTypeTraits<StringRef>::get(Args[0].Value)));
833     }
834     if (!ArgTypeTraits<llvm::Regex::RegexFlags>::hasCorrectType(
835             Args[1].Value)) {
836       Error->addError(Args[1].Range, Error->ET_RegistryWrongArgType)
837           << 2 << ArgTypeTraits<llvm::Regex::RegexFlags>::getKind().asString()
838           << Args[1].Value.getTypeAsString();
839       return VariantMatcher();
840     }
841     if (!ArgTypeTraits<llvm::Regex::RegexFlags>::hasCorrectValue(
842             Args[1].Value)) {
843       if (llvm::Optional<std::string> BestGuess =
844               ArgTypeTraits<llvm::Regex::RegexFlags>::getBestGuess(
845                   Args[1].Value)) {
846         Error->addError(Args[1].Range, Error->ET_RegistryUnknownEnumWithReplace)
847             << 2 << Args[1].Value.getString() << *BestGuess;
848       } else {
849         Error->addError(Args[1].Range, Error->ET_RegistryValueNotFound)
850             << Args[1].Value.getString();
851       }
852       return VariantMatcher();
853     }
854     return outvalueToVariantMatcher(
855         WithFlags(ArgTypeTraits<StringRef>::get(Args[0].Value),
856                   ArgTypeTraits<llvm::Regex::RegexFlags>::get(Args[1].Value)));
857   }
858 
859 private:
860   ReturnType (*const WithFlags)(StringRef, llvm::Regex::RegexFlags);
861   ReturnType (*const NoFlags)(StringRef);
862   const std::vector<ASTNodeKind> RetKinds;
863 };
864 
865 /// Variadic operator marshaller function.
866 class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
867 public:
868   using VarOp = DynTypedMatcher::VariadicOperator;
869 
870   VariadicOperatorMatcherDescriptor(unsigned MinCount, unsigned MaxCount,
871                                     VarOp Op, StringRef MatcherName)
872       : MinCount(MinCount), MaxCount(MaxCount), Op(Op),
873         MatcherName(MatcherName) {}
874 
875   VariantMatcher create(SourceRange NameRange,
876                         ArrayRef<ParserValue> Args,
877                         Diagnostics *Error) const override {
878     if (Args.size() < MinCount || MaxCount < Args.size()) {
879       const std::string MaxStr =
880           (MaxCount == std::numeric_limits<unsigned>::max() ? ""
881                                                             : Twine(MaxCount))
882               .str();
883       Error->addError(NameRange, Error->ET_RegistryWrongArgCount)
884           << ("(" + Twine(MinCount) + ", " + MaxStr + ")") << Args.size();
885       return VariantMatcher();
886     }
887 
888     std::vector<VariantMatcher> InnerArgs;
889     for (size_t i = 0, e = Args.size(); i != e; ++i) {
890       const ParserValue &Arg = Args[i];
891       const VariantValue &Value = Arg.Value;
892       if (!Value.isMatcher()) {
893         Error->addError(Arg.Range, Error->ET_RegistryWrongArgType)
894             << (i + 1) << "Matcher<>" << Value.getTypeAsString();
895         return VariantMatcher();
896       }
897       InnerArgs.push_back(Value.getMatcher());
898     }
899     return VariantMatcher::VariadicOperatorMatcher(Op, std::move(InnerArgs));
900   }
901 
902   bool isVariadic() const override { return true; }
903   unsigned getNumArgs() const override { return 0; }
904 
905   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
906                    std::vector<ArgKind> &Kinds) const override {
907     Kinds.push_back(ThisKind);
908   }
909 
910   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
911                        ASTNodeKind *LeastDerivedKind) const override {
912     if (Specificity)
913       *Specificity = 1;
914     if (LeastDerivedKind)
915       *LeastDerivedKind = Kind;
916     return true;
917   }
918 
919   bool isPolymorphic() const override { return true; }
920 
921 private:
922   const unsigned MinCount;
923   const unsigned MaxCount;
924   const VarOp Op;
925   const StringRef MatcherName;
926 };
927 
928 class MapAnyOfMatcherDescriptor : public MatcherDescriptor {
929   ASTNodeKind CladeNodeKind;
930   std::vector<ASTNodeKind> NodeKinds;
931 
932 public:
933   MapAnyOfMatcherDescriptor(ASTNodeKind CladeNodeKind,
934                             std::vector<ASTNodeKind> NodeKinds)
935       : CladeNodeKind(CladeNodeKind), NodeKinds(NodeKinds) {}
936 
937   VariantMatcher create(SourceRange NameRange, ArrayRef<ParserValue> Args,
938                         Diagnostics *Error) const override {
939 
940     std::vector<DynTypedMatcher> NodeArgs;
941 
942     for (auto NK : NodeKinds) {
943       std::vector<DynTypedMatcher> InnerArgs;
944 
945       for (const auto &Arg : Args) {
946         if (!Arg.Value.isMatcher())
947           return {};
948         const VariantMatcher &VM = Arg.Value.getMatcher();
949         if (VM.hasTypedMatcher(NK)) {
950           auto DM = VM.getTypedMatcher(NK);
951           InnerArgs.push_back(DM);
952         }
953       }
954 
955       if (InnerArgs.empty()) {
956         NodeArgs.push_back(
957             DynTypedMatcher::trueMatcher(NK).dynCastTo(CladeNodeKind));
958       } else {
959         NodeArgs.push_back(
960             DynTypedMatcher::constructVariadic(
961                 ast_matchers::internal::DynTypedMatcher::VO_AllOf, NK,
962                 InnerArgs)
963                 .dynCastTo(CladeNodeKind));
964       }
965     }
966 
967     auto Result = DynTypedMatcher::constructVariadic(
968         ast_matchers::internal::DynTypedMatcher::VO_AnyOf, CladeNodeKind,
969         NodeArgs);
970     Result.setAllowBind(true);
971     return VariantMatcher::SingleMatcher(Result);
972   }
973 
974   bool isVariadic() const override { return true; }
975   unsigned getNumArgs() const override { return 0; }
976 
977   void getArgKinds(ASTNodeKind ThisKind, unsigned,
978                    std::vector<ArgKind> &Kinds) const override {
979     Kinds.push_back(ThisKind);
980   }
981 
982   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
983                        ASTNodeKind *LeastDerivedKind) const override {
984     if (Specificity)
985       *Specificity = 1;
986     if (LeastDerivedKind)
987       *LeastDerivedKind = CladeNodeKind;
988     return true;
989   }
990 };
991 
992 /// Helper functions to select the appropriate marshaller functions.
993 /// They detect the number of arguments, arguments types and return type.
994 
995 /// 0-arg overload
996 template <typename ReturnType>
997 std::unique_ptr<MatcherDescriptor>
998 makeMatcherAutoMarshall(ReturnType (*Func)(), StringRef MatcherName) {
999   std::vector<ASTNodeKind> RetTypes;
1000   BuildReturnTypeVector<ReturnType>::build(RetTypes);
1001   return std::make_unique<FixedArgCountMatcherDescriptor>(
1002       matcherMarshall0<ReturnType>, reinterpret_cast<void (*)()>(Func),
1003       MatcherName, RetTypes, None);
1004 }
1005 
1006 /// 1-arg overload
1007 template <typename ReturnType, typename ArgType1>
1008 std::unique_ptr<MatcherDescriptor>
1009 makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1), StringRef MatcherName) {
1010   std::vector<ASTNodeKind> RetTypes;
1011   BuildReturnTypeVector<ReturnType>::build(RetTypes);
1012   ArgKind AK = ArgTypeTraits<ArgType1>::getKind();
1013   return std::make_unique<FixedArgCountMatcherDescriptor>(
1014       matcherMarshall1<ReturnType, ArgType1>,
1015       reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AK);
1016 }
1017 
1018 /// 2-arg overload
1019 template <typename ReturnType, typename ArgType1, typename ArgType2>
1020 std::unique_ptr<MatcherDescriptor>
1021 makeMatcherAutoMarshall(ReturnType (*Func)(ArgType1, ArgType2),
1022                         StringRef MatcherName) {
1023   std::vector<ASTNodeKind> RetTypes;
1024   BuildReturnTypeVector<ReturnType>::build(RetTypes);
1025   ArgKind AKs[] = { ArgTypeTraits<ArgType1>::getKind(),
1026                     ArgTypeTraits<ArgType2>::getKind() };
1027   return std::make_unique<FixedArgCountMatcherDescriptor>(
1028       matcherMarshall2<ReturnType, ArgType1, ArgType2>,
1029       reinterpret_cast<void (*)()>(Func), MatcherName, RetTypes, AKs);
1030 }
1031 
1032 template <typename ReturnType>
1033 std::unique_ptr<MatcherDescriptor> makeMatcherRegexMarshall(
1034     ReturnType (*FuncFlags)(llvm::StringRef, llvm::Regex::RegexFlags),
1035     ReturnType (*Func)(llvm::StringRef)) {
1036   std::vector<ASTNodeKind> RetTypes;
1037   BuildReturnTypeVector<ReturnType>::build(RetTypes);
1038   return std::make_unique<RegexMatcherDescriptor<ReturnType>>(FuncFlags, Func,
1039                                                               RetTypes);
1040 }
1041 
1042 /// Variadic overload.
1043 template <typename ResultT, typename ArgT,
1044           ResultT (*Func)(ArrayRef<const ArgT *>)>
1045 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
1046     ast_matchers::internal::VariadicFunction<ResultT, ArgT, Func> VarFunc,
1047     StringRef MatcherName) {
1048   return std::make_unique<VariadicFuncMatcherDescriptor>(VarFunc, MatcherName);
1049 }
1050 
1051 /// Overload for VariadicDynCastAllOfMatchers.
1052 ///
1053 /// Not strictly necessary, but DynCastAllOfMatcherDescriptor gives us better
1054 /// completion results for that type of matcher.
1055 template <typename BaseT, typename DerivedT>
1056 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
1057     ast_matchers::internal::VariadicDynCastAllOfMatcher<BaseT, DerivedT>
1058         VarFunc,
1059     StringRef MatcherName) {
1060   return std::make_unique<DynCastAllOfMatcherDescriptor>(VarFunc, MatcherName);
1061 }
1062 
1063 /// Argument adaptative overload.
1064 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
1065           typename FromTypes, typename ToTypes>
1066 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
1067     ast_matchers::internal::ArgumentAdaptingMatcherFunc<ArgumentAdapterT,
1068                                                         FromTypes, ToTypes>,
1069     StringRef MatcherName) {
1070   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
1071   AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes, ToTypes>(MatcherName,
1072                                                                     Overloads);
1073   return std::make_unique<OverloadedMatcherDescriptor>(Overloads);
1074 }
1075 
1076 template <template <typename ToArg, typename FromArg> class ArgumentAdapterT,
1077           typename FromTypes, typename ToTypes>
1078 template <typename FromTypeList>
1079 inline void AdaptativeOverloadCollector<ArgumentAdapterT, FromTypes,
1080                                         ToTypes>::collect(FromTypeList) {
1081   Out.push_back(makeMatcherAutoMarshall(
1082       &AdaptativeFunc::template create<typename FromTypeList::head>, Name));
1083   collect(typename FromTypeList::tail());
1084 }
1085 
1086 /// Variadic operator overload.
1087 template <unsigned MinCount, unsigned MaxCount>
1088 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
1089     ast_matchers::internal::VariadicOperatorMatcherFunc<MinCount, MaxCount>
1090         Func,
1091     StringRef MatcherName) {
1092   return std::make_unique<VariadicOperatorMatcherDescriptor>(
1093       MinCount, MaxCount, Func.Op, MatcherName);
1094 }
1095 
1096 template <typename CladeType, typename... MatcherT>
1097 std::unique_ptr<MatcherDescriptor> makeMatcherAutoMarshall(
1098     ast_matchers::internal::MapAnyOfMatcherImpl<CladeType, MatcherT...>,
1099     StringRef MatcherName) {
1100   return std::make_unique<MapAnyOfMatcherDescriptor>(
1101       ASTNodeKind::getFromNodeKind<CladeType>(),
1102       std::vector<ASTNodeKind>{ASTNodeKind::getFromNodeKind<MatcherT>()...});
1103 }
1104 
1105 } // namespace internal
1106 } // namespace dynamic
1107 } // namespace ast_matchers
1108 } // namespace clang
1109 
1110 #endif // LLVM_CLANG_AST_MATCHERS_DYNAMIC_MARSHALLERS_H
1111