xref: /freebsd/contrib/llvm-project/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp (revision a90b9d0159070121c221b966469c3e36d912bf82)
1 //===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/AST/Type.h"
10 #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
11 #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
12 #include "clang/StaticAnalyzer/Core/Checker.h"
13 #include "clang/StaticAnalyzer/Core/CheckerManager.h"
14 #include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
15 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
16 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
18 #include "llvm/ADT/FoldingSet.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/Casting.h"
21 #include <optional>
22 #include <string_view>
23 
24 #include "TaggedUnionModeling.h"
25 
26 using namespace clang;
27 using namespace ento;
28 using namespace tagged_union_modeling;
29 
30 REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
31 
32 namespace clang::ento::tagged_union_modeling {
33 
34 const CXXConstructorDecl *
35 getConstructorDeclarationForCall(const CallEvent &Call) {
36   const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
37   if (!ConstructorCall)
38     return nullptr;
39 
40   return ConstructorCall->getDecl();
41 }
42 
43 bool isCopyConstructorCall(const CallEvent &Call) {
44   if (const CXXConstructorDecl *ConstructorDecl =
45           getConstructorDeclarationForCall(Call))
46     return ConstructorDecl->isCopyConstructor();
47   return false;
48 }
49 
50 bool isCopyAssignmentCall(const CallEvent &Call) {
51   const Decl *CopyAssignmentDecl = Call.getDecl();
52 
53   if (const auto *AsMethodDecl =
54           dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
55     return AsMethodDecl->isCopyAssignmentOperator();
56   return false;
57 }
58 
59 bool isMoveConstructorCall(const CallEvent &Call) {
60   const CXXConstructorDecl *ConstructorDecl =
61       getConstructorDeclarationForCall(Call);
62   if (!ConstructorDecl)
63     return false;
64 
65   return ConstructorDecl->isMoveConstructor();
66 }
67 
68 bool isMoveAssignmentCall(const CallEvent &Call) {
69   const Decl *CopyAssignmentDecl = Call.getDecl();
70 
71   const auto *AsMethodDecl =
72       dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
73   if (!AsMethodDecl)
74     return false;
75 
76   return AsMethodDecl->isMoveAssignmentOperator();
77 }
78 
79 bool isStdType(const Type *Type, llvm::StringRef TypeName) {
80   auto *Decl = Type->getAsRecordDecl();
81   if (!Decl)
82     return false;
83   return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
84 }
85 
86 bool isStdVariant(const Type *Type) {
87   return isStdType(Type, llvm::StringLiteral("variant"));
88 }
89 
90 } // end of namespace clang::ento::tagged_union_modeling
91 
92 static std::optional<ArrayRef<TemplateArgument>>
93 getTemplateArgsFromVariant(const Type *VariantType) {
94   const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
95   if (!TempSpecType)
96     return {};
97 
98   return TempSpecType->template_arguments();
99 }
100 
101 static std::optional<QualType>
102 getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
103   std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
104       getTemplateArgsFromVariant(varType);
105   if (!VariantTemplates)
106     return {};
107 
108   return (*VariantTemplates)[i].getAsType();
109 }
110 
111 static bool isVowel(char a) {
112   switch (a) {
113   case 'a':
114   case 'e':
115   case 'i':
116   case 'o':
117   case 'u':
118     return true;
119   default:
120     return false;
121   }
122 }
123 
124 static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
125   if (isVowel(a))
126     return "an";
127   return "a";
128 }
129 
130 class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
131   // Call descriptors to find relevant calls
132   CallDescription VariantConstructor{{"std", "variant", "variant"}};
133   CallDescription VariantAssignmentOperator{{"std", "variant", "operator="}};
134   CallDescription StdGet{{"std", "get"}, 1, 1};
135 
136   BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
137 
138 public:
139   ProgramStateRef checkRegionChanges(ProgramStateRef State,
140                                      const InvalidatedSymbols *,
141                                      ArrayRef<const MemRegion *>,
142                                      ArrayRef<const MemRegion *> Regions,
143                                      const LocationContext *,
144                                      const CallEvent *Call) const {
145     if (!Call)
146       return State;
147 
148     return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
149         *Call, State, Regions);
150   }
151 
152   bool evalCall(const CallEvent &Call, CheckerContext &C) const {
153     // Check if the call was not made from a system header. If it was then
154     // we do an early return because it is part of the implementation.
155     if (Call.isCalledFromSystemHeader())
156       return false;
157 
158     if (StdGet.matches(Call))
159       return handleStdGetCall(Call, C);
160 
161     // First check if a constructor call is happening. If it is a
162     // constructor call, check if it is an std::variant constructor call.
163     bool IsVariantConstructor =
164         isa<CXXConstructorCall>(Call) && VariantConstructor.matches(Call);
165     bool IsVariantAssignmentOperatorCall =
166         isa<CXXMemberOperatorCall>(Call) &&
167         VariantAssignmentOperator.matches(Call);
168 
169     if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
170       if (Call.getNumArgs() == 0 && IsVariantConstructor) {
171         handleDefaultConstructor(cast<CXXConstructorCall>(&Call), C);
172         return true;
173       }
174 
175       // FIXME Later this checker should be extended to handle constructors
176       // with multiple arguments.
177       if (Call.getNumArgs() != 1)
178         return false;
179 
180       SVal ThisSVal;
181       if (IsVariantConstructor) {
182         const auto &AsConstructorCall = cast<CXXConstructorCall>(Call);
183         ThisSVal = AsConstructorCall.getCXXThisVal();
184       } else if (IsVariantAssignmentOperatorCall) {
185         const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Call);
186         ThisSVal = AsMemberOpCall.getCXXThisVal();
187       } else {
188         return false;
189       }
190 
191       handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
192       return true;
193     }
194     return false;
195   }
196 
197 private:
198   // The default constructed std::variant must be handled separately
199   // by default the std::variant is going to hold a default constructed instance
200   // of the first type of the possible types
201   void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
202                                 CheckerContext &C) const {
203     SVal ThisSVal = ConstructorCall->getCXXThisVal();
204 
205     const auto *const ThisMemRegion = ThisSVal.getAsRegion();
206     if (!ThisMemRegion)
207       return;
208 
209     std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
210         ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0);
211     if (!DefaultType)
212       return;
213 
214     ProgramStateRef State = ConstructorCall->getState();
215     State = State->set<VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
216     C.addTransition(State);
217   }
218 
219   bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
220     ProgramStateRef State = Call.getState();
221 
222     const auto &ArgType = Call.getArgSVal(0)
223                               .getType(C.getASTContext())
224                               ->getPointeeType()
225                               .getTypePtr();
226     // We have to make sure that the argument is an std::variant.
227     // There is another std::get with std::pair argument
228     if (!isStdVariant(ArgType))
229       return false;
230 
231     // Get the mem region of the argument std::variant and look up the type
232     // information that we know about it.
233     const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion();
234     const QualType *StoredType = State->get<VariantHeldTypeMap>(ArgMemRegion);
235     if (!StoredType)
236       return false;
237 
238     const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr());
239     const FunctionDecl *FD = CE->getDirectCallee();
240     if (FD->getTemplateSpecializationArgs()->size() < 1)
241       return false;
242 
243     const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
244     // std::get's first template parameter can be the type we want to get
245     // out of the std::variant or a natural number which is the position of
246     // the requested type in the argument type list of the std::variant's
247     // argument.
248     QualType RetrievedType;
249     switch (TypeOut.getKind()) {
250     case TemplateArgument::ArgKind::Type:
251       RetrievedType = TypeOut.getAsType();
252       break;
253     case TemplateArgument::ArgKind::Integral:
254       // In the natural number case we look up which type corresponds to the
255       // number.
256       if (std::optional<QualType> NthTemplate =
257               getNthTemplateTypeArgFromVariant(
258                   ArgType, TypeOut.getAsIntegral().getSExtValue())) {
259         RetrievedType = *NthTemplate;
260         break;
261       }
262       [[fallthrough]];
263     default:
264       return false;
265     }
266 
267     QualType RetrievedCanonicalType = RetrievedType.getCanonicalType();
268     QualType StoredCanonicalType = StoredType->getCanonicalType();
269     if (RetrievedCanonicalType == StoredCanonicalType)
270       return true;
271 
272     ExplodedNode *ErrNode = C.generateNonFatalErrorNode();
273     if (!ErrNode)
274       return false;
275     llvm::SmallString<128> Str;
276     llvm::raw_svector_ostream OS(Str);
277     std::string StoredTypeName = StoredType->getAsString();
278     std::string RetrievedTypeName = RetrievedType.getAsString();
279     OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held "
280        << indefiniteArticleBasedOnVowel(StoredTypeName[0]) << " \'"
281        << StoredTypeName << "\', not "
282        << indefiniteArticleBasedOnVowel(RetrievedTypeName[0]) << " \'"
283        << RetrievedTypeName << "\'";
284     auto R = std::make_unique<PathSensitiveBugReport>(BadVariantType, OS.str(),
285                                                       ErrNode);
286     C.emitReport(std::move(R));
287     return true;
288   }
289 };
290 
291 bool clang::ento::shouldRegisterStdVariantChecker(
292     clang::ento::CheckerManager const &mgr) {
293   return true;
294 }
295 
296 void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) {
297   mgr.registerChecker<StdVariantChecker>();
298 }