xref: /freebsd/contrib/llvm-project/clang/lib/Analysis/UnsafeBufferUsage.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- UnsafeBufferUsage.cpp - Replace pointers with modern C++ -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Analysis/Analyses/UnsafeBufferUsage.h"
10 #include "clang/AST/APValue.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/ASTTypeTraits.h"
13 #include "clang/AST/Attr.h"
14 #include "clang/AST/Decl.h"
15 #include "clang/AST/DeclCXX.h"
16 #include "clang/AST/DynamicRecursiveASTVisitor.h"
17 #include "clang/AST/Expr.h"
18 #include "clang/AST/FormatString.h"
19 #include "clang/AST/ParentMapContext.h"
20 #include "clang/AST/Stmt.h"
21 #include "clang/AST/StmtVisitor.h"
22 #include "clang/AST/Type.h"
23 #include "clang/ASTMatchers/LowLevelHelpers.h"
24 #include "clang/Analysis/Support/FixitUtil.h"
25 #include "clang/Basic/SourceLocation.h"
26 #include "clang/Lex/Lexer.h"
27 #include "clang/Lex/Preprocessor.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/STLFunctionalExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringRef.h"
33 #include <cstddef>
34 #include <optional>
35 #include <queue>
36 #include <set>
37 #include <sstream>
38 
39 using namespace clang;
40 
41 #ifndef NDEBUG
42 namespace {
43 class StmtDebugPrinter
44     : public ConstStmtVisitor<StmtDebugPrinter, std::string> {
45 public:
VisitStmt(const Stmt * S)46   std::string VisitStmt(const Stmt *S) { return S->getStmtClassName(); }
47 
VisitBinaryOperator(const BinaryOperator * BO)48   std::string VisitBinaryOperator(const BinaryOperator *BO) {
49     return "BinaryOperator(" + BO->getOpcodeStr().str() + ")";
50   }
51 
VisitUnaryOperator(const UnaryOperator * UO)52   std::string VisitUnaryOperator(const UnaryOperator *UO) {
53     return "UnaryOperator(" + UO->getOpcodeStr(UO->getOpcode()).str() + ")";
54   }
55 
VisitImplicitCastExpr(const ImplicitCastExpr * ICE)56   std::string VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
57     return "ImplicitCastExpr(" + std::string(ICE->getCastKindName()) + ")";
58   }
59 };
60 
61 // Returns a string of ancestor `Stmt`s of the given `DRE` in such a form:
62 // "DRE ==> parent-of-DRE ==> grandparent-of-DRE ==> ...".
getDREAncestorString(const DeclRefExpr * DRE,ASTContext & Ctx)63 static std::string getDREAncestorString(const DeclRefExpr *DRE,
64                                         ASTContext &Ctx) {
65   std::stringstream SS;
66   const Stmt *St = DRE;
67   StmtDebugPrinter StmtPriner;
68 
69   do {
70     SS << StmtPriner.Visit(St);
71 
72     DynTypedNodeList StParents = Ctx.getParents(*St);
73 
74     if (StParents.size() > 1)
75       return "unavailable due to multiple parents";
76     if (StParents.empty())
77       break;
78     St = StParents.begin()->get<Stmt>();
79     if (St)
80       SS << " ==> ";
81   } while (St);
82   return SS.str();
83 }
84 
85 } // namespace
86 #endif /* NDEBUG */
87 
88 namespace {
89 // Using a custom `FastMatcher` instead of ASTMatchers to achieve better
90 // performance. FastMatcher uses simple function `matches` to find if a node
91 // is a match, avoiding the dependency on the ASTMatchers framework which
92 // provide a nice abstraction, but incur big performance costs.
93 class FastMatcher {
94 public:
95   virtual bool matches(const DynTypedNode &DynNode, ASTContext &Ctx,
96                        const UnsafeBufferUsageHandler &Handler) = 0;
97   virtual ~FastMatcher() = default;
98 };
99 
100 class MatchResult {
101 
102 public:
getNodeAs(StringRef ID) const103   template <typename T> const T *getNodeAs(StringRef ID) const {
104     auto It = Nodes.find(ID);
105     if (It == Nodes.end()) {
106       return nullptr;
107     }
108     return It->second.get<T>();
109   }
110 
addNode(StringRef ID,const DynTypedNode & Node)111   void addNode(StringRef ID, const DynTypedNode &Node) { Nodes[ID] = Node; }
112 
113 private:
114   llvm::StringMap<DynTypedNode> Nodes;
115 };
116 } // namespace
117 
118 #define SIZED_CONTAINER_OR_VIEW_LIST                                           \
119   "span", "array", "vector", "basic_string_view", "basic_string",              \
120       "initializer_list",
121 
122 // A `RecursiveASTVisitor` that traverses all descendants of a given node "n"
123 // except for those belonging to a different callable of "n".
124 class MatchDescendantVisitor : public DynamicRecursiveASTVisitor {
125 public:
126   // Creates an AST visitor that matches `Matcher` on all
127   // descendants of a given node "n" except for the ones
128   // belonging to a different callable of "n".
MatchDescendantVisitor(ASTContext & Context,FastMatcher & Matcher,bool FindAll,bool ignoreUnevaluatedContext,const UnsafeBufferUsageHandler & NewHandler)129   MatchDescendantVisitor(ASTContext &Context, FastMatcher &Matcher,
130                          bool FindAll, bool ignoreUnevaluatedContext,
131                          const UnsafeBufferUsageHandler &NewHandler)
132       : Matcher(&Matcher), FindAll(FindAll), Matches(false),
133         ignoreUnevaluatedContext(ignoreUnevaluatedContext),
134         ActiveASTContext(&Context), Handler(&NewHandler) {
135     ShouldVisitTemplateInstantiations = true;
136     ShouldVisitImplicitCode = false; // TODO: let's ignore implicit code for now
137   }
138 
139   // Returns true if a match is found in a subtree of `DynNode`, which belongs
140   // to the same callable of `DynNode`.
findMatch(const DynTypedNode & DynNode)141   bool findMatch(const DynTypedNode &DynNode) {
142     Matches = false;
143     if (const Stmt *StmtNode = DynNode.get<Stmt>()) {
144       TraverseStmt(const_cast<Stmt *>(StmtNode));
145       return Matches;
146     }
147     return false;
148   }
149 
150   // The following are overriding methods from the base visitor class.
151   // They are public only to allow CRTP to work. They are *not *part
152   // of the public API of this class.
153 
154   // For the matchers so far used in safe buffers, we only need to match
155   // `Stmt`s.  To override more as needed.
156 
TraverseDecl(Decl * Node)157   bool TraverseDecl(Decl *Node) override {
158     if (!Node)
159       return true;
160     if (!match(*Node))
161       return false;
162     // To skip callables:
163     if (isa<FunctionDecl, BlockDecl, ObjCMethodDecl>(Node))
164       return true;
165     // Traverse descendants
166     return DynamicRecursiveASTVisitor::TraverseDecl(Node);
167   }
168 
TraverseGenericSelectionExpr(GenericSelectionExpr * Node)169   bool TraverseGenericSelectionExpr(GenericSelectionExpr *Node) override {
170     // These are unevaluated, except the result expression.
171     if (ignoreUnevaluatedContext)
172       return TraverseStmt(Node->getResultExpr());
173     return DynamicRecursiveASTVisitor::TraverseGenericSelectionExpr(Node);
174   }
175 
176   bool
TraverseUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr * Node)177   TraverseUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr *Node) override {
178     // Unevaluated context.
179     if (ignoreUnevaluatedContext)
180       return true;
181     return DynamicRecursiveASTVisitor::TraverseUnaryExprOrTypeTraitExpr(Node);
182   }
183 
TraverseTypeOfExprTypeLoc(TypeOfExprTypeLoc Node)184   bool TraverseTypeOfExprTypeLoc(TypeOfExprTypeLoc Node) override {
185     // Unevaluated context.
186     if (ignoreUnevaluatedContext)
187       return true;
188     return DynamicRecursiveASTVisitor::TraverseTypeOfExprTypeLoc(Node);
189   }
190 
TraverseDecltypeTypeLoc(DecltypeTypeLoc Node)191   bool TraverseDecltypeTypeLoc(DecltypeTypeLoc Node) override {
192     // Unevaluated context.
193     if (ignoreUnevaluatedContext)
194       return true;
195     return DynamicRecursiveASTVisitor::TraverseDecltypeTypeLoc(Node);
196   }
197 
TraverseCXXNoexceptExpr(CXXNoexceptExpr * Node)198   bool TraverseCXXNoexceptExpr(CXXNoexceptExpr *Node) override {
199     // Unevaluated context.
200     if (ignoreUnevaluatedContext)
201       return true;
202     return DynamicRecursiveASTVisitor::TraverseCXXNoexceptExpr(Node);
203   }
204 
TraverseCXXTypeidExpr(CXXTypeidExpr * Node)205   bool TraverseCXXTypeidExpr(CXXTypeidExpr *Node) override {
206     // Unevaluated context.
207     if (ignoreUnevaluatedContext)
208       return true;
209     return DynamicRecursiveASTVisitor::TraverseCXXTypeidExpr(Node);
210   }
211 
TraverseCXXDefaultInitExpr(CXXDefaultInitExpr * Node)212   bool TraverseCXXDefaultInitExpr(CXXDefaultInitExpr *Node) override {
213     if (!TraverseStmt(Node->getExpr()))
214       return false;
215     return DynamicRecursiveASTVisitor::TraverseCXXDefaultInitExpr(Node);
216   }
217 
TraverseStmt(Stmt * Node)218   bool TraverseStmt(Stmt *Node) override {
219     if (!Node)
220       return true;
221     if (!match(*Node))
222       return false;
223     return DynamicRecursiveASTVisitor::TraverseStmt(Node);
224   }
225 
226 private:
227   // Sets 'Matched' to true if 'Matcher' matches 'Node'
228   //
229   // Returns 'true' if traversal should continue after this function
230   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
match(const T & Node)231   template <typename T> bool match(const T &Node) {
232     if (Matcher->matches(DynTypedNode::create(Node), *ActiveASTContext,
233                          *Handler)) {
234       Matches = true;
235       if (!FindAll)
236         return false; // Abort as soon as a match is found.
237     }
238     return true;
239   }
240 
241   FastMatcher *const Matcher;
242   // When true, finds all matches. When false, finds the first match and stops.
243   const bool FindAll;
244   bool Matches;
245   bool ignoreUnevaluatedContext;
246   ASTContext *ActiveASTContext;
247   const UnsafeBufferUsageHandler *Handler;
248 };
249 
250 // Because we're dealing with raw pointers, let's define what we mean by that.
hasPointerType(const Expr & E)251 static bool hasPointerType(const Expr &E) {
252   return isa<PointerType>(E.getType().getCanonicalType());
253 }
254 
hasArrayType(const Expr & E)255 static bool hasArrayType(const Expr &E) {
256   return isa<ArrayType>(E.getType().getCanonicalType());
257 }
258 
259 static void
forEachDescendantEvaluatedStmt(const Stmt * S,ASTContext & Ctx,const UnsafeBufferUsageHandler & Handler,FastMatcher & Matcher)260 forEachDescendantEvaluatedStmt(const Stmt *S, ASTContext &Ctx,
261                                const UnsafeBufferUsageHandler &Handler,
262                                FastMatcher &Matcher) {
263   MatchDescendantVisitor Visitor(Ctx, Matcher, /*FindAll=*/true,
264                                  /*ignoreUnevaluatedContext=*/true, Handler);
265   Visitor.findMatch(DynTypedNode::create(*S));
266 }
267 
forEachDescendantStmt(const Stmt * S,ASTContext & Ctx,const UnsafeBufferUsageHandler & Handler,FastMatcher & Matcher)268 static void forEachDescendantStmt(const Stmt *S, ASTContext &Ctx,
269                                   const UnsafeBufferUsageHandler &Handler,
270                                   FastMatcher &Matcher) {
271   MatchDescendantVisitor Visitor(Ctx, Matcher, /*FindAll=*/true,
272                                  /*ignoreUnevaluatedContext=*/false, Handler);
273   Visitor.findMatch(DynTypedNode::create(*S));
274 }
275 
276 // Matches a `Stmt` node iff the node is in a safe-buffer opt-out region
notInSafeBufferOptOut(const Stmt & Node,const UnsafeBufferUsageHandler * Handler)277 static bool notInSafeBufferOptOut(const Stmt &Node,
278                                   const UnsafeBufferUsageHandler *Handler) {
279   return !Handler->isSafeBufferOptOut(Node.getBeginLoc());
280 }
281 
282 static bool
ignoreUnsafeBufferInContainer(const Stmt & Node,const UnsafeBufferUsageHandler * Handler)283 ignoreUnsafeBufferInContainer(const Stmt &Node,
284                               const UnsafeBufferUsageHandler *Handler) {
285   return Handler->ignoreUnsafeBufferInContainer(Node.getBeginLoc());
286 }
287 
ignoreUnsafeLibcCall(const ASTContext & Ctx,const Stmt & Node,const UnsafeBufferUsageHandler * Handler)288 static bool ignoreUnsafeLibcCall(const ASTContext &Ctx, const Stmt &Node,
289                                  const UnsafeBufferUsageHandler *Handler) {
290   if (Ctx.getLangOpts().CPlusPlus)
291     return Handler->ignoreUnsafeBufferInLibcCall(Node.getBeginLoc());
292   return true; /* Only warn about libc calls for C++ */
293 }
294 
295 // Finds any expression 'e' such that `OnResult`
296 // matches 'e' and 'e' is in an Unspecified Lvalue Context.
findStmtsInUnspecifiedLvalueContext(const Stmt * S,const llvm::function_ref<void (const Expr *)> OnResult)297 static void findStmtsInUnspecifiedLvalueContext(
298     const Stmt *S, const llvm::function_ref<void(const Expr *)> OnResult) {
299   if (const auto *CE = dyn_cast<ImplicitCastExpr>(S);
300       CE && CE->getCastKind() == CastKind::CK_LValueToRValue)
301     OnResult(CE->getSubExpr());
302   if (const auto *BO = dyn_cast<BinaryOperator>(S);
303       BO && BO->getOpcode() == BO_Assign)
304     OnResult(BO->getLHS());
305 }
306 
307 // Finds any expression `e` such that `InnerMatcher` matches `e` and
308 // `e` is in an Unspecified Pointer Context (UPC).
findStmtsInUnspecifiedPointerContext(const Stmt * S,llvm::function_ref<void (const Stmt *)> InnerMatcher)309 static void findStmtsInUnspecifiedPointerContext(
310     const Stmt *S, llvm::function_ref<void(const Stmt *)> InnerMatcher) {
311   // A UPC can be
312   // 1. an argument of a function call (except the callee has [[unsafe_...]]
313   //    attribute), or
314   // 2. the operand of a pointer-to-(integer or bool) cast operation; or
315   // 3. the operand of a comparator operation; or
316   // 4. the operand of a pointer subtraction operation
317   //    (i.e., computing the distance between two pointers); or ...
318 
319   if (auto *CE = dyn_cast<CallExpr>(S)) {
320     if (const auto *FnDecl = CE->getDirectCallee();
321         FnDecl && FnDecl->hasAttr<UnsafeBufferUsageAttr>())
322       return;
323     ast_matchers::matchEachArgumentWithParamType(
324         *CE, [&InnerMatcher](QualType Type, const Expr *Arg) {
325           if (Type->isAnyPointerType())
326             InnerMatcher(Arg);
327         });
328   }
329 
330   if (auto *CE = dyn_cast<CastExpr>(S)) {
331     if (CE->getCastKind() != CastKind::CK_PointerToIntegral &&
332         CE->getCastKind() != CastKind::CK_PointerToBoolean)
333       return;
334     if (!hasPointerType(*CE->getSubExpr()))
335       return;
336     InnerMatcher(CE->getSubExpr());
337   }
338 
339   // Pointer comparison operator.
340   if (const auto *BO = dyn_cast<BinaryOperator>(S);
341       BO && (BO->getOpcode() == BO_EQ || BO->getOpcode() == BO_NE ||
342              BO->getOpcode() == BO_LT || BO->getOpcode() == BO_LE ||
343              BO->getOpcode() == BO_GT || BO->getOpcode() == BO_GE)) {
344     auto *LHS = BO->getLHS();
345     if (hasPointerType(*LHS))
346       InnerMatcher(LHS);
347 
348     auto *RHS = BO->getRHS();
349     if (hasPointerType(*RHS))
350       InnerMatcher(RHS);
351   }
352 
353   // Pointer subtractions.
354   if (const auto *BO = dyn_cast<BinaryOperator>(S);
355       BO && BO->getOpcode() == BO_Sub && hasPointerType(*BO->getLHS()) &&
356       hasPointerType(*BO->getRHS())) {
357     // Note that here we need both LHS and RHS to be
358     // pointer. Then the inner matcher can match any of
359     // them:
360     InnerMatcher(BO->getLHS());
361     InnerMatcher(BO->getRHS());
362   }
363   // FIXME: any more cases? (UPC excludes the RHS of an assignment.  For now
364   // we don't have to check that.)
365 }
366 
367 // Finds statements in unspecified untyped context i.e. any expression 'e' such
368 // that `InnerMatcher` matches 'e' and 'e' is in an unspecified untyped context
369 // (i.e the expression 'e' isn't evaluated to an RValue). For example, consider
370 // the following code:
371 //    int *p = new int[4];
372 //    int *q = new int[4];
373 //    if ((p = q)) {}
374 //    p = q;
375 // The expression `p = q` in the conditional of the `if` statement
376 // `if ((p = q))` is evaluated as an RValue, whereas the expression `p = q;`
377 // in the assignment statement is in an untyped context.
findStmtsInUnspecifiedUntypedContext(const Stmt * S,llvm::function_ref<void (const Stmt *)> InnerMatcher)378 static void findStmtsInUnspecifiedUntypedContext(
379     const Stmt *S, llvm::function_ref<void(const Stmt *)> InnerMatcher) {
380   // An unspecified context can be
381   // 1. A compound statement,
382   // 2. The body of an if statement
383   // 3. Body of a loop
384   if (auto *CS = dyn_cast<CompoundStmt>(S)) {
385     for (auto *Child : CS->body())
386       InnerMatcher(Child);
387   }
388   if (auto *IfS = dyn_cast<IfStmt>(S)) {
389     if (IfS->getThen())
390       InnerMatcher(IfS->getThen());
391     if (IfS->getElse())
392       InnerMatcher(IfS->getElse());
393   }
394   // FIXME: Handle loop bodies.
395 }
396 
397 // Returns true iff integer E1 is equivalent to integer E2.
398 //
399 // For now we only support such expressions:
400 //    expr := DRE | const-value | expr BO expr
401 //    BO   := '*' | '+'
402 //
403 // FIXME: We can reuse the expression comparator of the interop analysis after
404 // it has been upstreamed.
405 static bool areEqualIntegers(const Expr *E1, const Expr *E2, ASTContext &Ctx);
areEqualIntegralBinaryOperators(const BinaryOperator * E1,const Expr * E2_LHS,BinaryOperatorKind BOP,const Expr * E2_RHS,ASTContext & Ctx)406 static bool areEqualIntegralBinaryOperators(const BinaryOperator *E1,
407                                             const Expr *E2_LHS,
408                                             BinaryOperatorKind BOP,
409                                             const Expr *E2_RHS,
410                                             ASTContext &Ctx) {
411   if (E1->getOpcode() == BOP) {
412     switch (BOP) {
413       // Commutative operators:
414     case BO_Mul:
415     case BO_Add:
416       return (areEqualIntegers(E1->getLHS(), E2_LHS, Ctx) &&
417               areEqualIntegers(E1->getRHS(), E2_RHS, Ctx)) ||
418              (areEqualIntegers(E1->getLHS(), E2_RHS, Ctx) &&
419               areEqualIntegers(E1->getRHS(), E2_LHS, Ctx));
420     default:
421       return false;
422     }
423   }
424   return false;
425 }
426 
areEqualIntegers(const Expr * E1,const Expr * E2,ASTContext & Ctx)427 static bool areEqualIntegers(const Expr *E1, const Expr *E2, ASTContext &Ctx) {
428   E1 = E1->IgnoreParenImpCasts();
429   E2 = E2->IgnoreParenImpCasts();
430   if (!E1->getType()->isIntegerType() || E1->getType() != E2->getType())
431     return false;
432 
433   Expr::EvalResult ER1, ER2;
434 
435   // If both are constants:
436   if (E1->EvaluateAsInt(ER1, Ctx) && E2->EvaluateAsInt(ER2, Ctx))
437     return ER1.Val.getInt() == ER2.Val.getInt();
438 
439   // Otherwise, they should have identical stmt kind:
440   if (E1->getStmtClass() != E2->getStmtClass())
441     return false;
442   switch (E1->getStmtClass()) {
443   case Stmt::DeclRefExprClass:
444     return cast<DeclRefExpr>(E1)->getDecl() == cast<DeclRefExpr>(E2)->getDecl();
445   case Stmt::BinaryOperatorClass: {
446     auto BO2 = cast<BinaryOperator>(E2);
447     return areEqualIntegralBinaryOperators(cast<BinaryOperator>(E1),
448                                            BO2->getLHS(), BO2->getOpcode(),
449                                            BO2->getRHS(), Ctx);
450   }
451   default:
452     return false;
453   }
454 }
455 
456 // Providing that `Ptr` is a pointer and `Size` is an unsigned-integral
457 // expression, returns true iff they follow one of the following safe
458 // patterns:
459 //  1. Ptr is `DRE.data()` and Size is `DRE.size()`, where DRE is a hardened
460 //     container or view;
461 //
462 //  2. Ptr is `a` and Size is `n`, where `a` is of an array-of-T with constant
463 //     size `n`;
464 //
465 //  3. Ptr is `&var` and Size is `1`; or
466 //     Ptr is `std::addressof(...)` and Size is `1`;
467 //
468 //  4. Size is `0`;
isPtrBufferSafe(const Expr * Ptr,const Expr * Size,ASTContext & Ctx)469 static bool isPtrBufferSafe(const Expr *Ptr, const Expr *Size,
470                             ASTContext &Ctx) {
471   // Pattern 1:
472   if (auto *MCEPtr = dyn_cast<CXXMemberCallExpr>(Ptr->IgnoreParenImpCasts()))
473     if (auto *MCESize =
474             dyn_cast<CXXMemberCallExpr>(Size->IgnoreParenImpCasts())) {
475       auto *DREOfPtr = dyn_cast<DeclRefExpr>(
476           MCEPtr->getImplicitObjectArgument()->IgnoreParenImpCasts());
477       auto *DREOfSize = dyn_cast<DeclRefExpr>(
478           MCESize->getImplicitObjectArgument()->IgnoreParenImpCasts());
479 
480       if (!DREOfPtr || !DREOfSize)
481         return false; // not in safe pattern
482       // We need to make sure 'a' is identical to 'b' for 'a.data()' and
483       // 'b.size()' otherwise we do not know they match:
484       if (DREOfPtr->getDecl() != DREOfSize->getDecl())
485         return false;
486       if (MCEPtr->getMethodDecl()->getName() != "data")
487         return false;
488       // `MCEPtr->getRecordDecl()` must be non-null as `DREOfPtr` is non-null:
489       if (!MCEPtr->getRecordDecl()->isInStdNamespace())
490         return false;
491 
492       auto *ObjII = MCEPtr->getRecordDecl()->getIdentifier();
493 
494       if (!ObjII)
495         return false;
496 
497       bool AcceptSizeBytes = Ptr->getType()->getPointeeType()->isCharType();
498 
499       if (!((AcceptSizeBytes &&
500              MCESize->getMethodDecl()->getName() == "size_bytes") ||
501             // Note here the pointer must be a pointer-to-char type unless there
502             // is explicit casting.  If there is explicit casting, this branch
503             // is unreachable. Thus, at this branch "size" and "size_bytes" are
504             // equivalent as the pointer is a char pointer:
505             MCESize->getMethodDecl()->getName() == "size"))
506         return false;
507 
508       return llvm::is_contained({SIZED_CONTAINER_OR_VIEW_LIST},
509                                 ObjII->getName());
510     }
511 
512   Expr::EvalResult ER;
513 
514   // Pattern 2-4:
515   if (Size->EvaluateAsInt(ER, Ctx)) {
516     // Pattern 2:
517     if (auto *DRE = dyn_cast<DeclRefExpr>(Ptr->IgnoreParenImpCasts())) {
518       if (auto *CAT = Ctx.getAsConstantArrayType(DRE->getType())) {
519         llvm::APSInt SizeInt = ER.Val.getInt();
520 
521         return llvm::APSInt::compareValues(
522                    SizeInt, llvm::APSInt(CAT->getSize(), true)) == 0;
523       }
524       return false;
525     }
526 
527     // Pattern 3:
528     if (ER.Val.getInt().isOne()) {
529       if (auto *UO = dyn_cast<UnaryOperator>(Ptr->IgnoreParenImpCasts()))
530         return UO && UO->getOpcode() == UnaryOperator::Opcode::UO_AddrOf;
531       if (auto *CE = dyn_cast<CallExpr>(Ptr->IgnoreParenImpCasts())) {
532         auto *FnDecl = CE->getDirectCallee();
533 
534         return FnDecl && FnDecl->getNameAsString() == "addressof" &&
535                FnDecl->isInStdNamespace();
536       }
537       return false;
538     }
539     // Pattern 4:
540     if (ER.Val.getInt().isZero())
541       return true;
542   }
543   return false;
544 }
545 
546 // Given a two-param std::span construct call, matches iff the call has the
547 // following forms:
548 //   1. `std::span<T>{new T[n], n}`, where `n` is a literal or a DRE
549 //   2. `std::span<T>{new T, 1}`
550 //   3. `std::span<T>{ (char *)f(args), args[N] * arg*[M]}`, where
551 //       `f` is a function with attribute `alloc_size(N, M)`;
552 //       `args` represents the list of arguments;
553 //       `N, M` are parameter indexes to the allocating element number and size.
554 //        Sometimes, there is only one parameter index representing the total
555 //        size.
556 //   4. `std::span<T>{x.begin(), x.end()}` where `x` is an object in the
557 //      SIZED_CONTAINER_OR_VIEW_LIST.
558 //   5. `isPtrBufferSafe` returns true for the two arguments of the span
559 //      constructor
isSafeSpanTwoParamConstruct(const CXXConstructExpr & Node,ASTContext & Ctx)560 static bool isSafeSpanTwoParamConstruct(const CXXConstructExpr &Node,
561                                         ASTContext &Ctx) {
562   assert(Node.getNumArgs() == 2 &&
563          "expecting a two-parameter std::span constructor");
564   const Expr *Arg0 = Node.getArg(0)->IgnoreParenImpCasts();
565   const Expr *Arg1 = Node.getArg(1)->IgnoreParenImpCasts();
566   auto HaveEqualConstantValues = [&Ctx](const Expr *E0, const Expr *E1) {
567     if (auto E0CV = E0->getIntegerConstantExpr(Ctx))
568       if (auto E1CV = E1->getIntegerConstantExpr(Ctx)) {
569         return llvm::APSInt::compareValues(*E0CV, *E1CV) == 0;
570       }
571     return false;
572   };
573   auto AreSameDRE = [](const Expr *E0, const Expr *E1) {
574     if (auto *DRE0 = dyn_cast<DeclRefExpr>(E0))
575       if (auto *DRE1 = dyn_cast<DeclRefExpr>(E1)) {
576         return DRE0->getDecl() == DRE1->getDecl();
577       }
578     return false;
579   };
580   std::optional<llvm::APSInt> Arg1CV = Arg1->getIntegerConstantExpr(Ctx);
581 
582   if (Arg1CV && Arg1CV->isZero())
583     // Check form 5:
584     return true;
585 
586   // Check forms 1-2:
587   switch (Arg0->getStmtClass()) {
588   case Stmt::CXXNewExprClass:
589     if (auto Size = cast<CXXNewExpr>(Arg0)->getArraySize()) {
590       // Check form 1:
591       return AreSameDRE((*Size)->IgnoreImplicit(), Arg1) ||
592              HaveEqualConstantValues(*Size, Arg1);
593     }
594     // TODO: what's placeholder type? avoid it for now.
595     if (!cast<CXXNewExpr>(Arg0)->hasPlaceholderType()) {
596       // Check form 2:
597       return Arg1CV && Arg1CV->isOne();
598     }
599     break;
600   default:
601     break;
602   }
603 
604   // Check form 3:
605   if (auto CCast = dyn_cast<CStyleCastExpr>(Arg0)) {
606     if (!CCast->getType()->isPointerType())
607       return false;
608 
609     QualType PteTy = CCast->getType()->getPointeeType();
610 
611     if (!(PteTy->isConstantSizeType() && Ctx.getTypeSizeInChars(PteTy).isOne()))
612       return false;
613 
614     if (const auto *Call = dyn_cast<CallExpr>(CCast->getSubExpr())) {
615       if (const FunctionDecl *FD = Call->getDirectCallee())
616         if (auto *AllocAttr = FD->getAttr<AllocSizeAttr>()) {
617           const Expr *EleSizeExpr =
618               Call->getArg(AllocAttr->getElemSizeParam().getASTIndex());
619           // NumElemIdx is invalid if AllocSizeAttr has 1 argument:
620           ParamIdx NumElemIdx = AllocAttr->getNumElemsParam();
621 
622           if (!NumElemIdx.isValid())
623             return areEqualIntegers(Arg1, EleSizeExpr, Ctx);
624 
625           const Expr *NumElesExpr = Call->getArg(NumElemIdx.getASTIndex());
626 
627           if (auto BO = dyn_cast<BinaryOperator>(Arg1))
628             return areEqualIntegralBinaryOperators(BO, NumElesExpr, BO_Mul,
629                                                    EleSizeExpr, Ctx);
630         }
631     }
632   }
633   // Check form 4:
634   auto IsMethodCallToSizedObject = [](const Stmt *Node, StringRef MethodName) {
635     if (const auto *MC = dyn_cast<CXXMemberCallExpr>(Node)) {
636       const auto *MD = MC->getMethodDecl();
637       const auto *RD = MC->getRecordDecl();
638 
639       if (RD && MD)
640         if (auto *II = RD->getDeclName().getAsIdentifierInfo();
641             II && RD->isInStdNamespace())
642           return llvm::is_contained({SIZED_CONTAINER_OR_VIEW_LIST},
643                                     II->getName()) &&
644                  MD->getName() == MethodName;
645     }
646     return false;
647   };
648 
649   if (IsMethodCallToSizedObject(Arg0, "begin") &&
650       IsMethodCallToSizedObject(Arg1, "end"))
651     return AreSameDRE(
652         // We know Arg0 and Arg1 are `CXXMemberCallExpr`s:
653         cast<CXXMemberCallExpr>(Arg0)
654             ->getImplicitObjectArgument()
655             ->IgnoreParenImpCasts(),
656         cast<CXXMemberCallExpr>(Arg1)
657             ->getImplicitObjectArgument()
658             ->IgnoreParenImpCasts());
659 
660   // Check 5:
661   return isPtrBufferSafe(Arg0, Arg1, Ctx);
662 }
663 
isSafeArraySubscript(const ArraySubscriptExpr & Node,const ASTContext & Ctx)664 static bool isSafeArraySubscript(const ArraySubscriptExpr &Node,
665                                  const ASTContext &Ctx) {
666   // FIXME: Proper solution:
667   //  - refactor Sema::CheckArrayAccess
668   //    - split safe/OOB/unknown decision logic from diagnostics emitting code
669   //    -  e. g. "Try harder to find a NamedDecl to point at in the note."
670   //    already duplicated
671   //  - call both from Sema and from here
672 
673   uint64_t limit;
674   if (const auto *CATy =
675           dyn_cast<ConstantArrayType>(Node.getBase()
676                                           ->IgnoreParenImpCasts()
677                                           ->getType()
678                                           ->getUnqualifiedDesugaredType())) {
679     limit = CATy->getLimitedSize();
680   } else if (const auto *SLiteral = dyn_cast<clang::StringLiteral>(
681                  Node.getBase()->IgnoreParenImpCasts())) {
682     limit = SLiteral->getLength() + 1;
683   } else {
684     return false;
685   }
686 
687   Expr::EvalResult EVResult;
688   const Expr *IndexExpr = Node.getIdx();
689   if (!IndexExpr->isValueDependent() &&
690       IndexExpr->EvaluateAsInt(EVResult, Ctx)) {
691     llvm::APSInt ArrIdx = EVResult.Val.getInt();
692     // FIXME: ArrIdx.isNegative() we could immediately emit an error as that's a
693     // bug
694     if (ArrIdx.isNonNegative() && ArrIdx.getLimitedValue() < limit)
695       return true;
696   } else if (const auto *BE = dyn_cast<BinaryOperator>(IndexExpr)) {
697     // For an integer expression `e` and an integer constant `n`, `e & n` and
698     // `n & e` are bounded by `n`:
699     if (BE->getOpcode() != BO_And && BE->getOpcode() != BO_Rem)
700       return false;
701 
702     const Expr *LHS = BE->getLHS();
703     const Expr *RHS = BE->getRHS();
704 
705     if (BE->getOpcode() == BO_Rem) {
706       // If n is a negative number, then n % const can be greater than const
707       if (!LHS->getType()->isUnsignedIntegerType()) {
708         return false;
709       }
710 
711       if (!RHS->isValueDependent() && RHS->EvaluateAsInt(EVResult, Ctx)) {
712         llvm::APSInt result = EVResult.Val.getInt();
713         if (result.isNonNegative() && result.getLimitedValue() <= limit)
714           return true;
715       }
716 
717       return false;
718     }
719 
720     if ((!LHS->isValueDependent() &&
721          LHS->EvaluateAsInt(EVResult, Ctx)) || // case: `n & e`
722         (!RHS->isValueDependent() &&
723          RHS->EvaluateAsInt(EVResult, Ctx))) { // `e & n`
724       llvm::APSInt result = EVResult.Val.getInt();
725       if (result.isNonNegative() && result.getLimitedValue() < limit)
726         return true;
727     }
728     return false;
729   }
730   return false;
731 }
732 
733 namespace libc_func_matchers {
734 // Under `libc_func_matchers`, define a set of matchers that match unsafe
735 // functions in libc and unsafe calls to them.
736 
737 //  A tiny parser to strip off common prefix and suffix of libc function names
738 //  in real code.
739 //
740 //  Given a function name, `matchName` returns `CoreName` according to the
741 //  following grammar:
742 //
743 //  LibcName     := CoreName | CoreName + "_s"
744 //  MatchingName := "__builtin_" + LibcName              |
745 //                  "__builtin___" + LibcName + "_chk"   |
746 //                  "__asan_" + LibcName
747 //
748 struct LibcFunNamePrefixSuffixParser {
matchNamelibc_func_matchers::LibcFunNamePrefixSuffixParser749   StringRef matchName(StringRef FunName, bool isBuiltin) {
750     // Try to match __builtin_:
751     if (isBuiltin && FunName.starts_with("__builtin_"))
752       // Then either it is __builtin_LibcName or __builtin___LibcName_chk or
753       // no match:
754       return matchLibcNameOrBuiltinChk(
755           FunName.drop_front(10 /* truncate "__builtin_" */));
756     // Try to match __asan_:
757     if (FunName.starts_with("__asan_"))
758       return matchLibcName(FunName.drop_front(7 /* truncate of "__asan_" */));
759     return matchLibcName(FunName);
760   }
761 
762   // Parameter `Name` is the substring after stripping off the prefix
763   // "__builtin_".
matchLibcNameOrBuiltinChklibc_func_matchers::LibcFunNamePrefixSuffixParser764   StringRef matchLibcNameOrBuiltinChk(StringRef Name) {
765     if (Name.starts_with("__") && Name.ends_with("_chk"))
766       return matchLibcName(
767           Name.drop_front(2).drop_back(4) /* truncate "__" and "_chk" */);
768     return matchLibcName(Name);
769   }
770 
matchLibcNamelibc_func_matchers::LibcFunNamePrefixSuffixParser771   StringRef matchLibcName(StringRef Name) {
772     if (Name.ends_with("_s"))
773       return Name.drop_back(2 /* truncate "_s" */);
774     return Name;
775   }
776 };
777 
778 // A pointer type expression is known to be null-terminated, if it has the
779 // form: E.c_str(), for any expression E of `std::string` type.
isNullTermPointer(const Expr * Ptr)780 static bool isNullTermPointer(const Expr *Ptr) {
781   if (isa<clang::StringLiteral>(Ptr->IgnoreParenImpCasts()))
782     return true;
783   if (isa<PredefinedExpr>(Ptr->IgnoreParenImpCasts()))
784     return true;
785   if (auto *MCE = dyn_cast<CXXMemberCallExpr>(Ptr->IgnoreParenImpCasts())) {
786     const CXXMethodDecl *MD = MCE->getMethodDecl();
787     const CXXRecordDecl *RD = MCE->getRecordDecl()->getCanonicalDecl();
788 
789     if (MD && RD && RD->isInStdNamespace() && MD->getIdentifier())
790       if (MD->getName() == "c_str" && RD->getName() == "basic_string")
791         return true;
792   }
793   return false;
794 }
795 
796 // Return true iff at least one of following cases holds:
797 //  1. Format string is a literal and there is an unsafe pointer argument
798 //     corresponding to an `s` specifier;
799 //  2. Format string is not a literal and there is least an unsafe pointer
800 //     argument (including the formatter argument).
801 //
802 // `UnsafeArg` is the output argument that will be set only if this function
803 // returns true.
hasUnsafeFormatOrSArg(const CallExpr * Call,const Expr * & UnsafeArg,const unsigned FmtArgIdx,ASTContext & Ctx,bool isKprintf=false)804 static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg,
805                                   const unsigned FmtArgIdx, ASTContext &Ctx,
806                                   bool isKprintf = false) {
807   class StringFormatStringHandler
808       : public analyze_format_string::FormatStringHandler {
809     const CallExpr *Call;
810     unsigned FmtArgIdx;
811     const Expr *&UnsafeArg;
812 
813   public:
814     StringFormatStringHandler(const CallExpr *Call, unsigned FmtArgIdx,
815                               const Expr *&UnsafeArg)
816         : Call(Call), FmtArgIdx(FmtArgIdx), UnsafeArg(UnsafeArg) {}
817 
818     bool HandlePrintfSpecifier(const analyze_printf::PrintfSpecifier &FS,
819                                const char *startSpecifier,
820                                unsigned specifierLen,
821                                const TargetInfo &Target) override {
822       if (FS.getConversionSpecifier().getKind() ==
823           analyze_printf::PrintfConversionSpecifier::sArg) {
824         unsigned ArgIdx = FS.getPositionalArgIndex() + FmtArgIdx;
825 
826         if (0 < ArgIdx && ArgIdx < Call->getNumArgs())
827           if (!isNullTermPointer(Call->getArg(ArgIdx))) {
828             UnsafeArg = Call->getArg(ArgIdx); // output
829             // returning false stops parsing immediately
830             return false;
831           }
832       }
833       return true; // continue parsing
834     }
835   };
836 
837   const Expr *Fmt = Call->getArg(FmtArgIdx);
838 
839   if (auto *SL = dyn_cast<clang::StringLiteral>(Fmt->IgnoreParenImpCasts())) {
840     StringRef FmtStr;
841 
842     if (SL->getCharByteWidth() == 1)
843       FmtStr = SL->getString();
844     else if (auto EvaledFmtStr = SL->tryEvaluateString(Ctx))
845       FmtStr = *EvaledFmtStr;
846     else
847       goto CHECK_UNSAFE_PTR;
848 
849     StringFormatStringHandler Handler(Call, FmtArgIdx, UnsafeArg);
850 
851     return analyze_format_string::ParsePrintfString(
852         Handler, FmtStr.begin(), FmtStr.end(), Ctx.getLangOpts(),
853         Ctx.getTargetInfo(), isKprintf);
854   }
855 CHECK_UNSAFE_PTR:
856   // If format is not a string literal, we cannot analyze the format string.
857   // In this case, this call is considered unsafe if at least one argument
858   // (including the format argument) is unsafe pointer.
859   return llvm::any_of(
860       llvm::make_range(Call->arg_begin() + FmtArgIdx, Call->arg_end()),
861       [&UnsafeArg](const Expr *Arg) -> bool {
862         if (Arg->getType()->isPointerType() && !isNullTermPointer(Arg)) {
863           UnsafeArg = Arg;
864           return true;
865         }
866         return false;
867       });
868 }
869 
870 // Matches a FunctionDecl node such that
871 //  1. It's name, after stripping off predefined prefix and suffix, is
872 //     `CoreName`; and
873 //  2. `CoreName` or `CoreName[str/wcs]` is one of the `PredefinedNames`, which
874 //     is a set of libc function names.
875 //
876 //  Note: For predefined prefix and suffix, see `LibcFunNamePrefixSuffixParser`.
877 //  The notation `CoreName[str/wcs]` means a new name obtained from replace
878 //  string "wcs" with "str" in `CoreName`.
isPredefinedUnsafeLibcFunc(const FunctionDecl & Node)879 static bool isPredefinedUnsafeLibcFunc(const FunctionDecl &Node) {
880   static std::unique_ptr<std::set<StringRef>> PredefinedNames = nullptr;
881   if (!PredefinedNames)
882     PredefinedNames =
883         std::make_unique<std::set<StringRef>, std::set<StringRef>>({
884             // numeric conversion:
885             "atof",
886             "atoi",
887             "atol",
888             "atoll",
889             "strtol",
890             "strtoll",
891             "strtoul",
892             "strtoull",
893             "strtof",
894             "strtod",
895             "strtold",
896             "strtoimax",
897             "strtoumax",
898             // "strfromf",  "strfromd", "strfroml", // C23?
899             // string manipulation:
900             "strcpy",
901             "strncpy",
902             "strlcpy",
903             "strcat",
904             "strncat",
905             "strlcat",
906             "strxfrm",
907             "strdup",
908             "strndup",
909             // string examination:
910             "strlen",
911             "strnlen",
912             "strcmp",
913             "strncmp",
914             "stricmp",
915             "strcasecmp",
916             "strcoll",
917             "strchr",
918             "strrchr",
919             "strspn",
920             "strcspn",
921             "strpbrk",
922             "strstr",
923             "strtok",
924             // "mem-" functions
925             "memchr",
926             "wmemchr",
927             "memcmp",
928             "wmemcmp",
929             "memcpy",
930             "memccpy",
931             "mempcpy",
932             "wmemcpy",
933             "memmove",
934             "wmemmove",
935             "memset",
936             "wmemset",
937             // IO:
938             "fread",
939             "fwrite",
940             "fgets",
941             "fgetws",
942             "gets",
943             "fputs",
944             "fputws",
945             "puts",
946             // others
947             "strerror_s",
948             "strerror_r",
949             "bcopy",
950             "bzero",
951             "bsearch",
952             "qsort",
953         });
954 
955   auto *II = Node.getIdentifier();
956 
957   if (!II)
958     return false;
959 
960   StringRef Name = LibcFunNamePrefixSuffixParser().matchName(
961       II->getName(), Node.getBuiltinID());
962 
963   // Match predefined names:
964   if (PredefinedNames->find(Name) != PredefinedNames->end())
965     return true;
966 
967   std::string NameWCS = Name.str();
968   size_t WcsPos = NameWCS.find("wcs");
969 
970   while (WcsPos != std::string::npos) {
971     NameWCS[WcsPos++] = 's';
972     NameWCS[WcsPos++] = 't';
973     NameWCS[WcsPos++] = 'r';
974     WcsPos = NameWCS.find("wcs", WcsPos);
975   }
976   if (PredefinedNames->find(NameWCS) != PredefinedNames->end())
977     return true;
978   // All `scanf` functions are unsafe (including `sscanf`, `vsscanf`, etc.. They
979   // all should end with "scanf"):
980   return Name.ends_with("scanf");
981 }
982 
983 // Match a call to one of the `v*printf` functions taking `va_list`.  We cannot
984 // check safety for these functions so they should be changed to their
985 // non-va_list versions.
isUnsafeVaListPrintfFunc(const FunctionDecl & Node)986 static bool isUnsafeVaListPrintfFunc(const FunctionDecl &Node) {
987   auto *II = Node.getIdentifier();
988 
989   if (!II)
990     return false;
991 
992   StringRef Name = LibcFunNamePrefixSuffixParser().matchName(
993       II->getName(), Node.getBuiltinID());
994 
995   if (!Name.ends_with("printf"))
996     return false; // neither printf nor scanf
997   return Name.starts_with("v");
998 }
999 
1000 // Matches a call to one of the `sprintf` functions as they are always unsafe
1001 // and should be changed to `snprintf`.
isUnsafeSprintfFunc(const FunctionDecl & Node)1002 static bool isUnsafeSprintfFunc(const FunctionDecl &Node) {
1003   auto *II = Node.getIdentifier();
1004 
1005   if (!II)
1006     return false;
1007 
1008   StringRef Name = LibcFunNamePrefixSuffixParser().matchName(
1009       II->getName(), Node.getBuiltinID());
1010 
1011   if (!Name.ends_with("printf") ||
1012       // Let `isUnsafeVaListPrintfFunc` check for cases with va-list:
1013       Name.starts_with("v"))
1014     return false;
1015 
1016   StringRef Prefix = Name.drop_back(6);
1017 
1018   if (Prefix.ends_with("w"))
1019     Prefix = Prefix.drop_back(1);
1020   return Prefix == "s";
1021 }
1022 
1023 // Match function declarations of `printf`, `fprintf`, `snprintf` and their wide
1024 // character versions.  Calls to these functions can be safe if their arguments
1025 // are carefully made safe.
isNormalPrintfFunc(const FunctionDecl & Node)1026 static bool isNormalPrintfFunc(const FunctionDecl &Node) {
1027   auto *II = Node.getIdentifier();
1028 
1029   if (!II)
1030     return false;
1031 
1032   StringRef Name = LibcFunNamePrefixSuffixParser().matchName(
1033       II->getName(), Node.getBuiltinID());
1034 
1035   if (!Name.ends_with("printf") || Name.starts_with("v"))
1036     return false;
1037 
1038   StringRef Prefix = Name.drop_back(6);
1039 
1040   if (Prefix.ends_with("w"))
1041     Prefix = Prefix.drop_back(1);
1042 
1043   return Prefix.empty() || Prefix == "k" || Prefix == "f" || Prefix == "sn";
1044 }
1045 
1046 // This matcher requires that it is known that the callee `isNormalPrintf`.
1047 // Then if the format string is a string literal, this matcher matches when at
1048 // least one string argument is unsafe. If the format is not a string literal,
1049 // this matcher matches when at least one pointer type argument is unsafe.
hasUnsafePrintfStringArg(const CallExpr & Node,ASTContext & Ctx,MatchResult & Result,llvm::StringRef Tag)1050 static bool hasUnsafePrintfStringArg(const CallExpr &Node, ASTContext &Ctx,
1051                                      MatchResult &Result, llvm::StringRef Tag) {
1052   // Determine what printf it is by examining formal parameters:
1053   const FunctionDecl *FD = Node.getDirectCallee();
1054 
1055   assert(FD && "It should have been checked that FD is non-null.");
1056 
1057   unsigned NumParms = FD->getNumParams();
1058 
1059   if (NumParms < 1)
1060     return false; // possibly some user-defined printf function
1061 
1062   QualType FirstParmTy = FD->getParamDecl(0)->getType();
1063 
1064   if (!FirstParmTy->isPointerType())
1065     return false; // possibly some user-defined printf function
1066 
1067   QualType FirstPteTy = FirstParmTy->castAs<PointerType>()->getPointeeType();
1068 
1069   if (!Ctx.getFILEType()
1070            .isNull() && //`FILE *` must be in the context if it is fprintf
1071       FirstPteTy.getCanonicalType() == Ctx.getFILEType().getCanonicalType()) {
1072     // It is a fprintf:
1073     const Expr *UnsafeArg;
1074 
1075     if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 1, Ctx, false)) {
1076       Result.addNode(Tag, DynTypedNode::create(*UnsafeArg));
1077       return true;
1078     }
1079     return false;
1080   }
1081 
1082   if (FirstPteTy.isConstQualified()) {
1083     // If the first parameter is a `const char *`, it is a printf/kprintf:
1084     bool isKprintf = false;
1085     const Expr *UnsafeArg;
1086 
1087     if (auto *II = FD->getIdentifier())
1088       isKprintf = II->getName() == "kprintf";
1089     if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 0, Ctx, isKprintf)) {
1090       Result.addNode(Tag, DynTypedNode::create(*UnsafeArg));
1091       return true;
1092     }
1093     return false;
1094   }
1095 
1096   if (NumParms > 2) {
1097     QualType SecondParmTy = FD->getParamDecl(1)->getType();
1098 
1099     if (!FirstPteTy.isConstQualified() && SecondParmTy->isIntegerType()) {
1100       // If the first parameter type is non-const qualified `char *` and the
1101       // second is an integer, it is a snprintf:
1102       const Expr *UnsafeArg;
1103 
1104       if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 2, Ctx, false)) {
1105         Result.addNode(Tag, DynTypedNode::create(*UnsafeArg));
1106         return true;
1107       }
1108       return false;
1109     }
1110   }
1111   // We don't really recognize this "normal" printf, the only thing we
1112   // can do is to require all pointers to be null-terminated:
1113   for (const auto *Arg : Node.arguments())
1114     if (Arg->getType()->isPointerType() && !isNullTermPointer(Arg)) {
1115       Result.addNode(Tag, DynTypedNode::create(*Arg));
1116       return true;
1117     }
1118   return false;
1119 }
1120 
1121 // This function requires that it is known that the callee `isNormalPrintf`.
1122 // It returns true iff the first two arguments of the call is a pointer
1123 // `Ptr` and an unsigned integer `Size` and they are NOT safe, i.e.,
1124 // `!isPtrBufferSafe(Ptr, Size)`.
hasUnsafeSnprintfBuffer(const CallExpr & Node,ASTContext & Ctx)1125 static bool hasUnsafeSnprintfBuffer(const CallExpr &Node, ASTContext &Ctx) {
1126   const FunctionDecl *FD = Node.getDirectCallee();
1127 
1128   assert(FD && "It should have been checked that FD is non-null.");
1129 
1130   if (FD->getNumParams() < 3)
1131     return false; // Not an snprint
1132 
1133   QualType FirstParmTy = FD->getParamDecl(0)->getType();
1134 
1135   if (!FirstParmTy->isPointerType())
1136     return false; // Not an snprint
1137 
1138   QualType FirstPteTy = FirstParmTy->castAs<PointerType>()->getPointeeType();
1139   const Expr *Buf = Node.getArg(0), *Size = Node.getArg(1);
1140 
1141   if (FirstPteTy.isConstQualified() || !FirstPteTy->isAnyCharacterType() ||
1142       !Buf->getType()->isPointerType() ||
1143       !Size->getType()->isUnsignedIntegerType())
1144     return false; // not an snprintf call
1145 
1146   return !isPtrBufferSafe(Buf, Size, Ctx);
1147 }
1148 } // namespace libc_func_matchers
1149 
1150 namespace {
1151 // Because the analysis revolves around variables and their types, we'll need to
1152 // track uses of variables (aka DeclRefExprs).
1153 using DeclUseList = SmallVector<const DeclRefExpr *, 1>;
1154 
1155 // Convenience typedef.
1156 using FixItList = SmallVector<FixItHint, 4>;
1157 } // namespace
1158 
1159 namespace {
1160 /// Gadget is an individual operation in the code that may be of interest to
1161 /// this analysis. Each (non-abstract) subclass corresponds to a specific
1162 /// rigid AST structure that constitutes an operation on a pointer-type object.
1163 /// Discovery of a gadget in the code corresponds to claiming that we understand
1164 /// what this part of code is doing well enough to potentially improve it.
1165 /// Gadgets can be warning (immediately deserving a warning) or fixable (not
1166 /// always deserving a warning per se, but requires our attention to identify
1167 /// it warrants a fixit).
1168 class Gadget {
1169 public:
1170   enum class Kind {
1171 #define GADGET(x) x,
1172 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1173   };
1174 
Gadget(Kind K)1175   Gadget(Kind K) : K(K) {}
1176 
getKind() const1177   Kind getKind() const { return K; }
1178 
1179 #ifndef NDEBUG
getDebugName() const1180   StringRef getDebugName() const {
1181     switch (K) {
1182 #define GADGET(x)                                                              \
1183   case Kind::x:                                                                \
1184     return #x;
1185 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1186     }
1187     llvm_unreachable("Unhandled Gadget::Kind enum");
1188   }
1189 #endif
1190 
1191   virtual bool isWarningGadget() const = 0;
1192   // TODO remove this method from WarningGadget interface. It's only used for
1193   // debug prints in FixableGadget.
1194   virtual SourceLocation getSourceLoc() const = 0;
1195 
1196   /// Returns the list of pointer-type variables on which this gadget performs
1197   /// its operation. Typically, there's only one variable. This isn't a list
1198   /// of all DeclRefExprs in the gadget's AST!
1199   virtual DeclUseList getClaimedVarUseSites() const = 0;
1200 
1201   virtual ~Gadget() = default;
1202 
1203 private:
1204   Kind K;
1205 };
1206 
1207 /// Warning gadgets correspond to unsafe code patterns that warrants
1208 /// an immediate warning.
1209 class WarningGadget : public Gadget {
1210 public:
WarningGadget(Kind K)1211   WarningGadget(Kind K) : Gadget(K) {}
1212 
classof(const Gadget * G)1213   static bool classof(const Gadget *G) { return G->isWarningGadget(); }
isWarningGadget() const1214   bool isWarningGadget() const final { return true; }
1215 
1216   virtual void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1217                                      bool IsRelatedToDecl,
1218                                      ASTContext &Ctx) const = 0;
1219 
1220   virtual SmallVector<const Expr *, 1> getUnsafePtrs() const = 0;
1221 };
1222 
1223 /// Fixable gadgets correspond to code patterns that aren't always unsafe but
1224 /// need to be properly recognized in order to emit fixes. For example, if a raw
1225 /// pointer-type variable is replaced by a safe C++ container, every use of such
1226 /// variable must be carefully considered and possibly updated.
1227 class FixableGadget : public Gadget {
1228 public:
FixableGadget(Kind K)1229   FixableGadget(Kind K) : Gadget(K) {}
1230 
classof(const Gadget * G)1231   static bool classof(const Gadget *G) { return !G->isWarningGadget(); }
isWarningGadget() const1232   bool isWarningGadget() const final { return false; }
1233 
1234   /// Returns a fixit that would fix the current gadget according to
1235   /// the current strategy. Returns std::nullopt if the fix cannot be produced;
1236   /// returns an empty list if no fixes are necessary.
getFixits(const FixitStrategy &) const1237   virtual std::optional<FixItList> getFixits(const FixitStrategy &) const {
1238     return std::nullopt;
1239   }
1240 
1241   /// Returns a list of two elements where the first element is the LHS of a
1242   /// pointer assignment statement and the second element is the RHS. This
1243   /// two-element list represents the fact that the LHS buffer gets its bounds
1244   /// information from the RHS buffer. This information will be used later to
1245   /// group all those variables whose types must be modified together to prevent
1246   /// type mismatches.
1247   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const1248   getStrategyImplications() const {
1249     return std::nullopt;
1250   }
1251 };
1252 
isSupportedVariable(const DeclRefExpr & Node)1253 static bool isSupportedVariable(const DeclRefExpr &Node) {
1254   const Decl *D = Node.getDecl();
1255   return D != nullptr && isa<VarDecl>(D);
1256 }
1257 
1258 using FixableGadgetList = std::vector<std::unique_ptr<FixableGadget>>;
1259 using WarningGadgetList = std::vector<std::unique_ptr<WarningGadget>>;
1260 
1261 /// An increment of a pointer-type value is unsafe as it may run the pointer
1262 /// out of bounds.
1263 class IncrementGadget : public WarningGadget {
1264   static constexpr const char *const OpTag = "op";
1265   const UnaryOperator *Op;
1266 
1267 public:
IncrementGadget(const MatchResult & Result)1268   IncrementGadget(const MatchResult &Result)
1269       : WarningGadget(Kind::Increment),
1270         Op(Result.getNodeAs<UnaryOperator>(OpTag)) {}
1271 
classof(const Gadget * G)1272   static bool classof(const Gadget *G) {
1273     return G->getKind() == Kind::Increment;
1274   }
1275 
matches(const Stmt * S,const ASTContext & Ctx,MatchResult & Result)1276   static bool matches(const Stmt *S, const ASTContext &Ctx,
1277                       MatchResult &Result) {
1278     const auto *UO = dyn_cast<UnaryOperator>(S);
1279     if (!UO || !UO->isIncrementOp())
1280       return false;
1281     if (!hasPointerType(*UO->getSubExpr()->IgnoreParenImpCasts()))
1282       return false;
1283     Result.addNode(OpTag, DynTypedNode::create(*UO));
1284     return true;
1285   }
1286 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1287   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1288                              bool IsRelatedToDecl,
1289                              ASTContext &Ctx) const override {
1290     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
1291   }
getSourceLoc() const1292   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1293 
getClaimedVarUseSites() const1294   DeclUseList getClaimedVarUseSites() const override {
1295     SmallVector<const DeclRefExpr *, 2> Uses;
1296     if (const auto *DRE =
1297             dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
1298       Uses.push_back(DRE);
1299     }
1300 
1301     return std::move(Uses);
1302   }
1303 
getUnsafePtrs() const1304   SmallVector<const Expr *, 1> getUnsafePtrs() const override {
1305     return {Op->getSubExpr()->IgnoreParenImpCasts()};
1306   }
1307 };
1308 
1309 /// A decrement of a pointer-type value is unsafe as it may run the pointer
1310 /// out of bounds.
1311 class DecrementGadget : public WarningGadget {
1312   static constexpr const char *const OpTag = "op";
1313   const UnaryOperator *Op;
1314 
1315 public:
DecrementGadget(const MatchResult & Result)1316   DecrementGadget(const MatchResult &Result)
1317       : WarningGadget(Kind::Decrement),
1318         Op(Result.getNodeAs<UnaryOperator>(OpTag)) {}
1319 
classof(const Gadget * G)1320   static bool classof(const Gadget *G) {
1321     return G->getKind() == Kind::Decrement;
1322   }
1323 
matches(const Stmt * S,const ASTContext & Ctx,MatchResult & Result)1324   static bool matches(const Stmt *S, const ASTContext &Ctx,
1325                       MatchResult &Result) {
1326     const auto *UO = dyn_cast<UnaryOperator>(S);
1327     if (!UO || !UO->isDecrementOp())
1328       return false;
1329     if (!hasPointerType(*UO->getSubExpr()->IgnoreParenImpCasts()))
1330       return false;
1331     Result.addNode(OpTag, DynTypedNode::create(*UO));
1332     return true;
1333   }
1334 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1335   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1336                              bool IsRelatedToDecl,
1337                              ASTContext &Ctx) const override {
1338     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
1339   }
getSourceLoc() const1340   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1341 
getClaimedVarUseSites() const1342   DeclUseList getClaimedVarUseSites() const override {
1343     if (const auto *DRE =
1344             dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
1345       return {DRE};
1346     }
1347 
1348     return {};
1349   }
1350 
getUnsafePtrs() const1351   SmallVector<const Expr *, 1> getUnsafePtrs() const override {
1352     return {Op->getSubExpr()->IgnoreParenImpCasts()};
1353   }
1354 };
1355 
1356 /// Array subscript expressions on raw pointers as if they're arrays. Unsafe as
1357 /// it doesn't have any bounds checks for the array.
1358 class ArraySubscriptGadget : public WarningGadget {
1359   static constexpr const char *const ArraySubscrTag = "ArraySubscript";
1360   const ArraySubscriptExpr *ASE;
1361 
1362 public:
ArraySubscriptGadget(const MatchResult & Result)1363   ArraySubscriptGadget(const MatchResult &Result)
1364       : WarningGadget(Kind::ArraySubscript),
1365         ASE(Result.getNodeAs<ArraySubscriptExpr>(ArraySubscrTag)) {}
1366 
classof(const Gadget * G)1367   static bool classof(const Gadget *G) {
1368     return G->getKind() == Kind::ArraySubscript;
1369   }
1370 
matches(const Stmt * S,const ASTContext & Ctx,MatchResult & Result)1371   static bool matches(const Stmt *S, const ASTContext &Ctx,
1372                       MatchResult &Result) {
1373     const auto *ASE = dyn_cast<ArraySubscriptExpr>(S);
1374     if (!ASE)
1375       return false;
1376     const auto *const Base = ASE->getBase()->IgnoreParenImpCasts();
1377     if (!hasPointerType(*Base) && !hasArrayType(*Base))
1378       return false;
1379     const auto *Idx = dyn_cast<IntegerLiteral>(ASE->getIdx());
1380     bool IsSafeIndex = (Idx && Idx->getValue().isZero()) ||
1381                        isa<ArrayInitIndexExpr>(ASE->getIdx());
1382     if (IsSafeIndex || isSafeArraySubscript(*ASE, Ctx))
1383       return false;
1384     Result.addNode(ArraySubscrTag, DynTypedNode::create(*ASE));
1385     return true;
1386   }
1387 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1388   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1389                              bool IsRelatedToDecl,
1390                              ASTContext &Ctx) const override {
1391     Handler.handleUnsafeOperation(ASE, IsRelatedToDecl, Ctx);
1392   }
getSourceLoc() const1393   SourceLocation getSourceLoc() const override { return ASE->getBeginLoc(); }
1394 
getClaimedVarUseSites() const1395   DeclUseList getClaimedVarUseSites() const override {
1396     if (const auto *DRE =
1397             dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts())) {
1398       return {DRE};
1399     }
1400 
1401     return {};
1402   }
1403 
getUnsafePtrs() const1404   SmallVector<const Expr *, 1> getUnsafePtrs() const override {
1405     return {ASE->getBase()->IgnoreParenImpCasts()};
1406   }
1407 };
1408 
1409 /// A pointer arithmetic expression of one of the forms:
1410 ///  \code
1411 ///  ptr + n | n + ptr | ptr - n | ptr += n | ptr -= n
1412 ///  \endcode
1413 class PointerArithmeticGadget : public WarningGadget {
1414   static constexpr const char *const PointerArithmeticTag = "ptrAdd";
1415   static constexpr const char *const PointerArithmeticPointerTag = "ptrAddPtr";
1416   const BinaryOperator *PA; // pointer arithmetic expression
1417   const Expr *Ptr;          // the pointer expression in `PA`
1418 
1419 public:
PointerArithmeticGadget(const MatchResult & Result)1420   PointerArithmeticGadget(const MatchResult &Result)
1421       : WarningGadget(Kind::PointerArithmetic),
1422         PA(Result.getNodeAs<BinaryOperator>(PointerArithmeticTag)),
1423         Ptr(Result.getNodeAs<Expr>(PointerArithmeticPointerTag)) {}
1424 
classof(const Gadget * G)1425   static bool classof(const Gadget *G) {
1426     return G->getKind() == Kind::PointerArithmetic;
1427   }
1428 
matches(const Stmt * S,const ASTContext & Ctx,MatchResult & Result)1429   static bool matches(const Stmt *S, const ASTContext &Ctx,
1430                       MatchResult &Result) {
1431     const auto *BO = dyn_cast<BinaryOperator>(S);
1432     if (!BO)
1433       return false;
1434     const auto *LHS = BO->getLHS();
1435     const auto *RHS = BO->getRHS();
1436     // ptr at left
1437     if (BO->getOpcode() == BO_Add || BO->getOpcode() == BO_Sub ||
1438         BO->getOpcode() == BO_AddAssign || BO->getOpcode() == BO_SubAssign) {
1439       if (hasPointerType(*LHS) && (RHS->getType()->isIntegerType() ||
1440                                    RHS->getType()->isEnumeralType())) {
1441         Result.addNode(PointerArithmeticPointerTag, DynTypedNode::create(*LHS));
1442         Result.addNode(PointerArithmeticTag, DynTypedNode::create(*BO));
1443         return true;
1444       }
1445     }
1446     // ptr at right
1447     if (BO->getOpcode() == BO_Add && hasPointerType(*RHS) &&
1448         (LHS->getType()->isIntegerType() || LHS->getType()->isEnumeralType())) {
1449       Result.addNode(PointerArithmeticPointerTag, DynTypedNode::create(*RHS));
1450       Result.addNode(PointerArithmeticTag, DynTypedNode::create(*BO));
1451       return true;
1452     }
1453     return false;
1454   }
1455 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1456   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1457                              bool IsRelatedToDecl,
1458                              ASTContext &Ctx) const override {
1459     Handler.handleUnsafeOperation(PA, IsRelatedToDecl, Ctx);
1460   }
getSourceLoc() const1461   SourceLocation getSourceLoc() const override { return PA->getBeginLoc(); }
1462 
getClaimedVarUseSites() const1463   DeclUseList getClaimedVarUseSites() const override {
1464     if (const auto *DRE = dyn_cast<DeclRefExpr>(Ptr->IgnoreParenImpCasts())) {
1465       return {DRE};
1466     }
1467 
1468     return {};
1469   }
1470 
getUnsafePtrs() const1471   SmallVector<const Expr *, 1> getUnsafePtrs() const override {
1472     return {Ptr->IgnoreParenImpCasts()};
1473   }
1474 
1475   // FIXME: pointer adding zero should be fine
1476   // FIXME: this gadge will need a fix-it
1477 };
1478 
1479 class SpanTwoParamConstructorGadget : public WarningGadget {
1480   static constexpr const char *const SpanTwoParamConstructorTag =
1481       "spanTwoParamConstructor";
1482   const CXXConstructExpr *Ctor; // the span constructor expression
1483 
1484 public:
SpanTwoParamConstructorGadget(const MatchResult & Result)1485   SpanTwoParamConstructorGadget(const MatchResult &Result)
1486       : WarningGadget(Kind::SpanTwoParamConstructor),
1487         Ctor(Result.getNodeAs<CXXConstructExpr>(SpanTwoParamConstructorTag)) {}
1488 
classof(const Gadget * G)1489   static bool classof(const Gadget *G) {
1490     return G->getKind() == Kind::SpanTwoParamConstructor;
1491   }
1492 
matches(const Stmt * S,ASTContext & Ctx,MatchResult & Result)1493   static bool matches(const Stmt *S, ASTContext &Ctx, MatchResult &Result) {
1494     const auto *CE = dyn_cast<CXXConstructExpr>(S);
1495     if (!CE)
1496       return false;
1497     const auto *CDecl = CE->getConstructor();
1498     const auto *CRecordDecl = CDecl->getParent();
1499     auto HasTwoParamSpanCtorDecl =
1500         CRecordDecl->isInStdNamespace() &&
1501         CDecl->getDeclName().getAsString() == "span" && CE->getNumArgs() == 2;
1502     if (!HasTwoParamSpanCtorDecl || isSafeSpanTwoParamConstruct(*CE, Ctx))
1503       return false;
1504     Result.addNode(SpanTwoParamConstructorTag, DynTypedNode::create(*CE));
1505     return true;
1506   }
1507 
matches(const Stmt * S,ASTContext & Ctx,const UnsafeBufferUsageHandler * Handler,MatchResult & Result)1508   static bool matches(const Stmt *S, ASTContext &Ctx,
1509                       const UnsafeBufferUsageHandler *Handler,
1510                       MatchResult &Result) {
1511     if (ignoreUnsafeBufferInContainer(*S, Handler))
1512       return false;
1513     return matches(S, Ctx, Result);
1514   }
1515 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1516   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1517                              bool IsRelatedToDecl,
1518                              ASTContext &Ctx) const override {
1519     Handler.handleUnsafeOperationInContainer(Ctor, IsRelatedToDecl, Ctx);
1520   }
getSourceLoc() const1521   SourceLocation getSourceLoc() const override { return Ctor->getBeginLoc(); }
1522 
getClaimedVarUseSites() const1523   DeclUseList getClaimedVarUseSites() const override {
1524     // If the constructor call is of the form `std::span{var, n}`, `var` is
1525     // considered an unsafe variable.
1526     if (auto *DRE = dyn_cast<DeclRefExpr>(Ctor->getArg(0))) {
1527       if (isa<VarDecl>(DRE->getDecl()))
1528         return {DRE};
1529     }
1530     return {};
1531   }
1532 
getUnsafePtrs() const1533   SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; }
1534 };
1535 
1536 /// A pointer initialization expression of the form:
1537 ///  \code
1538 ///  int *p = q;
1539 ///  \endcode
1540 class PointerInitGadget : public FixableGadget {
1541 private:
1542   static constexpr const char *const PointerInitLHSTag = "ptrInitLHS";
1543   static constexpr const char *const PointerInitRHSTag = "ptrInitRHS";
1544   const VarDecl *PtrInitLHS;     // the LHS pointer expression in `PI`
1545   const DeclRefExpr *PtrInitRHS; // the RHS pointer expression in `PI`
1546 
1547 public:
PointerInitGadget(const MatchResult & Result)1548   PointerInitGadget(const MatchResult &Result)
1549       : FixableGadget(Kind::PointerInit),
1550         PtrInitLHS(Result.getNodeAs<VarDecl>(PointerInitLHSTag)),
1551         PtrInitRHS(Result.getNodeAs<DeclRefExpr>(PointerInitRHSTag)) {}
1552 
classof(const Gadget * G)1553   static bool classof(const Gadget *G) {
1554     return G->getKind() == Kind::PointerInit;
1555   }
1556 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)1557   static bool matches(const Stmt *S,
1558                       llvm::SmallVectorImpl<MatchResult> &Results) {
1559     const DeclStmt *DS = dyn_cast<DeclStmt>(S);
1560     if (!DS || !DS->isSingleDecl())
1561       return false;
1562     const VarDecl *VD = dyn_cast<VarDecl>(DS->getSingleDecl());
1563     if (!VD)
1564       return false;
1565     const Expr *Init = VD->getAnyInitializer();
1566     if (!Init)
1567       return false;
1568     const auto *DRE = dyn_cast<DeclRefExpr>(Init->IgnoreImpCasts());
1569     if (!DRE || !hasPointerType(*DRE) || !isSupportedVariable(*DRE)) {
1570       return false;
1571     }
1572     MatchResult R;
1573     R.addNode(PointerInitLHSTag, DynTypedNode::create(*VD));
1574     R.addNode(PointerInitRHSTag, DynTypedNode::create(*DRE));
1575     Results.emplace_back(std::move(R));
1576     return true;
1577   }
1578 
1579   virtual std::optional<FixItList>
1580   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1581   SourceLocation getSourceLoc() const override {
1582     return PtrInitRHS->getBeginLoc();
1583   }
1584 
getClaimedVarUseSites() const1585   virtual DeclUseList getClaimedVarUseSites() const override {
1586     return DeclUseList{PtrInitRHS};
1587   }
1588 
1589   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const1590   getStrategyImplications() const override {
1591     return std::make_pair(PtrInitLHS, cast<VarDecl>(PtrInitRHS->getDecl()));
1592   }
1593 };
1594 
1595 /// A pointer assignment expression of the form:
1596 ///  \code
1597 ///  p = q;
1598 ///  \endcode
1599 /// where both `p` and `q` are pointers.
1600 class PtrToPtrAssignmentGadget : public FixableGadget {
1601 private:
1602   static constexpr const char *const PointerAssignLHSTag = "ptrLHS";
1603   static constexpr const char *const PointerAssignRHSTag = "ptrRHS";
1604   const DeclRefExpr *PtrLHS; // the LHS pointer expression in `PA`
1605   const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA`
1606 
1607 public:
PtrToPtrAssignmentGadget(const MatchResult & Result)1608   PtrToPtrAssignmentGadget(const MatchResult &Result)
1609       : FixableGadget(Kind::PtrToPtrAssignment),
1610         PtrLHS(Result.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)),
1611         PtrRHS(Result.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {}
1612 
classof(const Gadget * G)1613   static bool classof(const Gadget *G) {
1614     return G->getKind() == Kind::PtrToPtrAssignment;
1615   }
1616 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)1617   static bool matches(const Stmt *S,
1618                       llvm::SmallVectorImpl<MatchResult> &Results) {
1619     size_t SizeBefore = Results.size();
1620     findStmtsInUnspecifiedUntypedContext(S, [&Results](const Stmt *S) {
1621       const auto *BO = dyn_cast<BinaryOperator>(S);
1622       if (!BO || BO->getOpcode() != BO_Assign)
1623         return;
1624       const auto *RHS = BO->getRHS()->IgnoreParenImpCasts();
1625       if (const auto *RHSRef = dyn_cast<DeclRefExpr>(RHS);
1626           !RHSRef || !hasPointerType(*RHSRef) ||
1627           !isSupportedVariable(*RHSRef)) {
1628         return;
1629       }
1630       const auto *LHS = BO->getLHS();
1631       if (const auto *LHSRef = dyn_cast<DeclRefExpr>(LHS);
1632           !LHSRef || !hasPointerType(*LHSRef) ||
1633           !isSupportedVariable(*LHSRef)) {
1634         return;
1635       }
1636       MatchResult R;
1637       R.addNode(PointerAssignLHSTag, DynTypedNode::create(*LHS));
1638       R.addNode(PointerAssignRHSTag, DynTypedNode::create(*RHS));
1639       Results.emplace_back(std::move(R));
1640     });
1641     return SizeBefore != Results.size();
1642   }
1643 
1644   virtual std::optional<FixItList>
1645   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1646   SourceLocation getSourceLoc() const override { return PtrLHS->getBeginLoc(); }
1647 
getClaimedVarUseSites() const1648   virtual DeclUseList getClaimedVarUseSites() const override {
1649     return DeclUseList{PtrLHS, PtrRHS};
1650   }
1651 
1652   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const1653   getStrategyImplications() const override {
1654     return std::make_pair(cast<VarDecl>(PtrLHS->getDecl()),
1655                           cast<VarDecl>(PtrRHS->getDecl()));
1656   }
1657 };
1658 
1659 /// An assignment expression of the form:
1660 ///  \code
1661 ///  ptr = array;
1662 ///  \endcode
1663 /// where `p` is a pointer and `array` is a constant size array.
1664 class CArrayToPtrAssignmentGadget : public FixableGadget {
1665 private:
1666   static constexpr const char *const PointerAssignLHSTag = "ptrLHS";
1667   static constexpr const char *const PointerAssignRHSTag = "ptrRHS";
1668   const DeclRefExpr *PtrLHS; // the LHS pointer expression in `PA`
1669   const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA`
1670 
1671 public:
CArrayToPtrAssignmentGadget(const MatchResult & Result)1672   CArrayToPtrAssignmentGadget(const MatchResult &Result)
1673       : FixableGadget(Kind::CArrayToPtrAssignment),
1674         PtrLHS(Result.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)),
1675         PtrRHS(Result.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {}
1676 
classof(const Gadget * G)1677   static bool classof(const Gadget *G) {
1678     return G->getKind() == Kind::CArrayToPtrAssignment;
1679   }
1680 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)1681   static bool matches(const Stmt *S,
1682                       llvm::SmallVectorImpl<MatchResult> &Results) {
1683     size_t SizeBefore = Results.size();
1684     findStmtsInUnspecifiedUntypedContext(S, [&Results](const Stmt *S) {
1685       const auto *BO = dyn_cast<BinaryOperator>(S);
1686       if (!BO || BO->getOpcode() != BO_Assign)
1687         return;
1688       const auto *RHS = BO->getRHS()->IgnoreParenImpCasts();
1689       if (const auto *RHSRef = dyn_cast<DeclRefExpr>(RHS);
1690           !RHSRef ||
1691           !isa<ConstantArrayType>(RHSRef->getType().getCanonicalType()) ||
1692           !isSupportedVariable(*RHSRef)) {
1693         return;
1694       }
1695       const auto *LHS = BO->getLHS();
1696       if (const auto *LHSRef = dyn_cast<DeclRefExpr>(LHS);
1697           !LHSRef || !hasPointerType(*LHSRef) ||
1698           !isSupportedVariable(*LHSRef)) {
1699         return;
1700       }
1701       MatchResult R;
1702       R.addNode(PointerAssignLHSTag, DynTypedNode::create(*LHS));
1703       R.addNode(PointerAssignRHSTag, DynTypedNode::create(*RHS));
1704       Results.emplace_back(std::move(R));
1705     });
1706     return SizeBefore != Results.size();
1707   }
1708 
1709   virtual std::optional<FixItList>
1710   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1711   SourceLocation getSourceLoc() const override { return PtrLHS->getBeginLoc(); }
1712 
getClaimedVarUseSites() const1713   virtual DeclUseList getClaimedVarUseSites() const override {
1714     return DeclUseList{PtrLHS, PtrRHS};
1715   }
1716 
1717   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const1718   getStrategyImplications() const override {
1719     return {};
1720   }
1721 };
1722 
1723 /// A call of a function or method that performs unchecked buffer operations
1724 /// over one of its pointer parameters.
1725 class UnsafeBufferUsageAttrGadget : public WarningGadget {
1726   constexpr static const char *const OpTag = "attr_expr";
1727   const Expr *Op;
1728 
1729 public:
UnsafeBufferUsageAttrGadget(const MatchResult & Result)1730   UnsafeBufferUsageAttrGadget(const MatchResult &Result)
1731       : WarningGadget(Kind::UnsafeBufferUsageAttr),
1732         Op(Result.getNodeAs<Expr>(OpTag)) {}
1733 
classof(const Gadget * G)1734   static bool classof(const Gadget *G) {
1735     return G->getKind() == Kind::UnsafeBufferUsageAttr;
1736   }
1737 
matches(const Stmt * S,const ASTContext & Ctx,MatchResult & Result)1738   static bool matches(const Stmt *S, const ASTContext &Ctx,
1739                       MatchResult &Result) {
1740     if (auto *CE = dyn_cast<CallExpr>(S)) {
1741       if (CE->getDirectCallee() &&
1742           CE->getDirectCallee()->hasAttr<UnsafeBufferUsageAttr>()) {
1743         Result.addNode(OpTag, DynTypedNode::create(*CE));
1744         return true;
1745       }
1746     }
1747     if (auto *ME = dyn_cast<MemberExpr>(S)) {
1748       if (!isa<FieldDecl>(ME->getMemberDecl()))
1749         return false;
1750       if (ME->getMemberDecl()->hasAttr<UnsafeBufferUsageAttr>()) {
1751         Result.addNode(OpTag, DynTypedNode::create(*ME));
1752         return true;
1753       }
1754     }
1755     return false;
1756   }
1757 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1758   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1759                              bool IsRelatedToDecl,
1760                              ASTContext &Ctx) const override {
1761     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
1762   }
getSourceLoc() const1763   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1764 
getClaimedVarUseSites() const1765   DeclUseList getClaimedVarUseSites() const override { return {}; }
1766 
getUnsafePtrs() const1767   SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; }
1768 };
1769 
1770 /// A call of a constructor that performs unchecked buffer operations
1771 /// over one of its pointer parameters, or constructs a class object that will
1772 /// perform buffer operations that depend on the correctness of the parameters.
1773 class UnsafeBufferUsageCtorAttrGadget : public WarningGadget {
1774   constexpr static const char *const OpTag = "cxx_construct_expr";
1775   const CXXConstructExpr *Op;
1776 
1777 public:
UnsafeBufferUsageCtorAttrGadget(const MatchResult & Result)1778   UnsafeBufferUsageCtorAttrGadget(const MatchResult &Result)
1779       : WarningGadget(Kind::UnsafeBufferUsageCtorAttr),
1780         Op(Result.getNodeAs<CXXConstructExpr>(OpTag)) {}
1781 
classof(const Gadget * G)1782   static bool classof(const Gadget *G) {
1783     return G->getKind() == Kind::UnsafeBufferUsageCtorAttr;
1784   }
1785 
matches(const Stmt * S,ASTContext & Ctx,MatchResult & Result)1786   static bool matches(const Stmt *S, ASTContext &Ctx, MatchResult &Result) {
1787     const auto *CE = dyn_cast<CXXConstructExpr>(S);
1788     if (!CE || !CE->getConstructor()->hasAttr<UnsafeBufferUsageAttr>())
1789       return false;
1790     // std::span(ptr, size) ctor is handled by SpanTwoParamConstructorGadget.
1791     MatchResult Tmp;
1792     if (SpanTwoParamConstructorGadget::matches(CE, Ctx, Tmp))
1793       return false;
1794     Result.addNode(OpTag, DynTypedNode::create(*CE));
1795     return true;
1796   }
1797 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1798   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1799                              bool IsRelatedToDecl,
1800                              ASTContext &Ctx) const override {
1801     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
1802   }
getSourceLoc() const1803   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1804 
getClaimedVarUseSites() const1805   DeclUseList getClaimedVarUseSites() const override { return {}; }
1806 
getUnsafePtrs() const1807   SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; }
1808 };
1809 
1810 // Warning gadget for unsafe invocation of span::data method.
1811 // Triggers when the pointer returned by the invocation is immediately
1812 // cast to a larger type.
1813 
1814 class DataInvocationGadget : public WarningGadget {
1815   constexpr static const char *const OpTag = "data_invocation_expr";
1816   const ExplicitCastExpr *Op;
1817 
1818 public:
DataInvocationGadget(const MatchResult & Result)1819   DataInvocationGadget(const MatchResult &Result)
1820       : WarningGadget(Kind::DataInvocation),
1821         Op(Result.getNodeAs<ExplicitCastExpr>(OpTag)) {}
1822 
classof(const Gadget * G)1823   static bool classof(const Gadget *G) {
1824     return G->getKind() == Kind::DataInvocation;
1825   }
1826 
matches(const Stmt * S,const ASTContext & Ctx,MatchResult & Result)1827   static bool matches(const Stmt *S, const ASTContext &Ctx,
1828                       MatchResult &Result) {
1829     auto *CE = dyn_cast<ExplicitCastExpr>(S);
1830     if (!CE)
1831       return false;
1832     for (auto *Child : CE->children()) {
1833       if (auto *MCE = dyn_cast<CXXMemberCallExpr>(Child);
1834           MCE && isDataFunction(MCE)) {
1835         Result.addNode(OpTag, DynTypedNode::create(*CE));
1836         return true;
1837       }
1838       if (auto *Paren = dyn_cast<ParenExpr>(Child)) {
1839         if (auto *MCE = dyn_cast<CXXMemberCallExpr>(Paren->getSubExpr());
1840             MCE && isDataFunction(MCE)) {
1841           Result.addNode(OpTag, DynTypedNode::create(*CE));
1842           return true;
1843         }
1844       }
1845     }
1846     return false;
1847   }
1848 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1849   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1850                              bool IsRelatedToDecl,
1851                              ASTContext &Ctx) const override {
1852     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
1853   }
getSourceLoc() const1854   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1855 
getClaimedVarUseSites() const1856   DeclUseList getClaimedVarUseSites() const override { return {}; }
1857 
1858 private:
isDataFunction(const CXXMemberCallExpr * call)1859   static bool isDataFunction(const CXXMemberCallExpr *call) {
1860     if (!call)
1861       return false;
1862     auto *callee = call->getDirectCallee();
1863     if (!callee || !isa<CXXMethodDecl>(callee))
1864       return false;
1865     auto *method = cast<CXXMethodDecl>(callee);
1866     if (method->getNameAsString() == "data" &&
1867         method->getParent()->isInStdNamespace() &&
1868         llvm::is_contained({SIZED_CONTAINER_OR_VIEW_LIST},
1869                            method->getParent()->getName()))
1870       return true;
1871     return false;
1872   }
1873 
getUnsafePtrs() const1874   SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; }
1875 };
1876 
1877 class UnsafeLibcFunctionCallGadget : public WarningGadget {
1878   const CallExpr *const Call;
1879   const Expr *UnsafeArg = nullptr;
1880   constexpr static const char *const Tag = "UnsafeLibcFunctionCall";
1881   // Extra tags for additional information:
1882   constexpr static const char *const UnsafeSprintfTag =
1883       "UnsafeLibcFunctionCall_sprintf";
1884   constexpr static const char *const UnsafeSizedByTag =
1885       "UnsafeLibcFunctionCall_sized_by";
1886   constexpr static const char *const UnsafeStringTag =
1887       "UnsafeLibcFunctionCall_string";
1888   constexpr static const char *const UnsafeVaListTag =
1889       "UnsafeLibcFunctionCall_va_list";
1890 
1891   enum UnsafeKind {
1892     OTHERS = 0,  // no specific information, the callee function is unsafe
1893     SPRINTF = 1, // never call `-sprintf`s, call `-snprintf`s instead.
1894     SIZED_BY =
1895         2, // the first two arguments of `snprintf` function have
1896            // "__sized_by" relation but they do not conform to safe patterns
1897     STRING = 3,  // an argument is a pointer-to-char-as-string but does not
1898                  // guarantee null-termination
1899     VA_LIST = 4, // one of the `-printf`s function that take va_list, which is
1900                  // considered unsafe as it is not compile-time check
1901   } WarnedFunKind = OTHERS;
1902 
1903 public:
UnsafeLibcFunctionCallGadget(const MatchResult & Result)1904   UnsafeLibcFunctionCallGadget(const MatchResult &Result)
1905       : WarningGadget(Kind::UnsafeLibcFunctionCall),
1906         Call(Result.getNodeAs<CallExpr>(Tag)) {
1907     if (Result.getNodeAs<Decl>(UnsafeSprintfTag))
1908       WarnedFunKind = SPRINTF;
1909     else if (auto *E = Result.getNodeAs<Expr>(UnsafeStringTag)) {
1910       WarnedFunKind = STRING;
1911       UnsafeArg = E;
1912     } else if (Result.getNodeAs<CallExpr>(UnsafeSizedByTag)) {
1913       WarnedFunKind = SIZED_BY;
1914       UnsafeArg = Call->getArg(0);
1915     } else if (Result.getNodeAs<Decl>(UnsafeVaListTag))
1916       WarnedFunKind = VA_LIST;
1917   }
1918 
matches(const Stmt * S,ASTContext & Ctx,const UnsafeBufferUsageHandler * Handler,MatchResult & Result)1919   static bool matches(const Stmt *S, ASTContext &Ctx,
1920                       const UnsafeBufferUsageHandler *Handler,
1921                       MatchResult &Result) {
1922     if (ignoreUnsafeLibcCall(Ctx, *S, Handler))
1923       return false;
1924     auto *CE = dyn_cast<CallExpr>(S);
1925     if (!CE || !CE->getDirectCallee())
1926       return false;
1927     const auto *FD = dyn_cast<FunctionDecl>(CE->getDirectCallee());
1928     if (!FD)
1929       return false;
1930     auto isSingleStringLiteralArg = false;
1931     if (CE->getNumArgs() == 1) {
1932       isSingleStringLiteralArg =
1933           isa<clang::StringLiteral>(CE->getArg(0)->IgnoreParenImpCasts());
1934     }
1935     if (!isSingleStringLiteralArg) {
1936       // (unless the call has a sole string literal argument):
1937       if (libc_func_matchers::isPredefinedUnsafeLibcFunc(*FD)) {
1938         Result.addNode(Tag, DynTypedNode::create(*CE));
1939         return true;
1940       }
1941       if (libc_func_matchers::isUnsafeVaListPrintfFunc(*FD)) {
1942         Result.addNode(Tag, DynTypedNode::create(*CE));
1943         Result.addNode(UnsafeVaListTag, DynTypedNode::create(*FD));
1944         return true;
1945       }
1946       if (libc_func_matchers::isUnsafeSprintfFunc(*FD)) {
1947         Result.addNode(Tag, DynTypedNode::create(*CE));
1948         Result.addNode(UnsafeSprintfTag, DynTypedNode::create(*FD));
1949         return true;
1950       }
1951     }
1952     if (libc_func_matchers::isNormalPrintfFunc(*FD)) {
1953       if (libc_func_matchers::hasUnsafeSnprintfBuffer(*CE, Ctx)) {
1954         Result.addNode(Tag, DynTypedNode::create(*CE));
1955         Result.addNode(UnsafeSizedByTag, DynTypedNode::create(*CE));
1956         return true;
1957       }
1958       if (libc_func_matchers::hasUnsafePrintfStringArg(*CE, Ctx, Result,
1959                                                        UnsafeStringTag)) {
1960         Result.addNode(Tag, DynTypedNode::create(*CE));
1961         return true;
1962       }
1963     }
1964     return false;
1965   }
1966 
getBaseStmt() const1967   const Stmt *getBaseStmt() const { return Call; }
1968 
getSourceLoc() const1969   SourceLocation getSourceLoc() const override { return Call->getBeginLoc(); }
1970 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1971   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1972                              bool IsRelatedToDecl,
1973                              ASTContext &Ctx) const override {
1974     Handler.handleUnsafeLibcCall(Call, WarnedFunKind, Ctx, UnsafeArg);
1975   }
1976 
getClaimedVarUseSites() const1977   DeclUseList getClaimedVarUseSites() const override { return {}; }
1978 
getUnsafePtrs() const1979   SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; }
1980 };
1981 
1982 // Represents expressions of the form `DRE[*]` in the Unspecified Lvalue
1983 // Context (see `findStmtsInUnspecifiedLvalueContext`).
1984 // Note here `[]` is the built-in subscript operator.
1985 class ULCArraySubscriptGadget : public FixableGadget {
1986 private:
1987   static constexpr const char *const ULCArraySubscriptTag =
1988       "ArraySubscriptUnderULC";
1989   const ArraySubscriptExpr *Node;
1990 
1991 public:
ULCArraySubscriptGadget(const MatchResult & Result)1992   ULCArraySubscriptGadget(const MatchResult &Result)
1993       : FixableGadget(Kind::ULCArraySubscript),
1994         Node(Result.getNodeAs<ArraySubscriptExpr>(ULCArraySubscriptTag)) {
1995     assert(Node != nullptr && "Expecting a non-null matching result");
1996   }
1997 
classof(const Gadget * G)1998   static bool classof(const Gadget *G) {
1999     return G->getKind() == Kind::ULCArraySubscript;
2000   }
2001 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)2002   static bool matches(const Stmt *S,
2003                       llvm::SmallVectorImpl<MatchResult> &Results) {
2004     size_t SizeBefore = Results.size();
2005     findStmtsInUnspecifiedLvalueContext(S, [&Results](const Expr *E) {
2006       const auto *ASE = dyn_cast<ArraySubscriptExpr>(E);
2007       if (!ASE)
2008         return;
2009       const auto *DRE =
2010           dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts());
2011       if (!DRE || !(hasPointerType(*DRE) || hasArrayType(*DRE)) ||
2012           !isSupportedVariable(*DRE))
2013         return;
2014       MatchResult R;
2015       R.addNode(ULCArraySubscriptTag, DynTypedNode::create(*ASE));
2016       Results.emplace_back(std::move(R));
2017     });
2018     return SizeBefore != Results.size();
2019   }
2020 
2021   virtual std::optional<FixItList>
2022   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const2023   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
2024 
getClaimedVarUseSites() const2025   virtual DeclUseList getClaimedVarUseSites() const override {
2026     if (const auto *DRE =
2027             dyn_cast<DeclRefExpr>(Node->getBase()->IgnoreImpCasts())) {
2028       return {DRE};
2029     }
2030     return {};
2031   }
2032 };
2033 
2034 // Fixable gadget to handle stand alone pointers of the form `UPC(DRE)` in the
2035 // unspecified pointer context (findStmtsInUnspecifiedPointerContext). The
2036 // gadget emits fixit of the form `UPC(DRE.data())`.
2037 class UPCStandalonePointerGadget : public FixableGadget {
2038 private:
2039   static constexpr const char *const DeclRefExprTag = "StandalonePointer";
2040   const DeclRefExpr *Node;
2041 
2042 public:
UPCStandalonePointerGadget(const MatchResult & Result)2043   UPCStandalonePointerGadget(const MatchResult &Result)
2044       : FixableGadget(Kind::UPCStandalonePointer),
2045         Node(Result.getNodeAs<DeclRefExpr>(DeclRefExprTag)) {
2046     assert(Node != nullptr && "Expecting a non-null matching result");
2047   }
2048 
classof(const Gadget * G)2049   static bool classof(const Gadget *G) {
2050     return G->getKind() == Kind::UPCStandalonePointer;
2051   }
2052 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)2053   static bool matches(const Stmt *S,
2054                       llvm::SmallVectorImpl<MatchResult> &Results) {
2055     size_t SizeBefore = Results.size();
2056     findStmtsInUnspecifiedPointerContext(S, [&Results](const Stmt *S) {
2057       auto *E = dyn_cast<Expr>(S);
2058       if (!E)
2059         return;
2060       const auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenImpCasts());
2061       if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE)) ||
2062           !isSupportedVariable(*DRE))
2063         return;
2064       MatchResult R;
2065       R.addNode(DeclRefExprTag, DynTypedNode::create(*DRE));
2066       Results.emplace_back(std::move(R));
2067     });
2068     return SizeBefore != Results.size();
2069   }
2070 
2071   virtual std::optional<FixItList>
2072   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const2073   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
2074 
getClaimedVarUseSites() const2075   virtual DeclUseList getClaimedVarUseSites() const override { return {Node}; }
2076 };
2077 
2078 class PointerDereferenceGadget : public FixableGadget {
2079   static constexpr const char *const BaseDeclRefExprTag = "BaseDRE";
2080   static constexpr const char *const OperatorTag = "op";
2081 
2082   const DeclRefExpr *BaseDeclRefExpr = nullptr;
2083   const UnaryOperator *Op = nullptr;
2084 
2085 public:
PointerDereferenceGadget(const MatchResult & Result)2086   PointerDereferenceGadget(const MatchResult &Result)
2087       : FixableGadget(Kind::PointerDereference),
2088         BaseDeclRefExpr(Result.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)),
2089         Op(Result.getNodeAs<UnaryOperator>(OperatorTag)) {}
2090 
classof(const Gadget * G)2091   static bool classof(const Gadget *G) {
2092     return G->getKind() == Kind::PointerDereference;
2093   }
2094 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)2095   static bool matches(const Stmt *S,
2096                       llvm::SmallVectorImpl<MatchResult> &Results) {
2097     size_t SizeBefore = Results.size();
2098     findStmtsInUnspecifiedLvalueContext(S, [&Results](const Stmt *S) {
2099       const auto *UO = dyn_cast<UnaryOperator>(S);
2100       if (!UO || UO->getOpcode() != UO_Deref)
2101         return;
2102       const auto *CE = dyn_cast<Expr>(UO->getSubExpr());
2103       if (!CE)
2104         return;
2105       CE = CE->IgnoreParenImpCasts();
2106       const auto *DRE = dyn_cast<DeclRefExpr>(CE);
2107       if (!DRE || !isSupportedVariable(*DRE))
2108         return;
2109       MatchResult R;
2110       R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE));
2111       R.addNode(OperatorTag, DynTypedNode::create(*UO));
2112       Results.emplace_back(std::move(R));
2113     });
2114     return SizeBefore != Results.size();
2115   }
2116 
getClaimedVarUseSites() const2117   DeclUseList getClaimedVarUseSites() const override {
2118     return {BaseDeclRefExpr};
2119   }
2120 
2121   virtual std::optional<FixItList>
2122   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const2123   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
2124 };
2125 
2126 // Represents expressions of the form `&DRE[any]` in the Unspecified Pointer
2127 // Context (see `findStmtsInUnspecifiedPointerContext`).
2128 // Note here `[]` is the built-in subscript operator.
2129 class UPCAddressofArraySubscriptGadget : public FixableGadget {
2130 private:
2131   static constexpr const char *const UPCAddressofArraySubscriptTag =
2132       "AddressofArraySubscriptUnderUPC";
2133   const UnaryOperator *Node; // the `&DRE[any]` node
2134 
2135 public:
UPCAddressofArraySubscriptGadget(const MatchResult & Result)2136   UPCAddressofArraySubscriptGadget(const MatchResult &Result)
2137       : FixableGadget(Kind::ULCArraySubscript),
2138         Node(Result.getNodeAs<UnaryOperator>(UPCAddressofArraySubscriptTag)) {
2139     assert(Node != nullptr && "Expecting a non-null matching result");
2140   }
2141 
classof(const Gadget * G)2142   static bool classof(const Gadget *G) {
2143     return G->getKind() == Kind::UPCAddressofArraySubscript;
2144   }
2145 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)2146   static bool matches(const Stmt *S,
2147                       llvm::SmallVectorImpl<MatchResult> &Results) {
2148     size_t SizeBefore = Results.size();
2149     findStmtsInUnspecifiedPointerContext(S, [&Results](const Stmt *S) {
2150       auto *E = dyn_cast<Expr>(S);
2151       if (!E)
2152         return;
2153       const auto *UO = dyn_cast<UnaryOperator>(E->IgnoreImpCasts());
2154       if (!UO || UO->getOpcode() != UO_AddrOf)
2155         return;
2156       const auto *ASE = dyn_cast<ArraySubscriptExpr>(UO->getSubExpr());
2157       if (!ASE)
2158         return;
2159       const auto *DRE =
2160           dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts());
2161       if (!DRE || !isSupportedVariable(*DRE))
2162         return;
2163       MatchResult R;
2164       R.addNode(UPCAddressofArraySubscriptTag, DynTypedNode::create(*UO));
2165       Results.emplace_back(std::move(R));
2166     });
2167     return SizeBefore != Results.size();
2168   }
2169 
2170   virtual std::optional<FixItList>
2171   getFixits(const FixitStrategy &) const override;
getSourceLoc() const2172   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
2173 
getClaimedVarUseSites() const2174   virtual DeclUseList getClaimedVarUseSites() const override {
2175     const auto *ArraySubst = cast<ArraySubscriptExpr>(Node->getSubExpr());
2176     const auto *DRE =
2177         cast<DeclRefExpr>(ArraySubst->getBase()->IgnoreParenImpCasts());
2178     return {DRE};
2179   }
2180 };
2181 } // namespace
2182 
2183 namespace {
2184 // An auxiliary tracking facility for the fixit analysis. It helps connect
2185 // declarations to its uses and make sure we've covered all uses with our
2186 // analysis before we try to fix the declaration.
2187 class DeclUseTracker {
2188   using UseSetTy = llvm::SmallSet<const DeclRefExpr *, 16>;
2189   using DefMapTy = llvm::DenseMap<const VarDecl *, const DeclStmt *>;
2190 
2191   // Allocate on the heap for easier move.
2192   std::unique_ptr<UseSetTy> Uses{std::make_unique<UseSetTy>()};
2193   DefMapTy Defs{};
2194 
2195 public:
2196   DeclUseTracker() = default;
2197   DeclUseTracker(const DeclUseTracker &) = delete; // Let's avoid copies.
2198   DeclUseTracker &operator=(const DeclUseTracker &) = delete;
2199   DeclUseTracker(DeclUseTracker &&) = default;
2200   DeclUseTracker &operator=(DeclUseTracker &&) = default;
2201 
2202   // Start tracking a freshly discovered DRE.
discoverUse(const DeclRefExpr * DRE)2203   void discoverUse(const DeclRefExpr *DRE) { Uses->insert(DRE); }
2204 
2205   // Stop tracking the DRE as it's been fully figured out.
claimUse(const DeclRefExpr * DRE)2206   void claimUse(const DeclRefExpr *DRE) {
2207     assert(Uses->count(DRE) &&
2208            "DRE not found or claimed by multiple matchers!");
2209     Uses->erase(DRE);
2210   }
2211 
2212   // A variable is unclaimed if at least one use is unclaimed.
hasUnclaimedUses(const VarDecl * VD) const2213   bool hasUnclaimedUses(const VarDecl *VD) const {
2214     // FIXME: Can this be less linear? Maybe maintain a map from VDs to DREs?
2215     return any_of(*Uses, [VD](const DeclRefExpr *DRE) {
2216       return DRE->getDecl()->getCanonicalDecl() == VD->getCanonicalDecl();
2217     });
2218   }
2219 
getUnclaimedUses(const VarDecl * VD) const2220   UseSetTy getUnclaimedUses(const VarDecl *VD) const {
2221     UseSetTy ReturnSet;
2222     for (auto use : *Uses) {
2223       if (use->getDecl()->getCanonicalDecl() == VD->getCanonicalDecl()) {
2224         ReturnSet.insert(use);
2225       }
2226     }
2227     return ReturnSet;
2228   }
2229 
discoverDecl(const DeclStmt * DS)2230   void discoverDecl(const DeclStmt *DS) {
2231     for (const Decl *D : DS->decls()) {
2232       if (const auto *VD = dyn_cast<VarDecl>(D)) {
2233         // FIXME: Assertion temporarily disabled due to a bug in
2234         // ASTMatcher internal behavior in presence of GNU
2235         // statement-expressions. We need to properly investigate this
2236         // because it can screw up our algorithm in other ways.
2237         // assert(Defs.count(VD) == 0 && "Definition already discovered!");
2238         Defs[VD] = DS;
2239       }
2240     }
2241   }
2242 
lookupDecl(const VarDecl * VD) const2243   const DeclStmt *lookupDecl(const VarDecl *VD) const {
2244     return Defs.lookup(VD);
2245   }
2246 };
2247 } // namespace
2248 
2249 // Representing a pointer type expression of the form `++Ptr` in an Unspecified
2250 // Pointer Context (UPC):
2251 class UPCPreIncrementGadget : public FixableGadget {
2252 private:
2253   static constexpr const char *const UPCPreIncrementTag =
2254       "PointerPreIncrementUnderUPC";
2255   const UnaryOperator *Node; // the `++Ptr` node
2256 
2257 public:
UPCPreIncrementGadget(const MatchResult & Result)2258   UPCPreIncrementGadget(const MatchResult &Result)
2259       : FixableGadget(Kind::UPCPreIncrement),
2260         Node(Result.getNodeAs<UnaryOperator>(UPCPreIncrementTag)) {
2261     assert(Node != nullptr && "Expecting a non-null matching result");
2262   }
2263 
classof(const Gadget * G)2264   static bool classof(const Gadget *G) {
2265     return G->getKind() == Kind::UPCPreIncrement;
2266   }
2267 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)2268   static bool matches(const Stmt *S,
2269                       llvm::SmallVectorImpl<MatchResult> &Results) {
2270     // Note here we match `++Ptr` for any expression `Ptr` of pointer type.
2271     // Although currently we can only provide fix-its when `Ptr` is a DRE, we
2272     // can have the matcher be general, so long as `getClaimedVarUseSites` does
2273     // things right.
2274     size_t SizeBefore = Results.size();
2275     findStmtsInUnspecifiedPointerContext(S, [&Results](const Stmt *S) {
2276       auto *E = dyn_cast<Expr>(S);
2277       if (!E)
2278         return;
2279       const auto *UO = dyn_cast<UnaryOperator>(E->IgnoreImpCasts());
2280       if (!UO || UO->getOpcode() != UO_PreInc)
2281         return;
2282       const auto *DRE = dyn_cast<DeclRefExpr>(UO->getSubExpr());
2283       if (!DRE || !isSupportedVariable(*DRE))
2284         return;
2285       MatchResult R;
2286       R.addNode(UPCPreIncrementTag, DynTypedNode::create(*UO));
2287       Results.emplace_back(std::move(R));
2288     });
2289     return SizeBefore != Results.size();
2290   }
2291 
2292   virtual std::optional<FixItList>
2293   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const2294   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
2295 
getClaimedVarUseSites() const2296   virtual DeclUseList getClaimedVarUseSites() const override {
2297     return {dyn_cast<DeclRefExpr>(Node->getSubExpr())};
2298   }
2299 };
2300 
2301 // Representing a pointer type expression of the form `Ptr += n` in an
2302 // Unspecified Untyped Context (UUC):
2303 class UUCAddAssignGadget : public FixableGadget {
2304 private:
2305   static constexpr const char *const UUCAddAssignTag =
2306       "PointerAddAssignUnderUUC";
2307   static constexpr const char *const OffsetTag = "Offset";
2308 
2309   const BinaryOperator *Node; // the `Ptr += n` node
2310   const Expr *Offset = nullptr;
2311 
2312 public:
UUCAddAssignGadget(const MatchResult & Result)2313   UUCAddAssignGadget(const MatchResult &Result)
2314       : FixableGadget(Kind::UUCAddAssign),
2315         Node(Result.getNodeAs<BinaryOperator>(UUCAddAssignTag)),
2316         Offset(Result.getNodeAs<Expr>(OffsetTag)) {
2317     assert(Node != nullptr && "Expecting a non-null matching result");
2318   }
2319 
classof(const Gadget * G)2320   static bool classof(const Gadget *G) {
2321     return G->getKind() == Kind::UUCAddAssign;
2322   }
2323 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)2324   static bool matches(const Stmt *S,
2325                       llvm::SmallVectorImpl<MatchResult> &Results) {
2326     size_t SizeBefore = Results.size();
2327     findStmtsInUnspecifiedUntypedContext(S, [&Results](const Stmt *S) {
2328       const auto *E = dyn_cast<Expr>(S);
2329       if (!E)
2330         return;
2331       const auto *BO = dyn_cast<BinaryOperator>(E->IgnoreImpCasts());
2332       if (!BO || BO->getOpcode() != BO_AddAssign)
2333         return;
2334       const auto *DRE = dyn_cast<DeclRefExpr>(BO->getLHS());
2335       if (!DRE || !hasPointerType(*DRE) || !isSupportedVariable(*DRE))
2336         return;
2337       MatchResult R;
2338       R.addNode(UUCAddAssignTag, DynTypedNode::create(*BO));
2339       R.addNode(OffsetTag, DynTypedNode::create(*BO->getRHS()));
2340       Results.emplace_back(std::move(R));
2341     });
2342     return SizeBefore != Results.size();
2343   }
2344 
2345   virtual std::optional<FixItList>
2346   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const2347   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
2348 
getClaimedVarUseSites() const2349   virtual DeclUseList getClaimedVarUseSites() const override {
2350     return {dyn_cast<DeclRefExpr>(Node->getLHS())};
2351   }
2352 };
2353 
2354 // Representing a fixable expression of the form `*(ptr + 123)` or `*(123 +
2355 // ptr)`:
2356 class DerefSimplePtrArithFixableGadget : public FixableGadget {
2357   static constexpr const char *const BaseDeclRefExprTag = "BaseDRE";
2358   static constexpr const char *const DerefOpTag = "DerefOp";
2359   static constexpr const char *const AddOpTag = "AddOp";
2360   static constexpr const char *const OffsetTag = "Offset";
2361 
2362   const DeclRefExpr *BaseDeclRefExpr = nullptr;
2363   const UnaryOperator *DerefOp = nullptr;
2364   const BinaryOperator *AddOp = nullptr;
2365   const IntegerLiteral *Offset = nullptr;
2366 
2367 public:
DerefSimplePtrArithFixableGadget(const MatchResult & Result)2368   DerefSimplePtrArithFixableGadget(const MatchResult &Result)
2369       : FixableGadget(Kind::DerefSimplePtrArithFixable),
2370         BaseDeclRefExpr(Result.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)),
2371         DerefOp(Result.getNodeAs<UnaryOperator>(DerefOpTag)),
2372         AddOp(Result.getNodeAs<BinaryOperator>(AddOpTag)),
2373         Offset(Result.getNodeAs<IntegerLiteral>(OffsetTag)) {}
2374 
matches(const Stmt * S,llvm::SmallVectorImpl<MatchResult> & Results)2375   static bool matches(const Stmt *S,
2376                       llvm::SmallVectorImpl<MatchResult> &Results) {
2377     auto IsPtr = [](const Expr *E, MatchResult &R) {
2378       if (!E || !hasPointerType(*E))
2379         return false;
2380       const auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreImpCasts());
2381       if (!DRE || !isSupportedVariable(*DRE))
2382         return false;
2383       R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE));
2384       return true;
2385     };
2386     const auto IsPlusOverPtrAndInteger = [&IsPtr](const Expr *E,
2387                                                   MatchResult &R) {
2388       const auto *BO = dyn_cast<BinaryOperator>(E);
2389       if (!BO || BO->getOpcode() != BO_Add)
2390         return false;
2391 
2392       const auto *LHS = BO->getLHS();
2393       const auto *RHS = BO->getRHS();
2394       if (isa<IntegerLiteral>(RHS) && IsPtr(LHS, R)) {
2395         R.addNode(OffsetTag, DynTypedNode::create(*RHS));
2396         R.addNode(AddOpTag, DynTypedNode::create(*BO));
2397         return true;
2398       }
2399       if (isa<IntegerLiteral>(LHS) && IsPtr(RHS, R)) {
2400         R.addNode(OffsetTag, DynTypedNode::create(*LHS));
2401         R.addNode(AddOpTag, DynTypedNode::create(*BO));
2402         return true;
2403       }
2404       return false;
2405     };
2406     size_t SizeBefore = Results.size();
2407     const auto InnerMatcher = [&IsPlusOverPtrAndInteger,
2408                                &Results](const Expr *E) {
2409       const auto *UO = dyn_cast<UnaryOperator>(E);
2410       if (!UO || UO->getOpcode() != UO_Deref)
2411         return;
2412 
2413       const auto *Operand = UO->getSubExpr()->IgnoreParens();
2414       MatchResult R;
2415       if (IsPlusOverPtrAndInteger(Operand, R)) {
2416         R.addNode(DerefOpTag, DynTypedNode::create(*UO));
2417         Results.emplace_back(std::move(R));
2418       }
2419     };
2420     findStmtsInUnspecifiedLvalueContext(S, InnerMatcher);
2421     return SizeBefore != Results.size();
2422   }
2423 
2424   virtual std::optional<FixItList>
2425   getFixits(const FixitStrategy &s) const final;
getSourceLoc() const2426   SourceLocation getSourceLoc() const override {
2427     return DerefOp->getBeginLoc();
2428   }
2429 
getClaimedVarUseSites() const2430   virtual DeclUseList getClaimedVarUseSites() const final {
2431     return {BaseDeclRefExpr};
2432   }
2433 };
2434 
2435 class WarningGadgetMatcher : public FastMatcher {
2436 
2437 public:
WarningGadgetMatcher(WarningGadgetList & WarningGadgets)2438   WarningGadgetMatcher(WarningGadgetList &WarningGadgets)
2439       : WarningGadgets(WarningGadgets) {}
2440 
matches(const DynTypedNode & DynNode,ASTContext & Ctx,const UnsafeBufferUsageHandler & Handler)2441   bool matches(const DynTypedNode &DynNode, ASTContext &Ctx,
2442                const UnsafeBufferUsageHandler &Handler) override {
2443     const Stmt *S = DynNode.get<Stmt>();
2444     if (!S)
2445       return false;
2446 
2447     MatchResult Result;
2448 #define WARNING_GADGET(name)                                                   \
2449   if (name##Gadget::matches(S, Ctx, Result) &&                                 \
2450       notInSafeBufferOptOut(*S, &Handler)) {                                   \
2451     WarningGadgets.push_back(std::make_unique<name##Gadget>(Result));          \
2452     return true;                                                               \
2453   }
2454 #define WARNING_OPTIONAL_GADGET(name)                                          \
2455   if (name##Gadget::matches(S, Ctx, &Handler, Result) &&                       \
2456       notInSafeBufferOptOut(*S, &Handler)) {                                   \
2457     WarningGadgets.push_back(std::make_unique<name##Gadget>(Result));          \
2458     return true;                                                               \
2459   }
2460 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
2461     return false;
2462   }
2463 
2464 private:
2465   WarningGadgetList &WarningGadgets;
2466 };
2467 
2468 class FixableGadgetMatcher : public FastMatcher {
2469 
2470 public:
FixableGadgetMatcher(FixableGadgetList & FixableGadgets,DeclUseTracker & Tracker)2471   FixableGadgetMatcher(FixableGadgetList &FixableGadgets,
2472                        DeclUseTracker &Tracker)
2473       : FixableGadgets(FixableGadgets), Tracker(Tracker) {}
2474 
matches(const DynTypedNode & DynNode,ASTContext & Ctx,const UnsafeBufferUsageHandler & Handler)2475   bool matches(const DynTypedNode &DynNode, ASTContext &Ctx,
2476                const UnsafeBufferUsageHandler &Handler) override {
2477     bool matchFound = false;
2478     const Stmt *S = DynNode.get<Stmt>();
2479     if (!S) {
2480       return matchFound;
2481     }
2482 
2483     llvm::SmallVector<MatchResult> Results;
2484 #define FIXABLE_GADGET(name)                                                   \
2485   if (name##Gadget::matches(S, Results)) {                                     \
2486     for (const auto &R : Results) {                                            \
2487       FixableGadgets.push_back(std::make_unique<name##Gadget>(R));             \
2488       matchFound = true;                                                       \
2489     }                                                                          \
2490     Results = {};                                                              \
2491   }
2492 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
2493     // In parallel, match all DeclRefExprs so that to find out
2494     // whether there are any uncovered by gadgets.
2495     if (auto *DRE = findDeclRefExpr(S); DRE) {
2496       Tracker.discoverUse(DRE);
2497       matchFound = true;
2498     }
2499     // Also match DeclStmts because we'll need them when fixing
2500     // their underlying VarDecls that otherwise don't have
2501     // any backreferences to DeclStmts.
2502     if (auto *DS = findDeclStmt(S); DS) {
2503       Tracker.discoverDecl(DS);
2504       matchFound = true;
2505     }
2506     return matchFound;
2507   }
2508 
2509 private:
findDeclRefExpr(const Stmt * S)2510   const DeclRefExpr *findDeclRefExpr(const Stmt *S) {
2511     const auto *DRE = dyn_cast<DeclRefExpr>(S);
2512     if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE)))
2513       return nullptr;
2514     const Decl *D = DRE->getDecl();
2515     if (!D || (!isa<VarDecl>(D) && !isa<BindingDecl>(D)))
2516       return nullptr;
2517     return DRE;
2518   }
findDeclStmt(const Stmt * S)2519   const DeclStmt *findDeclStmt(const Stmt *S) {
2520     const auto *DS = dyn_cast<DeclStmt>(S);
2521     if (!DS)
2522       return nullptr;
2523     return DS;
2524   }
2525   FixableGadgetList &FixableGadgets;
2526   DeclUseTracker &Tracker;
2527 };
2528 
2529 // Scan the function and return a list of gadgets found with provided kits.
findGadgets(const Stmt * S,ASTContext & Ctx,const UnsafeBufferUsageHandler & Handler,bool EmitSuggestions,FixableGadgetList & FixableGadgets,WarningGadgetList & WarningGadgets,DeclUseTracker & Tracker)2530 static void findGadgets(const Stmt *S, ASTContext &Ctx,
2531                         const UnsafeBufferUsageHandler &Handler,
2532                         bool EmitSuggestions, FixableGadgetList &FixableGadgets,
2533                         WarningGadgetList &WarningGadgets,
2534                         DeclUseTracker &Tracker) {
2535   WarningGadgetMatcher WMatcher{WarningGadgets};
2536   forEachDescendantEvaluatedStmt(S, Ctx, Handler, WMatcher);
2537   if (EmitSuggestions) {
2538     FixableGadgetMatcher FMatcher{FixableGadgets, Tracker};
2539     forEachDescendantStmt(S, Ctx, Handler, FMatcher);
2540   }
2541 }
2542 
2543 // Compares AST nodes by source locations.
2544 template <typename NodeTy> struct CompareNode {
operator ()CompareNode2545   bool operator()(const NodeTy *N1, const NodeTy *N2) const {
2546     return N1->getBeginLoc().getRawEncoding() <
2547            N2->getBeginLoc().getRawEncoding();
2548   }
2549 };
2550 
findUnsafePointers(const FunctionDecl * FD)2551 std::set<const Expr *> clang::findUnsafePointers(const FunctionDecl *FD) {
2552   class MockReporter : public UnsafeBufferUsageHandler {
2553   public:
2554     MockReporter() {}
2555     void handleUnsafeOperation(const Stmt *, bool, ASTContext &) override {}
2556     void handleUnsafeLibcCall(const CallExpr *, unsigned, ASTContext &,
2557                               const Expr *UnsafeArg = nullptr) override {}
2558     void handleUnsafeOperationInContainer(const Stmt *, bool,
2559                                           ASTContext &) override {}
2560     void handleUnsafeVariableGroup(const VarDecl *,
2561                                    const VariableGroupsManager &, FixItList &&,
2562                                    const Decl *,
2563                                    const FixitStrategy &) override {}
2564     bool isSafeBufferOptOut(const SourceLocation &) const override {
2565       return false;
2566     }
2567     bool ignoreUnsafeBufferInContainer(const SourceLocation &) const override {
2568       return false;
2569     }
2570     bool ignoreUnsafeBufferInLibcCall(const SourceLocation &) const override {
2571       return false;
2572     }
2573     std::string getUnsafeBufferUsageAttributeTextAt(
2574         SourceLocation, StringRef WSSuffix = "") const override {
2575       return "";
2576     }
2577   };
2578 
2579   FixableGadgetList FixableGadgets;
2580   WarningGadgetList WarningGadgets;
2581   DeclUseTracker Tracker;
2582   MockReporter IgnoreHandler;
2583 
2584   findGadgets(FD->getBody(), FD->getASTContext(), IgnoreHandler, false,
2585               FixableGadgets, WarningGadgets, Tracker);
2586 
2587   std::set<const Expr *> Result;
2588   for (auto &G : WarningGadgets) {
2589     for (const Expr *E : G->getUnsafePtrs()) {
2590       Result.insert(E);
2591     }
2592   }
2593 
2594   return Result;
2595 }
2596 
2597 struct WarningGadgetSets {
2598   std::map<const VarDecl *, std::set<const WarningGadget *>,
2599            // To keep keys sorted by their locations in the map so that the
2600            // order is deterministic:
2601            CompareNode<VarDecl>>
2602       byVar;
2603   // These Gadgets are not related to pointer variables (e. g. temporaries).
2604   llvm::SmallVector<const WarningGadget *, 16> noVar;
2605 };
2606 
2607 static WarningGadgetSets
groupWarningGadgetsByVar(const WarningGadgetList & AllUnsafeOperations)2608 groupWarningGadgetsByVar(const WarningGadgetList &AllUnsafeOperations) {
2609   WarningGadgetSets result;
2610   // If some gadgets cover more than one
2611   // variable, they'll appear more than once in the map.
2612   for (auto &G : AllUnsafeOperations) {
2613     DeclUseList ClaimedVarUseSites = G->getClaimedVarUseSites();
2614 
2615     bool AssociatedWithVarDecl = false;
2616     for (const DeclRefExpr *DRE : ClaimedVarUseSites) {
2617       if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
2618         result.byVar[VD].insert(G.get());
2619         AssociatedWithVarDecl = true;
2620       }
2621     }
2622 
2623     if (!AssociatedWithVarDecl) {
2624       result.noVar.push_back(G.get());
2625       continue;
2626     }
2627   }
2628   return result;
2629 }
2630 
2631 struct FixableGadgetSets {
2632   std::map<const VarDecl *, std::set<const FixableGadget *>,
2633            // To keep keys sorted by their locations in the map so that the
2634            // order is deterministic:
2635            CompareNode<VarDecl>>
2636       byVar;
2637 };
2638 
2639 static FixableGadgetSets
groupFixablesByVar(FixableGadgetList && AllFixableOperations)2640 groupFixablesByVar(FixableGadgetList &&AllFixableOperations) {
2641   FixableGadgetSets FixablesForUnsafeVars;
2642   for (auto &F : AllFixableOperations) {
2643     DeclUseList DREs = F->getClaimedVarUseSites();
2644 
2645     for (const DeclRefExpr *DRE : DREs) {
2646       if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
2647         FixablesForUnsafeVars.byVar[VD].insert(F.get());
2648       }
2649     }
2650   }
2651   return FixablesForUnsafeVars;
2652 }
2653 
anyConflict(const SmallVectorImpl<FixItHint> & FixIts,const SourceManager & SM)2654 bool clang::internal::anyConflict(const SmallVectorImpl<FixItHint> &FixIts,
2655                                   const SourceManager &SM) {
2656   // A simple interval overlap detection algorithm.  Sorts all ranges by their
2657   // begin location then finds the first overlap in one pass.
2658   std::vector<const FixItHint *> All; // a copy of `FixIts`
2659 
2660   for (const FixItHint &H : FixIts)
2661     All.push_back(&H);
2662   std::sort(All.begin(), All.end(),
2663             [&SM](const FixItHint *H1, const FixItHint *H2) {
2664               return SM.isBeforeInTranslationUnit(H1->RemoveRange.getBegin(),
2665                                                   H2->RemoveRange.getBegin());
2666             });
2667 
2668   const FixItHint *CurrHint = nullptr;
2669 
2670   for (const FixItHint *Hint : All) {
2671     if (!CurrHint ||
2672         SM.isBeforeInTranslationUnit(CurrHint->RemoveRange.getEnd(),
2673                                      Hint->RemoveRange.getBegin())) {
2674       // Either to initialize `CurrHint` or `CurrHint` does not
2675       // overlap with `Hint`:
2676       CurrHint = Hint;
2677     } else
2678       // In case `Hint` overlaps the `CurrHint`, we found at least one
2679       // conflict:
2680       return true;
2681   }
2682   return false;
2683 }
2684 
2685 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2686 PtrToPtrAssignmentGadget::getFixits(const FixitStrategy &S) const {
2687   const auto *LeftVD = cast<VarDecl>(PtrLHS->getDecl());
2688   const auto *RightVD = cast<VarDecl>(PtrRHS->getDecl());
2689   switch (S.lookup(LeftVD)) {
2690   case FixitStrategy::Kind::Span:
2691     if (S.lookup(RightVD) == FixitStrategy::Kind::Span)
2692       return FixItList{};
2693     return std::nullopt;
2694   case FixitStrategy::Kind::Wontfix:
2695     return std::nullopt;
2696   case FixitStrategy::Kind::Iterator:
2697   case FixitStrategy::Kind::Array:
2698     return std::nullopt;
2699   case FixitStrategy::Kind::Vector:
2700     llvm_unreachable("unsupported strategies for FixableGadgets");
2701   }
2702   return std::nullopt;
2703 }
2704 
2705 /// \returns fixit that adds .data() call after \DRE.
2706 static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx,
2707                                                        const DeclRefExpr *DRE);
2708 
2709 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2710 CArrayToPtrAssignmentGadget::getFixits(const FixitStrategy &S) const {
2711   const auto *LeftVD = cast<VarDecl>(PtrLHS->getDecl());
2712   const auto *RightVD = cast<VarDecl>(PtrRHS->getDecl());
2713   // TLDR: Implementing fixits for non-Wontfix strategy on both LHS and RHS is
2714   // non-trivial.
2715   //
2716   // CArrayToPtrAssignmentGadget doesn't have strategy implications because
2717   // constant size array propagates its bounds. Because of that LHS and RHS are
2718   // addressed by two different fixits.
2719   //
2720   // At the same time FixitStrategy S doesn't reflect what group a fixit belongs
2721   // to and can't be generally relied on in multi-variable Fixables!
2722   //
2723   // E. g. If an instance of this gadget is fixing variable on LHS then the
2724   // variable on RHS is fixed by a different fixit and its strategy for LHS
2725   // fixit is as if Wontfix.
2726   //
2727   // The only exception is Wontfix strategy for a given variable as that is
2728   // valid for any fixit produced for the given input source code.
2729   if (S.lookup(LeftVD) == FixitStrategy::Kind::Span) {
2730     if (S.lookup(RightVD) == FixitStrategy::Kind::Wontfix) {
2731       return FixItList{};
2732     }
2733   } else if (S.lookup(LeftVD) == FixitStrategy::Kind::Wontfix) {
2734     if (S.lookup(RightVD) == FixitStrategy::Kind::Array) {
2735       return createDataFixit(RightVD->getASTContext(), PtrRHS);
2736     }
2737   }
2738   return std::nullopt;
2739 }
2740 
2741 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2742 PointerInitGadget::getFixits(const FixitStrategy &S) const {
2743   const auto *LeftVD = PtrInitLHS;
2744   const auto *RightVD = cast<VarDecl>(PtrInitRHS->getDecl());
2745   switch (S.lookup(LeftVD)) {
2746   case FixitStrategy::Kind::Span:
2747     if (S.lookup(RightVD) == FixitStrategy::Kind::Span)
2748       return FixItList{};
2749     return std::nullopt;
2750   case FixitStrategy::Kind::Wontfix:
2751     return std::nullopt;
2752   case FixitStrategy::Kind::Iterator:
2753   case FixitStrategy::Kind::Array:
2754     return std::nullopt;
2755   case FixitStrategy::Kind::Vector:
2756     llvm_unreachable("unsupported strategies for FixableGadgets");
2757   }
2758   return std::nullopt;
2759 }
2760 
isNonNegativeIntegerExpr(const Expr * Expr,const VarDecl * VD,const ASTContext & Ctx)2761 static bool isNonNegativeIntegerExpr(const Expr *Expr, const VarDecl *VD,
2762                                      const ASTContext &Ctx) {
2763   if (auto ConstVal = Expr->getIntegerConstantExpr(Ctx)) {
2764     if (ConstVal->isNegative())
2765       return false;
2766   } else if (!Expr->getType()->isUnsignedIntegerType())
2767     return false;
2768   return true;
2769 }
2770 
2771 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2772 ULCArraySubscriptGadget::getFixits(const FixitStrategy &S) const {
2773   if (const auto *DRE =
2774           dyn_cast<DeclRefExpr>(Node->getBase()->IgnoreImpCasts()))
2775     if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
2776       switch (S.lookup(VD)) {
2777       case FixitStrategy::Kind::Span: {
2778 
2779         // If the index has a negative constant value, we give up as no valid
2780         // fix-it can be generated:
2781         const ASTContext &Ctx = // FIXME: we need ASTContext to be passed in!
2782             VD->getASTContext();
2783         if (!isNonNegativeIntegerExpr(Node->getIdx(), VD, Ctx))
2784           return std::nullopt;
2785         // no-op is a good fix-it, otherwise
2786         return FixItList{};
2787       }
2788       case FixitStrategy::Kind::Array:
2789         return FixItList{};
2790       case FixitStrategy::Kind::Wontfix:
2791       case FixitStrategy::Kind::Iterator:
2792       case FixitStrategy::Kind::Vector:
2793         llvm_unreachable("unsupported strategies for FixableGadgets");
2794       }
2795     }
2796   return std::nullopt;
2797 }
2798 
2799 static std::optional<FixItList> // forward declaration
2800 fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator *Node);
2801 
2802 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2803 UPCAddressofArraySubscriptGadget::getFixits(const FixitStrategy &S) const {
2804   auto DREs = getClaimedVarUseSites();
2805   const auto *VD = cast<VarDecl>(DREs.front()->getDecl());
2806 
2807   switch (S.lookup(VD)) {
2808   case FixitStrategy::Kind::Span:
2809     return fixUPCAddressofArraySubscriptWithSpan(Node);
2810   case FixitStrategy::Kind::Wontfix:
2811   case FixitStrategy::Kind::Iterator:
2812   case FixitStrategy::Kind::Array:
2813     return std::nullopt;
2814   case FixitStrategy::Kind::Vector:
2815     llvm_unreachable("unsupported strategies for FixableGadgets");
2816   }
2817   return std::nullopt; // something went wrong, no fix-it
2818 }
2819 
2820 // FIXME: this function should be customizable through format
getEndOfLine()2821 static StringRef getEndOfLine() {
2822   static const char *const EOL = "\n";
2823   return EOL;
2824 }
2825 
2826 // Returns the text indicating that the user needs to provide input there:
2827 static std::string
getUserFillPlaceHolder(StringRef HintTextToUser="placeholder")2828 getUserFillPlaceHolder(StringRef HintTextToUser = "placeholder") {
2829   std::string s = std::string("<# ");
2830   s += HintTextToUser;
2831   s += " #>";
2832   return s;
2833 }
2834 
2835 // Return the source location of the last character of the AST `Node`.
2836 template <typename NodeTy>
2837 static std::optional<SourceLocation>
getEndCharLoc(const NodeTy * Node,const SourceManager & SM,const LangOptions & LangOpts)2838 getEndCharLoc(const NodeTy *Node, const SourceManager &SM,
2839               const LangOptions &LangOpts) {
2840   if (unsigned TkLen =
2841           Lexer::MeasureTokenLength(Node->getEndLoc(), SM, LangOpts)) {
2842     SourceLocation Loc = Node->getEndLoc().getLocWithOffset(TkLen - 1);
2843 
2844     if (Loc.isValid())
2845       return Loc;
2846   }
2847   return std::nullopt;
2848 }
2849 
2850 // We cannot fix a variable declaration if it has some other specifiers than the
2851 // type specifier.  Because the source ranges of those specifiers could overlap
2852 // with the source range that is being replaced using fix-its.  Especially when
2853 // we often cannot obtain accurate source ranges of cv-qualified type
2854 // specifiers.
2855 // FIXME: also deal with type attributes
hasUnsupportedSpecifiers(const VarDecl * VD,const SourceManager & SM)2856 static bool hasUnsupportedSpecifiers(const VarDecl *VD,
2857                                      const SourceManager &SM) {
2858   // AttrRangeOverlapping: true if at least one attribute of `VD` overlaps the
2859   // source range of `VD`:
2860   bool AttrRangeOverlapping = llvm::any_of(VD->attrs(), [&](Attr *At) -> bool {
2861     return !(SM.isBeforeInTranslationUnit(At->getRange().getEnd(),
2862                                           VD->getBeginLoc())) &&
2863            !(SM.isBeforeInTranslationUnit(VD->getEndLoc(),
2864                                           At->getRange().getBegin()));
2865   });
2866   return VD->isInlineSpecified() || VD->isConstexpr() ||
2867          VD->hasConstantInitialization() || !VD->hasLocalStorage() ||
2868          AttrRangeOverlapping;
2869 }
2870 
2871 // Returns the `SourceRange` of `D`.  The reason why this function exists is
2872 // that `D->getSourceRange()` may return a range where the end location is the
2873 // starting location of the last token.  The end location of the source range
2874 // returned by this function is the last location of the last token.
getSourceRangeToTokenEnd(const Decl * D,const SourceManager & SM,const LangOptions & LangOpts)2875 static SourceRange getSourceRangeToTokenEnd(const Decl *D,
2876                                             const SourceManager &SM,
2877                                             const LangOptions &LangOpts) {
2878   SourceLocation Begin = D->getBeginLoc();
2879   SourceLocation
2880       End = // `D->getEndLoc` should always return the starting location of the
2881       // last token, so we should get the end of the token
2882       Lexer::getLocForEndOfToken(D->getEndLoc(), 0, SM, LangOpts);
2883 
2884   return SourceRange(Begin, End);
2885 }
2886 
2887 // Returns the text of the name (with qualifiers) of a `FunctionDecl`.
getFunNameText(const FunctionDecl * FD,const SourceManager & SM,const LangOptions & LangOpts)2888 static std::optional<StringRef> getFunNameText(const FunctionDecl *FD,
2889                                                const SourceManager &SM,
2890                                                const LangOptions &LangOpts) {
2891   SourceLocation BeginLoc = FD->getQualifier()
2892                                 ? FD->getQualifierLoc().getBeginLoc()
2893                                 : FD->getNameInfo().getBeginLoc();
2894   // Note that `FD->getNameInfo().getEndLoc()` returns the begin location of the
2895   // last token:
2896   SourceLocation EndLoc = Lexer::getLocForEndOfToken(
2897       FD->getNameInfo().getEndLoc(), 0, SM, LangOpts);
2898   SourceRange NameRange{BeginLoc, EndLoc};
2899 
2900   return getRangeText(NameRange, SM, LangOpts);
2901 }
2902 
2903 // Returns the text representing a `std::span` type where the element type is
2904 // represented by `EltTyText`.
2905 //
2906 // Note the optional parameter `Qualifiers`: one needs to pass qualifiers
2907 // explicitly if the element type needs to be qualified.
2908 static std::string
getSpanTypeText(StringRef EltTyText,std::optional<Qualifiers> Quals=std::nullopt)2909 getSpanTypeText(StringRef EltTyText,
2910                 std::optional<Qualifiers> Quals = std::nullopt) {
2911   const char *const SpanOpen = "std::span<";
2912 
2913   if (Quals)
2914     return SpanOpen + EltTyText.str() + ' ' + Quals->getAsString() + '>';
2915   return SpanOpen + EltTyText.str() + '>';
2916 }
2917 
2918 std::optional<FixItList>
getFixits(const FixitStrategy & s) const2919 DerefSimplePtrArithFixableGadget::getFixits(const FixitStrategy &s) const {
2920   const VarDecl *VD = dyn_cast<VarDecl>(BaseDeclRefExpr->getDecl());
2921 
2922   if (VD && s.lookup(VD) == FixitStrategy::Kind::Span) {
2923     ASTContext &Ctx = VD->getASTContext();
2924     // std::span can't represent elements before its begin()
2925     if (auto ConstVal = Offset->getIntegerConstantExpr(Ctx))
2926       if (ConstVal->isNegative())
2927         return std::nullopt;
2928 
2929     // note that the expr may (oddly) has multiple layers of parens
2930     // example:
2931     //   *((..(pointer + 123)..))
2932     // goal:
2933     //   pointer[123]
2934     // Fix-It:
2935     //   remove '*('
2936     //   replace ' + ' with '['
2937     //   replace ')' with ']'
2938 
2939     // example:
2940     //   *((..(123 + pointer)..))
2941     // goal:
2942     //   123[pointer]
2943     // Fix-It:
2944     //   remove '*('
2945     //   replace ' + ' with '['
2946     //   replace ')' with ']'
2947 
2948     const Expr *LHS = AddOp->getLHS(), *RHS = AddOp->getRHS();
2949     const SourceManager &SM = Ctx.getSourceManager();
2950     const LangOptions &LangOpts = Ctx.getLangOpts();
2951     CharSourceRange StarWithTrailWhitespace =
2952         clang::CharSourceRange::getCharRange(DerefOp->getOperatorLoc(),
2953                                              LHS->getBeginLoc());
2954 
2955     std::optional<SourceLocation> LHSLocation = getPastLoc(LHS, SM, LangOpts);
2956     if (!LHSLocation)
2957       return std::nullopt;
2958 
2959     CharSourceRange PlusWithSurroundingWhitespace =
2960         clang::CharSourceRange::getCharRange(*LHSLocation, RHS->getBeginLoc());
2961 
2962     std::optional<SourceLocation> AddOpLocation =
2963         getPastLoc(AddOp, SM, LangOpts);
2964     std::optional<SourceLocation> DerefOpLocation =
2965         getPastLoc(DerefOp, SM, LangOpts);
2966 
2967     if (!AddOpLocation || !DerefOpLocation)
2968       return std::nullopt;
2969 
2970     CharSourceRange ClosingParenWithPrecWhitespace =
2971         clang::CharSourceRange::getCharRange(*AddOpLocation, *DerefOpLocation);
2972 
2973     return FixItList{
2974         {FixItHint::CreateRemoval(StarWithTrailWhitespace),
2975          FixItHint::CreateReplacement(PlusWithSurroundingWhitespace, "["),
2976          FixItHint::CreateReplacement(ClosingParenWithPrecWhitespace, "]")}};
2977   }
2978   return std::nullopt; // something wrong or unsupported, give up
2979 }
2980 
2981 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2982 PointerDereferenceGadget::getFixits(const FixitStrategy &S) const {
2983   const VarDecl *VD = cast<VarDecl>(BaseDeclRefExpr->getDecl());
2984   switch (S.lookup(VD)) {
2985   case FixitStrategy::Kind::Span: {
2986     ASTContext &Ctx = VD->getASTContext();
2987     SourceManager &SM = Ctx.getSourceManager();
2988     // Required changes: *(ptr); => (ptr[0]); and *ptr; => ptr[0]
2989     // Deletes the *operand
2990     CharSourceRange derefRange = clang::CharSourceRange::getCharRange(
2991         Op->getBeginLoc(), Op->getBeginLoc().getLocWithOffset(1));
2992     // Inserts the [0]
2993     if (auto LocPastOperand =
2994             getPastLoc(BaseDeclRefExpr, SM, Ctx.getLangOpts())) {
2995       return FixItList{{FixItHint::CreateRemoval(derefRange),
2996                         FixItHint::CreateInsertion(*LocPastOperand, "[0]")}};
2997     }
2998     break;
2999   }
3000   case FixitStrategy::Kind::Iterator:
3001   case FixitStrategy::Kind::Array:
3002     return std::nullopt;
3003   case FixitStrategy::Kind::Vector:
3004     llvm_unreachable("FixitStrategy not implemented yet!");
3005   case FixitStrategy::Kind::Wontfix:
3006     llvm_unreachable("Invalid strategy!");
3007   }
3008 
3009   return std::nullopt;
3010 }
3011 
createDataFixit(const ASTContext & Ctx,const DeclRefExpr * DRE)3012 static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx,
3013                                                        const DeclRefExpr *DRE) {
3014   const SourceManager &SM = Ctx.getSourceManager();
3015   // Inserts the .data() after the DRE
3016   std::optional<SourceLocation> EndOfOperand =
3017       getPastLoc(DRE, SM, Ctx.getLangOpts());
3018 
3019   if (EndOfOperand)
3020     return FixItList{{FixItHint::CreateInsertion(*EndOfOperand, ".data()")}};
3021 
3022   return std::nullopt;
3023 }
3024 
3025 // Generates fix-its replacing an expression of the form UPC(DRE) with
3026 // `DRE.data()`
3027 std::optional<FixItList>
getFixits(const FixitStrategy & S) const3028 UPCStandalonePointerGadget::getFixits(const FixitStrategy &S) const {
3029   const auto VD = cast<VarDecl>(Node->getDecl());
3030   switch (S.lookup(VD)) {
3031   case FixitStrategy::Kind::Array:
3032   case FixitStrategy::Kind::Span: {
3033     return createDataFixit(VD->getASTContext(), Node);
3034     // FIXME: Points inside a macro expansion.
3035     break;
3036   }
3037   case FixitStrategy::Kind::Wontfix:
3038   case FixitStrategy::Kind::Iterator:
3039     return std::nullopt;
3040   case FixitStrategy::Kind::Vector:
3041     llvm_unreachable("unsupported strategies for FixableGadgets");
3042   }
3043 
3044   return std::nullopt;
3045 }
3046 
3047 // Generates fix-its replacing an expression of the form `&DRE[e]` with
3048 // `&DRE.data()[e]`:
3049 static std::optional<FixItList>
fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator * Node)3050 fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator *Node) {
3051   const auto *ArraySub = cast<ArraySubscriptExpr>(Node->getSubExpr());
3052   const auto *DRE = cast<DeclRefExpr>(ArraySub->getBase()->IgnoreImpCasts());
3053   // FIXME: this `getASTContext` call is costly, we should pass the
3054   // ASTContext in:
3055   const ASTContext &Ctx = DRE->getDecl()->getASTContext();
3056   const Expr *Idx = ArraySub->getIdx();
3057   const SourceManager &SM = Ctx.getSourceManager();
3058   const LangOptions &LangOpts = Ctx.getLangOpts();
3059   std::stringstream SS;
3060   bool IdxIsLitZero = false;
3061 
3062   if (auto ICE = Idx->getIntegerConstantExpr(Ctx))
3063     if ((*ICE).isZero())
3064       IdxIsLitZero = true;
3065   std::optional<StringRef> DreString = getExprText(DRE, SM, LangOpts);
3066   if (!DreString)
3067     return std::nullopt;
3068 
3069   if (IdxIsLitZero) {
3070     // If the index is literal zero, we produce the most concise fix-it:
3071     SS << (*DreString).str() << ".data()";
3072   } else {
3073     std::optional<StringRef> IndexString = getExprText(Idx, SM, LangOpts);
3074     if (!IndexString)
3075       return std::nullopt;
3076 
3077     SS << "&" << (*DreString).str() << ".data()"
3078        << "[" << (*IndexString).str() << "]";
3079   }
3080   return FixItList{
3081       FixItHint::CreateReplacement(Node->getSourceRange(), SS.str())};
3082 }
3083 
3084 std::optional<FixItList>
getFixits(const FixitStrategy & S) const3085 UUCAddAssignGadget::getFixits(const FixitStrategy &S) const {
3086   DeclUseList DREs = getClaimedVarUseSites();
3087 
3088   if (DREs.size() != 1)
3089     return std::nullopt; // In cases of `Ptr += n` where `Ptr` is not a DRE, we
3090                          // give up
3091   if (const VarDecl *VD = dyn_cast<VarDecl>(DREs.front()->getDecl())) {
3092     if (S.lookup(VD) == FixitStrategy::Kind::Span) {
3093       FixItList Fixes;
3094 
3095       const Stmt *AddAssignNode = Node;
3096       StringRef varName = VD->getName();
3097       const ASTContext &Ctx = VD->getASTContext();
3098 
3099       if (!isNonNegativeIntegerExpr(Offset, VD, Ctx))
3100         return std::nullopt;
3101 
3102       // To transform UUC(p += n) to UUC(p = p.subspan(..)):
3103       bool NotParenExpr =
3104           (Offset->IgnoreParens()->getBeginLoc() == Offset->getBeginLoc());
3105       std::string SS = varName.str() + " = " + varName.str() + ".subspan";
3106       if (NotParenExpr)
3107         SS += "(";
3108 
3109       std::optional<SourceLocation> AddAssignLocation = getEndCharLoc(
3110           AddAssignNode, Ctx.getSourceManager(), Ctx.getLangOpts());
3111       if (!AddAssignLocation)
3112         return std::nullopt;
3113 
3114       Fixes.push_back(FixItHint::CreateReplacement(
3115           SourceRange(AddAssignNode->getBeginLoc(), Node->getOperatorLoc()),
3116           SS));
3117       if (NotParenExpr)
3118         Fixes.push_back(FixItHint::CreateInsertion(
3119             Offset->getEndLoc().getLocWithOffset(1), ")"));
3120       return Fixes;
3121     }
3122   }
3123   return std::nullopt; // Not in the cases that we can handle for now, give up.
3124 }
3125 
3126 std::optional<FixItList>
getFixits(const FixitStrategy & S) const3127 UPCPreIncrementGadget::getFixits(const FixitStrategy &S) const {
3128   DeclUseList DREs = getClaimedVarUseSites();
3129 
3130   if (DREs.size() != 1)
3131     return std::nullopt; // In cases of `++Ptr` where `Ptr` is not a DRE, we
3132                          // give up
3133   if (const VarDecl *VD = dyn_cast<VarDecl>(DREs.front()->getDecl())) {
3134     if (S.lookup(VD) == FixitStrategy::Kind::Span) {
3135       FixItList Fixes;
3136       std::stringstream SS;
3137       StringRef varName = VD->getName();
3138       const ASTContext &Ctx = VD->getASTContext();
3139 
3140       // To transform UPC(++p) to UPC((p = p.subspan(1)).data()):
3141       SS << "(" << varName.data() << " = " << varName.data()
3142          << ".subspan(1)).data()";
3143       std::optional<SourceLocation> PreIncLocation =
3144           getEndCharLoc(Node, Ctx.getSourceManager(), Ctx.getLangOpts());
3145       if (!PreIncLocation)
3146         return std::nullopt;
3147 
3148       Fixes.push_back(FixItHint::CreateReplacement(
3149           SourceRange(Node->getBeginLoc(), *PreIncLocation), SS.str()));
3150       return Fixes;
3151     }
3152   }
3153   return std::nullopt; // Not in the cases that we can handle for now, give up.
3154 }
3155 
3156 // For a non-null initializer `Init` of `T *` type, this function returns
3157 // `FixItHint`s producing a list initializer `{Init,  S}` as a part of a fix-it
3158 // to output stream.
3159 // In many cases, this function cannot figure out the actual extent `S`.  It
3160 // then will use a place holder to replace `S` to ask users to fill `S` in.  The
3161 // initializer shall be used to initialize a variable of type `std::span<T>`.
3162 // In some cases (e. g. constant size array) the initializer should remain
3163 // unchanged and the function returns empty list. In case the function can't
3164 // provide the right fixit it will return nullopt.
3165 //
3166 // FIXME: Support multi-level pointers
3167 //
3168 // Parameters:
3169 //   `Init` a pointer to the initializer expression
3170 //   `Ctx` a reference to the ASTContext
3171 static std::optional<FixItList>
FixVarInitializerWithSpan(const Expr * Init,ASTContext & Ctx,const StringRef UserFillPlaceHolder)3172 FixVarInitializerWithSpan(const Expr *Init, ASTContext &Ctx,
3173                           const StringRef UserFillPlaceHolder) {
3174   const SourceManager &SM = Ctx.getSourceManager();
3175   const LangOptions &LangOpts = Ctx.getLangOpts();
3176 
3177   // If `Init` has a constant value that is (or equivalent to) a
3178   // NULL pointer, we use the default constructor to initialize the span
3179   // object, i.e., a `std:span` variable declaration with no initializer.
3180   // So the fix-it is just to remove the initializer.
3181   if (Init->isNullPointerConstant(
3182           Ctx,
3183           // FIXME: Why does this function not ask for `const ASTContext
3184           // &`? It should. Maybe worth an NFC patch later.
3185           Expr::NullPointerConstantValueDependence::
3186               NPC_ValueDependentIsNotNull)) {
3187     std::optional<SourceLocation> InitLocation =
3188         getEndCharLoc(Init, SM, LangOpts);
3189     if (!InitLocation)
3190       return std::nullopt;
3191 
3192     SourceRange SR(Init->getBeginLoc(), *InitLocation);
3193 
3194     return FixItList{FixItHint::CreateRemoval(SR)};
3195   }
3196 
3197   FixItList FixIts{};
3198   std::string ExtentText = UserFillPlaceHolder.data();
3199   StringRef One = "1";
3200 
3201   // Insert `{` before `Init`:
3202   FixIts.push_back(FixItHint::CreateInsertion(Init->getBeginLoc(), "{"));
3203   // Try to get the data extent. Break into different cases:
3204   if (auto CxxNew = dyn_cast<CXXNewExpr>(Init->IgnoreImpCasts())) {
3205     // In cases `Init` is `new T[n]` and there is no explicit cast over
3206     // `Init`, we know that `Init` must evaluates to a pointer to `n` objects
3207     // of `T`. So the extent is `n` unless `n` has side effects.  Similar but
3208     // simpler for the case where `Init` is `new T`.
3209     if (const Expr *Ext = CxxNew->getArraySize().value_or(nullptr)) {
3210       if (!Ext->HasSideEffects(Ctx)) {
3211         std::optional<StringRef> ExtentString = getExprText(Ext, SM, LangOpts);
3212         if (!ExtentString)
3213           return std::nullopt;
3214         ExtentText = *ExtentString;
3215       }
3216     } else if (!CxxNew->isArray())
3217       // Although the initializer is not allocating a buffer, the pointer
3218       // variable could still be used in buffer access operations.
3219       ExtentText = One;
3220   } else if (Ctx.getAsConstantArrayType(Init->IgnoreImpCasts()->getType())) {
3221     // std::span has a single parameter constructor for initialization with
3222     // constant size array. The size is auto-deduced as the constructor is a
3223     // function template. The correct fixit is empty - no changes should happen.
3224     return FixItList{};
3225   } else {
3226     // In cases `Init` is of the form `&Var` after stripping of implicit
3227     // casts, where `&` is the built-in operator, the extent is 1.
3228     if (auto AddrOfExpr = dyn_cast<UnaryOperator>(Init->IgnoreImpCasts()))
3229       if (AddrOfExpr->getOpcode() == UnaryOperatorKind::UO_AddrOf &&
3230           isa_and_present<DeclRefExpr>(AddrOfExpr->getSubExpr()))
3231         ExtentText = One;
3232     // TODO: we can handle more cases, e.g., `&a[0]`, `&a`, `std::addressof`,
3233     // and explicit casting, etc. etc.
3234   }
3235 
3236   SmallString<32> StrBuffer{};
3237   std::optional<SourceLocation> LocPassInit = getPastLoc(Init, SM, LangOpts);
3238 
3239   if (!LocPassInit)
3240     return std::nullopt;
3241 
3242   StrBuffer.append(", ");
3243   StrBuffer.append(ExtentText);
3244   StrBuffer.append("}");
3245   FixIts.push_back(FixItHint::CreateInsertion(*LocPassInit, StrBuffer.str()));
3246   return FixIts;
3247 }
3248 
3249 #ifndef NDEBUG
3250 #define DEBUG_NOTE_DECL_FAIL(D, Msg)                                           \
3251   Handler.addDebugNoteForVar((D), (D)->getBeginLoc(),                          \
3252                              "failed to produce fixit for declaration '" +     \
3253                                  (D)->getNameAsString() + "'" + (Msg))
3254 #else
3255 #define DEBUG_NOTE_DECL_FAIL(D, Msg)
3256 #endif
3257 
3258 // For the given variable declaration with a pointer-to-T type, returns the text
3259 // `std::span<T>`.  If it is unable to generate the text, returns
3260 // `std::nullopt`.
3261 static std::optional<std::string>
createSpanTypeForVarDecl(const VarDecl * VD,const ASTContext & Ctx)3262 createSpanTypeForVarDecl(const VarDecl *VD, const ASTContext &Ctx) {
3263   assert(VD->getType()->isPointerType());
3264 
3265   std::optional<Qualifiers> PteTyQualifiers = std::nullopt;
3266   std::optional<std::string> PteTyText = getPointeeTypeText(
3267       VD, Ctx.getSourceManager(), Ctx.getLangOpts(), &PteTyQualifiers);
3268 
3269   if (!PteTyText)
3270     return std::nullopt;
3271 
3272   std::string SpanTyText = "std::span<";
3273 
3274   SpanTyText.append(*PteTyText);
3275   // Append qualifiers to span element type if any:
3276   if (PteTyQualifiers) {
3277     SpanTyText.append(" ");
3278     SpanTyText.append(PteTyQualifiers->getAsString());
3279   }
3280   SpanTyText.append(">");
3281   return SpanTyText;
3282 }
3283 
3284 // For a `VarDecl` of the form `T  * var (= Init)?`, this
3285 // function generates fix-its that
3286 //  1) replace `T * var` with `std::span<T> var`; and
3287 //  2) change `Init` accordingly to a span constructor, if it exists.
3288 //
3289 // FIXME: support Multi-level pointers
3290 //
3291 // Parameters:
3292 //   `D` a pointer the variable declaration node
3293 //   `Ctx` a reference to the ASTContext
3294 //   `UserFillPlaceHolder` the user-input placeholder text
3295 // Returns:
3296 //    the non-empty fix-it list, if fix-its are successfuly generated; empty
3297 //    list otherwise.
fixLocalVarDeclWithSpan(const VarDecl * D,ASTContext & Ctx,const StringRef UserFillPlaceHolder,UnsafeBufferUsageHandler & Handler)3298 static FixItList fixLocalVarDeclWithSpan(const VarDecl *D, ASTContext &Ctx,
3299                                          const StringRef UserFillPlaceHolder,
3300                                          UnsafeBufferUsageHandler &Handler) {
3301   if (hasUnsupportedSpecifiers(D, Ctx.getSourceManager()))
3302     return {};
3303 
3304   FixItList FixIts{};
3305   std::optional<std::string> SpanTyText = createSpanTypeForVarDecl(D, Ctx);
3306 
3307   if (!SpanTyText) {
3308     DEBUG_NOTE_DECL_FAIL(D, " : failed to generate 'std::span' type");
3309     return {};
3310   }
3311 
3312   // Will hold the text for `std::span<T> Ident`:
3313   std::stringstream SS;
3314 
3315   SS << *SpanTyText;
3316   // Fix the initializer if it exists:
3317   if (const Expr *Init = D->getInit()) {
3318     std::optional<FixItList> InitFixIts =
3319         FixVarInitializerWithSpan(Init, Ctx, UserFillPlaceHolder);
3320     if (!InitFixIts)
3321       return {};
3322     FixIts.insert(FixIts.end(), std::make_move_iterator(InitFixIts->begin()),
3323                   std::make_move_iterator(InitFixIts->end()));
3324   }
3325   // For declaration of the form `T * ident = init;`, we want to replace
3326   // `T * ` with `std::span<T>`.
3327   // We ignore CV-qualifiers so for `T * const ident;` we also want to replace
3328   // just `T *` with `std::span<T>`.
3329   const SourceLocation EndLocForReplacement = D->getTypeSpecEndLoc();
3330   if (!EndLocForReplacement.isValid()) {
3331     DEBUG_NOTE_DECL_FAIL(D, " : failed to locate the end of the declaration");
3332     return {};
3333   }
3334   // The only exception is that for `T *ident` we'll add a single space between
3335   // "std::span<T>" and "ident".
3336   // FIXME: The condition is false for identifiers expended from macros.
3337   if (EndLocForReplacement.getLocWithOffset(1) == getVarDeclIdentifierLoc(D))
3338     SS << " ";
3339 
3340   FixIts.push_back(FixItHint::CreateReplacement(
3341       SourceRange(D->getBeginLoc(), EndLocForReplacement), SS.str()));
3342   return FixIts;
3343 }
3344 
hasConflictingOverload(const FunctionDecl * FD)3345 static bool hasConflictingOverload(const FunctionDecl *FD) {
3346   return !FD->getDeclContext()->lookup(FD->getDeclName()).isSingleResult();
3347 }
3348 
3349 // For a `FunctionDecl`, whose `ParmVarDecl`s are being changed to have new
3350 // types, this function produces fix-its to make the change self-contained.  Let
3351 // 'F' be the entity defined by the original `FunctionDecl` and "NewF" be the
3352 // entity defined by the `FunctionDecl` after the change to the parameters.
3353 // Fix-its produced by this function are
3354 //   1. Add the `[[clang::unsafe_buffer_usage]]` attribute to each declaration
3355 //   of 'F';
3356 //   2. Create a declaration of "NewF" next to each declaration of `F`;
3357 //   3. Create a definition of "F" (as its' original definition is now belongs
3358 //      to "NewF") next to its original definition.  The body of the creating
3359 //      definition calls to "NewF".
3360 //
3361 // Example:
3362 //
3363 // void f(int *p);  // original declaration
3364 // void f(int *p) { // original definition
3365 //    p[5];
3366 // }
3367 //
3368 // To change the parameter `p` to be of `std::span<int>` type, we
3369 // also add overloads:
3370 //
3371 // [[clang::unsafe_buffer_usage]] void f(int *p); // original decl
3372 // void f(std::span<int> p);                      // added overload decl
3373 // void f(std::span<int> p) {     // original def where param is changed
3374 //    p[5];
3375 // }
3376 // [[clang::unsafe_buffer_usage]] void f(int *p) {  // added def
3377 //   return f(std::span(p, <# size #>));
3378 // }
3379 //
3380 static std::optional<FixItList>
createOverloadsForFixedParams(const FixitStrategy & S,const FunctionDecl * FD,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)3381 createOverloadsForFixedParams(const FixitStrategy &S, const FunctionDecl *FD,
3382                               const ASTContext &Ctx,
3383                               UnsafeBufferUsageHandler &Handler) {
3384   // FIXME: need to make this conflict checking better:
3385   if (hasConflictingOverload(FD))
3386     return std::nullopt;
3387 
3388   const SourceManager &SM = Ctx.getSourceManager();
3389   const LangOptions &LangOpts = Ctx.getLangOpts();
3390   const unsigned NumParms = FD->getNumParams();
3391   std::vector<std::string> NewTysTexts(NumParms);
3392   std::vector<bool> ParmsMask(NumParms, false);
3393   bool AtLeastOneParmToFix = false;
3394 
3395   for (unsigned i = 0; i < NumParms; i++) {
3396     const ParmVarDecl *PVD = FD->getParamDecl(i);
3397 
3398     if (S.lookup(PVD) == FixitStrategy::Kind::Wontfix)
3399       continue;
3400     if (S.lookup(PVD) != FixitStrategy::Kind::Span)
3401       // Not supported, not suppose to happen:
3402       return std::nullopt;
3403 
3404     std::optional<Qualifiers> PteTyQuals = std::nullopt;
3405     std::optional<std::string> PteTyText =
3406         getPointeeTypeText(PVD, SM, LangOpts, &PteTyQuals);
3407 
3408     if (!PteTyText)
3409       // something wrong in obtaining the text of the pointee type, give up
3410       return std::nullopt;
3411     // FIXME: whether we should create std::span type depends on the
3412     // FixitStrategy.
3413     NewTysTexts[i] = getSpanTypeText(*PteTyText, PteTyQuals);
3414     ParmsMask[i] = true;
3415     AtLeastOneParmToFix = true;
3416   }
3417   if (!AtLeastOneParmToFix)
3418     // No need to create function overloads:
3419     return {};
3420   // FIXME Respect indentation of the original code.
3421 
3422   // A lambda that creates the text representation of a function declaration
3423   // with the new type signatures:
3424   const auto NewOverloadSignatureCreator =
3425       [&SM, &LangOpts, &NewTysTexts,
3426        &ParmsMask](const FunctionDecl *FD) -> std::optional<std::string> {
3427     std::stringstream SS;
3428 
3429     SS << ";";
3430     SS << getEndOfLine().str();
3431     // Append: ret-type func-name "("
3432     if (auto Prefix = getRangeText(
3433             SourceRange(FD->getBeginLoc(), (*FD->param_begin())->getBeginLoc()),
3434             SM, LangOpts))
3435       SS << Prefix->str();
3436     else
3437       return std::nullopt; // give up
3438     // Append: parameter-type-list
3439     const unsigned NumParms = FD->getNumParams();
3440 
3441     for (unsigned i = 0; i < NumParms; i++) {
3442       const ParmVarDecl *Parm = FD->getParamDecl(i);
3443 
3444       if (Parm->isImplicit())
3445         continue;
3446       if (ParmsMask[i]) {
3447         // This `i`-th parameter will be fixed with `NewTysTexts[i]` being its
3448         // new type:
3449         SS << NewTysTexts[i];
3450         // print parameter name if provided:
3451         if (IdentifierInfo *II = Parm->getIdentifier())
3452           SS << ' ' << II->getName().str();
3453       } else if (auto ParmTypeText =
3454                      getRangeText(getSourceRangeToTokenEnd(Parm, SM, LangOpts),
3455                                   SM, LangOpts)) {
3456         // print the whole `Parm` without modification:
3457         SS << ParmTypeText->str();
3458       } else
3459         return std::nullopt; // something wrong, give up
3460       if (i != NumParms - 1)
3461         SS << ", ";
3462     }
3463     SS << ")";
3464     return SS.str();
3465   };
3466 
3467   // A lambda that creates the text representation of a function definition with
3468   // the original signature:
3469   const auto OldOverloadDefCreator =
3470       [&Handler, &SM, &LangOpts, &NewTysTexts,
3471        &ParmsMask](const FunctionDecl *FD) -> std::optional<std::string> {
3472     std::stringstream SS;
3473 
3474     SS << getEndOfLine().str();
3475     // Append: attr-name ret-type func-name "(" param-list ")" "{"
3476     if (auto FDPrefix = getRangeText(
3477             SourceRange(FD->getBeginLoc(), FD->getBody()->getBeginLoc()), SM,
3478             LangOpts))
3479       SS << Handler.getUnsafeBufferUsageAttributeTextAt(FD->getBeginLoc(), " ")
3480          << FDPrefix->str() << "{";
3481     else
3482       return std::nullopt;
3483     // Append: "return" func-name "("
3484     if (auto FunQualName = getFunNameText(FD, SM, LangOpts))
3485       SS << "return " << FunQualName->str() << "(";
3486     else
3487       return std::nullopt;
3488 
3489     // Append: arg-list
3490     const unsigned NumParms = FD->getNumParams();
3491     for (unsigned i = 0; i < NumParms; i++) {
3492       const ParmVarDecl *Parm = FD->getParamDecl(i);
3493 
3494       if (Parm->isImplicit())
3495         continue;
3496       // FIXME: If a parameter has no name, it is unused in the
3497       // definition. So we could just leave it as it is.
3498       if (!Parm->getIdentifier())
3499         // If a parameter of a function definition has no name:
3500         return std::nullopt;
3501       if (ParmsMask[i])
3502         // This is our spanified paramter!
3503         SS << NewTysTexts[i] << "(" << Parm->getIdentifier()->getName().str()
3504            << ", " << getUserFillPlaceHolder("size") << ")";
3505       else
3506         SS << Parm->getIdentifier()->getName().str();
3507       if (i != NumParms - 1)
3508         SS << ", ";
3509     }
3510     // finish call and the body
3511     SS << ");}" << getEndOfLine().str();
3512     // FIXME: 80-char line formatting?
3513     return SS.str();
3514   };
3515 
3516   FixItList FixIts{};
3517   for (FunctionDecl *FReDecl : FD->redecls()) {
3518     std::optional<SourceLocation> Loc = getPastLoc(FReDecl, SM, LangOpts);
3519 
3520     if (!Loc)
3521       return {};
3522     if (FReDecl->isThisDeclarationADefinition()) {
3523       assert(FReDecl == FD && "inconsistent function definition");
3524       // Inserts a definition with the old signature to the end of
3525       // `FReDecl`:
3526       if (auto OldOverloadDef = OldOverloadDefCreator(FReDecl))
3527         FixIts.emplace_back(FixItHint::CreateInsertion(*Loc, *OldOverloadDef));
3528       else
3529         return {}; // give up
3530     } else {
3531       // Adds the unsafe-buffer attribute (if not already there) to `FReDecl`:
3532       if (!FReDecl->hasAttr<UnsafeBufferUsageAttr>()) {
3533         FixIts.emplace_back(FixItHint::CreateInsertion(
3534             FReDecl->getBeginLoc(), Handler.getUnsafeBufferUsageAttributeTextAt(
3535                                         FReDecl->getBeginLoc(), " ")));
3536       }
3537       // Inserts a declaration with the new signature to the end of `FReDecl`:
3538       if (auto NewOverloadDecl = NewOverloadSignatureCreator(FReDecl))
3539         FixIts.emplace_back(FixItHint::CreateInsertion(*Loc, *NewOverloadDecl));
3540       else
3541         return {};
3542     }
3543   }
3544   return FixIts;
3545 }
3546 
3547 // To fix a `ParmVarDecl` to be of `std::span` type.
fixParamWithSpan(const ParmVarDecl * PVD,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)3548 static FixItList fixParamWithSpan(const ParmVarDecl *PVD, const ASTContext &Ctx,
3549                                   UnsafeBufferUsageHandler &Handler) {
3550   if (hasUnsupportedSpecifiers(PVD, Ctx.getSourceManager())) {
3551     DEBUG_NOTE_DECL_FAIL(PVD, " : has unsupport specifier(s)");
3552     return {};
3553   }
3554   if (PVD->hasDefaultArg()) {
3555     // FIXME: generate fix-its for default values:
3556     DEBUG_NOTE_DECL_FAIL(PVD, " : has default arg");
3557     return {};
3558   }
3559 
3560   std::optional<Qualifiers> PteTyQualifiers = std::nullopt;
3561   std::optional<std::string> PteTyText = getPointeeTypeText(
3562       PVD, Ctx.getSourceManager(), Ctx.getLangOpts(), &PteTyQualifiers);
3563 
3564   if (!PteTyText) {
3565     DEBUG_NOTE_DECL_FAIL(PVD, " : invalid pointee type");
3566     return {};
3567   }
3568 
3569   std::optional<StringRef> PVDNameText = PVD->getIdentifier()->getName();
3570 
3571   if (!PVDNameText) {
3572     DEBUG_NOTE_DECL_FAIL(PVD, " : invalid identifier name");
3573     return {};
3574   }
3575 
3576   std::stringstream SS;
3577   std::optional<std::string> SpanTyText = createSpanTypeForVarDecl(PVD, Ctx);
3578 
3579   if (PteTyQualifiers)
3580     // Append qualifiers if they exist:
3581     SS << getSpanTypeText(*PteTyText, PteTyQualifiers);
3582   else
3583     SS << getSpanTypeText(*PteTyText);
3584   // Append qualifiers to the type of the parameter:
3585   if (PVD->getType().hasQualifiers())
3586     SS << ' ' << PVD->getType().getQualifiers().getAsString();
3587   // Append parameter's name:
3588   SS << ' ' << PVDNameText->str();
3589   // Add replacement fix-it:
3590   return {FixItHint::CreateReplacement(PVD->getSourceRange(), SS.str())};
3591 }
3592 
fixVariableWithSpan(const VarDecl * VD,const DeclUseTracker & Tracker,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)3593 static FixItList fixVariableWithSpan(const VarDecl *VD,
3594                                      const DeclUseTracker &Tracker,
3595                                      ASTContext &Ctx,
3596                                      UnsafeBufferUsageHandler &Handler) {
3597   const DeclStmt *DS = Tracker.lookupDecl(VD);
3598   if (!DS) {
3599     DEBUG_NOTE_DECL_FAIL(VD,
3600                          " : variables declared this way not implemented yet");
3601     return {};
3602   }
3603   if (!DS->isSingleDecl()) {
3604     // FIXME: to support handling multiple `VarDecl`s in a single `DeclStmt`
3605     DEBUG_NOTE_DECL_FAIL(VD, " : multiple VarDecls");
3606     return {};
3607   }
3608   // Currently DS is an unused variable but we'll need it when
3609   // non-single decls are implemented, where the pointee type name
3610   // and the '*' are spread around the place.
3611   (void)DS;
3612 
3613   // FIXME: handle cases where DS has multiple declarations
3614   return fixLocalVarDeclWithSpan(VD, Ctx, getUserFillPlaceHolder(), Handler);
3615 }
3616 
fixVarDeclWithArray(const VarDecl * D,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)3617 static FixItList fixVarDeclWithArray(const VarDecl *D, const ASTContext &Ctx,
3618                                      UnsafeBufferUsageHandler &Handler) {
3619   FixItList FixIts{};
3620 
3621   // Note: the code below expects the declaration to not use any type sugar like
3622   // typedef.
3623   if (auto CAT = Ctx.getAsConstantArrayType(D->getType())) {
3624     const QualType &ArrayEltT = CAT->getElementType();
3625     assert(!ArrayEltT.isNull() && "Trying to fix a non-array type variable!");
3626     // FIXME: support multi-dimensional arrays
3627     if (isa<clang::ArrayType>(ArrayEltT.getCanonicalType()))
3628       return {};
3629 
3630     const SourceLocation IdentifierLoc = getVarDeclIdentifierLoc(D);
3631 
3632     // Get the spelling of the element type as written in the source file
3633     // (including macros, etc.).
3634     auto MaybeElemTypeTxt =
3635         getRangeText({D->getBeginLoc(), IdentifierLoc}, Ctx.getSourceManager(),
3636                      Ctx.getLangOpts());
3637     if (!MaybeElemTypeTxt)
3638       return {};
3639     const llvm::StringRef ElemTypeTxt = MaybeElemTypeTxt->trim();
3640 
3641     // Find the '[' token.
3642     std::optional<Token> NextTok = Lexer::findNextToken(
3643         IdentifierLoc, Ctx.getSourceManager(), Ctx.getLangOpts());
3644     while (NextTok && !NextTok->is(tok::l_square) &&
3645            NextTok->getLocation() <= D->getSourceRange().getEnd())
3646       NextTok = Lexer::findNextToken(NextTok->getLocation(),
3647                                      Ctx.getSourceManager(), Ctx.getLangOpts());
3648     if (!NextTok)
3649       return {};
3650     const SourceLocation LSqBracketLoc = NextTok->getLocation();
3651 
3652     // Get the spelling of the array size as written in the source file
3653     // (including macros, etc.).
3654     auto MaybeArraySizeTxt = getRangeText(
3655         {LSqBracketLoc.getLocWithOffset(1), D->getTypeSpecEndLoc()},
3656         Ctx.getSourceManager(), Ctx.getLangOpts());
3657     if (!MaybeArraySizeTxt)
3658       return {};
3659     const llvm::StringRef ArraySizeTxt = MaybeArraySizeTxt->trim();
3660     if (ArraySizeTxt.empty()) {
3661       // FIXME: Support array size getting determined from the initializer.
3662       // Examples:
3663       //    int arr1[] = {0, 1, 2};
3664       //    int arr2{3, 4, 5};
3665       // We might be able to preserve the non-specified size with `auto` and
3666       // `std::to_array`:
3667       //    auto arr1 = std::to_array<int>({0, 1, 2});
3668       return {};
3669     }
3670 
3671     std::optional<StringRef> IdentText =
3672         getVarDeclIdentifierText(D, Ctx.getSourceManager(), Ctx.getLangOpts());
3673 
3674     if (!IdentText) {
3675       DEBUG_NOTE_DECL_FAIL(D, " : failed to locate the identifier");
3676       return {};
3677     }
3678 
3679     SmallString<32> Replacement;
3680     llvm::raw_svector_ostream OS(Replacement);
3681     OS << "std::array<" << ElemTypeTxt << ", " << ArraySizeTxt << "> "
3682        << IdentText->str();
3683 
3684     FixIts.push_back(FixItHint::CreateReplacement(
3685         SourceRange{D->getBeginLoc(), D->getTypeSpecEndLoc()}, OS.str()));
3686   }
3687 
3688   return FixIts;
3689 }
3690 
fixVariableWithArray(const VarDecl * VD,const DeclUseTracker & Tracker,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)3691 static FixItList fixVariableWithArray(const VarDecl *VD,
3692                                       const DeclUseTracker &Tracker,
3693                                       const ASTContext &Ctx,
3694                                       UnsafeBufferUsageHandler &Handler) {
3695   const DeclStmt *DS = Tracker.lookupDecl(VD);
3696   assert(DS && "Fixing non-local variables not implemented yet!");
3697   if (!DS->isSingleDecl()) {
3698     // FIXME: to support handling multiple `VarDecl`s in a single `DeclStmt`
3699     return {};
3700   }
3701   // Currently DS is an unused variable but we'll need it when
3702   // non-single decls are implemented, where the pointee type name
3703   // and the '*' are spread around the place.
3704   (void)DS;
3705 
3706   // FIXME: handle cases where DS has multiple declarations
3707   return fixVarDeclWithArray(VD, Ctx, Handler);
3708 }
3709 
3710 // TODO: we should be consistent to use `std::nullopt` to represent no-fix due
3711 // to any unexpected problem.
3712 static FixItList
fixVariable(const VarDecl * VD,FixitStrategy::Kind K,const Decl * D,const DeclUseTracker & Tracker,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)3713 fixVariable(const VarDecl *VD, FixitStrategy::Kind K,
3714             /* The function decl under analysis */ const Decl *D,
3715             const DeclUseTracker &Tracker, ASTContext &Ctx,
3716             UnsafeBufferUsageHandler &Handler) {
3717   if (const auto *PVD = dyn_cast<ParmVarDecl>(VD)) {
3718     auto *FD = dyn_cast<clang::FunctionDecl>(PVD->getDeclContext());
3719     if (!FD || FD != D) {
3720       // `FD != D` means that `PVD` belongs to a function that is not being
3721       // analyzed currently.  Thus `FD` may not be complete.
3722       DEBUG_NOTE_DECL_FAIL(VD, " : function not currently analyzed");
3723       return {};
3724     }
3725 
3726     // TODO If function has a try block we can't change params unless we check
3727     // also its catch block for their use.
3728     // FIXME We might support static class methods, some select methods,
3729     // operators and possibly lamdas.
3730     if (FD->isMain() || FD->isConstexpr() ||
3731         FD->getTemplatedKind() != FunctionDecl::TemplatedKind::TK_NonTemplate ||
3732         FD->isVariadic() ||
3733         // also covers call-operator of lamdas
3734         isa<CXXMethodDecl>(FD) ||
3735         // skip when the function body is a try-block
3736         (FD->hasBody() && isa<CXXTryStmt>(FD->getBody())) ||
3737         FD->isOverloadedOperator()) {
3738       DEBUG_NOTE_DECL_FAIL(VD, " : unsupported function decl");
3739       return {}; // TODO test all these cases
3740     }
3741   }
3742 
3743   switch (K) {
3744   case FixitStrategy::Kind::Span: {
3745     if (VD->getType()->isPointerType()) {
3746       if (const auto *PVD = dyn_cast<ParmVarDecl>(VD))
3747         return fixParamWithSpan(PVD, Ctx, Handler);
3748 
3749       if (VD->isLocalVarDecl())
3750         return fixVariableWithSpan(VD, Tracker, Ctx, Handler);
3751     }
3752     DEBUG_NOTE_DECL_FAIL(VD, " : not a pointer");
3753     return {};
3754   }
3755   case FixitStrategy::Kind::Array: {
3756     if (VD->isLocalVarDecl() && Ctx.getAsConstantArrayType(VD->getType()))
3757       return fixVariableWithArray(VD, Tracker, Ctx, Handler);
3758 
3759     DEBUG_NOTE_DECL_FAIL(VD, " : not a local const-size array");
3760     return {};
3761   }
3762   case FixitStrategy::Kind::Iterator:
3763   case FixitStrategy::Kind::Vector:
3764     llvm_unreachable("FixitStrategy not implemented yet!");
3765   case FixitStrategy::Kind::Wontfix:
3766     llvm_unreachable("Invalid strategy!");
3767   }
3768   llvm_unreachable("Unknown strategy!");
3769 }
3770 
3771 // Returns true iff there exists a `FixItHint` 'h' in `FixIts` such that the
3772 // `RemoveRange` of 'h' overlaps with a macro use.
overlapWithMacro(const FixItList & FixIts)3773 static bool overlapWithMacro(const FixItList &FixIts) {
3774   // FIXME: For now we only check if the range (or the first token) is (part of)
3775   // a macro expansion.  Ideally, we want to check for all tokens in the range.
3776   return llvm::any_of(FixIts, [](const FixItHint &Hint) {
3777     auto Range = Hint.RemoveRange;
3778     if (Range.getBegin().isMacroID() || Range.getEnd().isMacroID())
3779       // If the range (or the first token) is (part of) a macro expansion:
3780       return true;
3781     return false;
3782   });
3783 }
3784 
3785 // Returns true iff `VD` is a parameter of the declaration `D`:
isParameterOf(const VarDecl * VD,const Decl * D)3786 static bool isParameterOf(const VarDecl *VD, const Decl *D) {
3787   return isa<ParmVarDecl>(VD) &&
3788          VD->getDeclContext() == dyn_cast<DeclContext>(D);
3789 }
3790 
3791 // Erases variables in `FixItsForVariable`, if such a variable has an unfixable
3792 // group mate.  A variable `v` is unfixable iff `FixItsForVariable` does not
3793 // contain `v`.
eraseVarsForUnfixableGroupMates(std::map<const VarDecl *,FixItList> & FixItsForVariable,const VariableGroupsManager & VarGrpMgr)3794 static void eraseVarsForUnfixableGroupMates(
3795     std::map<const VarDecl *, FixItList> &FixItsForVariable,
3796     const VariableGroupsManager &VarGrpMgr) {
3797   // Variables will be removed from `FixItsForVariable`:
3798   SmallVector<const VarDecl *, 8> ToErase;
3799 
3800   for (const auto &[VD, Ignore] : FixItsForVariable) {
3801     VarGrpRef Grp = VarGrpMgr.getGroupOfVar(VD);
3802     if (llvm::any_of(Grp,
3803                      [&FixItsForVariable](const VarDecl *GrpMember) -> bool {
3804                        return !FixItsForVariable.count(GrpMember);
3805                      })) {
3806       // At least one group member cannot be fixed, so we have to erase the
3807       // whole group:
3808       for (const VarDecl *Member : Grp)
3809         ToErase.push_back(Member);
3810     }
3811   }
3812   for (auto *VarToErase : ToErase)
3813     FixItsForVariable.erase(VarToErase);
3814 }
3815 
3816 // Returns the fix-its that create bounds-safe function overloads for the
3817 // function `D`, if `D`'s parameters will be changed to safe-types through
3818 // fix-its in `FixItsForVariable`.
3819 //
3820 // NOTE: In case `D`'s parameters will be changed but bounds-safe function
3821 // overloads cannot created, the whole group that contains the parameters will
3822 // be erased from `FixItsForVariable`.
createFunctionOverloadsForParms(std::map<const VarDecl *,FixItList> & FixItsForVariable,const VariableGroupsManager & VarGrpMgr,const FunctionDecl * FD,const FixitStrategy & S,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)3823 static FixItList createFunctionOverloadsForParms(
3824     std::map<const VarDecl *, FixItList> &FixItsForVariable /* mutable */,
3825     const VariableGroupsManager &VarGrpMgr, const FunctionDecl *FD,
3826     const FixitStrategy &S, ASTContext &Ctx,
3827     UnsafeBufferUsageHandler &Handler) {
3828   FixItList FixItsSharedByParms{};
3829 
3830   std::optional<FixItList> OverloadFixes =
3831       createOverloadsForFixedParams(S, FD, Ctx, Handler);
3832 
3833   if (OverloadFixes) {
3834     FixItsSharedByParms.append(*OverloadFixes);
3835   } else {
3836     // Something wrong in generating `OverloadFixes`, need to remove the
3837     // whole group, where parameters are in, from `FixItsForVariable` (Note
3838     // that all parameters should be in the same group):
3839     for (auto *Member : VarGrpMgr.getGroupOfParms())
3840       FixItsForVariable.erase(Member);
3841   }
3842   return FixItsSharedByParms;
3843 }
3844 
3845 // Constructs self-contained fix-its for each variable in `FixablesForAllVars`.
3846 static std::map<const VarDecl *, FixItList>
getFixIts(FixableGadgetSets & FixablesForAllVars,const FixitStrategy & S,ASTContext & Ctx,const Decl * D,const DeclUseTracker & Tracker,UnsafeBufferUsageHandler & Handler,const VariableGroupsManager & VarGrpMgr)3847 getFixIts(FixableGadgetSets &FixablesForAllVars, const FixitStrategy &S,
3848           ASTContext &Ctx,
3849           /* The function decl under analysis */ const Decl *D,
3850           const DeclUseTracker &Tracker, UnsafeBufferUsageHandler &Handler,
3851           const VariableGroupsManager &VarGrpMgr) {
3852   // `FixItsForVariable` will map each variable to a set of fix-its directly
3853   // associated to the variable itself.  Fix-its of distinct variables in
3854   // `FixItsForVariable` are disjoint.
3855   std::map<const VarDecl *, FixItList> FixItsForVariable;
3856 
3857   // Populate `FixItsForVariable` with fix-its directly associated with each
3858   // variable.  Fix-its directly associated to a variable 'v' are the ones
3859   // produced by the `FixableGadget`s whose claimed variable is 'v'.
3860   for (const auto &[VD, Fixables] : FixablesForAllVars.byVar) {
3861     FixItsForVariable[VD] =
3862         fixVariable(VD, S.lookup(VD), D, Tracker, Ctx, Handler);
3863     // If we fail to produce Fix-It for the declaration we have to skip the
3864     // variable entirely.
3865     if (FixItsForVariable[VD].empty()) {
3866       FixItsForVariable.erase(VD);
3867       continue;
3868     }
3869     for (const auto &F : Fixables) {
3870       std::optional<FixItList> Fixits = F->getFixits(S);
3871 
3872       if (Fixits) {
3873         FixItsForVariable[VD].insert(FixItsForVariable[VD].end(),
3874                                      Fixits->begin(), Fixits->end());
3875         continue;
3876       }
3877 #ifndef NDEBUG
3878       Handler.addDebugNoteForVar(
3879           VD, F->getSourceLoc(),
3880           ("gadget '" + F->getDebugName() + "' refused to produce a fix")
3881               .str());
3882 #endif
3883       FixItsForVariable.erase(VD);
3884       break;
3885     }
3886   }
3887 
3888   // `FixItsForVariable` now contains only variables that can be
3889   // fixed. A variable can be fixed if its' declaration and all Fixables
3890   // associated to it can all be fixed.
3891 
3892   // To further remove from `FixItsForVariable` variables whose group mates
3893   // cannot be fixed...
3894   eraseVarsForUnfixableGroupMates(FixItsForVariable, VarGrpMgr);
3895   // Now `FixItsForVariable` gets further reduced: a variable is in
3896   // `FixItsForVariable` iff it can be fixed and all its group mates can be
3897   // fixed.
3898 
3899   // Fix-its of bounds-safe overloads of `D` are shared by parameters of `D`.
3900   // That is,  when fixing multiple parameters in one step,  these fix-its will
3901   // be applied only once (instead of being applied per parameter).
3902   FixItList FixItsSharedByParms{};
3903 
3904   if (auto *FD = dyn_cast<FunctionDecl>(D))
3905     FixItsSharedByParms = createFunctionOverloadsForParms(
3906         FixItsForVariable, VarGrpMgr, FD, S, Ctx, Handler);
3907 
3908   // The map that maps each variable `v` to fix-its for the whole group where
3909   // `v` is in:
3910   std::map<const VarDecl *, FixItList> FinalFixItsForVariable{
3911       FixItsForVariable};
3912 
3913   for (auto &[Var, Ignore] : FixItsForVariable) {
3914     bool AnyParm = false;
3915     const auto VarGroupForVD = VarGrpMgr.getGroupOfVar(Var, &AnyParm);
3916 
3917     for (const VarDecl *GrpMate : VarGroupForVD) {
3918       if (Var == GrpMate)
3919         continue;
3920       if (FixItsForVariable.count(GrpMate))
3921         FinalFixItsForVariable[Var].append(FixItsForVariable[GrpMate]);
3922     }
3923     if (AnyParm) {
3924       // This assertion should never fail.  Otherwise we have a bug.
3925       assert(!FixItsSharedByParms.empty() &&
3926              "Should not try to fix a parameter that does not belong to a "
3927              "FunctionDecl");
3928       FinalFixItsForVariable[Var].append(FixItsSharedByParms);
3929     }
3930   }
3931   // Fix-its that will be applied in one step shall NOT:
3932   // 1. overlap with macros or/and templates; or
3933   // 2. conflict with each other.
3934   // Otherwise, the fix-its will be dropped.
3935   for (auto Iter = FinalFixItsForVariable.begin();
3936        Iter != FinalFixItsForVariable.end();)
3937     if (overlapWithMacro(Iter->second) ||
3938         clang::internal::anyConflict(Iter->second, Ctx.getSourceManager())) {
3939       Iter = FinalFixItsForVariable.erase(Iter);
3940     } else
3941       Iter++;
3942   return FinalFixItsForVariable;
3943 }
3944 
3945 template <typename VarDeclIterTy>
3946 static FixitStrategy
getNaiveStrategy(llvm::iterator_range<VarDeclIterTy> UnsafeVars)3947 getNaiveStrategy(llvm::iterator_range<VarDeclIterTy> UnsafeVars) {
3948   FixitStrategy S;
3949   for (const VarDecl *VD : UnsafeVars) {
3950     if (isa<ConstantArrayType>(VD->getType().getCanonicalType()))
3951       S.set(VD, FixitStrategy::Kind::Array);
3952     else
3953       S.set(VD, FixitStrategy::Kind::Span);
3954   }
3955   return S;
3956 }
3957 
3958 //  Manages variable groups:
3959 class VariableGroupsManagerImpl : public VariableGroupsManager {
3960   const std::vector<VarGrpTy> Groups;
3961   const std::map<const VarDecl *, unsigned> &VarGrpMap;
3962   const llvm::SetVector<const VarDecl *> &GrpsUnionForParms;
3963 
3964 public:
VariableGroupsManagerImpl(const std::vector<VarGrpTy> & Groups,const std::map<const VarDecl *,unsigned> & VarGrpMap,const llvm::SetVector<const VarDecl * > & GrpsUnionForParms)3965   VariableGroupsManagerImpl(
3966       const std::vector<VarGrpTy> &Groups,
3967       const std::map<const VarDecl *, unsigned> &VarGrpMap,
3968       const llvm::SetVector<const VarDecl *> &GrpsUnionForParms)
3969       : Groups(Groups), VarGrpMap(VarGrpMap),
3970         GrpsUnionForParms(GrpsUnionForParms) {}
3971 
getGroupOfVar(const VarDecl * Var,bool * HasParm) const3972   VarGrpRef getGroupOfVar(const VarDecl *Var, bool *HasParm) const override {
3973     if (GrpsUnionForParms.contains(Var)) {
3974       if (HasParm)
3975         *HasParm = true;
3976       return GrpsUnionForParms.getArrayRef();
3977     }
3978     if (HasParm)
3979       *HasParm = false;
3980 
3981     auto It = VarGrpMap.find(Var);
3982 
3983     if (It == VarGrpMap.end())
3984       return {};
3985     return Groups[It->second];
3986   }
3987 
getGroupOfParms() const3988   VarGrpRef getGroupOfParms() const override {
3989     return GrpsUnionForParms.getArrayRef();
3990   }
3991 };
3992 
applyGadgets(const Decl * D,FixableGadgetList FixableGadgets,WarningGadgetList WarningGadgets,DeclUseTracker Tracker,UnsafeBufferUsageHandler & Handler,bool EmitSuggestions)3993 static void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets,
3994                          WarningGadgetList WarningGadgets,
3995                          DeclUseTracker Tracker,
3996                          UnsafeBufferUsageHandler &Handler,
3997                          bool EmitSuggestions) {
3998   if (!EmitSuggestions) {
3999     // Our job is very easy without suggestions. Just warn about
4000     // every problematic operation and consider it done. No need to deal
4001     // with fixable gadgets, no need to group operations by variable.
4002     for (const auto &G : WarningGadgets) {
4003       G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/false,
4004                                D->getASTContext());
4005     }
4006 
4007     // This return guarantees that most of the machine doesn't run when
4008     // suggestions aren't requested.
4009     assert(FixableGadgets.empty() &&
4010            "Fixable gadgets found but suggestions not requested!");
4011     return;
4012   }
4013 
4014   // If no `WarningGadget`s ever matched, there is no unsafe operations in the
4015   //  function under the analysis. No need to fix any Fixables.
4016   if (!WarningGadgets.empty()) {
4017     // Gadgets "claim" variables they're responsible for. Once this loop
4018     // finishes, the tracker will only track DREs that weren't claimed by any
4019     // gadgets, i.e. not understood by the analysis.
4020     for (const auto &G : FixableGadgets) {
4021       for (const auto *DRE : G->getClaimedVarUseSites()) {
4022         Tracker.claimUse(DRE);
4023       }
4024     }
4025   }
4026 
4027   // If no `WarningGadget`s ever matched, there is no unsafe operations in the
4028   // function under the analysis.  Thus, it early returns here as there is
4029   // nothing needs to be fixed.
4030   //
4031   // Note this claim is based on the assumption that there is no unsafe
4032   // variable whose declaration is invisible from the analyzing function.
4033   // Otherwise, we need to consider if the uses of those unsafe varuables needs
4034   // fix.
4035   // So far, we are not fixing any global variables or class members. And,
4036   // lambdas will be analyzed along with the enclosing function. So this early
4037   // return is correct for now.
4038   if (WarningGadgets.empty())
4039     return;
4040 
4041   WarningGadgetSets UnsafeOps =
4042       groupWarningGadgetsByVar(std::move(WarningGadgets));
4043   FixableGadgetSets FixablesForAllVars =
4044       groupFixablesByVar(std::move(FixableGadgets));
4045 
4046   std::map<const VarDecl *, FixItList> FixItsForVariableGroup;
4047 
4048   // Filter out non-local vars and vars with unclaimed DeclRefExpr-s.
4049   for (auto it = FixablesForAllVars.byVar.cbegin();
4050        it != FixablesForAllVars.byVar.cend();) {
4051     // FIXME: need to deal with global variables later
4052     if ((!it->first->isLocalVarDecl() && !isa<ParmVarDecl>(it->first))) {
4053 #ifndef NDEBUG
4054       Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
4055                                  ("failed to produce fixit for '" +
4056                                   it->first->getNameAsString() +
4057                                   "' : neither local nor a parameter"));
4058 #endif
4059       it = FixablesForAllVars.byVar.erase(it);
4060     } else if (it->first->getType().getCanonicalType()->isReferenceType()) {
4061 #ifndef NDEBUG
4062       Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
4063                                  ("failed to produce fixit for '" +
4064                                   it->first->getNameAsString() +
4065                                   "' : has a reference type"));
4066 #endif
4067       it = FixablesForAllVars.byVar.erase(it);
4068     } else if (Tracker.hasUnclaimedUses(it->first)) {
4069       it = FixablesForAllVars.byVar.erase(it);
4070     } else if (it->first->isInitCapture()) {
4071 #ifndef NDEBUG
4072       Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
4073                                  ("failed to produce fixit for '" +
4074                                   it->first->getNameAsString() +
4075                                   "' : init capture"));
4076 #endif
4077       it = FixablesForAllVars.byVar.erase(it);
4078     } else {
4079       ++it;
4080     }
4081   }
4082 
4083 #ifndef NDEBUG
4084   for (const auto &it : UnsafeOps.byVar) {
4085     const VarDecl *const UnsafeVD = it.first;
4086     auto UnclaimedDREs = Tracker.getUnclaimedUses(UnsafeVD);
4087     if (UnclaimedDREs.empty())
4088       continue;
4089     const auto UnfixedVDName = UnsafeVD->getNameAsString();
4090     for (const clang::DeclRefExpr *UnclaimedDRE : UnclaimedDREs) {
4091       std::string UnclaimedUseTrace =
4092           getDREAncestorString(UnclaimedDRE, D->getASTContext());
4093 
4094       Handler.addDebugNoteForVar(
4095           UnsafeVD, UnclaimedDRE->getBeginLoc(),
4096           ("failed to produce fixit for '" + UnfixedVDName +
4097            "' : has an unclaimed use\nThe unclaimed DRE trace: " +
4098            UnclaimedUseTrace));
4099     }
4100   }
4101 #endif
4102 
4103   // Fixpoint iteration for pointer assignments
4104   using DepMapTy =
4105       llvm::DenseMap<const VarDecl *, llvm::SetVector<const VarDecl *>>;
4106   DepMapTy DependenciesMap{};
4107   DepMapTy PtrAssignmentGraph{};
4108 
4109   for (const auto &it : FixablesForAllVars.byVar) {
4110     for (const FixableGadget *fixable : it.second) {
4111       std::optional<std::pair<const VarDecl *, const VarDecl *>> ImplPair =
4112           fixable->getStrategyImplications();
4113       if (ImplPair) {
4114         std::pair<const VarDecl *, const VarDecl *> Impl = std::move(*ImplPair);
4115         PtrAssignmentGraph[Impl.first].insert(Impl.second);
4116       }
4117     }
4118   }
4119 
4120   /*
4121    The following code does a BFS traversal of the `PtrAssignmentGraph`
4122    considering all unsafe vars as starting nodes and constructs an undirected
4123    graph `DependenciesMap`. Constructing the `DependenciesMap` in this manner
4124    elimiates all variables that are unreachable from any unsafe var. In other
4125    words, this removes all dependencies that don't include any unsafe variable
4126    and consequently don't need any fixit generation.
4127    Note: A careful reader would observe that the code traverses
4128    `PtrAssignmentGraph` using `CurrentVar` but adds edges between `Var` and
4129    `Adj` and not between `CurrentVar` and `Adj`. Both approaches would
4130    achieve the same result but the one used here dramatically cuts the
4131    amount of hoops the second part of the algorithm needs to jump, given that
4132    a lot of these connections become "direct". The reader is advised not to
4133    imagine how the graph is transformed because of using `Var` instead of
4134    `CurrentVar`. The reader can continue reading as if `CurrentVar` was used,
4135    and think about why it's equivalent later.
4136    */
4137   std::set<const VarDecl *> VisitedVarsDirected{};
4138   for (const auto &[Var, ignore] : UnsafeOps.byVar) {
4139     if (VisitedVarsDirected.find(Var) == VisitedVarsDirected.end()) {
4140 
4141       std::queue<const VarDecl *> QueueDirected{};
4142       QueueDirected.push(Var);
4143       while (!QueueDirected.empty()) {
4144         const VarDecl *CurrentVar = QueueDirected.front();
4145         QueueDirected.pop();
4146         VisitedVarsDirected.insert(CurrentVar);
4147         auto AdjacentNodes = PtrAssignmentGraph[CurrentVar];
4148         for (const VarDecl *Adj : AdjacentNodes) {
4149           if (VisitedVarsDirected.find(Adj) == VisitedVarsDirected.end()) {
4150             QueueDirected.push(Adj);
4151           }
4152           DependenciesMap[Var].insert(Adj);
4153           DependenciesMap[Adj].insert(Var);
4154         }
4155       }
4156     }
4157   }
4158 
4159   // `Groups` stores the set of Connected Components in the graph.
4160   std::vector<VarGrpTy> Groups;
4161   // `VarGrpMap` maps variables that need fix to the groups (indexes) that the
4162   // variables belong to.  Group indexes refer to the elements in `Groups`.
4163   // `VarGrpMap` is complete in that every variable that needs fix is in it.
4164   std::map<const VarDecl *, unsigned> VarGrpMap;
4165   // The union group over the ones in "Groups" that contain parameters of `D`:
4166   llvm::SetVector<const VarDecl *>
4167       GrpsUnionForParms; // these variables need to be fixed in one step
4168 
4169   // Group Connected Components for Unsafe Vars
4170   // (Dependencies based on pointer assignments)
4171   std::set<const VarDecl *> VisitedVars{};
4172   for (const auto &[Var, ignore] : UnsafeOps.byVar) {
4173     if (VisitedVars.find(Var) == VisitedVars.end()) {
4174       VarGrpTy &VarGroup = Groups.emplace_back();
4175       std::queue<const VarDecl *> Queue{};
4176 
4177       Queue.push(Var);
4178       while (!Queue.empty()) {
4179         const VarDecl *CurrentVar = Queue.front();
4180         Queue.pop();
4181         VisitedVars.insert(CurrentVar);
4182         VarGroup.push_back(CurrentVar);
4183         auto AdjacentNodes = DependenciesMap[CurrentVar];
4184         for (const VarDecl *Adj : AdjacentNodes) {
4185           if (VisitedVars.find(Adj) == VisitedVars.end()) {
4186             Queue.push(Adj);
4187           }
4188         }
4189       }
4190 
4191       bool HasParm = false;
4192       unsigned GrpIdx = Groups.size() - 1;
4193 
4194       for (const VarDecl *V : VarGroup) {
4195         VarGrpMap[V] = GrpIdx;
4196         if (!HasParm && isParameterOf(V, D))
4197           HasParm = true;
4198       }
4199       if (HasParm)
4200         GrpsUnionForParms.insert_range(VarGroup);
4201     }
4202   }
4203 
4204   // Remove a `FixableGadget` if the associated variable is not in the graph
4205   // computed above.  We do not want to generate fix-its for such variables,
4206   // since they are neither warned nor reachable from a warned one.
4207   //
4208   // Note a variable is not warned if it is not directly used in any unsafe
4209   // operation. A variable `v` is NOT reachable from an unsafe variable, if it
4210   // does not exist another variable `u` such that `u` is warned and fixing `u`
4211   // (transitively) implicates fixing `v`.
4212   //
4213   // For example,
4214   // ```
4215   // void f(int * p) {
4216   //   int * a = p; *p = 0;
4217   // }
4218   // ```
4219   // `*p = 0` is a fixable gadget associated with a variable `p` that is neither
4220   // warned nor reachable from a warned one.  If we add `a[5] = 0` to the end of
4221   // the function above, `p` becomes reachable from a warned variable.
4222   for (auto I = FixablesForAllVars.byVar.begin();
4223        I != FixablesForAllVars.byVar.end();) {
4224     // Note `VisitedVars` contain all the variables in the graph:
4225     if (!VisitedVars.count((*I).first)) {
4226       // no such var in graph:
4227       I = FixablesForAllVars.byVar.erase(I);
4228     } else
4229       ++I;
4230   }
4231 
4232   // We assign strategies to variables that are 1) in the graph and 2) can be
4233   // fixed. Other variables have the default "Won't fix" strategy.
4234   FixitStrategy NaiveStrategy = getNaiveStrategy(llvm::make_filter_range(
4235       VisitedVars, [&FixablesForAllVars](const VarDecl *V) {
4236         // If a warned variable has no "Fixable", it is considered unfixable:
4237         return FixablesForAllVars.byVar.count(V);
4238       }));
4239   VariableGroupsManagerImpl VarGrpMgr(Groups, VarGrpMap, GrpsUnionForParms);
4240 
4241   if (isa<NamedDecl>(D))
4242     // The only case where `D` is not a `NamedDecl` is when `D` is a
4243     // `BlockDecl`. Let's not fix variables in blocks for now
4244     FixItsForVariableGroup =
4245         getFixIts(FixablesForAllVars, NaiveStrategy, D->getASTContext(), D,
4246                   Tracker, Handler, VarGrpMgr);
4247 
4248   for (const auto &G : UnsafeOps.noVar) {
4249     G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/false,
4250                              D->getASTContext());
4251   }
4252 
4253   for (const auto &[VD, WarningGadgets] : UnsafeOps.byVar) {
4254     auto FixItsIt = FixItsForVariableGroup.find(VD);
4255     Handler.handleUnsafeVariableGroup(VD, VarGrpMgr,
4256                                       FixItsIt != FixItsForVariableGroup.end()
4257                                           ? std::move(FixItsIt->second)
4258                                           : FixItList{},
4259                                       D, NaiveStrategy);
4260     for (const auto &G : WarningGadgets) {
4261       G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/true,
4262                                D->getASTContext());
4263     }
4264   }
4265 }
4266 
checkUnsafeBufferUsage(const Decl * D,UnsafeBufferUsageHandler & Handler,bool EmitSuggestions)4267 void clang::checkUnsafeBufferUsage(const Decl *D,
4268                                    UnsafeBufferUsageHandler &Handler,
4269                                    bool EmitSuggestions) {
4270 #ifndef NDEBUG
4271   Handler.clearDebugNotes();
4272 #endif
4273 
4274   assert(D);
4275 
4276   SmallVector<Stmt *> Stmts;
4277 
4278   if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
4279     // We do not want to visit a Lambda expression defined inside a method
4280     // independently. Instead, it should be visited along with the outer method.
4281     // FIXME: do we want to do the same thing for `BlockDecl`s?
4282     if (const auto *MD = dyn_cast<CXXMethodDecl>(D)) {
4283       if (MD->getParent()->isLambda() && MD->getParent()->isLocalClass())
4284         return;
4285     }
4286 
4287     for (FunctionDecl *FReDecl : FD->redecls()) {
4288       if (FReDecl->isExternC()) {
4289         // Do not emit fixit suggestions for functions declared in an
4290         // extern "C" block.
4291         EmitSuggestions = false;
4292         break;
4293       }
4294     }
4295 
4296     Stmts.push_back(FD->getBody());
4297 
4298     if (const auto *ID = dyn_cast<CXXConstructorDecl>(D)) {
4299       for (const CXXCtorInitializer *CI : ID->inits()) {
4300         Stmts.push_back(CI->getInit());
4301       }
4302     }
4303   } else if (isa<BlockDecl>(D) || isa<ObjCMethodDecl>(D)) {
4304     Stmts.push_back(D->getBody());
4305   }
4306 
4307   assert(!Stmts.empty());
4308 
4309   FixableGadgetList FixableGadgets;
4310   WarningGadgetList WarningGadgets;
4311   DeclUseTracker Tracker;
4312   for (Stmt *S : Stmts) {
4313     findGadgets(S, D->getASTContext(), Handler, EmitSuggestions, FixableGadgets,
4314                 WarningGadgets, Tracker);
4315   }
4316   applyGadgets(D, std::move(FixableGadgets), std::move(WarningGadgets),
4317                std::move(Tracker), Handler, EmitSuggestions);
4318 }
4319