xref: /freebsd/contrib/llvm-project/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/AST/Type.h"
10 #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
11 #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
12 #include "clang/StaticAnalyzer/Core/Checker.h"
13 #include "clang/StaticAnalyzer/Core/CheckerManager.h"
14 #include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
15 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
16 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
18 #include "llvm/ADT/FoldingSet.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/Casting.h"
21 #include <optional>
22 #include <string_view>
23 
24 #include "TaggedUnionModeling.h"
25 
26 using namespace clang;
27 using namespace ento;
28 using namespace tagged_union_modeling;
29 
30 REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
31 
32 namespace clang::ento::tagged_union_modeling {
33 
34 const CXXConstructorDecl *
35 getConstructorDeclarationForCall(const CallEvent &Call) {
36   const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
37   if (!ConstructorCall)
38     return nullptr;
39 
40   return ConstructorCall->getDecl();
41 }
42 
43 bool isCopyConstructorCall(const CallEvent &Call) {
44   if (const CXXConstructorDecl *ConstructorDecl =
45           getConstructorDeclarationForCall(Call))
46     return ConstructorDecl->isCopyConstructor();
47   return false;
48 }
49 
50 bool isCopyAssignmentCall(const CallEvent &Call) {
51   const Decl *CopyAssignmentDecl = Call.getDecl();
52 
53   if (const auto *AsMethodDecl =
54           dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
55     return AsMethodDecl->isCopyAssignmentOperator();
56   return false;
57 }
58 
59 bool isMoveConstructorCall(const CallEvent &Call) {
60   const CXXConstructorDecl *ConstructorDecl =
61       getConstructorDeclarationForCall(Call);
62   if (!ConstructorDecl)
63     return false;
64 
65   return ConstructorDecl->isMoveConstructor();
66 }
67 
68 bool isMoveAssignmentCall(const CallEvent &Call) {
69   const Decl *CopyAssignmentDecl = Call.getDecl();
70 
71   const auto *AsMethodDecl =
72       dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
73   if (!AsMethodDecl)
74     return false;
75 
76   return AsMethodDecl->isMoveAssignmentOperator();
77 }
78 
79 bool isStdType(const Type *Type, llvm::StringRef TypeName) {
80   auto *Decl = Type->getAsRecordDecl();
81   if (!Decl)
82     return false;
83   return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
84 }
85 
86 bool isStdVariant(const Type *Type) {
87   return isStdType(Type, llvm::StringLiteral("variant"));
88 }
89 
90 } // end of namespace clang::ento::tagged_union_modeling
91 
92 static std::optional<ArrayRef<TemplateArgument>>
93 getTemplateArgsFromVariant(const Type *VariantType) {
94   const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
95   if (!TempSpecType)
96     return {};
97 
98   return TempSpecType->template_arguments();
99 }
100 
101 static std::optional<QualType>
102 getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
103   std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
104       getTemplateArgsFromVariant(varType);
105   if (!VariantTemplates)
106     return {};
107 
108   return (*VariantTemplates)[i].getAsType();
109 }
110 
111 static bool isVowel(char a) {
112   switch (a) {
113   case 'a':
114   case 'e':
115   case 'i':
116   case 'o':
117   case 'u':
118     return true;
119   default:
120     return false;
121   }
122 }
123 
124 static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
125   if (isVowel(a))
126     return "an";
127   return "a";
128 }
129 
130 class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
131   // Call descriptors to find relevant calls
132   CallDescription VariantConstructor{CDM::CXXMethod,
133                                      {"std", "variant", "variant"}};
134   CallDescription VariantAssignmentOperator{CDM::CXXMethod,
135                                             {"std", "variant", "operator="}};
136   CallDescription StdGet{CDM::SimpleFunc, {"std", "get"}, 1, 1};
137 
138   BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
139 
140 public:
141   ProgramStateRef checkRegionChanges(ProgramStateRef State,
142                                      const InvalidatedSymbols *,
143                                      ArrayRef<const MemRegion *>,
144                                      ArrayRef<const MemRegion *> Regions,
145                                      const LocationContext *,
146                                      const CallEvent *Call) const {
147     if (!Call)
148       return State;
149 
150     return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
151         *Call, State, Regions);
152   }
153 
154   bool evalCall(const CallEvent &Call, CheckerContext &C) const {
155     // Check if the call was not made from a system header. If it was then
156     // we do an early return because it is part of the implementation.
157     if (Call.isCalledFromSystemHeader())
158       return false;
159 
160     if (StdGet.matches(Call))
161       return handleStdGetCall(Call, C);
162 
163     // First check if a constructor call is happening. If it is a
164     // constructor call, check if it is an std::variant constructor call.
165     bool IsVariantConstructor =
166         isa<CXXConstructorCall>(Call) && VariantConstructor.matches(Call);
167     bool IsVariantAssignmentOperatorCall =
168         isa<CXXMemberOperatorCall>(Call) &&
169         VariantAssignmentOperator.matches(Call);
170 
171     if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
172       if (Call.getNumArgs() == 0 && IsVariantConstructor) {
173         handleDefaultConstructor(cast<CXXConstructorCall>(&Call), C);
174         return true;
175       }
176 
177       // FIXME Later this checker should be extended to handle constructors
178       // with multiple arguments.
179       if (Call.getNumArgs() != 1)
180         return false;
181 
182       SVal ThisSVal;
183       if (IsVariantConstructor) {
184         const auto &AsConstructorCall = cast<CXXConstructorCall>(Call);
185         ThisSVal = AsConstructorCall.getCXXThisVal();
186       } else if (IsVariantAssignmentOperatorCall) {
187         const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Call);
188         ThisSVal = AsMemberOpCall.getCXXThisVal();
189       } else {
190         return false;
191       }
192 
193       handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
194       return true;
195     }
196     return false;
197   }
198 
199 private:
200   // The default constructed std::variant must be handled separately
201   // by default the std::variant is going to hold a default constructed instance
202   // of the first type of the possible types
203   void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
204                                 CheckerContext &C) const {
205     SVal ThisSVal = ConstructorCall->getCXXThisVal();
206 
207     const auto *const ThisMemRegion = ThisSVal.getAsRegion();
208     if (!ThisMemRegion)
209       return;
210 
211     std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
212         ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0);
213     if (!DefaultType)
214       return;
215 
216     ProgramStateRef State = ConstructorCall->getState();
217     State = State->set<VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
218     C.addTransition(State);
219   }
220 
221   bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
222     ProgramStateRef State = Call.getState();
223 
224     const auto &ArgType = Call.getArgSVal(0)
225                               .getType(C.getASTContext())
226                               ->getPointeeType()
227                               .getTypePtr();
228     // We have to make sure that the argument is an std::variant.
229     // There is another std::get with std::pair argument
230     if (!isStdVariant(ArgType))
231       return false;
232 
233     // Get the mem region of the argument std::variant and look up the type
234     // information that we know about it.
235     const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion();
236     const QualType *StoredType = State->get<VariantHeldTypeMap>(ArgMemRegion);
237     if (!StoredType)
238       return false;
239 
240     const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr());
241     const FunctionDecl *FD = CE->getDirectCallee();
242     if (FD->getTemplateSpecializationArgs()->size() < 1)
243       return false;
244 
245     const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
246     // std::get's first template parameter can be the type we want to get
247     // out of the std::variant or a natural number which is the position of
248     // the requested type in the argument type list of the std::variant's
249     // argument.
250     QualType RetrievedType;
251     switch (TypeOut.getKind()) {
252     case TemplateArgument::ArgKind::Type:
253       RetrievedType = TypeOut.getAsType();
254       break;
255     case TemplateArgument::ArgKind::Integral:
256       // In the natural number case we look up which type corresponds to the
257       // number.
258       if (std::optional<QualType> NthTemplate =
259               getNthTemplateTypeArgFromVariant(
260                   ArgType, TypeOut.getAsIntegral().getSExtValue())) {
261         RetrievedType = *NthTemplate;
262         break;
263       }
264       [[fallthrough]];
265     default:
266       return false;
267     }
268 
269     QualType RetrievedCanonicalType = RetrievedType.getCanonicalType();
270     QualType StoredCanonicalType = StoredType->getCanonicalType();
271     if (RetrievedCanonicalType == StoredCanonicalType)
272       return true;
273 
274     ExplodedNode *ErrNode = C.generateNonFatalErrorNode();
275     if (!ErrNode)
276       return false;
277     llvm::SmallString<128> Str;
278     llvm::raw_svector_ostream OS(Str);
279     std::string StoredTypeName = StoredType->getAsString();
280     std::string RetrievedTypeName = RetrievedType.getAsString();
281     OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held "
282        << indefiniteArticleBasedOnVowel(StoredTypeName[0]) << " \'"
283        << StoredTypeName << "\', not "
284        << indefiniteArticleBasedOnVowel(RetrievedTypeName[0]) << " \'"
285        << RetrievedTypeName << "\'";
286     auto R = std::make_unique<PathSensitiveBugReport>(BadVariantType, OS.str(),
287                                                       ErrNode);
288     C.emitReport(std::move(R));
289     return true;
290   }
291 };
292 
293 bool clang::ento::shouldRegisterStdVariantChecker(
294     clang::ento::CheckerManager const &mgr) {
295   return true;
296 }
297 
298 void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) {
299   mgr.registerChecker<StdVariantChecker>();
300 }
301