1 //===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "clang/AST/Type.h"
10 #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
11 #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
12 #include "clang/StaticAnalyzer/Core/Checker.h"
13 #include "clang/StaticAnalyzer/Core/CheckerManager.h"
14 #include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
15 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
16 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
18 #include "llvm/ADT/FoldingSet.h"
19 #include "llvm/ADT/StringRef.h"
20 #include <optional>
21
22 #include "TaggedUnionModeling.h"
23
24 using namespace clang;
25 using namespace ento;
26 using namespace tagged_union_modeling;
27
28 REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
29
30 namespace clang::ento::tagged_union_modeling {
31
32 static const CXXConstructorDecl *
getConstructorDeclarationForCall(const CallEvent & Call)33 getConstructorDeclarationForCall(const CallEvent &Call) {
34 const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
35 if (!ConstructorCall)
36 return nullptr;
37
38 return ConstructorCall->getDecl();
39 }
40
isCopyConstructorCall(const CallEvent & Call)41 bool isCopyConstructorCall(const CallEvent &Call) {
42 if (const CXXConstructorDecl *ConstructorDecl =
43 getConstructorDeclarationForCall(Call))
44 return ConstructorDecl->isCopyConstructor();
45 return false;
46 }
47
isCopyAssignmentCall(const CallEvent & Call)48 bool isCopyAssignmentCall(const CallEvent &Call) {
49 const Decl *CopyAssignmentDecl = Call.getDecl();
50
51 if (const auto *AsMethodDecl =
52 dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
53 return AsMethodDecl->isCopyAssignmentOperator();
54 return false;
55 }
56
isMoveConstructorCall(const CallEvent & Call)57 bool isMoveConstructorCall(const CallEvent &Call) {
58 const CXXConstructorDecl *ConstructorDecl =
59 getConstructorDeclarationForCall(Call);
60 if (!ConstructorDecl)
61 return false;
62
63 return ConstructorDecl->isMoveConstructor();
64 }
65
isMoveAssignmentCall(const CallEvent & Call)66 bool isMoveAssignmentCall(const CallEvent &Call) {
67 const Decl *CopyAssignmentDecl = Call.getDecl();
68
69 const auto *AsMethodDecl =
70 dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
71 if (!AsMethodDecl)
72 return false;
73
74 return AsMethodDecl->isMoveAssignmentOperator();
75 }
76
isStdType(const Type * Type,llvm::StringRef TypeName)77 static bool isStdType(const Type *Type, llvm::StringRef TypeName) {
78 auto *Decl = Type->getAsRecordDecl();
79 if (!Decl)
80 return false;
81 return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
82 }
83
isStdVariant(const Type * Type)84 bool isStdVariant(const Type *Type) {
85 return isStdType(Type, llvm::StringLiteral("variant"));
86 }
87
88 } // end of namespace clang::ento::tagged_union_modeling
89
90 static std::optional<ArrayRef<TemplateArgument>>
getTemplateArgsFromVariant(const Type * VariantType)91 getTemplateArgsFromVariant(const Type *VariantType) {
92 const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
93 if (!TempSpecType)
94 return {};
95
96 return TempSpecType->template_arguments();
97 }
98
99 static std::optional<QualType>
getNthTemplateTypeArgFromVariant(const Type * varType,unsigned i)100 getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
101 std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
102 getTemplateArgsFromVariant(varType);
103 if (!VariantTemplates)
104 return {};
105
106 return (*VariantTemplates)[i].getAsType();
107 }
108
isVowel(char a)109 static bool isVowel(char a) {
110 switch (a) {
111 case 'a':
112 case 'e':
113 case 'i':
114 case 'o':
115 case 'u':
116 return true;
117 default:
118 return false;
119 }
120 }
121
indefiniteArticleBasedOnVowel(char a)122 static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
123 if (isVowel(a))
124 return "an";
125 return "a";
126 }
127
128 class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
129 // Call descriptors to find relevant calls
130 CallDescription VariantConstructor{CDM::CXXMethod,
131 {"std", "variant", "variant"}};
132 CallDescription VariantAssignmentOperator{CDM::CXXMethod,
133 {"std", "variant", "operator="}};
134 CallDescription StdGet{CDM::SimpleFunc, {"std", "get"}, 1, 1};
135
136 BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
137
138 public:
checkRegionChanges(ProgramStateRef State,const InvalidatedSymbols *,ArrayRef<const MemRegion * >,ArrayRef<const MemRegion * > Regions,const LocationContext *,const CallEvent * Call) const139 ProgramStateRef checkRegionChanges(ProgramStateRef State,
140 const InvalidatedSymbols *,
141 ArrayRef<const MemRegion *>,
142 ArrayRef<const MemRegion *> Regions,
143 const LocationContext *,
144 const CallEvent *Call) const {
145 if (!Call)
146 return State;
147
148 return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
149 *Call, State, Regions);
150 }
151
evalCall(const CallEvent & Call,CheckerContext & C) const152 bool evalCall(const CallEvent &Call, CheckerContext &C) const {
153 // Check if the call was not made from a system header. If it was then
154 // we do an early return because it is part of the implementation.
155 if (Call.isCalledFromSystemHeader())
156 return false;
157
158 if (StdGet.matches(Call))
159 return handleStdGetCall(Call, C);
160
161 // First check if a constructor call is happening. If it is a
162 // constructor call, check if it is an std::variant constructor call.
163 bool IsVariantConstructor =
164 isa<CXXConstructorCall>(Call) && VariantConstructor.matches(Call);
165 bool IsVariantAssignmentOperatorCall =
166 isa<CXXMemberOperatorCall>(Call) &&
167 VariantAssignmentOperator.matches(Call);
168
169 if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
170 if (Call.getNumArgs() == 0 && IsVariantConstructor) {
171 handleDefaultConstructor(cast<CXXConstructorCall>(&Call), C);
172 return true;
173 }
174
175 // FIXME Later this checker should be extended to handle constructors
176 // with multiple arguments.
177 if (Call.getNumArgs() != 1)
178 return false;
179
180 SVal ThisSVal;
181 if (IsVariantConstructor) {
182 const auto &AsConstructorCall = cast<CXXConstructorCall>(Call);
183 ThisSVal = AsConstructorCall.getCXXThisVal();
184 } else if (IsVariantAssignmentOperatorCall) {
185 const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Call);
186 ThisSVal = AsMemberOpCall.getCXXThisVal();
187 } else {
188 return false;
189 }
190
191 handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
192 return true;
193 }
194 return false;
195 }
196
197 private:
198 // The default constructed std::variant must be handled separately
199 // by default the std::variant is going to hold a default constructed instance
200 // of the first type of the possible types
handleDefaultConstructor(const CXXConstructorCall * ConstructorCall,CheckerContext & C) const201 void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
202 CheckerContext &C) const {
203 SVal ThisSVal = ConstructorCall->getCXXThisVal();
204
205 const auto *const ThisMemRegion = ThisSVal.getAsRegion();
206 if (!ThisMemRegion)
207 return;
208
209 std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
210 ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0);
211 if (!DefaultType)
212 return;
213
214 ProgramStateRef State = ConstructorCall->getState();
215 State = State->set<VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
216 C.addTransition(State);
217 }
218
handleStdGetCall(const CallEvent & Call,CheckerContext & C) const219 bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
220 ProgramStateRef State = Call.getState();
221
222 const auto &ArgType = Call.getArgSVal(0)
223 .getType(C.getASTContext())
224 ->getPointeeType()
225 .getTypePtr();
226 // We have to make sure that the argument is an std::variant.
227 // There is another std::get with std::pair argument
228 if (!isStdVariant(ArgType))
229 return false;
230
231 // Get the mem region of the argument std::variant and look up the type
232 // information that we know about it.
233 const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion();
234 const QualType *StoredType = State->get<VariantHeldTypeMap>(ArgMemRegion);
235 if (!StoredType)
236 return false;
237
238 const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr());
239 const FunctionDecl *FD = CE->getDirectCallee();
240 if (FD->getTemplateSpecializationArgs()->size() < 1)
241 return false;
242
243 const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
244 // std::get's first template parameter can be the type we want to get
245 // out of the std::variant or a natural number which is the position of
246 // the requested type in the argument type list of the std::variant's
247 // argument.
248 QualType RetrievedType;
249 switch (TypeOut.getKind()) {
250 case TemplateArgument::ArgKind::Type:
251 RetrievedType = TypeOut.getAsType();
252 break;
253 case TemplateArgument::ArgKind::Integral:
254 // In the natural number case we look up which type corresponds to the
255 // number.
256 if (std::optional<QualType> NthTemplate =
257 getNthTemplateTypeArgFromVariant(
258 ArgType, TypeOut.getAsIntegral().getSExtValue())) {
259 RetrievedType = *NthTemplate;
260 break;
261 }
262 [[fallthrough]];
263 default:
264 return false;
265 }
266
267 QualType RetrievedCanonicalType = RetrievedType.getCanonicalType();
268 QualType StoredCanonicalType = StoredType->getCanonicalType();
269 if (RetrievedCanonicalType == StoredCanonicalType)
270 return true;
271
272 ExplodedNode *ErrNode = C.generateNonFatalErrorNode();
273 if (!ErrNode)
274 return false;
275 llvm::SmallString<128> Str;
276 llvm::raw_svector_ostream OS(Str);
277 std::string StoredTypeName = StoredType->getAsString();
278 std::string RetrievedTypeName = RetrievedType.getAsString();
279 OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held "
280 << indefiniteArticleBasedOnVowel(StoredTypeName[0]) << " \'"
281 << StoredTypeName << "\', not "
282 << indefiniteArticleBasedOnVowel(RetrievedTypeName[0]) << " \'"
283 << RetrievedTypeName << "\'";
284 auto R = std::make_unique<PathSensitiveBugReport>(BadVariantType, OS.str(),
285 ErrNode);
286 C.emitReport(std::move(R));
287 return true;
288 }
289 };
290
shouldRegisterStdVariantChecker(clang::ento::CheckerManager const & mgr)291 bool clang::ento::shouldRegisterStdVariantChecker(
292 clang::ento::CheckerManager const &mgr) {
293 return true;
294 }
295
registerStdVariantChecker(clang::ento::CheckerManager & mgr)296 void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) {
297 mgr.registerChecker<StdVariantChecker>();
298 }
299