xref: /freebsd/contrib/llvm-project/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //=======- PtrTypesSemantics.cpp ---------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PtrTypesSemantics.h"
10 #include "ASTUtils.h"
11 #include "clang/AST/Attr.h"
12 #include "clang/AST/CXXInheritance.h"
13 #include "clang/AST/Decl.h"
14 #include "clang/AST/DeclCXX.h"
15 #include "clang/AST/ExprCXX.h"
16 #include "clang/AST/StmtVisitor.h"
17 #include "clang/Analysis/DomainSpecific/CocoaConventions.h"
18 #include <optional>
19 
20 using namespace clang;
21 
22 namespace {
23 
hasPublicMethodInBaseClass(const CXXRecordDecl * R,StringRef NameToMatch)24 bool hasPublicMethodInBaseClass(const CXXRecordDecl *R, StringRef NameToMatch) {
25   assert(R);
26   assert(R->hasDefinition());
27 
28   for (const CXXMethodDecl *MD : R->methods()) {
29     const auto MethodName = safeGetName(MD);
30     if (MethodName == NameToMatch && MD->getAccess() == AS_public)
31       return true;
32   }
33   return false;
34 }
35 
36 } // namespace
37 
38 namespace clang {
39 
40 std::optional<const clang::CXXRecordDecl *>
hasPublicMethodInBase(const CXXBaseSpecifier * Base,StringRef NameToMatch)41 hasPublicMethodInBase(const CXXBaseSpecifier *Base, StringRef NameToMatch) {
42   assert(Base);
43 
44   const Type *T = Base->getType().getTypePtrOrNull();
45   if (!T)
46     return std::nullopt;
47 
48   const CXXRecordDecl *R = T->getAsCXXRecordDecl();
49   if (!R) {
50     auto CT = Base->getType().getCanonicalType();
51     if (auto *TST = dyn_cast<TemplateSpecializationType>(CT)) {
52       auto TmplName = TST->getTemplateName();
53       if (!TmplName.isNull()) {
54         if (auto *TD = TmplName.getAsTemplateDecl())
55           R = dyn_cast_or_null<CXXRecordDecl>(TD->getTemplatedDecl());
56       }
57     }
58     if (!R)
59       return std::nullopt;
60   }
61   if (!R->hasDefinition())
62     return std::nullopt;
63 
64   return hasPublicMethodInBaseClass(R, NameToMatch) ? R : nullptr;
65 }
66 
isSmartPtrCompatible(const CXXRecordDecl * R,StringRef IncMethodName,StringRef DecMethodName)67 std::optional<bool> isSmartPtrCompatible(const CXXRecordDecl *R,
68                                          StringRef IncMethodName,
69                                          StringRef DecMethodName) {
70   assert(R);
71 
72   R = R->getDefinition();
73   if (!R)
74     return std::nullopt;
75 
76   bool hasRef = hasPublicMethodInBaseClass(R, IncMethodName);
77   bool hasDeref = hasPublicMethodInBaseClass(R, DecMethodName);
78   if (hasRef && hasDeref)
79     return true;
80 
81   CXXBasePaths Paths;
82   Paths.setOrigin(const_cast<CXXRecordDecl *>(R));
83 
84   bool AnyInconclusiveBase = false;
85   const auto hasPublicRefInBase = [&](const CXXBaseSpecifier *Base,
86                                       CXXBasePath &) {
87     auto hasRefInBase = clang::hasPublicMethodInBase(Base, IncMethodName);
88     if (!hasRefInBase) {
89       AnyInconclusiveBase = true;
90       return false;
91     }
92     return (*hasRefInBase) != nullptr;
93   };
94 
95   hasRef = hasRef || R->lookupInBases(hasPublicRefInBase, Paths,
96                                       /*LookupInDependent =*/true);
97   if (AnyInconclusiveBase)
98     return std::nullopt;
99 
100   Paths.clear();
101   const auto hasPublicDerefInBase = [&](const CXXBaseSpecifier *Base,
102                                         CXXBasePath &) {
103     auto hasDerefInBase = clang::hasPublicMethodInBase(Base, DecMethodName);
104     if (!hasDerefInBase) {
105       AnyInconclusiveBase = true;
106       return false;
107     }
108     return (*hasDerefInBase) != nullptr;
109   };
110   hasDeref = hasDeref || R->lookupInBases(hasPublicDerefInBase, Paths,
111                                           /*LookupInDependent =*/true);
112   if (AnyInconclusiveBase)
113     return std::nullopt;
114 
115   return hasRef && hasDeref;
116 }
117 
isRefCountable(const clang::CXXRecordDecl * R)118 std::optional<bool> isRefCountable(const clang::CXXRecordDecl *R) {
119   return isSmartPtrCompatible(R, "ref", "deref");
120 }
121 
isCheckedPtrCapable(const clang::CXXRecordDecl * R)122 std::optional<bool> isCheckedPtrCapable(const clang::CXXRecordDecl *R) {
123   return isSmartPtrCompatible(R, "incrementCheckedPtrCount",
124                               "decrementCheckedPtrCount");
125 }
126 
isRefType(const std::string & Name)127 bool isRefType(const std::string &Name) {
128   return Name == "Ref" || Name == "RefAllowingPartiallyDestroyed" ||
129          Name == "RefPtr" || Name == "RefPtrAllowingPartiallyDestroyed";
130 }
131 
isRetainPtr(const std::string & Name)132 bool isRetainPtr(const std::string &Name) {
133   return Name == "RetainPtr" || Name == "RetainPtrArc";
134 }
135 
isCheckedPtr(const std::string & Name)136 bool isCheckedPtr(const std::string &Name) {
137   return Name == "CheckedPtr" || Name == "CheckedRef";
138 }
139 
isSmartPtrClass(const std::string & Name)140 bool isSmartPtrClass(const std::string &Name) {
141   return isRefType(Name) || isCheckedPtr(Name) || isRetainPtr(Name) ||
142          Name == "WeakPtr" || Name == "WeakPtrFactory" ||
143          Name == "WeakPtrFactoryWithBitField" || Name == "WeakPtrImplBase" ||
144          Name == "WeakPtrImplBaseSingleThread" || Name == "ThreadSafeWeakPtr" ||
145          Name == "ThreadSafeWeakOrStrongPtr" ||
146          Name == "ThreadSafeWeakPtrControlBlock" ||
147          Name == "ThreadSafeRefCountedAndCanMakeThreadSafeWeakPtr";
148 }
149 
isCtorOfRefCounted(const clang::FunctionDecl * F)150 bool isCtorOfRefCounted(const clang::FunctionDecl *F) {
151   assert(F);
152   const std::string &FunctionName = safeGetName(F);
153 
154   return isRefType(FunctionName) || FunctionName == "adoptRef" ||
155          FunctionName == "UniqueRef" || FunctionName == "makeUniqueRef" ||
156          FunctionName == "makeUniqueRefWithoutFastMallocCheck"
157 
158          || FunctionName == "String" || FunctionName == "AtomString" ||
159          FunctionName == "UniqueString"
160          // FIXME: Implement as attribute.
161          || FunctionName == "Identifier";
162 }
163 
isCtorOfCheckedPtr(const clang::FunctionDecl * F)164 bool isCtorOfCheckedPtr(const clang::FunctionDecl *F) {
165   assert(F);
166   return isCheckedPtr(safeGetName(F));
167 }
168 
isCtorOfRetainPtr(const clang::FunctionDecl * F)169 bool isCtorOfRetainPtr(const clang::FunctionDecl *F) {
170   const std::string &FunctionName = safeGetName(F);
171   return FunctionName == "RetainPtr" || FunctionName == "adoptNS" ||
172          FunctionName == "adoptCF" || FunctionName == "retainPtr" ||
173          FunctionName == "RetainPtrArc" || FunctionName == "adoptNSArc";
174 }
175 
isCtorOfSafePtr(const clang::FunctionDecl * F)176 bool isCtorOfSafePtr(const clang::FunctionDecl *F) {
177   return isCtorOfRefCounted(F) || isCtorOfCheckedPtr(F) || isCtorOfRetainPtr(F);
178 }
179 
180 template <typename Predicate>
isPtrOfType(const clang::QualType T,Predicate Pred)181 static bool isPtrOfType(const clang::QualType T, Predicate Pred) {
182   QualType type = T;
183   while (!type.isNull()) {
184     if (auto *elaboratedT = type->getAs<ElaboratedType>()) {
185       type = elaboratedT->desugar();
186       continue;
187     }
188     if (auto *SpecialT = type->getAs<TemplateSpecializationType>()) {
189       auto *Decl = SpecialT->getTemplateName().getAsTemplateDecl();
190       return Decl && Pred(Decl->getNameAsString());
191     } else if (auto *DTS = type->getAs<DeducedTemplateSpecializationType>()) {
192       auto *Decl = DTS->getTemplateName().getAsTemplateDecl();
193       return Decl && Pred(Decl->getNameAsString());
194     } else
195       break;
196   }
197   return false;
198 }
199 
isRefOrCheckedPtrType(const clang::QualType T)200 bool isRefOrCheckedPtrType(const clang::QualType T) {
201   return isPtrOfType(
202       T, [](auto Name) { return isRefType(Name) || isCheckedPtr(Name); });
203 }
204 
isRetainPtrType(const clang::QualType T)205 bool isRetainPtrType(const clang::QualType T) {
206   return isPtrOfType(T, [](auto Name) { return isRetainPtr(Name); });
207 }
208 
isOwnerPtrType(const clang::QualType T)209 bool isOwnerPtrType(const clang::QualType T) {
210   return isPtrOfType(T, [](auto Name) {
211     return isRefType(Name) || isCheckedPtr(Name) || Name == "unique_ptr" ||
212            Name == "UniqueRef" || Name == "LazyUniqueRef";
213   });
214 }
215 
isUncounted(const QualType T)216 std::optional<bool> isUncounted(const QualType T) {
217   if (auto *Subst = dyn_cast<SubstTemplateTypeParmType>(T)) {
218     if (auto *Decl = Subst->getAssociatedDecl()) {
219       if (isRefType(safeGetName(Decl)))
220         return false;
221     }
222   }
223   return isUncounted(T->getAsCXXRecordDecl());
224 }
225 
isUnchecked(const QualType T)226 std::optional<bool> isUnchecked(const QualType T) {
227   if (auto *Subst = dyn_cast<SubstTemplateTypeParmType>(T)) {
228     if (auto *Decl = Subst->getAssociatedDecl()) {
229       if (isCheckedPtr(safeGetName(Decl)))
230         return false;
231     }
232   }
233   return isUnchecked(T->getAsCXXRecordDecl());
234 }
235 
visitTranslationUnitDecl(const TranslationUnitDecl * TUD)236 void RetainTypeChecker::visitTranslationUnitDecl(
237     const TranslationUnitDecl *TUD) {
238   IsARCEnabled = TUD->getLangOpts().ObjCAutoRefCount;
239   DefaultSynthProperties = TUD->getLangOpts().ObjCDefaultSynthProperties;
240 }
241 
visitTypedef(const TypedefDecl * TD)242 void RetainTypeChecker::visitTypedef(const TypedefDecl *TD) {
243   auto QT = TD->getUnderlyingType();
244   if (!QT->isPointerType())
245     return;
246 
247   auto PointeeQT = QT->getPointeeType();
248   const RecordType *RT = PointeeQT->getAs<RecordType>();
249   if (!RT) {
250     if (TD->hasAttr<ObjCBridgeAttr>() || TD->hasAttr<ObjCBridgeMutableAttr>()) {
251       if (auto *Type = TD->getTypeForDecl())
252         RecordlessTypes.insert(Type);
253     }
254     return;
255   }
256 
257   for (auto *Redecl : RT->getDecl()->getMostRecentDecl()->redecls()) {
258     if (Redecl->getAttr<ObjCBridgeAttr>() ||
259         Redecl->getAttr<ObjCBridgeMutableAttr>()) {
260       CFPointees.insert(RT);
261       return;
262     }
263   }
264 }
265 
isUnretained(const QualType QT,bool ignoreARC)266 bool RetainTypeChecker::isUnretained(const QualType QT, bool ignoreARC) {
267   if (ento::cocoa::isCocoaObjectRef(QT) && (!IsARCEnabled || ignoreARC))
268     return true;
269   auto CanonicalType = QT.getCanonicalType();
270   auto PointeeType = CanonicalType->getPointeeType();
271   auto *RT = dyn_cast_or_null<RecordType>(PointeeType.getTypePtrOrNull());
272   if (!RT) {
273     auto *Type = QT.getTypePtrOrNull();
274     while (Type) {
275       if (RecordlessTypes.contains(Type))
276         return true;
277       auto *ET = dyn_cast_or_null<ElaboratedType>(Type);
278       if (!ET)
279         break;
280       Type = ET->desugar().getTypePtrOrNull();
281     }
282   }
283   return RT && CFPointees.contains(RT);
284 }
285 
isUnretained(const QualType T,bool IsARCEnabled)286 std::optional<bool> isUnretained(const QualType T, bool IsARCEnabled) {
287   if (auto *Subst = dyn_cast<SubstTemplateTypeParmType>(T)) {
288     if (auto *Decl = Subst->getAssociatedDecl()) {
289       if (isRetainPtr(safeGetName(Decl)))
290         return false;
291     }
292   }
293   if ((ento::cocoa::isCocoaObjectRef(T) && !IsARCEnabled) ||
294       ento::coreFoundation::isCFObjectRef(T))
295     return true;
296 
297   // RetainPtr strips typedef for CF*Ref. Manually check for struct __CF* types.
298   auto CanonicalType = T.getCanonicalType();
299   auto *Type = CanonicalType.getTypePtrOrNull();
300   if (!Type)
301     return false;
302   auto Pointee = Type->getPointeeType();
303   auto *PointeeType = Pointee.getTypePtrOrNull();
304   if (!PointeeType)
305     return false;
306   auto *Record = PointeeType->getAsStructureType();
307   if (!Record)
308     return false;
309   auto *Decl = Record->getDecl();
310   if (!Decl)
311     return false;
312   auto TypeName = Decl->getName();
313   return TypeName.starts_with("__CF") || TypeName.starts_with("__CG") ||
314          TypeName.starts_with("__CM");
315 }
316 
isUncounted(const CXXRecordDecl * Class)317 std::optional<bool> isUncounted(const CXXRecordDecl* Class)
318 {
319   // Keep isRefCounted first as it's cheaper.
320   if (!Class || isRefCounted(Class))
321     return false;
322 
323   std::optional<bool> IsRefCountable = isRefCountable(Class);
324   if (!IsRefCountable)
325     return std::nullopt;
326 
327   return (*IsRefCountable);
328 }
329 
isUnchecked(const CXXRecordDecl * Class)330 std::optional<bool> isUnchecked(const CXXRecordDecl *Class) {
331   if (!Class || isCheckedPtr(Class))
332     return false; // Cheaper than below
333   return isCheckedPtrCapable(Class);
334 }
335 
isUncountedPtr(const QualType T)336 std::optional<bool> isUncountedPtr(const QualType T) {
337   if (T->isPointerType() || T->isReferenceType()) {
338     if (auto *CXXRD = T->getPointeeCXXRecordDecl())
339       return isUncounted(CXXRD);
340   }
341   return false;
342 }
343 
isUncheckedPtr(const QualType T)344 std::optional<bool> isUncheckedPtr(const QualType T) {
345   if (T->isPointerType() || T->isReferenceType()) {
346     if (auto *CXXRD = T->getPointeeCXXRecordDecl())
347       return isUnchecked(CXXRD);
348   }
349   return false;
350 }
351 
isUnsafePtr(const QualType T,bool IsArcEnabled)352 std::optional<bool> isUnsafePtr(const QualType T, bool IsArcEnabled) {
353   if (T->isPointerType() || T->isReferenceType()) {
354     if (auto *CXXRD = T->getPointeeCXXRecordDecl()) {
355       auto isUncountedPtr = isUncounted(CXXRD);
356       auto isUncheckedPtr = isUnchecked(CXXRD);
357       auto isUnretainedPtr = isUnretained(T, IsArcEnabled);
358       std::optional<bool> result;
359       if (isUncountedPtr)
360         result = *isUncountedPtr;
361       if (isUncheckedPtr)
362         result = result ? *result || *isUncheckedPtr : *isUncheckedPtr;
363       if (isUnretainedPtr)
364         result = result ? *result || *isUnretainedPtr : *isUnretainedPtr;
365       return result;
366     }
367   }
368   return false;
369 }
370 
isGetterOfSafePtr(const CXXMethodDecl * M)371 std::optional<bool> isGetterOfSafePtr(const CXXMethodDecl *M) {
372   assert(M);
373 
374   if (isa<CXXMethodDecl>(M)) {
375     const CXXRecordDecl *calleeMethodsClass = M->getParent();
376     auto className = safeGetName(calleeMethodsClass);
377     auto method = safeGetName(M);
378 
379     if (isCheckedPtr(className) && (method == "get" || method == "ptr"))
380       return true;
381 
382     if ((isRefType(className) && (method == "get" || method == "ptr")) ||
383         ((className == "String" || className == "AtomString" ||
384           className == "AtomStringImpl" || className == "UniqueString" ||
385           className == "UniqueStringImpl" || className == "Identifier") &&
386          method == "impl"))
387       return true;
388 
389     if (isRetainPtr(className) && method == "get")
390       return true;
391 
392     // Ref<T> -> T conversion
393     // FIXME: Currently allowing any Ref<T> -> whatever cast.
394     if (isRefType(className)) {
395       if (auto *maybeRefToRawOperator = dyn_cast<CXXConversionDecl>(M)) {
396         auto QT = maybeRefToRawOperator->getConversionType();
397         auto *T = QT.getTypePtrOrNull();
398         return T && (T->isPointerType() || T->isReferenceType());
399       }
400     }
401 
402     if (isCheckedPtr(className)) {
403       if (auto *maybeRefToRawOperator = dyn_cast<CXXConversionDecl>(M)) {
404         auto QT = maybeRefToRawOperator->getConversionType();
405         auto *T = QT.getTypePtrOrNull();
406         return T && (T->isPointerType() || T->isReferenceType());
407       }
408     }
409 
410     if (isRetainPtr(className)) {
411       if (auto *maybeRefToRawOperator = dyn_cast<CXXConversionDecl>(M)) {
412         auto QT = maybeRefToRawOperator->getConversionType();
413         auto *T = QT.getTypePtrOrNull();
414         return T && (T->isPointerType() || T->isReferenceType() ||
415                      T->isObjCObjectPointerType());
416       }
417     }
418   }
419   return false;
420 }
421 
isRefCounted(const CXXRecordDecl * R)422 bool isRefCounted(const CXXRecordDecl *R) {
423   assert(R);
424   if (auto *TmplR = R->getTemplateInstantiationPattern()) {
425     // FIXME: String/AtomString/UniqueString
426     const auto &ClassName = safeGetName(TmplR);
427     return isRefType(ClassName);
428   }
429   return false;
430 }
431 
isCheckedPtr(const CXXRecordDecl * R)432 bool isCheckedPtr(const CXXRecordDecl *R) {
433   assert(R);
434   if (auto *TmplR = R->getTemplateInstantiationPattern()) {
435     const auto &ClassName = safeGetName(TmplR);
436     return isCheckedPtr(ClassName);
437   }
438   return false;
439 }
440 
isRetainPtr(const CXXRecordDecl * R)441 bool isRetainPtr(const CXXRecordDecl *R) {
442   assert(R);
443   if (auto *TmplR = R->getTemplateInstantiationPattern())
444     return isRetainPtr(safeGetName(TmplR));
445   return false;
446 }
447 
isSmartPtr(const CXXRecordDecl * R)448 bool isSmartPtr(const CXXRecordDecl *R) {
449   assert(R);
450   if (auto *TmplR = R->getTemplateInstantiationPattern())
451     return isSmartPtrClass(safeGetName(TmplR));
452   return false;
453 }
454 
isPtrConversion(const FunctionDecl * F)455 bool isPtrConversion(const FunctionDecl *F) {
456   assert(F);
457   if (isCtorOfRefCounted(F))
458     return true;
459 
460   // FIXME: check # of params == 1
461   const auto FunctionName = safeGetName(F);
462   if (FunctionName == "getPtr" || FunctionName == "WeakPtr" ||
463       FunctionName == "dynamicDowncast" || FunctionName == "downcast" ||
464       FunctionName == "checkedDowncast" || FunctionName == "bit_cast" ||
465       FunctionName == "uncheckedDowncast" || FunctionName == "bitwise_cast" ||
466       FunctionName == "bridge_cast" || FunctionName == "bridge_id_cast" ||
467       FunctionName == "dynamic_cf_cast" || FunctionName == "checked_cf_cast" ||
468       FunctionName == "dynamic_objc_cast" ||
469       FunctionName == "checked_objc_cast")
470     return true;
471 
472   auto ReturnType = F->getReturnType();
473   if (auto *Type = ReturnType.getTypePtrOrNull()) {
474     if (auto *AttrType = dyn_cast<AttributedType>(Type)) {
475       if (auto *Attr = AttrType->getAttr()) {
476         if (auto *AnnotateType = dyn_cast<AnnotateTypeAttr>(Attr)) {
477           if (AnnotateType->getAnnotation() == "webkit.pointerconversion")
478             return true;
479         }
480       }
481     }
482   }
483 
484   return false;
485 }
486 
isTrivialBuiltinFunction(const FunctionDecl * F)487 bool isTrivialBuiltinFunction(const FunctionDecl *F) {
488   if (!F || !F->getDeclName().isIdentifier())
489     return false;
490   auto Name = F->getName();
491   return Name.starts_with("__builtin") || Name == "__libcpp_verbose_abort" ||
492          Name.starts_with("os_log") || Name.starts_with("_os_log");
493 }
494 
isSingleton(const FunctionDecl * F)495 bool isSingleton(const FunctionDecl *F) {
496   assert(F);
497   // FIXME: check # of params == 1
498   if (auto *MethodDecl = dyn_cast<CXXMethodDecl>(F)) {
499     if (!MethodDecl->isStatic())
500       return false;
501   }
502   const auto &NameStr = safeGetName(F);
503   StringRef Name = NameStr; // FIXME: Make safeGetName return StringRef.
504   return Name == "singleton" || Name.ends_with("Singleton");
505 }
506 
507 // We only care about statements so let's use the simple
508 // (non-recursive) visitor.
509 class TrivialFunctionAnalysisVisitor
510     : public ConstStmtVisitor<TrivialFunctionAnalysisVisitor, bool> {
511 
512   // Returns false if at least one child is non-trivial.
VisitChildren(const Stmt * S)513   bool VisitChildren(const Stmt *S) {
514     for (const Stmt *Child : S->children()) {
515       if (Child && !Visit(Child))
516         return false;
517     }
518 
519     return true;
520   }
521 
522   template <typename StmtOrDecl, typename CheckFunction>
WithCachedResult(const StmtOrDecl * S,CheckFunction Function)523   bool WithCachedResult(const StmtOrDecl *S, CheckFunction Function) {
524     auto CacheIt = Cache.find(S);
525     if (CacheIt != Cache.end())
526       return CacheIt->second;
527 
528     // Treat a recursive statement to be trivial until proven otherwise.
529     auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(S, true));
530     if (!IsNew)
531       return RecursiveIt->second;
532 
533     bool Result = Function();
534 
535     if (!Result) {
536       for (auto &It : RecursiveFn)
537         It.second = false;
538     }
539     RecursiveIt = RecursiveFn.find(S);
540     assert(RecursiveIt != RecursiveFn.end());
541     Result = RecursiveIt->second;
542     RecursiveFn.erase(RecursiveIt);
543     Cache[S] = Result;
544 
545     return Result;
546   }
547 
548 public:
549   using CacheTy = TrivialFunctionAnalysis::CacheTy;
550 
TrivialFunctionAnalysisVisitor(CacheTy & Cache)551   TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {}
552 
IsFunctionTrivial(const Decl * D)553   bool IsFunctionTrivial(const Decl *D) {
554     if (auto *FnDecl = dyn_cast<FunctionDecl>(D)) {
555       if (FnDecl->isVirtualAsWritten())
556         return false;
557     }
558     return WithCachedResult(D, [&]() {
559       if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
560         for (auto *CtorInit : CtorDecl->inits()) {
561           if (!Visit(CtorInit->getInit()))
562             return false;
563         }
564       }
565       const Stmt *Body = D->getBody();
566       if (!Body)
567         return false;
568       return Visit(Body);
569     });
570   }
571 
VisitStmt(const Stmt * S)572   bool VisitStmt(const Stmt *S) {
573     // All statements are non-trivial unless overriden later.
574     // Don't even recurse into children by default.
575     return false;
576   }
577 
VisitAttributedStmt(const AttributedStmt * AS)578   bool VisitAttributedStmt(const AttributedStmt *AS) {
579     // Ignore attributes.
580     return Visit(AS->getSubStmt());
581   }
582 
VisitCompoundStmt(const CompoundStmt * CS)583   bool VisitCompoundStmt(const CompoundStmt *CS) {
584     // A compound statement is allowed as long each individual sub-statement
585     // is trivial.
586     return WithCachedResult(CS, [&]() { return VisitChildren(CS); });
587   }
588 
VisitReturnStmt(const ReturnStmt * RS)589   bool VisitReturnStmt(const ReturnStmt *RS) {
590     // A return statement is allowed as long as the return value is trivial.
591     if (auto *RV = RS->getRetValue())
592       return Visit(RV);
593     return true;
594   }
595 
VisitDeclStmt(const DeclStmt * DS)596   bool VisitDeclStmt(const DeclStmt *DS) { return VisitChildren(DS); }
VisitDoStmt(const DoStmt * DS)597   bool VisitDoStmt(const DoStmt *DS) { return VisitChildren(DS); }
VisitIfStmt(const IfStmt * IS)598   bool VisitIfStmt(const IfStmt *IS) {
599     return WithCachedResult(IS, [&]() { return VisitChildren(IS); });
600   }
VisitForStmt(const ForStmt * FS)601   bool VisitForStmt(const ForStmt *FS) {
602     return WithCachedResult(FS, [&]() { return VisitChildren(FS); });
603   }
VisitCXXForRangeStmt(const CXXForRangeStmt * FS)604   bool VisitCXXForRangeStmt(const CXXForRangeStmt *FS) {
605     return WithCachedResult(FS, [&]() { return VisitChildren(FS); });
606   }
VisitWhileStmt(const WhileStmt * WS)607   bool VisitWhileStmt(const WhileStmt *WS) {
608     return WithCachedResult(WS, [&]() { return VisitChildren(WS); });
609   }
VisitSwitchStmt(const SwitchStmt * SS)610   bool VisitSwitchStmt(const SwitchStmt *SS) { return VisitChildren(SS); }
VisitCaseStmt(const CaseStmt * CS)611   bool VisitCaseStmt(const CaseStmt *CS) { return VisitChildren(CS); }
VisitDefaultStmt(const DefaultStmt * DS)612   bool VisitDefaultStmt(const DefaultStmt *DS) { return VisitChildren(DS); }
613 
614   // break, continue, goto, and label statements are always trivial.
VisitBreakStmt(const BreakStmt *)615   bool VisitBreakStmt(const BreakStmt *) { return true; }
VisitContinueStmt(const ContinueStmt *)616   bool VisitContinueStmt(const ContinueStmt *) { return true; }
VisitGotoStmt(const GotoStmt *)617   bool VisitGotoStmt(const GotoStmt *) { return true; }
VisitLabelStmt(const LabelStmt *)618   bool VisitLabelStmt(const LabelStmt *) { return true; }
619 
VisitUnaryOperator(const UnaryOperator * UO)620   bool VisitUnaryOperator(const UnaryOperator *UO) {
621     // Unary operators are trivial if its operand is trivial except co_await.
622     return UO->getOpcode() != UO_Coawait && Visit(UO->getSubExpr());
623   }
624 
VisitBinaryOperator(const BinaryOperator * BO)625   bool VisitBinaryOperator(const BinaryOperator *BO) {
626     // Binary operators are trivial if their operands are trivial.
627     return Visit(BO->getLHS()) && Visit(BO->getRHS());
628   }
629 
VisitCompoundAssignOperator(const CompoundAssignOperator * CAO)630   bool VisitCompoundAssignOperator(const CompoundAssignOperator *CAO) {
631     // Compound assignment operator such as |= is trivial if its
632     // subexpresssions are trivial.
633     return VisitChildren(CAO);
634   }
635 
VisitArraySubscriptExpr(const ArraySubscriptExpr * ASE)636   bool VisitArraySubscriptExpr(const ArraySubscriptExpr *ASE) {
637     return VisitChildren(ASE);
638   }
639 
VisitConditionalOperator(const ConditionalOperator * CO)640   bool VisitConditionalOperator(const ConditionalOperator *CO) {
641     // Ternary operators are trivial if their conditions & values are trivial.
642     return VisitChildren(CO);
643   }
644 
VisitAtomicExpr(const AtomicExpr * E)645   bool VisitAtomicExpr(const AtomicExpr *E) { return VisitChildren(E); }
646 
VisitStaticAssertDecl(const StaticAssertDecl * SAD)647   bool VisitStaticAssertDecl(const StaticAssertDecl *SAD) {
648     // Any static_assert is considered trivial.
649     return true;
650   }
651 
VisitCallExpr(const CallExpr * CE)652   bool VisitCallExpr(const CallExpr *CE) {
653     if (!checkArguments(CE))
654       return false;
655 
656     auto *Callee = CE->getDirectCallee();
657     if (!Callee)
658       return false;
659 
660     if (isPtrConversion(Callee))
661       return true;
662 
663     const auto &Name = safeGetName(Callee);
664 
665     if (Callee->isInStdNamespace() &&
666         (Name == "addressof" || Name == "forward" || Name == "move"))
667       return true;
668 
669     if (Name == "WTFCrashWithInfo" || Name == "WTFBreakpointTrap" ||
670         Name == "WTFReportBacktrace" ||
671         Name == "WTFCrashWithSecurityImplication" || Name == "WTFCrash" ||
672         Name == "WTFReportAssertionFailure" || Name == "isMainThread" ||
673         Name == "isMainThreadOrGCThread" || Name == "isMainRunLoop" ||
674         Name == "isWebThread" || Name == "isUIThread" ||
675         Name == "mayBeGCThread" || Name == "compilerFenceForCrash" ||
676         isTrivialBuiltinFunction(Callee))
677       return true;
678 
679     return IsFunctionTrivial(Callee);
680   }
681 
682   bool
VisitSubstNonTypeTemplateParmExpr(const SubstNonTypeTemplateParmExpr * E)683   VisitSubstNonTypeTemplateParmExpr(const SubstNonTypeTemplateParmExpr *E) {
684     // Non-type template paramter is compile time constant and trivial.
685     return true;
686   }
687 
VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr * E)688   bool VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *E) {
689     return VisitChildren(E);
690   }
691 
VisitPredefinedExpr(const PredefinedExpr * E)692   bool VisitPredefinedExpr(const PredefinedExpr *E) {
693     // A predefined identifier such as "func" is considered trivial.
694     return true;
695   }
696 
VisitOffsetOfExpr(const OffsetOfExpr * OE)697   bool VisitOffsetOfExpr(const OffsetOfExpr *OE) {
698     // offsetof(T, D) is considered trivial.
699     return true;
700   }
701 
VisitCXXMemberCallExpr(const CXXMemberCallExpr * MCE)702   bool VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE) {
703     if (!checkArguments(MCE))
704       return false;
705 
706     bool TrivialThis = Visit(MCE->getImplicitObjectArgument());
707     if (!TrivialThis)
708       return false;
709 
710     auto *Callee = MCE->getMethodDecl();
711     if (!Callee)
712       return false;
713 
714     auto Name = safeGetName(Callee);
715     if (Name == "ref" || Name == "incrementCheckedPtrCount")
716       return true;
717 
718     std::optional<bool> IsGetterOfRefCounted = isGetterOfSafePtr(Callee);
719     if (IsGetterOfRefCounted && *IsGetterOfRefCounted)
720       return true;
721 
722     // Recursively descend into the callee to confirm that it's trivial as well.
723     return IsFunctionTrivial(Callee);
724   }
725 
VisitCXXOperatorCallExpr(const CXXOperatorCallExpr * OCE)726   bool VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) {
727     if (!checkArguments(OCE))
728       return false;
729     auto *Callee = OCE->getCalleeDecl();
730     if (!Callee)
731       return false;
732     // Recursively descend into the callee to confirm that it's trivial as well.
733     return IsFunctionTrivial(Callee);
734   }
735 
VisitCXXDefaultArgExpr(const CXXDefaultArgExpr * E)736   bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
737     if (auto *Expr = E->getExpr()) {
738       if (!Visit(Expr))
739         return false;
740     }
741     return true;
742   }
743 
checkArguments(const CallExpr * CE)744   bool checkArguments(const CallExpr *CE) {
745     for (const Expr *Arg : CE->arguments()) {
746       if (Arg && !Visit(Arg))
747         return false;
748     }
749     return true;
750   }
751 
VisitCXXConstructExpr(const CXXConstructExpr * CE)752   bool VisitCXXConstructExpr(const CXXConstructExpr *CE) {
753     for (const Expr *Arg : CE->arguments()) {
754       if (Arg && !Visit(Arg))
755         return false;
756     }
757 
758     // Recursively descend into the callee to confirm that it's trivial.
759     return IsFunctionTrivial(CE->getConstructor());
760   }
761 
VisitCXXInheritedCtorInitExpr(const CXXInheritedCtorInitExpr * E)762   bool VisitCXXInheritedCtorInitExpr(const CXXInheritedCtorInitExpr *E) {
763     return IsFunctionTrivial(E->getConstructor());
764   }
765 
VisitCXXNewExpr(const CXXNewExpr * NE)766   bool VisitCXXNewExpr(const CXXNewExpr *NE) { return VisitChildren(NE); }
767 
VisitImplicitCastExpr(const ImplicitCastExpr * ICE)768   bool VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
769     return Visit(ICE->getSubExpr());
770   }
771 
VisitExplicitCastExpr(const ExplicitCastExpr * ECE)772   bool VisitExplicitCastExpr(const ExplicitCastExpr *ECE) {
773     return Visit(ECE->getSubExpr());
774   }
775 
VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr * VMT)776   bool VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *VMT) {
777     return Visit(VMT->getSubExpr());
778   }
779 
VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr * BTE)780   bool VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *BTE) {
781     if (auto *Temp = BTE->getTemporary()) {
782       if (!TrivialFunctionAnalysis::isTrivialImpl(Temp->getDestructor(), Cache))
783         return false;
784     }
785     return Visit(BTE->getSubExpr());
786   }
787 
VisitArrayInitLoopExpr(const ArrayInitLoopExpr * AILE)788   bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *AILE) {
789     return Visit(AILE->getCommonExpr()) && Visit(AILE->getSubExpr());
790   }
791 
VisitArrayInitIndexExpr(const ArrayInitIndexExpr * AIIE)792   bool VisitArrayInitIndexExpr(const ArrayInitIndexExpr *AIIE) {
793     return true; // The current array index in VisitArrayInitLoopExpr is always
794                  // trivial.
795   }
796 
VisitOpaqueValueExpr(const OpaqueValueExpr * OVE)797   bool VisitOpaqueValueExpr(const OpaqueValueExpr *OVE) {
798     return Visit(OVE->getSourceExpr());
799   }
800 
VisitExprWithCleanups(const ExprWithCleanups * EWC)801   bool VisitExprWithCleanups(const ExprWithCleanups *EWC) {
802     return Visit(EWC->getSubExpr());
803   }
804 
VisitParenExpr(const ParenExpr * PE)805   bool VisitParenExpr(const ParenExpr *PE) { return Visit(PE->getSubExpr()); }
806 
VisitInitListExpr(const InitListExpr * ILE)807   bool VisitInitListExpr(const InitListExpr *ILE) {
808     for (const Expr *Child : ILE->inits()) {
809       if (Child && !Visit(Child))
810         return false;
811     }
812     return true;
813   }
814 
VisitMemberExpr(const MemberExpr * ME)815   bool VisitMemberExpr(const MemberExpr *ME) {
816     // Field access is allowed but the base pointer may itself be non-trivial.
817     return Visit(ME->getBase());
818   }
819 
VisitCXXThisExpr(const CXXThisExpr * CTE)820   bool VisitCXXThisExpr(const CXXThisExpr *CTE) {
821     // The expression 'this' is always trivial, be it explicit or implicit.
822     return true;
823   }
824 
VisitCXXNullPtrLiteralExpr(const CXXNullPtrLiteralExpr * E)825   bool VisitCXXNullPtrLiteralExpr(const CXXNullPtrLiteralExpr *E) {
826     // nullptr is trivial.
827     return true;
828   }
829 
VisitDeclRefExpr(const DeclRefExpr * DRE)830   bool VisitDeclRefExpr(const DeclRefExpr *DRE) {
831     // The use of a variable is trivial.
832     return true;
833   }
834 
835   // Constant literal expressions are always trivial
VisitIntegerLiteral(const IntegerLiteral * E)836   bool VisitIntegerLiteral(const IntegerLiteral *E) { return true; }
VisitFloatingLiteral(const FloatingLiteral * E)837   bool VisitFloatingLiteral(const FloatingLiteral *E) { return true; }
VisitFixedPointLiteral(const FixedPointLiteral * E)838   bool VisitFixedPointLiteral(const FixedPointLiteral *E) { return true; }
VisitCharacterLiteral(const CharacterLiteral * E)839   bool VisitCharacterLiteral(const CharacterLiteral *E) { return true; }
VisitStringLiteral(const StringLiteral * E)840   bool VisitStringLiteral(const StringLiteral *E) { return true; }
VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr * E)841   bool VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *E) { return true; }
842 
VisitConstantExpr(const ConstantExpr * CE)843   bool VisitConstantExpr(const ConstantExpr *CE) {
844     // Constant expressions are trivial.
845     return true;
846   }
847 
VisitImplicitValueInitExpr(const ImplicitValueInitExpr * IVIE)848   bool VisitImplicitValueInitExpr(const ImplicitValueInitExpr *IVIE) {
849     // An implicit value initialization is trvial.
850     return true;
851   }
852 
853 private:
854   CacheTy &Cache;
855   CacheTy RecursiveFn;
856 };
857 
isTrivialImpl(const Decl * D,TrivialFunctionAnalysis::CacheTy & Cache)858 bool TrivialFunctionAnalysis::isTrivialImpl(
859     const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
860   TrivialFunctionAnalysisVisitor V(Cache);
861   return V.IsFunctionTrivial(D);
862 }
863 
isTrivialImpl(const Stmt * S,TrivialFunctionAnalysis::CacheTy & Cache)864 bool TrivialFunctionAnalysis::isTrivialImpl(
865     const Stmt *S, TrivialFunctionAnalysis::CacheTy &Cache) {
866   TrivialFunctionAnalysisVisitor V(Cache);
867   bool Result = V.Visit(S);
868   assert(Cache.contains(S) && "Top-level statement not properly cached!");
869   return Result;
870 }
871 
872 } // namespace clang
873