xref: /freebsd/contrib/llvm-project/clang/lib/Analysis/UnsafeBufferUsage.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- UnsafeBufferUsage.cpp - Replace pointers with modern C++ -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Analysis/Analyses/UnsafeBufferUsage.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/AST/Decl.h"
12 #include "clang/AST/Expr.h"
13 #include "clang/AST/RecursiveASTVisitor.h"
14 #include "clang/AST/Stmt.h"
15 #include "clang/AST/StmtVisitor.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/ASTMatchers/ASTMatchers.h"
18 #include "clang/Basic/CharInfo.h"
19 #include "clang/Basic/SourceLocation.h"
20 #include "clang/Lex/Lexer.h"
21 #include "clang/Lex/Preprocessor.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include <memory>
27 #include <optional>
28 #include <queue>
29 #include <sstream>
30 
31 using namespace llvm;
32 using namespace clang;
33 using namespace ast_matchers;
34 
35 #ifndef NDEBUG
36 namespace {
37 class StmtDebugPrinter
38     : public ConstStmtVisitor<StmtDebugPrinter, std::string> {
39 public:
VisitStmt(const Stmt * S)40   std::string VisitStmt(const Stmt *S) { return S->getStmtClassName(); }
41 
VisitBinaryOperator(const BinaryOperator * BO)42   std::string VisitBinaryOperator(const BinaryOperator *BO) {
43     return "BinaryOperator(" + BO->getOpcodeStr().str() + ")";
44   }
45 
VisitUnaryOperator(const UnaryOperator * UO)46   std::string VisitUnaryOperator(const UnaryOperator *UO) {
47     return "UnaryOperator(" + UO->getOpcodeStr(UO->getOpcode()).str() + ")";
48   }
49 
VisitImplicitCastExpr(const ImplicitCastExpr * ICE)50   std::string VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
51     return "ImplicitCastExpr(" + std::string(ICE->getCastKindName()) + ")";
52   }
53 };
54 
55 // Returns a string of ancestor `Stmt`s of the given `DRE` in such a form:
56 // "DRE ==> parent-of-DRE ==> grandparent-of-DRE ==> ...".
getDREAncestorString(const DeclRefExpr * DRE,ASTContext & Ctx)57 static std::string getDREAncestorString(const DeclRefExpr *DRE,
58                                         ASTContext &Ctx) {
59   std::stringstream SS;
60   const Stmt *St = DRE;
61   StmtDebugPrinter StmtPriner;
62 
63   do {
64     SS << StmtPriner.Visit(St);
65 
66     DynTypedNodeList StParents = Ctx.getParents(*St);
67 
68     if (StParents.size() > 1)
69       return "unavailable due to multiple parents";
70     if (StParents.size() == 0)
71       break;
72     St = StParents.begin()->get<Stmt>();
73     if (St)
74       SS << " ==> ";
75   } while (St);
76   return SS.str();
77 }
78 } // namespace
79 #endif /* NDEBUG */
80 
81 namespace clang::ast_matchers {
82 // A `RecursiveASTVisitor` that traverses all descendants of a given node "n"
83 // except for those belonging to a different callable of "n".
84 class MatchDescendantVisitor
85     : public RecursiveASTVisitor<MatchDescendantVisitor> {
86 public:
87   typedef RecursiveASTVisitor<MatchDescendantVisitor> VisitorBase;
88 
89   // Creates an AST visitor that matches `Matcher` on all
90   // descendants of a given node "n" except for the ones
91   // belonging to a different callable of "n".
MatchDescendantVisitor(const internal::DynTypedMatcher * Matcher,internal::ASTMatchFinder * Finder,internal::BoundNodesTreeBuilder * Builder,internal::ASTMatchFinder::BindKind Bind,const bool ignoreUnevaluatedContext)92   MatchDescendantVisitor(const internal::DynTypedMatcher *Matcher,
93                          internal::ASTMatchFinder *Finder,
94                          internal::BoundNodesTreeBuilder *Builder,
95                          internal::ASTMatchFinder::BindKind Bind,
96                          const bool ignoreUnevaluatedContext)
97       : Matcher(Matcher), Finder(Finder), Builder(Builder), Bind(Bind),
98         Matches(false), ignoreUnevaluatedContext(ignoreUnevaluatedContext) {}
99 
100   // Returns true if a match is found in a subtree of `DynNode`, which belongs
101   // to the same callable of `DynNode`.
findMatch(const DynTypedNode & DynNode)102   bool findMatch(const DynTypedNode &DynNode) {
103     Matches = false;
104     if (const Stmt *StmtNode = DynNode.get<Stmt>()) {
105       TraverseStmt(const_cast<Stmt *>(StmtNode));
106       *Builder = ResultBindings;
107       return Matches;
108     }
109     return false;
110   }
111 
112   // The following are overriding methods from the base visitor class.
113   // They are public only to allow CRTP to work. They are *not *part
114   // of the public API of this class.
115 
116   // For the matchers so far used in safe buffers, we only need to match
117   // `Stmt`s.  To override more as needed.
118 
TraverseDecl(Decl * Node)119   bool TraverseDecl(Decl *Node) {
120     if (!Node)
121       return true;
122     if (!match(*Node))
123       return false;
124     // To skip callables:
125     if (isa<FunctionDecl, BlockDecl, ObjCMethodDecl>(Node))
126       return true;
127     // Traverse descendants
128     return VisitorBase::TraverseDecl(Node);
129   }
130 
TraverseGenericSelectionExpr(GenericSelectionExpr * Node)131   bool TraverseGenericSelectionExpr(GenericSelectionExpr *Node) {
132     // These are unevaluated, except the result expression.
133     if (ignoreUnevaluatedContext)
134       return TraverseStmt(Node->getResultExpr());
135     return VisitorBase::TraverseGenericSelectionExpr(Node);
136   }
137 
TraverseUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr * Node)138   bool TraverseUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr *Node) {
139     // Unevaluated context.
140     if (ignoreUnevaluatedContext)
141       return true;
142     return VisitorBase::TraverseUnaryExprOrTypeTraitExpr(Node);
143   }
144 
TraverseTypeOfExprTypeLoc(TypeOfExprTypeLoc Node)145   bool TraverseTypeOfExprTypeLoc(TypeOfExprTypeLoc Node) {
146     // Unevaluated context.
147     if (ignoreUnevaluatedContext)
148       return true;
149     return VisitorBase::TraverseTypeOfExprTypeLoc(Node);
150   }
151 
TraverseDecltypeTypeLoc(DecltypeTypeLoc Node)152   bool TraverseDecltypeTypeLoc(DecltypeTypeLoc Node) {
153     // Unevaluated context.
154     if (ignoreUnevaluatedContext)
155       return true;
156     return VisitorBase::TraverseDecltypeTypeLoc(Node);
157   }
158 
TraverseCXXNoexceptExpr(CXXNoexceptExpr * Node)159   bool TraverseCXXNoexceptExpr(CXXNoexceptExpr *Node) {
160     // Unevaluated context.
161     if (ignoreUnevaluatedContext)
162       return true;
163     return VisitorBase::TraverseCXXNoexceptExpr(Node);
164   }
165 
TraverseCXXTypeidExpr(CXXTypeidExpr * Node)166   bool TraverseCXXTypeidExpr(CXXTypeidExpr *Node) {
167     // Unevaluated context.
168     if (ignoreUnevaluatedContext)
169       return true;
170     return VisitorBase::TraverseCXXTypeidExpr(Node);
171   }
172 
TraverseStmt(Stmt * Node,DataRecursionQueue * Queue=nullptr)173   bool TraverseStmt(Stmt *Node, DataRecursionQueue *Queue = nullptr) {
174     if (!Node)
175       return true;
176     if (!match(*Node))
177       return false;
178     return VisitorBase::TraverseStmt(Node);
179   }
180 
shouldVisitTemplateInstantiations() const181   bool shouldVisitTemplateInstantiations() const { return true; }
shouldVisitImplicitCode() const182   bool shouldVisitImplicitCode() const {
183     // TODO: let's ignore implicit code for now
184     return false;
185   }
186 
187 private:
188   // Sets 'Matched' to true if 'Matcher' matches 'Node'
189   //
190   // Returns 'true' if traversal should continue after this function
191   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
match(const T & Node)192   template <typename T> bool match(const T &Node) {
193     internal::BoundNodesTreeBuilder RecursiveBuilder(*Builder);
194 
195     if (Matcher->matches(DynTypedNode::create(Node), Finder,
196                          &RecursiveBuilder)) {
197       ResultBindings.addMatch(RecursiveBuilder);
198       Matches = true;
199       if (Bind != internal::ASTMatchFinder::BK_All)
200         return false; // Abort as soon as a match is found.
201     }
202     return true;
203   }
204 
205   const internal::DynTypedMatcher *const Matcher;
206   internal::ASTMatchFinder *const Finder;
207   internal::BoundNodesTreeBuilder *const Builder;
208   internal::BoundNodesTreeBuilder ResultBindings;
209   const internal::ASTMatchFinder::BindKind Bind;
210   bool Matches;
211   bool ignoreUnevaluatedContext;
212 };
213 
214 // Because we're dealing with raw pointers, let's define what we mean by that.
hasPointerType()215 static auto hasPointerType() {
216   return hasType(hasCanonicalType(pointerType()));
217 }
218 
hasArrayType()219 static auto hasArrayType() { return hasType(hasCanonicalType(arrayType())); }
220 
AST_MATCHER_P(Stmt,forEachDescendantEvaluatedStmt,internal::Matcher<Stmt>,innerMatcher)221 AST_MATCHER_P(Stmt, forEachDescendantEvaluatedStmt, internal::Matcher<Stmt>,
222               innerMatcher) {
223   const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher);
224 
225   MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All,
226                                  true);
227   return Visitor.findMatch(DynTypedNode::create(Node));
228 }
229 
AST_MATCHER_P(Stmt,forEachDescendantStmt,internal::Matcher<Stmt>,innerMatcher)230 AST_MATCHER_P(Stmt, forEachDescendantStmt, internal::Matcher<Stmt>,
231               innerMatcher) {
232   const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher);
233 
234   MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All,
235                                  false);
236   return Visitor.findMatch(DynTypedNode::create(Node));
237 }
238 
239 // Matches a `Stmt` node iff the node is in a safe-buffer opt-out region
AST_MATCHER_P(Stmt,notInSafeBufferOptOut,const UnsafeBufferUsageHandler *,Handler)240 AST_MATCHER_P(Stmt, notInSafeBufferOptOut, const UnsafeBufferUsageHandler *,
241               Handler) {
242   return !Handler->isSafeBufferOptOut(Node.getBeginLoc());
243 }
244 
AST_MATCHER_P(Stmt,ignoreUnsafeBufferInContainer,const UnsafeBufferUsageHandler *,Handler)245 AST_MATCHER_P(Stmt, ignoreUnsafeBufferInContainer,
246               const UnsafeBufferUsageHandler *, Handler) {
247   return Handler->ignoreUnsafeBufferInContainer(Node.getBeginLoc());
248 }
249 
AST_MATCHER_P(CastExpr,castSubExpr,internal::Matcher<Expr>,innerMatcher)250 AST_MATCHER_P(CastExpr, castSubExpr, internal::Matcher<Expr>, innerMatcher) {
251   return innerMatcher.matches(*Node.getSubExpr(), Finder, Builder);
252 }
253 
254 // Matches a `UnaryOperator` whose operator is pre-increment:
AST_MATCHER(UnaryOperator,isPreInc)255 AST_MATCHER(UnaryOperator, isPreInc) {
256   return Node.getOpcode() == UnaryOperator::Opcode::UO_PreInc;
257 }
258 
259 // Returns a matcher that matches any expression 'e' such that `innerMatcher`
260 // matches 'e' and 'e' is in an Unspecified Lvalue Context.
isInUnspecifiedLvalueContext(internal::Matcher<Expr> innerMatcher)261 static auto isInUnspecifiedLvalueContext(internal::Matcher<Expr> innerMatcher) {
262   // clang-format off
263   return
264     expr(anyOf(
265       implicitCastExpr(
266         hasCastKind(CastKind::CK_LValueToRValue),
267         castSubExpr(innerMatcher)),
268       binaryOperator(
269         hasAnyOperatorName("="),
270         hasLHS(innerMatcher)
271       )
272     ));
273   // clang-format on
274 }
275 
276 // Returns a matcher that matches any expression `e` such that `InnerMatcher`
277 // matches `e` and `e` is in an Unspecified Pointer Context (UPC).
278 static internal::Matcher<Stmt>
isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher)279 isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) {
280   // A UPC can be
281   // 1. an argument of a function call (except the callee has [[unsafe_...]]
282   //    attribute), or
283   // 2. the operand of a pointer-to-(integer or bool) cast operation; or
284   // 3. the operand of a comparator operation; or
285   // 4. the operand of a pointer subtraction operation
286   //    (i.e., computing the distance between two pointers); or ...
287 
288   // clang-format off
289   auto CallArgMatcher = callExpr(
290     forEachArgumentWithParamType(
291       InnerMatcher,
292       isAnyPointer() /* array also decays to pointer type*/),
293     unless(callee(
294       functionDecl(hasAttr(attr::UnsafeBufferUsage)))));
295 
296   auto CastOperandMatcher =
297       castExpr(anyOf(hasCastKind(CastKind::CK_PointerToIntegral),
298 		     hasCastKind(CastKind::CK_PointerToBoolean)),
299 	       castSubExpr(allOf(hasPointerType(), InnerMatcher)));
300 
301   auto CompOperandMatcher =
302       binaryOperator(hasAnyOperatorName("!=", "==", "<", "<=", ">", ">="),
303                      eachOf(hasLHS(allOf(hasPointerType(), InnerMatcher)),
304                             hasRHS(allOf(hasPointerType(), InnerMatcher))));
305 
306   // A matcher that matches pointer subtractions:
307   auto PtrSubtractionMatcher =
308       binaryOperator(hasOperatorName("-"),
309 		     // Note that here we need both LHS and RHS to be
310 		     // pointer. Then the inner matcher can match any of
311 		     // them:
312 		     allOf(hasLHS(hasPointerType()),
313 			   hasRHS(hasPointerType())),
314 		     eachOf(hasLHS(InnerMatcher),
315 			    hasRHS(InnerMatcher)));
316   // clang-format on
317 
318   return stmt(anyOf(CallArgMatcher, CastOperandMatcher, CompOperandMatcher,
319                     PtrSubtractionMatcher));
320   // FIXME: any more cases? (UPC excludes the RHS of an assignment.  For now we
321   // don't have to check that.)
322 }
323 
324 // Returns a matcher that matches any expression 'e' such that `innerMatcher`
325 // matches 'e' and 'e' is in an unspecified untyped context (i.e the expression
326 // 'e' isn't evaluated to an RValue). For example, consider the following code:
327 //    int *p = new int[4];
328 //    int *q = new int[4];
329 //    if ((p = q)) {}
330 //    p = q;
331 // The expression `p = q` in the conditional of the `if` statement
332 // `if ((p = q))` is evaluated as an RValue, whereas the expression `p = q;`
333 // in the assignment statement is in an untyped context.
334 static internal::Matcher<Stmt>
isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher)335 isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher) {
336   // An unspecified context can be
337   // 1. A compound statement,
338   // 2. The body of an if statement
339   // 3. Body of a loop
340   auto CompStmt = compoundStmt(forEach(InnerMatcher));
341   auto IfStmtThen = ifStmt(hasThen(InnerMatcher));
342   auto IfStmtElse = ifStmt(hasElse(InnerMatcher));
343   // FIXME: Handle loop bodies.
344   return stmt(anyOf(CompStmt, IfStmtThen, IfStmtElse));
345 }
346 
347 // Given a two-param std::span construct call, matches iff the call has the
348 // following forms:
349 //   1. `std::span<T>{new T[n], n}`, where `n` is a literal or a DRE
350 //   2. `std::span<T>{new T, 1}`
351 //   3. `std::span<T>{&var, 1}`
352 //   4. `std::span<T>{a, n}`, where `a` is of an array-of-T with constant size
353 //   `n`
354 //   5. `std::span<T>{any, 0}`
AST_MATCHER(CXXConstructExpr,isSafeSpanTwoParamConstruct)355 AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) {
356   assert(Node.getNumArgs() == 2 &&
357          "expecting a two-parameter std::span constructor");
358   const Expr *Arg0 = Node.getArg(0)->IgnoreImplicit();
359   const Expr *Arg1 = Node.getArg(1)->IgnoreImplicit();
360   auto HaveEqualConstantValues = [&Finder](const Expr *E0, const Expr *E1) {
361     if (auto E0CV = E0->getIntegerConstantExpr(Finder->getASTContext()))
362       if (auto E1CV = E1->getIntegerConstantExpr(Finder->getASTContext())) {
363         return APSInt::compareValues(*E0CV, *E1CV) == 0;
364       }
365     return false;
366   };
367   auto AreSameDRE = [](const Expr *E0, const Expr *E1) {
368     if (auto *DRE0 = dyn_cast<DeclRefExpr>(E0))
369       if (auto *DRE1 = dyn_cast<DeclRefExpr>(E1)) {
370         return DRE0->getDecl() == DRE1->getDecl();
371       }
372     return false;
373   };
374   std::optional<APSInt> Arg1CV =
375       Arg1->getIntegerConstantExpr(Finder->getASTContext());
376 
377   if (Arg1CV && Arg1CV->isZero())
378     // Check form 5:
379     return true;
380   switch (Arg0->IgnoreImplicit()->getStmtClass()) {
381   case Stmt::CXXNewExprClass:
382     if (auto Size = cast<CXXNewExpr>(Arg0)->getArraySize()) {
383       // Check form 1:
384       return AreSameDRE((*Size)->IgnoreImplicit(), Arg1) ||
385              HaveEqualConstantValues(*Size, Arg1);
386     }
387     // TODO: what's placeholder type? avoid it for now.
388     if (!cast<CXXNewExpr>(Arg0)->hasPlaceholderType()) {
389       // Check form 2:
390       return Arg1CV && Arg1CV->isOne();
391     }
392     break;
393   case Stmt::UnaryOperatorClass:
394     if (cast<UnaryOperator>(Arg0)->getOpcode() ==
395         UnaryOperator::Opcode::UO_AddrOf)
396       // Check form 3:
397       return Arg1CV && Arg1CV->isOne();
398     break;
399   default:
400     break;
401   }
402 
403   QualType Arg0Ty = Arg0->IgnoreImplicit()->getType();
404 
405   if (Arg0Ty->isConstantArrayType()) {
406     const APSInt ConstArrSize =
407         APSInt(cast<ConstantArrayType>(Arg0Ty)->getSize());
408 
409     // Check form 4:
410     return Arg1CV && APSInt::compareValues(ConstArrSize, *Arg1CV) == 0;
411   }
412   return false;
413 }
414 
AST_MATCHER(ArraySubscriptExpr,isSafeArraySubscript)415 AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) {
416   // FIXME: Proper solution:
417   //  - refactor Sema::CheckArrayAccess
418   //    - split safe/OOB/unknown decision logic from diagnostics emitting code
419   //    -  e. g. "Try harder to find a NamedDecl to point at in the note."
420   //    already duplicated
421   //  - call both from Sema and from here
422 
423   const auto *BaseDRE =
424       dyn_cast<DeclRefExpr>(Node.getBase()->IgnoreParenImpCasts());
425   if (!BaseDRE)
426     return false;
427   if (!BaseDRE->getDecl())
428     return false;
429   const auto *CATy = Finder->getASTContext().getAsConstantArrayType(
430       BaseDRE->getDecl()->getType());
431   if (!CATy)
432     return false;
433 
434   if (const auto *IdxLit = dyn_cast<IntegerLiteral>(Node.getIdx())) {
435     const APInt ArrIdx = IdxLit->getValue();
436     // FIXME: ArrIdx.isNegative() we could immediately emit an error as that's a
437     // bug
438     if (ArrIdx.isNonNegative() &&
439         ArrIdx.getLimitedValue() < CATy->getLimitedSize())
440       return true;
441   }
442 
443   return false;
444 }
445 
446 } // namespace clang::ast_matchers
447 
448 namespace {
449 // Because the analysis revolves around variables and their types, we'll need to
450 // track uses of variables (aka DeclRefExprs).
451 using DeclUseList = SmallVector<const DeclRefExpr *, 1>;
452 
453 // Convenience typedef.
454 using FixItList = SmallVector<FixItHint, 4>;
455 } // namespace
456 
457 namespace {
458 /// Gadget is an individual operation in the code that may be of interest to
459 /// this analysis. Each (non-abstract) subclass corresponds to a specific
460 /// rigid AST structure that constitutes an operation on a pointer-type object.
461 /// Discovery of a gadget in the code corresponds to claiming that we understand
462 /// what this part of code is doing well enough to potentially improve it.
463 /// Gadgets can be warning (immediately deserving a warning) or fixable (not
464 /// always deserving a warning per se, but requires our attention to identify
465 /// it warrants a fixit).
466 class Gadget {
467 public:
468   enum class Kind {
469 #define GADGET(x) x,
470 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
471   };
472 
473   /// Common type of ASTMatchers used for discovering gadgets.
474   /// Useful for implementing the static matcher() methods
475   /// that are expected from all non-abstract subclasses.
476   using Matcher = decltype(stmt());
477 
Gadget(Kind K)478   Gadget(Kind K) : K(K) {}
479 
getKind() const480   Kind getKind() const { return K; }
481 
482 #ifndef NDEBUG
getDebugName() const483   StringRef getDebugName() const {
484     switch (K) {
485 #define GADGET(x)                                                              \
486   case Kind::x:                                                                \
487     return #x;
488 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
489     }
490     llvm_unreachable("Unhandled Gadget::Kind enum");
491   }
492 #endif
493 
494   virtual bool isWarningGadget() const = 0;
495   // TODO remove this method from WarningGadget interface. It's only used for
496   // debug prints in FixableGadget.
497   virtual SourceLocation getSourceLoc() const = 0;
498 
499   /// Returns the list of pointer-type variables on which this gadget performs
500   /// its operation. Typically, there's only one variable. This isn't a list
501   /// of all DeclRefExprs in the gadget's AST!
502   virtual DeclUseList getClaimedVarUseSites() const = 0;
503 
504   virtual ~Gadget() = default;
505 
506 private:
507   Kind K;
508 };
509 
510 /// Warning gadgets correspond to unsafe code patterns that warrants
511 /// an immediate warning.
512 class WarningGadget : public Gadget {
513 public:
WarningGadget(Kind K)514   WarningGadget(Kind K) : Gadget(K) {}
515 
classof(const Gadget * G)516   static bool classof(const Gadget *G) { return G->isWarningGadget(); }
isWarningGadget() const517   bool isWarningGadget() const final { return true; }
518 
519   virtual void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
520                                      bool IsRelatedToDecl,
521                                      ASTContext &Ctx) const = 0;
522 };
523 
524 /// Fixable gadgets correspond to code patterns that aren't always unsafe but
525 /// need to be properly recognized in order to emit fixes. For example, if a raw
526 /// pointer-type variable is replaced by a safe C++ container, every use of such
527 /// variable must be carefully considered and possibly updated.
528 class FixableGadget : public Gadget {
529 public:
FixableGadget(Kind K)530   FixableGadget(Kind K) : Gadget(K) {}
531 
classof(const Gadget * G)532   static bool classof(const Gadget *G) { return !G->isWarningGadget(); }
isWarningGadget() const533   bool isWarningGadget() const final { return false; }
534 
535   /// Returns a fixit that would fix the current gadget according to
536   /// the current strategy. Returns std::nullopt if the fix cannot be produced;
537   /// returns an empty list if no fixes are necessary.
getFixits(const FixitStrategy &) const538   virtual std::optional<FixItList> getFixits(const FixitStrategy &) const {
539     return std::nullopt;
540   }
541 
542   /// Returns a list of two elements where the first element is the LHS of a
543   /// pointer assignment statement and the second element is the RHS. This
544   /// two-element list represents the fact that the LHS buffer gets its bounds
545   /// information from the RHS buffer. This information will be used later to
546   /// group all those variables whose types must be modified together to prevent
547   /// type mismatches.
548   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const549   getStrategyImplications() const {
550     return std::nullopt;
551   }
552 };
553 
toSupportedVariable()554 static auto toSupportedVariable() { return to(varDecl()); }
555 
556 using FixableGadgetList = std::vector<std::unique_ptr<FixableGadget>>;
557 using WarningGadgetList = std::vector<std::unique_ptr<WarningGadget>>;
558 
559 /// An increment of a pointer-type value is unsafe as it may run the pointer
560 /// out of bounds.
561 class IncrementGadget : public WarningGadget {
562   static constexpr const char *const OpTag = "op";
563   const UnaryOperator *Op;
564 
565 public:
IncrementGadget(const MatchFinder::MatchResult & Result)566   IncrementGadget(const MatchFinder::MatchResult &Result)
567       : WarningGadget(Kind::Increment),
568         Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {}
569 
classof(const Gadget * G)570   static bool classof(const Gadget *G) {
571     return G->getKind() == Kind::Increment;
572   }
573 
matcher()574   static Matcher matcher() {
575     return stmt(
576         unaryOperator(hasOperatorName("++"),
577                       hasUnaryOperand(ignoringParenImpCasts(hasPointerType())))
578             .bind(OpTag));
579   }
580 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const581   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
582                              bool IsRelatedToDecl,
583                              ASTContext &Ctx) const override {
584     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
585   }
getSourceLoc() const586   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
587 
getClaimedVarUseSites() const588   DeclUseList getClaimedVarUseSites() const override {
589     SmallVector<const DeclRefExpr *, 2> Uses;
590     if (const auto *DRE =
591             dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
592       Uses.push_back(DRE);
593     }
594 
595     return std::move(Uses);
596   }
597 };
598 
599 /// A decrement of a pointer-type value is unsafe as it may run the pointer
600 /// out of bounds.
601 class DecrementGadget : public WarningGadget {
602   static constexpr const char *const OpTag = "op";
603   const UnaryOperator *Op;
604 
605 public:
DecrementGadget(const MatchFinder::MatchResult & Result)606   DecrementGadget(const MatchFinder::MatchResult &Result)
607       : WarningGadget(Kind::Decrement),
608         Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {}
609 
classof(const Gadget * G)610   static bool classof(const Gadget *G) {
611     return G->getKind() == Kind::Decrement;
612   }
613 
matcher()614   static Matcher matcher() {
615     return stmt(
616         unaryOperator(hasOperatorName("--"),
617                       hasUnaryOperand(ignoringParenImpCasts(hasPointerType())))
618             .bind(OpTag));
619   }
620 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const621   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
622                              bool IsRelatedToDecl,
623                              ASTContext &Ctx) const override {
624     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
625   }
getSourceLoc() const626   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
627 
getClaimedVarUseSites() const628   DeclUseList getClaimedVarUseSites() const override {
629     if (const auto *DRE =
630             dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
631       return {DRE};
632     }
633 
634     return {};
635   }
636 };
637 
638 /// Array subscript expressions on raw pointers as if they're arrays. Unsafe as
639 /// it doesn't have any bounds checks for the array.
640 class ArraySubscriptGadget : public WarningGadget {
641   static constexpr const char *const ArraySubscrTag = "ArraySubscript";
642   const ArraySubscriptExpr *ASE;
643 
644 public:
ArraySubscriptGadget(const MatchFinder::MatchResult & Result)645   ArraySubscriptGadget(const MatchFinder::MatchResult &Result)
646       : WarningGadget(Kind::ArraySubscript),
647         ASE(Result.Nodes.getNodeAs<ArraySubscriptExpr>(ArraySubscrTag)) {}
648 
classof(const Gadget * G)649   static bool classof(const Gadget *G) {
650     return G->getKind() == Kind::ArraySubscript;
651   }
652 
matcher()653   static Matcher matcher() {
654     // clang-format off
655       return stmt(arraySubscriptExpr(
656             hasBase(ignoringParenImpCasts(
657               anyOf(hasPointerType(), hasArrayType()))),
658             unless(anyOf(
659               isSafeArraySubscript(),
660               hasIndex(
661                   anyOf(integerLiteral(equals(0)), arrayInitIndexExpr())
662               )
663             ))).bind(ArraySubscrTag));
664     // clang-format on
665   }
666 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const667   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
668                              bool IsRelatedToDecl,
669                              ASTContext &Ctx) const override {
670     Handler.handleUnsafeOperation(ASE, IsRelatedToDecl, Ctx);
671   }
getSourceLoc() const672   SourceLocation getSourceLoc() const override { return ASE->getBeginLoc(); }
673 
getClaimedVarUseSites() const674   DeclUseList getClaimedVarUseSites() const override {
675     if (const auto *DRE =
676             dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts())) {
677       return {DRE};
678     }
679 
680     return {};
681   }
682 };
683 
684 /// A pointer arithmetic expression of one of the forms:
685 ///  \code
686 ///  ptr + n | n + ptr | ptr - n | ptr += n | ptr -= n
687 ///  \endcode
688 class PointerArithmeticGadget : public WarningGadget {
689   static constexpr const char *const PointerArithmeticTag = "ptrAdd";
690   static constexpr const char *const PointerArithmeticPointerTag = "ptrAddPtr";
691   const BinaryOperator *PA; // pointer arithmetic expression
692   const Expr *Ptr;          // the pointer expression in `PA`
693 
694 public:
PointerArithmeticGadget(const MatchFinder::MatchResult & Result)695   PointerArithmeticGadget(const MatchFinder::MatchResult &Result)
696       : WarningGadget(Kind::PointerArithmetic),
697         PA(Result.Nodes.getNodeAs<BinaryOperator>(PointerArithmeticTag)),
698         Ptr(Result.Nodes.getNodeAs<Expr>(PointerArithmeticPointerTag)) {}
699 
classof(const Gadget * G)700   static bool classof(const Gadget *G) {
701     return G->getKind() == Kind::PointerArithmetic;
702   }
703 
matcher()704   static Matcher matcher() {
705     auto HasIntegerType = anyOf(hasType(isInteger()), hasType(enumType()));
706     auto PtrAtRight =
707         allOf(hasOperatorName("+"),
708               hasRHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)),
709               hasLHS(HasIntegerType));
710     auto PtrAtLeft =
711         allOf(anyOf(hasOperatorName("+"), hasOperatorName("-"),
712                     hasOperatorName("+="), hasOperatorName("-=")),
713               hasLHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)),
714               hasRHS(HasIntegerType));
715 
716     return stmt(binaryOperator(anyOf(PtrAtLeft, PtrAtRight))
717                     .bind(PointerArithmeticTag));
718   }
719 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const720   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
721                              bool IsRelatedToDecl,
722                              ASTContext &Ctx) const override {
723     Handler.handleUnsafeOperation(PA, IsRelatedToDecl, Ctx);
724   }
getSourceLoc() const725   SourceLocation getSourceLoc() const override { return PA->getBeginLoc(); }
726 
getClaimedVarUseSites() const727   DeclUseList getClaimedVarUseSites() const override {
728     if (const auto *DRE = dyn_cast<DeclRefExpr>(Ptr->IgnoreParenImpCasts())) {
729       return {DRE};
730     }
731 
732     return {};
733   }
734   // FIXME: pointer adding zero should be fine
735   // FIXME: this gadge will need a fix-it
736 };
737 
738 class SpanTwoParamConstructorGadget : public WarningGadget {
739   static constexpr const char *const SpanTwoParamConstructorTag =
740       "spanTwoParamConstructor";
741   const CXXConstructExpr *Ctor; // the span constructor expression
742 
743 public:
SpanTwoParamConstructorGadget(const MatchFinder::MatchResult & Result)744   SpanTwoParamConstructorGadget(const MatchFinder::MatchResult &Result)
745       : WarningGadget(Kind::SpanTwoParamConstructor),
746         Ctor(Result.Nodes.getNodeAs<CXXConstructExpr>(
747             SpanTwoParamConstructorTag)) {}
748 
classof(const Gadget * G)749   static bool classof(const Gadget *G) {
750     return G->getKind() == Kind::SpanTwoParamConstructor;
751   }
752 
matcher()753   static Matcher matcher() {
754     auto HasTwoParamSpanCtorDecl = hasDeclaration(
755         cxxConstructorDecl(hasDeclContext(isInStdNamespace()), hasName("span"),
756                            parameterCountIs(2)));
757 
758     return stmt(cxxConstructExpr(HasTwoParamSpanCtorDecl,
759                                  unless(isSafeSpanTwoParamConstruct()))
760                     .bind(SpanTwoParamConstructorTag));
761   }
762 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const763   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
764                              bool IsRelatedToDecl,
765                              ASTContext &Ctx) const override {
766     Handler.handleUnsafeOperationInContainer(Ctor, IsRelatedToDecl, Ctx);
767   }
getSourceLoc() const768   SourceLocation getSourceLoc() const override { return Ctor->getBeginLoc(); }
769 
getClaimedVarUseSites() const770   DeclUseList getClaimedVarUseSites() const override {
771     // If the constructor call is of the form `std::span{var, n}`, `var` is
772     // considered an unsafe variable.
773     if (auto *DRE = dyn_cast<DeclRefExpr>(Ctor->getArg(0))) {
774       if (isa<VarDecl>(DRE->getDecl()))
775         return {DRE};
776     }
777     return {};
778   }
779 };
780 
781 /// A pointer initialization expression of the form:
782 ///  \code
783 ///  int *p = q;
784 ///  \endcode
785 class PointerInitGadget : public FixableGadget {
786 private:
787   static constexpr const char *const PointerInitLHSTag = "ptrInitLHS";
788   static constexpr const char *const PointerInitRHSTag = "ptrInitRHS";
789   const VarDecl *PtrInitLHS;     // the LHS pointer expression in `PI`
790   const DeclRefExpr *PtrInitRHS; // the RHS pointer expression in `PI`
791 
792 public:
PointerInitGadget(const MatchFinder::MatchResult & Result)793   PointerInitGadget(const MatchFinder::MatchResult &Result)
794       : FixableGadget(Kind::PointerInit),
795         PtrInitLHS(Result.Nodes.getNodeAs<VarDecl>(PointerInitLHSTag)),
796         PtrInitRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerInitRHSTag)) {}
797 
classof(const Gadget * G)798   static bool classof(const Gadget *G) {
799     return G->getKind() == Kind::PointerInit;
800   }
801 
matcher()802   static Matcher matcher() {
803     auto PtrInitStmt = declStmt(hasSingleDecl(
804         varDecl(hasInitializer(ignoringImpCasts(
805                     declRefExpr(hasPointerType(), toSupportedVariable())
806                         .bind(PointerInitRHSTag))))
807             .bind(PointerInitLHSTag)));
808 
809     return stmt(PtrInitStmt);
810   }
811 
812   virtual std::optional<FixItList>
813   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const814   SourceLocation getSourceLoc() const override {
815     return PtrInitRHS->getBeginLoc();
816   }
817 
getClaimedVarUseSites() const818   virtual DeclUseList getClaimedVarUseSites() const override {
819     return DeclUseList{PtrInitRHS};
820   }
821 
822   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const823   getStrategyImplications() const override {
824     return std::make_pair(PtrInitLHS, cast<VarDecl>(PtrInitRHS->getDecl()));
825   }
826 };
827 
828 /// A pointer assignment expression of the form:
829 ///  \code
830 ///  p = q;
831 ///  \endcode
832 /// where both `p` and `q` are pointers.
833 class PtrToPtrAssignmentGadget : public FixableGadget {
834 private:
835   static constexpr const char *const PointerAssignLHSTag = "ptrLHS";
836   static constexpr const char *const PointerAssignRHSTag = "ptrRHS";
837   const DeclRefExpr *PtrLHS; // the LHS pointer expression in `PA`
838   const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA`
839 
840 public:
PtrToPtrAssignmentGadget(const MatchFinder::MatchResult & Result)841   PtrToPtrAssignmentGadget(const MatchFinder::MatchResult &Result)
842       : FixableGadget(Kind::PtrToPtrAssignment),
843         PtrLHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)),
844         PtrRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {}
845 
classof(const Gadget * G)846   static bool classof(const Gadget *G) {
847     return G->getKind() == Kind::PtrToPtrAssignment;
848   }
849 
matcher()850   static Matcher matcher() {
851     auto PtrAssignExpr = binaryOperator(
852         allOf(hasOperatorName("="),
853               hasRHS(ignoringParenImpCasts(
854                   declRefExpr(hasPointerType(), toSupportedVariable())
855                       .bind(PointerAssignRHSTag))),
856               hasLHS(declRefExpr(hasPointerType(), toSupportedVariable())
857                          .bind(PointerAssignLHSTag))));
858 
859     return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr));
860   }
861 
862   virtual std::optional<FixItList>
863   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const864   SourceLocation getSourceLoc() const override { return PtrLHS->getBeginLoc(); }
865 
getClaimedVarUseSites() const866   virtual DeclUseList getClaimedVarUseSites() const override {
867     return DeclUseList{PtrLHS, PtrRHS};
868   }
869 
870   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const871   getStrategyImplications() const override {
872     return std::make_pair(cast<VarDecl>(PtrLHS->getDecl()),
873                           cast<VarDecl>(PtrRHS->getDecl()));
874   }
875 };
876 
877 /// An assignment expression of the form:
878 ///  \code
879 ///  ptr = array;
880 ///  \endcode
881 /// where `p` is a pointer and `array` is a constant size array.
882 class CArrayToPtrAssignmentGadget : public FixableGadget {
883 private:
884   static constexpr const char *const PointerAssignLHSTag = "ptrLHS";
885   static constexpr const char *const PointerAssignRHSTag = "ptrRHS";
886   const DeclRefExpr *PtrLHS; // the LHS pointer expression in `PA`
887   const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA`
888 
889 public:
CArrayToPtrAssignmentGadget(const MatchFinder::MatchResult & Result)890   CArrayToPtrAssignmentGadget(const MatchFinder::MatchResult &Result)
891       : FixableGadget(Kind::CArrayToPtrAssignment),
892         PtrLHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)),
893         PtrRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {}
894 
classof(const Gadget * G)895   static bool classof(const Gadget *G) {
896     return G->getKind() == Kind::CArrayToPtrAssignment;
897   }
898 
matcher()899   static Matcher matcher() {
900     auto PtrAssignExpr = binaryOperator(
901         allOf(hasOperatorName("="),
902               hasRHS(ignoringParenImpCasts(
903                   declRefExpr(hasType(hasCanonicalType(constantArrayType())),
904                               toSupportedVariable())
905                       .bind(PointerAssignRHSTag))),
906               hasLHS(declRefExpr(hasPointerType(), toSupportedVariable())
907                          .bind(PointerAssignLHSTag))));
908 
909     return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr));
910   }
911 
912   virtual std::optional<FixItList>
913   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const914   SourceLocation getSourceLoc() const override { return PtrLHS->getBeginLoc(); }
915 
getClaimedVarUseSites() const916   virtual DeclUseList getClaimedVarUseSites() const override {
917     return DeclUseList{PtrLHS, PtrRHS};
918   }
919 
920   virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const921   getStrategyImplications() const override {
922     return {};
923   }
924 };
925 
926 /// A call of a function or method that performs unchecked buffer operations
927 /// over one of its pointer parameters.
928 class UnsafeBufferUsageAttrGadget : public WarningGadget {
929   constexpr static const char *const OpTag = "call_expr";
930   const CallExpr *Op;
931 
932 public:
UnsafeBufferUsageAttrGadget(const MatchFinder::MatchResult & Result)933   UnsafeBufferUsageAttrGadget(const MatchFinder::MatchResult &Result)
934       : WarningGadget(Kind::UnsafeBufferUsageAttr),
935         Op(Result.Nodes.getNodeAs<CallExpr>(OpTag)) {}
936 
classof(const Gadget * G)937   static bool classof(const Gadget *G) {
938     return G->getKind() == Kind::UnsafeBufferUsageAttr;
939   }
940 
matcher()941   static Matcher matcher() {
942     auto HasUnsafeFnDecl =
943         callee(functionDecl(hasAttr(attr::UnsafeBufferUsage)));
944     return stmt(callExpr(HasUnsafeFnDecl).bind(OpTag));
945   }
946 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const947   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
948                              bool IsRelatedToDecl,
949                              ASTContext &Ctx) const override {
950     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
951   }
getSourceLoc() const952   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
953 
getClaimedVarUseSites() const954   DeclUseList getClaimedVarUseSites() const override { return {}; }
955 };
956 
957 /// A call of a constructor that performs unchecked buffer operations
958 /// over one of its pointer parameters, or constructs a class object that will
959 /// perform buffer operations that depend on the correctness of the parameters.
960 class UnsafeBufferUsageCtorAttrGadget : public WarningGadget {
961   constexpr static const char *const OpTag = "cxx_construct_expr";
962   const CXXConstructExpr *Op;
963 
964 public:
UnsafeBufferUsageCtorAttrGadget(const MatchFinder::MatchResult & Result)965   UnsafeBufferUsageCtorAttrGadget(const MatchFinder::MatchResult &Result)
966       : WarningGadget(Kind::UnsafeBufferUsageCtorAttr),
967         Op(Result.Nodes.getNodeAs<CXXConstructExpr>(OpTag)) {}
968 
classof(const Gadget * G)969   static bool classof(const Gadget *G) {
970     return G->getKind() == Kind::UnsafeBufferUsageCtorAttr;
971   }
972 
matcher()973   static Matcher matcher() {
974     auto HasUnsafeCtorDecl =
975         hasDeclaration(cxxConstructorDecl(hasAttr(attr::UnsafeBufferUsage)));
976     // std::span(ptr, size) ctor is handled by SpanTwoParamConstructorGadget.
977     auto HasTwoParamSpanCtorDecl = SpanTwoParamConstructorGadget::matcher();
978     return stmt(
979         cxxConstructExpr(HasUnsafeCtorDecl, unless(HasTwoParamSpanCtorDecl))
980             .bind(OpTag));
981   }
982 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const983   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
984                              bool IsRelatedToDecl,
985                              ASTContext &Ctx) const override {
986     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
987   }
getSourceLoc() const988   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
989 
getClaimedVarUseSites() const990   DeclUseList getClaimedVarUseSites() const override { return {}; }
991 };
992 
993 // Warning gadget for unsafe invocation of span::data method.
994 // Triggers when the pointer returned by the invocation is immediately
995 // cast to a larger type.
996 
997 class DataInvocationGadget : public WarningGadget {
998   constexpr static const char *const OpTag = "data_invocation_expr";
999   const ExplicitCastExpr *Op;
1000 
1001 public:
DataInvocationGadget(const MatchFinder::MatchResult & Result)1002   DataInvocationGadget(const MatchFinder::MatchResult &Result)
1003       : WarningGadget(Kind::DataInvocation),
1004         Op(Result.Nodes.getNodeAs<ExplicitCastExpr>(OpTag)) {}
1005 
classof(const Gadget * G)1006   static bool classof(const Gadget *G) {
1007     return G->getKind() == Kind::DataInvocation;
1008   }
1009 
matcher()1010   static Matcher matcher() {
1011     Matcher callExpr = cxxMemberCallExpr(
1012         callee(cxxMethodDecl(hasName("data"), ofClass(hasName("std::span")))));
1013     return stmt(
1014         explicitCastExpr(anyOf(has(callExpr), has(parenExpr(has(callExpr)))))
1015             .bind(OpTag));
1016   }
1017 
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1018   void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1019                              bool IsRelatedToDecl,
1020                              ASTContext &Ctx) const override {
1021     Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
1022   }
getSourceLoc() const1023   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1024 
getClaimedVarUseSites() const1025   DeclUseList getClaimedVarUseSites() const override { return {}; }
1026 };
1027 
1028 // Represents expressions of the form `DRE[*]` in the Unspecified Lvalue
1029 // Context (see `isInUnspecifiedLvalueContext`).
1030 // Note here `[]` is the built-in subscript operator.
1031 class ULCArraySubscriptGadget : public FixableGadget {
1032 private:
1033   static constexpr const char *const ULCArraySubscriptTag =
1034       "ArraySubscriptUnderULC";
1035   const ArraySubscriptExpr *Node;
1036 
1037 public:
ULCArraySubscriptGadget(const MatchFinder::MatchResult & Result)1038   ULCArraySubscriptGadget(const MatchFinder::MatchResult &Result)
1039       : FixableGadget(Kind::ULCArraySubscript),
1040         Node(Result.Nodes.getNodeAs<ArraySubscriptExpr>(ULCArraySubscriptTag)) {
1041     assert(Node != nullptr && "Expecting a non-null matching result");
1042   }
1043 
classof(const Gadget * G)1044   static bool classof(const Gadget *G) {
1045     return G->getKind() == Kind::ULCArraySubscript;
1046   }
1047 
matcher()1048   static Matcher matcher() {
1049     auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType());
1050     auto BaseIsArrayOrPtrDRE = hasBase(
1051         ignoringParenImpCasts(declRefExpr(ArrayOrPtr, toSupportedVariable())));
1052     auto Target =
1053         arraySubscriptExpr(BaseIsArrayOrPtrDRE).bind(ULCArraySubscriptTag);
1054 
1055     return expr(isInUnspecifiedLvalueContext(Target));
1056   }
1057 
1058   virtual std::optional<FixItList>
1059   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1060   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1061 
getClaimedVarUseSites() const1062   virtual DeclUseList getClaimedVarUseSites() const override {
1063     if (const auto *DRE =
1064             dyn_cast<DeclRefExpr>(Node->getBase()->IgnoreImpCasts())) {
1065       return {DRE};
1066     }
1067     return {};
1068   }
1069 };
1070 
1071 // Fixable gadget to handle stand alone pointers of the form `UPC(DRE)` in the
1072 // unspecified pointer context (isInUnspecifiedPointerContext). The gadget emits
1073 // fixit of the form `UPC(DRE.data())`.
1074 class UPCStandalonePointerGadget : public FixableGadget {
1075 private:
1076   static constexpr const char *const DeclRefExprTag = "StandalonePointer";
1077   const DeclRefExpr *Node;
1078 
1079 public:
UPCStandalonePointerGadget(const MatchFinder::MatchResult & Result)1080   UPCStandalonePointerGadget(const MatchFinder::MatchResult &Result)
1081       : FixableGadget(Kind::UPCStandalonePointer),
1082         Node(Result.Nodes.getNodeAs<DeclRefExpr>(DeclRefExprTag)) {
1083     assert(Node != nullptr && "Expecting a non-null matching result");
1084   }
1085 
classof(const Gadget * G)1086   static bool classof(const Gadget *G) {
1087     return G->getKind() == Kind::UPCStandalonePointer;
1088   }
1089 
matcher()1090   static Matcher matcher() {
1091     auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType());
1092     auto target = expr(ignoringParenImpCasts(
1093         declRefExpr(allOf(ArrayOrPtr, toSupportedVariable()))
1094             .bind(DeclRefExprTag)));
1095     return stmt(isInUnspecifiedPointerContext(target));
1096   }
1097 
1098   virtual std::optional<FixItList>
1099   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1100   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1101 
getClaimedVarUseSites() const1102   virtual DeclUseList getClaimedVarUseSites() const override { return {Node}; }
1103 };
1104 
1105 class PointerDereferenceGadget : public FixableGadget {
1106   static constexpr const char *const BaseDeclRefExprTag = "BaseDRE";
1107   static constexpr const char *const OperatorTag = "op";
1108 
1109   const DeclRefExpr *BaseDeclRefExpr = nullptr;
1110   const UnaryOperator *Op = nullptr;
1111 
1112 public:
PointerDereferenceGadget(const MatchFinder::MatchResult & Result)1113   PointerDereferenceGadget(const MatchFinder::MatchResult &Result)
1114       : FixableGadget(Kind::PointerDereference),
1115         BaseDeclRefExpr(
1116             Result.Nodes.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)),
1117         Op(Result.Nodes.getNodeAs<UnaryOperator>(OperatorTag)) {}
1118 
classof(const Gadget * G)1119   static bool classof(const Gadget *G) {
1120     return G->getKind() == Kind::PointerDereference;
1121   }
1122 
matcher()1123   static Matcher matcher() {
1124     auto Target =
1125         unaryOperator(
1126             hasOperatorName("*"),
1127             has(expr(ignoringParenImpCasts(
1128                 declRefExpr(toSupportedVariable()).bind(BaseDeclRefExprTag)))))
1129             .bind(OperatorTag);
1130 
1131     return expr(isInUnspecifiedLvalueContext(Target));
1132   }
1133 
getClaimedVarUseSites() const1134   DeclUseList getClaimedVarUseSites() const override {
1135     return {BaseDeclRefExpr};
1136   }
1137 
1138   virtual std::optional<FixItList>
1139   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1140   SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1141 };
1142 
1143 // Represents expressions of the form `&DRE[any]` in the Unspecified Pointer
1144 // Context (see `isInUnspecifiedPointerContext`).
1145 // Note here `[]` is the built-in subscript operator.
1146 class UPCAddressofArraySubscriptGadget : public FixableGadget {
1147 private:
1148   static constexpr const char *const UPCAddressofArraySubscriptTag =
1149       "AddressofArraySubscriptUnderUPC";
1150   const UnaryOperator *Node; // the `&DRE[any]` node
1151 
1152 public:
UPCAddressofArraySubscriptGadget(const MatchFinder::MatchResult & Result)1153   UPCAddressofArraySubscriptGadget(const MatchFinder::MatchResult &Result)
1154       : FixableGadget(Kind::ULCArraySubscript),
1155         Node(Result.Nodes.getNodeAs<UnaryOperator>(
1156             UPCAddressofArraySubscriptTag)) {
1157     assert(Node != nullptr && "Expecting a non-null matching result");
1158   }
1159 
classof(const Gadget * G)1160   static bool classof(const Gadget *G) {
1161     return G->getKind() == Kind::UPCAddressofArraySubscript;
1162   }
1163 
matcher()1164   static Matcher matcher() {
1165     return expr(isInUnspecifiedPointerContext(expr(ignoringImpCasts(
1166         unaryOperator(
1167             hasOperatorName("&"),
1168             hasUnaryOperand(arraySubscriptExpr(hasBase(
1169                 ignoringParenImpCasts(declRefExpr(toSupportedVariable()))))))
1170             .bind(UPCAddressofArraySubscriptTag)))));
1171   }
1172 
1173   virtual std::optional<FixItList>
1174   getFixits(const FixitStrategy &) const override;
getSourceLoc() const1175   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1176 
getClaimedVarUseSites() const1177   virtual DeclUseList getClaimedVarUseSites() const override {
1178     const auto *ArraySubst = cast<ArraySubscriptExpr>(Node->getSubExpr());
1179     const auto *DRE =
1180         cast<DeclRefExpr>(ArraySubst->getBase()->IgnoreParenImpCasts());
1181     return {DRE};
1182   }
1183 };
1184 } // namespace
1185 
1186 namespace {
1187 // An auxiliary tracking facility for the fixit analysis. It helps connect
1188 // declarations to its uses and make sure we've covered all uses with our
1189 // analysis before we try to fix the declaration.
1190 class DeclUseTracker {
1191   using UseSetTy = SmallSet<const DeclRefExpr *, 16>;
1192   using DefMapTy = DenseMap<const VarDecl *, const DeclStmt *>;
1193 
1194   // Allocate on the heap for easier move.
1195   std::unique_ptr<UseSetTy> Uses{std::make_unique<UseSetTy>()};
1196   DefMapTy Defs{};
1197 
1198 public:
1199   DeclUseTracker() = default;
1200   DeclUseTracker(const DeclUseTracker &) = delete; // Let's avoid copies.
1201   DeclUseTracker &operator=(const DeclUseTracker &) = delete;
1202   DeclUseTracker(DeclUseTracker &&) = default;
1203   DeclUseTracker &operator=(DeclUseTracker &&) = default;
1204 
1205   // Start tracking a freshly discovered DRE.
discoverUse(const DeclRefExpr * DRE)1206   void discoverUse(const DeclRefExpr *DRE) { Uses->insert(DRE); }
1207 
1208   // Stop tracking the DRE as it's been fully figured out.
claimUse(const DeclRefExpr * DRE)1209   void claimUse(const DeclRefExpr *DRE) {
1210     assert(Uses->count(DRE) &&
1211            "DRE not found or claimed by multiple matchers!");
1212     Uses->erase(DRE);
1213   }
1214 
1215   // A variable is unclaimed if at least one use is unclaimed.
hasUnclaimedUses(const VarDecl * VD) const1216   bool hasUnclaimedUses(const VarDecl *VD) const {
1217     // FIXME: Can this be less linear? Maybe maintain a map from VDs to DREs?
1218     return any_of(*Uses, [VD](const DeclRefExpr *DRE) {
1219       return DRE->getDecl()->getCanonicalDecl() == VD->getCanonicalDecl();
1220     });
1221   }
1222 
getUnclaimedUses(const VarDecl * VD) const1223   UseSetTy getUnclaimedUses(const VarDecl *VD) const {
1224     UseSetTy ReturnSet;
1225     for (auto use : *Uses) {
1226       if (use->getDecl()->getCanonicalDecl() == VD->getCanonicalDecl()) {
1227         ReturnSet.insert(use);
1228       }
1229     }
1230     return ReturnSet;
1231   }
1232 
discoverDecl(const DeclStmt * DS)1233   void discoverDecl(const DeclStmt *DS) {
1234     for (const Decl *D : DS->decls()) {
1235       if (const auto *VD = dyn_cast<VarDecl>(D)) {
1236         // FIXME: Assertion temporarily disabled due to a bug in
1237         // ASTMatcher internal behavior in presence of GNU
1238         // statement-expressions. We need to properly investigate this
1239         // because it can screw up our algorithm in other ways.
1240         // assert(Defs.count(VD) == 0 && "Definition already discovered!");
1241         Defs[VD] = DS;
1242       }
1243     }
1244   }
1245 
lookupDecl(const VarDecl * VD) const1246   const DeclStmt *lookupDecl(const VarDecl *VD) const {
1247     return Defs.lookup(VD);
1248   }
1249 };
1250 } // namespace
1251 
1252 // Representing a pointer type expression of the form `++Ptr` in an Unspecified
1253 // Pointer Context (UPC):
1254 class UPCPreIncrementGadget : public FixableGadget {
1255 private:
1256   static constexpr const char *const UPCPreIncrementTag =
1257       "PointerPreIncrementUnderUPC";
1258   const UnaryOperator *Node; // the `++Ptr` node
1259 
1260 public:
UPCPreIncrementGadget(const MatchFinder::MatchResult & Result)1261   UPCPreIncrementGadget(const MatchFinder::MatchResult &Result)
1262       : FixableGadget(Kind::UPCPreIncrement),
1263         Node(Result.Nodes.getNodeAs<UnaryOperator>(UPCPreIncrementTag)) {
1264     assert(Node != nullptr && "Expecting a non-null matching result");
1265   }
1266 
classof(const Gadget * G)1267   static bool classof(const Gadget *G) {
1268     return G->getKind() == Kind::UPCPreIncrement;
1269   }
1270 
matcher()1271   static Matcher matcher() {
1272     // Note here we match `++Ptr` for any expression `Ptr` of pointer type.
1273     // Although currently we can only provide fix-its when `Ptr` is a DRE, we
1274     // can have the matcher be general, so long as `getClaimedVarUseSites` does
1275     // things right.
1276     return stmt(isInUnspecifiedPointerContext(expr(ignoringImpCasts(
1277         unaryOperator(isPreInc(),
1278                       hasUnaryOperand(declRefExpr(toSupportedVariable())))
1279             .bind(UPCPreIncrementTag)))));
1280   }
1281 
1282   virtual std::optional<FixItList>
1283   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1284   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1285 
getClaimedVarUseSites() const1286   virtual DeclUseList getClaimedVarUseSites() const override {
1287     return {dyn_cast<DeclRefExpr>(Node->getSubExpr())};
1288   }
1289 };
1290 
1291 // Representing a pointer type expression of the form `Ptr += n` in an
1292 // Unspecified Untyped Context (UUC):
1293 class UUCAddAssignGadget : public FixableGadget {
1294 private:
1295   static constexpr const char *const UUCAddAssignTag =
1296       "PointerAddAssignUnderUUC";
1297   static constexpr const char *const OffsetTag = "Offset";
1298 
1299   const BinaryOperator *Node; // the `Ptr += n` node
1300   const Expr *Offset = nullptr;
1301 
1302 public:
UUCAddAssignGadget(const MatchFinder::MatchResult & Result)1303   UUCAddAssignGadget(const MatchFinder::MatchResult &Result)
1304       : FixableGadget(Kind::UUCAddAssign),
1305         Node(Result.Nodes.getNodeAs<BinaryOperator>(UUCAddAssignTag)),
1306         Offset(Result.Nodes.getNodeAs<Expr>(OffsetTag)) {
1307     assert(Node != nullptr && "Expecting a non-null matching result");
1308   }
1309 
classof(const Gadget * G)1310   static bool classof(const Gadget *G) {
1311     return G->getKind() == Kind::UUCAddAssign;
1312   }
1313 
matcher()1314   static Matcher matcher() {
1315     // clang-format off
1316     return stmt(isInUnspecifiedUntypedContext(expr(ignoringImpCasts(
1317         binaryOperator(hasOperatorName("+="),
1318                        hasLHS(
1319                         declRefExpr(
1320                           hasPointerType(),
1321                           toSupportedVariable())),
1322                        hasRHS(expr().bind(OffsetTag)))
1323             .bind(UUCAddAssignTag)))));
1324     // clang-format on
1325   }
1326 
1327   virtual std::optional<FixItList>
1328   getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1329   SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1330 
getClaimedVarUseSites() const1331   virtual DeclUseList getClaimedVarUseSites() const override {
1332     return {dyn_cast<DeclRefExpr>(Node->getLHS())};
1333   }
1334 };
1335 
1336 // Representing a fixable expression of the form `*(ptr + 123)` or `*(123 +
1337 // ptr)`:
1338 class DerefSimplePtrArithFixableGadget : public FixableGadget {
1339   static constexpr const char *const BaseDeclRefExprTag = "BaseDRE";
1340   static constexpr const char *const DerefOpTag = "DerefOp";
1341   static constexpr const char *const AddOpTag = "AddOp";
1342   static constexpr const char *const OffsetTag = "Offset";
1343 
1344   const DeclRefExpr *BaseDeclRefExpr = nullptr;
1345   const UnaryOperator *DerefOp = nullptr;
1346   const BinaryOperator *AddOp = nullptr;
1347   const IntegerLiteral *Offset = nullptr;
1348 
1349 public:
DerefSimplePtrArithFixableGadget(const MatchFinder::MatchResult & Result)1350   DerefSimplePtrArithFixableGadget(const MatchFinder::MatchResult &Result)
1351       : FixableGadget(Kind::DerefSimplePtrArithFixable),
1352         BaseDeclRefExpr(
1353             Result.Nodes.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)),
1354         DerefOp(Result.Nodes.getNodeAs<UnaryOperator>(DerefOpTag)),
1355         AddOp(Result.Nodes.getNodeAs<BinaryOperator>(AddOpTag)),
1356         Offset(Result.Nodes.getNodeAs<IntegerLiteral>(OffsetTag)) {}
1357 
matcher()1358   static Matcher matcher() {
1359     // clang-format off
1360     auto ThePtr = expr(hasPointerType(),
1361                        ignoringImpCasts(declRefExpr(toSupportedVariable()).
1362                                         bind(BaseDeclRefExprTag)));
1363     auto PlusOverPtrAndInteger = expr(anyOf(
1364           binaryOperator(hasOperatorName("+"), hasLHS(ThePtr),
1365                          hasRHS(integerLiteral().bind(OffsetTag)))
1366                          .bind(AddOpTag),
1367           binaryOperator(hasOperatorName("+"), hasRHS(ThePtr),
1368                          hasLHS(integerLiteral().bind(OffsetTag)))
1369                          .bind(AddOpTag)));
1370     return isInUnspecifiedLvalueContext(unaryOperator(
1371         hasOperatorName("*"),
1372         hasUnaryOperand(ignoringParens(PlusOverPtrAndInteger)))
1373         .bind(DerefOpTag));
1374     // clang-format on
1375   }
1376 
1377   virtual std::optional<FixItList>
1378   getFixits(const FixitStrategy &s) const final;
getSourceLoc() const1379   SourceLocation getSourceLoc() const override {
1380     return DerefOp->getBeginLoc();
1381   }
1382 
getClaimedVarUseSites() const1383   virtual DeclUseList getClaimedVarUseSites() const final {
1384     return {BaseDeclRefExpr};
1385   }
1386 };
1387 
1388 /// Scan the function and return a list of gadgets found with provided kits.
1389 static std::tuple<FixableGadgetList, WarningGadgetList, DeclUseTracker>
findGadgets(const Decl * D,const UnsafeBufferUsageHandler & Handler,bool EmitSuggestions)1390 findGadgets(const Decl *D, const UnsafeBufferUsageHandler &Handler,
1391             bool EmitSuggestions) {
1392 
1393   struct GadgetFinderCallback : MatchFinder::MatchCallback {
1394     FixableGadgetList FixableGadgets;
1395     WarningGadgetList WarningGadgets;
1396     DeclUseTracker Tracker;
1397 
1398     void run(const MatchFinder::MatchResult &Result) override {
1399       // In debug mode, assert that we've found exactly one gadget.
1400       // This helps us avoid conflicts in .bind() tags.
1401 #if NDEBUG
1402 #define NEXT return
1403 #else
1404       [[maybe_unused]] int numFound = 0;
1405 #define NEXT ++numFound
1406 #endif
1407 
1408       if (const auto *DRE = Result.Nodes.getNodeAs<DeclRefExpr>("any_dre")) {
1409         Tracker.discoverUse(DRE);
1410         NEXT;
1411       }
1412 
1413       if (const auto *DS = Result.Nodes.getNodeAs<DeclStmt>("any_ds")) {
1414         Tracker.discoverDecl(DS);
1415         NEXT;
1416       }
1417 
1418       // Figure out which matcher we've found, and call the appropriate
1419       // subclass constructor.
1420       // FIXME: Can we do this more logarithmically?
1421 #define FIXABLE_GADGET(name)                                                   \
1422   if (Result.Nodes.getNodeAs<Stmt>(#name)) {                                   \
1423     FixableGadgets.push_back(std::make_unique<name##Gadget>(Result));          \
1424     NEXT;                                                                      \
1425   }
1426 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1427 #define WARNING_GADGET(name)                                                   \
1428   if (Result.Nodes.getNodeAs<Stmt>(#name)) {                                   \
1429     WarningGadgets.push_back(std::make_unique<name##Gadget>(Result));          \
1430     NEXT;                                                                      \
1431   }
1432 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1433 
1434       assert(numFound >= 1 && "Gadgets not found in match result!");
1435       assert(numFound <= 1 && "Conflicting bind tags in gadgets!");
1436     }
1437   };
1438 
1439   MatchFinder M;
1440   GadgetFinderCallback CB;
1441 
1442   // clang-format off
1443   M.addMatcher(
1444       stmt(
1445         forEachDescendantEvaluatedStmt(stmt(anyOf(
1446           // Add Gadget::matcher() for every gadget in the registry.
1447 #define WARNING_GADGET(x)                                                      \
1448           allOf(x ## Gadget::matcher().bind(#x),                               \
1449                 notInSafeBufferOptOut(&Handler)),
1450 #define WARNING_CONTAINER_GADGET(x)                                            \
1451           allOf(x ## Gadget::matcher().bind(#x),                               \
1452                 notInSafeBufferOptOut(&Handler),                               \
1453                 unless(ignoreUnsafeBufferInContainer(&Handler))),
1454 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1455             // Avoid a hanging comma.
1456             unless(stmt())
1457         )))
1458     ),
1459     &CB
1460   );
1461   // clang-format on
1462 
1463   if (EmitSuggestions) {
1464     // clang-format off
1465     M.addMatcher(
1466         stmt(
1467           forEachDescendantStmt(stmt(eachOf(
1468 #define FIXABLE_GADGET(x)                                                      \
1469             x ## Gadget::matcher().bind(#x),
1470 #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1471             // In parallel, match all DeclRefExprs so that to find out
1472             // whether there are any uncovered by gadgets.
1473             declRefExpr(anyOf(hasPointerType(), hasArrayType()),
1474                         to(anyOf(varDecl(), bindingDecl()))).bind("any_dre"),
1475             // Also match DeclStmts because we'll need them when fixing
1476             // their underlying VarDecls that otherwise don't have
1477             // any backreferences to DeclStmts.
1478             declStmt().bind("any_ds")
1479           )))
1480       ),
1481       &CB
1482     );
1483     // clang-format on
1484   }
1485 
1486   M.match(*D->getBody(), D->getASTContext());
1487   return {std::move(CB.FixableGadgets), std::move(CB.WarningGadgets),
1488           std::move(CB.Tracker)};
1489 }
1490 
1491 // Compares AST nodes by source locations.
1492 template <typename NodeTy> struct CompareNode {
operator ()CompareNode1493   bool operator()(const NodeTy *N1, const NodeTy *N2) const {
1494     return N1->getBeginLoc().getRawEncoding() <
1495            N2->getBeginLoc().getRawEncoding();
1496   }
1497 };
1498 
1499 struct WarningGadgetSets {
1500   std::map<const VarDecl *, std::set<const WarningGadget *>,
1501            // To keep keys sorted by their locations in the map so that the
1502            // order is deterministic:
1503            CompareNode<VarDecl>>
1504       byVar;
1505   // These Gadgets are not related to pointer variables (e. g. temporaries).
1506   llvm::SmallVector<const WarningGadget *, 16> noVar;
1507 };
1508 
1509 static WarningGadgetSets
groupWarningGadgetsByVar(const WarningGadgetList & AllUnsafeOperations)1510 groupWarningGadgetsByVar(const WarningGadgetList &AllUnsafeOperations) {
1511   WarningGadgetSets result;
1512   // If some gadgets cover more than one
1513   // variable, they'll appear more than once in the map.
1514   for (auto &G : AllUnsafeOperations) {
1515     DeclUseList ClaimedVarUseSites = G->getClaimedVarUseSites();
1516 
1517     bool AssociatedWithVarDecl = false;
1518     for (const DeclRefExpr *DRE : ClaimedVarUseSites) {
1519       if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1520         result.byVar[VD].insert(G.get());
1521         AssociatedWithVarDecl = true;
1522       }
1523     }
1524 
1525     if (!AssociatedWithVarDecl) {
1526       result.noVar.push_back(G.get());
1527       continue;
1528     }
1529   }
1530   return result;
1531 }
1532 
1533 struct FixableGadgetSets {
1534   std::map<const VarDecl *, std::set<const FixableGadget *>,
1535            // To keep keys sorted by their locations in the map so that the
1536            // order is deterministic:
1537            CompareNode<VarDecl>>
1538       byVar;
1539 };
1540 
1541 static FixableGadgetSets
groupFixablesByVar(FixableGadgetList && AllFixableOperations)1542 groupFixablesByVar(FixableGadgetList &&AllFixableOperations) {
1543   FixableGadgetSets FixablesForUnsafeVars;
1544   for (auto &F : AllFixableOperations) {
1545     DeclUseList DREs = F->getClaimedVarUseSites();
1546 
1547     for (const DeclRefExpr *DRE : DREs) {
1548       if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1549         FixablesForUnsafeVars.byVar[VD].insert(F.get());
1550       }
1551     }
1552   }
1553   return FixablesForUnsafeVars;
1554 }
1555 
anyConflict(const SmallVectorImpl<FixItHint> & FixIts,const SourceManager & SM)1556 bool clang::internal::anyConflict(const SmallVectorImpl<FixItHint> &FixIts,
1557                                   const SourceManager &SM) {
1558   // A simple interval overlap detection algorithm.  Sorts all ranges by their
1559   // begin location then finds the first overlap in one pass.
1560   std::vector<const FixItHint *> All; // a copy of `FixIts`
1561 
1562   for (const FixItHint &H : FixIts)
1563     All.push_back(&H);
1564   std::sort(All.begin(), All.end(),
1565             [&SM](const FixItHint *H1, const FixItHint *H2) {
1566               return SM.isBeforeInTranslationUnit(H1->RemoveRange.getBegin(),
1567                                                   H2->RemoveRange.getBegin());
1568             });
1569 
1570   const FixItHint *CurrHint = nullptr;
1571 
1572   for (const FixItHint *Hint : All) {
1573     if (!CurrHint ||
1574         SM.isBeforeInTranslationUnit(CurrHint->RemoveRange.getEnd(),
1575                                      Hint->RemoveRange.getBegin())) {
1576       // Either to initialize `CurrHint` or `CurrHint` does not
1577       // overlap with `Hint`:
1578       CurrHint = Hint;
1579     } else
1580       // In case `Hint` overlaps the `CurrHint`, we found at least one
1581       // conflict:
1582       return true;
1583   }
1584   return false;
1585 }
1586 
1587 std::optional<FixItList>
getFixits(const FixitStrategy & S) const1588 PtrToPtrAssignmentGadget::getFixits(const FixitStrategy &S) const {
1589   const auto *LeftVD = cast<VarDecl>(PtrLHS->getDecl());
1590   const auto *RightVD = cast<VarDecl>(PtrRHS->getDecl());
1591   switch (S.lookup(LeftVD)) {
1592   case FixitStrategy::Kind::Span:
1593     if (S.lookup(RightVD) == FixitStrategy::Kind::Span)
1594       return FixItList{};
1595     return std::nullopt;
1596   case FixitStrategy::Kind::Wontfix:
1597     return std::nullopt;
1598   case FixitStrategy::Kind::Iterator:
1599   case FixitStrategy::Kind::Array:
1600     return std::nullopt;
1601   case FixitStrategy::Kind::Vector:
1602     llvm_unreachable("unsupported strategies for FixableGadgets");
1603   }
1604   return std::nullopt;
1605 }
1606 
1607 /// \returns fixit that adds .data() call after \DRE.
1608 static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx,
1609                                                        const DeclRefExpr *DRE);
1610 
1611 std::optional<FixItList>
getFixits(const FixitStrategy & S) const1612 CArrayToPtrAssignmentGadget::getFixits(const FixitStrategy &S) const {
1613   const auto *LeftVD = cast<VarDecl>(PtrLHS->getDecl());
1614   const auto *RightVD = cast<VarDecl>(PtrRHS->getDecl());
1615   // TLDR: Implementing fixits for non-Wontfix strategy on both LHS and RHS is
1616   // non-trivial.
1617   //
1618   // CArrayToPtrAssignmentGadget doesn't have strategy implications because
1619   // constant size array propagates its bounds. Because of that LHS and RHS are
1620   // addressed by two different fixits.
1621   //
1622   // At the same time FixitStrategy S doesn't reflect what group a fixit belongs
1623   // to and can't be generally relied on in multi-variable Fixables!
1624   //
1625   // E. g. If an instance of this gadget is fixing variable on LHS then the
1626   // variable on RHS is fixed by a different fixit and its strategy for LHS
1627   // fixit is as if Wontfix.
1628   //
1629   // The only exception is Wontfix strategy for a given variable as that is
1630   // valid for any fixit produced for the given input source code.
1631   if (S.lookup(LeftVD) == FixitStrategy::Kind::Span) {
1632     if (S.lookup(RightVD) == FixitStrategy::Kind::Wontfix) {
1633       return FixItList{};
1634     }
1635   } else if (S.lookup(LeftVD) == FixitStrategy::Kind::Wontfix) {
1636     if (S.lookup(RightVD) == FixitStrategy::Kind::Array) {
1637       return createDataFixit(RightVD->getASTContext(), PtrRHS);
1638     }
1639   }
1640   return std::nullopt;
1641 }
1642 
1643 std::optional<FixItList>
getFixits(const FixitStrategy & S) const1644 PointerInitGadget::getFixits(const FixitStrategy &S) const {
1645   const auto *LeftVD = PtrInitLHS;
1646   const auto *RightVD = cast<VarDecl>(PtrInitRHS->getDecl());
1647   switch (S.lookup(LeftVD)) {
1648   case FixitStrategy::Kind::Span:
1649     if (S.lookup(RightVD) == FixitStrategy::Kind::Span)
1650       return FixItList{};
1651     return std::nullopt;
1652   case FixitStrategy::Kind::Wontfix:
1653     return std::nullopt;
1654   case FixitStrategy::Kind::Iterator:
1655   case FixitStrategy::Kind::Array:
1656     return std::nullopt;
1657   case FixitStrategy::Kind::Vector:
1658     llvm_unreachable("unsupported strategies for FixableGadgets");
1659   }
1660   return std::nullopt;
1661 }
1662 
isNonNegativeIntegerExpr(const Expr * Expr,const VarDecl * VD,const ASTContext & Ctx)1663 static bool isNonNegativeIntegerExpr(const Expr *Expr, const VarDecl *VD,
1664                                      const ASTContext &Ctx) {
1665   if (auto ConstVal = Expr->getIntegerConstantExpr(Ctx)) {
1666     if (ConstVal->isNegative())
1667       return false;
1668   } else if (!Expr->getType()->isUnsignedIntegerType())
1669     return false;
1670   return true;
1671 }
1672 
1673 std::optional<FixItList>
getFixits(const FixitStrategy & S) const1674 ULCArraySubscriptGadget::getFixits(const FixitStrategy &S) const {
1675   if (const auto *DRE =
1676           dyn_cast<DeclRefExpr>(Node->getBase()->IgnoreImpCasts()))
1677     if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1678       switch (S.lookup(VD)) {
1679       case FixitStrategy::Kind::Span: {
1680 
1681         // If the index has a negative constant value, we give up as no valid
1682         // fix-it can be generated:
1683         const ASTContext &Ctx = // FIXME: we need ASTContext to be passed in!
1684             VD->getASTContext();
1685         if (!isNonNegativeIntegerExpr(Node->getIdx(), VD, Ctx))
1686           return std::nullopt;
1687         // no-op is a good fix-it, otherwise
1688         return FixItList{};
1689       }
1690       case FixitStrategy::Kind::Array:
1691         return FixItList{};
1692       case FixitStrategy::Kind::Wontfix:
1693       case FixitStrategy::Kind::Iterator:
1694       case FixitStrategy::Kind::Vector:
1695         llvm_unreachable("unsupported strategies for FixableGadgets");
1696       }
1697     }
1698   return std::nullopt;
1699 }
1700 
1701 static std::optional<FixItList> // forward declaration
1702 fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator *Node);
1703 
1704 std::optional<FixItList>
getFixits(const FixitStrategy & S) const1705 UPCAddressofArraySubscriptGadget::getFixits(const FixitStrategy &S) const {
1706   auto DREs = getClaimedVarUseSites();
1707   const auto *VD = cast<VarDecl>(DREs.front()->getDecl());
1708 
1709   switch (S.lookup(VD)) {
1710   case FixitStrategy::Kind::Span:
1711     return fixUPCAddressofArraySubscriptWithSpan(Node);
1712   case FixitStrategy::Kind::Wontfix:
1713   case FixitStrategy::Kind::Iterator:
1714   case FixitStrategy::Kind::Array:
1715     return std::nullopt;
1716   case FixitStrategy::Kind::Vector:
1717     llvm_unreachable("unsupported strategies for FixableGadgets");
1718   }
1719   return std::nullopt; // something went wrong, no fix-it
1720 }
1721 
1722 // FIXME: this function should be customizable through format
getEndOfLine()1723 static StringRef getEndOfLine() {
1724   static const char *const EOL = "\n";
1725   return EOL;
1726 }
1727 
1728 // Returns the text indicating that the user needs to provide input there:
getUserFillPlaceHolder(StringRef HintTextToUser="placeholder")1729 std::string getUserFillPlaceHolder(StringRef HintTextToUser = "placeholder") {
1730   std::string s = std::string("<# ");
1731   s += HintTextToUser;
1732   s += " #>";
1733   return s;
1734 }
1735 
1736 // Return the source location of the last character of the AST `Node`.
1737 template <typename NodeTy>
1738 static std::optional<SourceLocation>
getEndCharLoc(const NodeTy * Node,const SourceManager & SM,const LangOptions & LangOpts)1739 getEndCharLoc(const NodeTy *Node, const SourceManager &SM,
1740               const LangOptions &LangOpts) {
1741   unsigned TkLen = Lexer::MeasureTokenLength(Node->getEndLoc(), SM, LangOpts);
1742   SourceLocation Loc = Node->getEndLoc().getLocWithOffset(TkLen - 1);
1743 
1744   if (Loc.isValid())
1745     return Loc;
1746 
1747   return std::nullopt;
1748 }
1749 
1750 // Return the source location just past the last character of the AST `Node`.
1751 template <typename NodeTy>
getPastLoc(const NodeTy * Node,const SourceManager & SM,const LangOptions & LangOpts)1752 static std::optional<SourceLocation> getPastLoc(const NodeTy *Node,
1753                                                 const SourceManager &SM,
1754                                                 const LangOptions &LangOpts) {
1755   SourceLocation Loc =
1756       Lexer::getLocForEndOfToken(Node->getEndLoc(), 0, SM, LangOpts);
1757   if (Loc.isValid())
1758     return Loc;
1759   return std::nullopt;
1760 }
1761 
1762 // Return text representation of an `Expr`.
getExprText(const Expr * E,const SourceManager & SM,const LangOptions & LangOpts)1763 static std::optional<StringRef> getExprText(const Expr *E,
1764                                             const SourceManager &SM,
1765                                             const LangOptions &LangOpts) {
1766   std::optional<SourceLocation> LastCharLoc = getPastLoc(E, SM, LangOpts);
1767 
1768   if (LastCharLoc)
1769     return Lexer::getSourceText(
1770         CharSourceRange::getCharRange(E->getBeginLoc(), *LastCharLoc), SM,
1771         LangOpts);
1772 
1773   return std::nullopt;
1774 }
1775 
1776 // Returns the literal text in `SourceRange SR`, if `SR` is a valid range.
getRangeText(SourceRange SR,const SourceManager & SM,const LangOptions & LangOpts)1777 static std::optional<StringRef> getRangeText(SourceRange SR,
1778                                              const SourceManager &SM,
1779                                              const LangOptions &LangOpts) {
1780   bool Invalid = false;
1781   CharSourceRange CSR = CharSourceRange::getCharRange(SR);
1782   StringRef Text = Lexer::getSourceText(CSR, SM, LangOpts, &Invalid);
1783 
1784   if (!Invalid)
1785     return Text;
1786   return std::nullopt;
1787 }
1788 
1789 // Returns the begin location of the identifier of the given variable
1790 // declaration.
getVarDeclIdentifierLoc(const VarDecl * VD)1791 static SourceLocation getVarDeclIdentifierLoc(const VarDecl *VD) {
1792   // According to the implementation of `VarDecl`, `VD->getLocation()` actually
1793   // returns the begin location of the identifier of the declaration:
1794   return VD->getLocation();
1795 }
1796 
1797 // Returns the literal text of the identifier of the given variable declaration.
1798 static std::optional<StringRef>
getVarDeclIdentifierText(const VarDecl * VD,const SourceManager & SM,const LangOptions & LangOpts)1799 getVarDeclIdentifierText(const VarDecl *VD, const SourceManager &SM,
1800                          const LangOptions &LangOpts) {
1801   SourceLocation ParmIdentBeginLoc = getVarDeclIdentifierLoc(VD);
1802   SourceLocation ParmIdentEndLoc =
1803       Lexer::getLocForEndOfToken(ParmIdentBeginLoc, 0, SM, LangOpts);
1804 
1805   if (ParmIdentEndLoc.isMacroID() &&
1806       !Lexer::isAtEndOfMacroExpansion(ParmIdentEndLoc, SM, LangOpts))
1807     return std::nullopt;
1808   return getRangeText({ParmIdentBeginLoc, ParmIdentEndLoc}, SM, LangOpts);
1809 }
1810 
1811 // We cannot fix a variable declaration if it has some other specifiers than the
1812 // type specifier.  Because the source ranges of those specifiers could overlap
1813 // with the source range that is being replaced using fix-its.  Especially when
1814 // we often cannot obtain accurate source ranges of cv-qualified type
1815 // specifiers.
1816 // FIXME: also deal with type attributes
hasUnsupportedSpecifiers(const VarDecl * VD,const SourceManager & SM)1817 static bool hasUnsupportedSpecifiers(const VarDecl *VD,
1818                                      const SourceManager &SM) {
1819   // AttrRangeOverlapping: true if at least one attribute of `VD` overlaps the
1820   // source range of `VD`:
1821   bool AttrRangeOverlapping = llvm::any_of(VD->attrs(), [&](Attr *At) -> bool {
1822     return !(SM.isBeforeInTranslationUnit(At->getRange().getEnd(),
1823                                           VD->getBeginLoc())) &&
1824            !(SM.isBeforeInTranslationUnit(VD->getEndLoc(),
1825                                           At->getRange().getBegin()));
1826   });
1827   return VD->isInlineSpecified() || VD->isConstexpr() ||
1828          VD->hasConstantInitialization() || !VD->hasLocalStorage() ||
1829          AttrRangeOverlapping;
1830 }
1831 
1832 // Returns the `SourceRange` of `D`.  The reason why this function exists is
1833 // that `D->getSourceRange()` may return a range where the end location is the
1834 // starting location of the last token.  The end location of the source range
1835 // returned by this function is the last location of the last token.
getSourceRangeToTokenEnd(const Decl * D,const SourceManager & SM,const LangOptions & LangOpts)1836 static SourceRange getSourceRangeToTokenEnd(const Decl *D,
1837                                             const SourceManager &SM,
1838                                             const LangOptions &LangOpts) {
1839   SourceLocation Begin = D->getBeginLoc();
1840   SourceLocation
1841       End = // `D->getEndLoc` should always return the starting location of the
1842       // last token, so we should get the end of the token
1843       Lexer::getLocForEndOfToken(D->getEndLoc(), 0, SM, LangOpts);
1844 
1845   return SourceRange(Begin, End);
1846 }
1847 
1848 // Returns the text of the pointee type of `T` from a `VarDecl` of a pointer
1849 // type. The text is obtained through from `TypeLoc`s.  Since `TypeLoc` does not
1850 // have source ranges of qualifiers ( The `QualifiedTypeLoc` looks hacky too me
1851 // :( ), `Qualifiers` of the pointee type is returned separately through the
1852 // output parameter `QualifiersToAppend`.
1853 static std::optional<std::string>
getPointeeTypeText(const VarDecl * VD,const SourceManager & SM,const LangOptions & LangOpts,std::optional<Qualifiers> * QualifiersToAppend)1854 getPointeeTypeText(const VarDecl *VD, const SourceManager &SM,
1855                    const LangOptions &LangOpts,
1856                    std::optional<Qualifiers> *QualifiersToAppend) {
1857   QualType Ty = VD->getType();
1858   QualType PteTy;
1859 
1860   assert(Ty->isPointerType() && !Ty->isFunctionPointerType() &&
1861          "Expecting a VarDecl of type of pointer to object type");
1862   PteTy = Ty->getPointeeType();
1863 
1864   TypeLoc TyLoc = VD->getTypeSourceInfo()->getTypeLoc().getUnqualifiedLoc();
1865   TypeLoc PteTyLoc;
1866 
1867   // We only deal with the cases that we know `TypeLoc::getNextTypeLoc` returns
1868   // the `TypeLoc` of the pointee type:
1869   switch (TyLoc.getTypeLocClass()) {
1870   case TypeLoc::ConstantArray:
1871   case TypeLoc::IncompleteArray:
1872   case TypeLoc::VariableArray:
1873   case TypeLoc::DependentSizedArray:
1874   case TypeLoc::Decayed:
1875     assert(isa<ParmVarDecl>(VD) && "An array type shall not be treated as a "
1876                                    "pointer type unless it decays.");
1877     PteTyLoc = TyLoc.getNextTypeLoc();
1878     break;
1879   case TypeLoc::Pointer:
1880     PteTyLoc = TyLoc.castAs<PointerTypeLoc>().getPointeeLoc();
1881     break;
1882   default:
1883     return std::nullopt;
1884   }
1885   if (PteTyLoc.isNull())
1886     // Sometimes we cannot get a useful `TypeLoc` for the pointee type, e.g.,
1887     // when the pointer type is `auto`.
1888     return std::nullopt;
1889 
1890   SourceLocation IdentLoc = getVarDeclIdentifierLoc(VD);
1891 
1892   if (!(IdentLoc.isValid() && PteTyLoc.getSourceRange().isValid())) {
1893     // We are expecting these locations to be valid. But in some cases, they are
1894     // not all valid. It is a Clang bug to me and we are not responsible for
1895     // fixing it.  So we will just give up for now when it happens.
1896     return std::nullopt;
1897   }
1898 
1899   // Note that TypeLoc.getEndLoc() returns the begin location of the last token:
1900   SourceLocation PteEndOfTokenLoc =
1901       Lexer::getLocForEndOfToken(PteTyLoc.getEndLoc(), 0, SM, LangOpts);
1902 
1903   if (!PteEndOfTokenLoc.isValid())
1904     // Sometimes we cannot get the end location of the pointee type, e.g., when
1905     // there are macros involved.
1906     return std::nullopt;
1907   if (!SM.isBeforeInTranslationUnit(PteEndOfTokenLoc, IdentLoc)) {
1908     // We only deal with the cases where the source text of the pointee type
1909     // appears on the left-hand side of the variable identifier completely,
1910     // including the following forms:
1911     // `T ident`,
1912     // `T ident[]`, where `T` is any type.
1913     // Examples of excluded cases are `T (*ident)[]` or `T ident[][n]`.
1914     return std::nullopt;
1915   }
1916   if (PteTy.hasQualifiers()) {
1917     // TypeLoc does not provide source ranges for qualifiers (it says it's
1918     // intentional but seems fishy to me), so we cannot get the full text
1919     // `PteTy` via source ranges.
1920     *QualifiersToAppend = PteTy.getQualifiers();
1921   }
1922   return getRangeText({PteTyLoc.getBeginLoc(), PteEndOfTokenLoc}, SM, LangOpts)
1923       ->str();
1924 }
1925 
1926 // Returns the text of the name (with qualifiers) of a `FunctionDecl`.
getFunNameText(const FunctionDecl * FD,const SourceManager & SM,const LangOptions & LangOpts)1927 static std::optional<StringRef> getFunNameText(const FunctionDecl *FD,
1928                                                const SourceManager &SM,
1929                                                const LangOptions &LangOpts) {
1930   SourceLocation BeginLoc = FD->getQualifier()
1931                                 ? FD->getQualifierLoc().getBeginLoc()
1932                                 : FD->getNameInfo().getBeginLoc();
1933   // Note that `FD->getNameInfo().getEndLoc()` returns the begin location of the
1934   // last token:
1935   SourceLocation EndLoc = Lexer::getLocForEndOfToken(
1936       FD->getNameInfo().getEndLoc(), 0, SM, LangOpts);
1937   SourceRange NameRange{BeginLoc, EndLoc};
1938 
1939   return getRangeText(NameRange, SM, LangOpts);
1940 }
1941 
1942 // Returns the text representing a `std::span` type where the element type is
1943 // represented by `EltTyText`.
1944 //
1945 // Note the optional parameter `Qualifiers`: one needs to pass qualifiers
1946 // explicitly if the element type needs to be qualified.
1947 static std::string
getSpanTypeText(StringRef EltTyText,std::optional<Qualifiers> Quals=std::nullopt)1948 getSpanTypeText(StringRef EltTyText,
1949                 std::optional<Qualifiers> Quals = std::nullopt) {
1950   const char *const SpanOpen = "std::span<";
1951 
1952   if (Quals)
1953     return SpanOpen + EltTyText.str() + ' ' + Quals->getAsString() + '>';
1954   return SpanOpen + EltTyText.str() + '>';
1955 }
1956 
1957 std::optional<FixItList>
getFixits(const FixitStrategy & s) const1958 DerefSimplePtrArithFixableGadget::getFixits(const FixitStrategy &s) const {
1959   const VarDecl *VD = dyn_cast<VarDecl>(BaseDeclRefExpr->getDecl());
1960 
1961   if (VD && s.lookup(VD) == FixitStrategy::Kind::Span) {
1962     ASTContext &Ctx = VD->getASTContext();
1963     // std::span can't represent elements before its begin()
1964     if (auto ConstVal = Offset->getIntegerConstantExpr(Ctx))
1965       if (ConstVal->isNegative())
1966         return std::nullopt;
1967 
1968     // note that the expr may (oddly) has multiple layers of parens
1969     // example:
1970     //   *((..(pointer + 123)..))
1971     // goal:
1972     //   pointer[123]
1973     // Fix-It:
1974     //   remove '*('
1975     //   replace ' + ' with '['
1976     //   replace ')' with ']'
1977 
1978     // example:
1979     //   *((..(123 + pointer)..))
1980     // goal:
1981     //   123[pointer]
1982     // Fix-It:
1983     //   remove '*('
1984     //   replace ' + ' with '['
1985     //   replace ')' with ']'
1986 
1987     const Expr *LHS = AddOp->getLHS(), *RHS = AddOp->getRHS();
1988     const SourceManager &SM = Ctx.getSourceManager();
1989     const LangOptions &LangOpts = Ctx.getLangOpts();
1990     CharSourceRange StarWithTrailWhitespace =
1991         clang::CharSourceRange::getCharRange(DerefOp->getOperatorLoc(),
1992                                              LHS->getBeginLoc());
1993 
1994     std::optional<SourceLocation> LHSLocation = getPastLoc(LHS, SM, LangOpts);
1995     if (!LHSLocation)
1996       return std::nullopt;
1997 
1998     CharSourceRange PlusWithSurroundingWhitespace =
1999         clang::CharSourceRange::getCharRange(*LHSLocation, RHS->getBeginLoc());
2000 
2001     std::optional<SourceLocation> AddOpLocation =
2002         getPastLoc(AddOp, SM, LangOpts);
2003     std::optional<SourceLocation> DerefOpLocation =
2004         getPastLoc(DerefOp, SM, LangOpts);
2005 
2006     if (!AddOpLocation || !DerefOpLocation)
2007       return std::nullopt;
2008 
2009     CharSourceRange ClosingParenWithPrecWhitespace =
2010         clang::CharSourceRange::getCharRange(*AddOpLocation, *DerefOpLocation);
2011 
2012     return FixItList{
2013         {FixItHint::CreateRemoval(StarWithTrailWhitespace),
2014          FixItHint::CreateReplacement(PlusWithSurroundingWhitespace, "["),
2015          FixItHint::CreateReplacement(ClosingParenWithPrecWhitespace, "]")}};
2016   }
2017   return std::nullopt; // something wrong or unsupported, give up
2018 }
2019 
2020 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2021 PointerDereferenceGadget::getFixits(const FixitStrategy &S) const {
2022   const VarDecl *VD = cast<VarDecl>(BaseDeclRefExpr->getDecl());
2023   switch (S.lookup(VD)) {
2024   case FixitStrategy::Kind::Span: {
2025     ASTContext &Ctx = VD->getASTContext();
2026     SourceManager &SM = Ctx.getSourceManager();
2027     // Required changes: *(ptr); => (ptr[0]); and *ptr; => ptr[0]
2028     // Deletes the *operand
2029     CharSourceRange derefRange = clang::CharSourceRange::getCharRange(
2030         Op->getBeginLoc(), Op->getBeginLoc().getLocWithOffset(1));
2031     // Inserts the [0]
2032     if (auto LocPastOperand =
2033             getPastLoc(BaseDeclRefExpr, SM, Ctx.getLangOpts())) {
2034       return FixItList{{FixItHint::CreateRemoval(derefRange),
2035                         FixItHint::CreateInsertion(*LocPastOperand, "[0]")}};
2036     }
2037     break;
2038   }
2039   case FixitStrategy::Kind::Iterator:
2040   case FixitStrategy::Kind::Array:
2041     return std::nullopt;
2042   case FixitStrategy::Kind::Vector:
2043     llvm_unreachable("FixitStrategy not implemented yet!");
2044   case FixitStrategy::Kind::Wontfix:
2045     llvm_unreachable("Invalid strategy!");
2046   }
2047 
2048   return std::nullopt;
2049 }
2050 
createDataFixit(const ASTContext & Ctx,const DeclRefExpr * DRE)2051 static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx,
2052                                                        const DeclRefExpr *DRE) {
2053   const SourceManager &SM = Ctx.getSourceManager();
2054   // Inserts the .data() after the DRE
2055   std::optional<SourceLocation> EndOfOperand =
2056       getPastLoc(DRE, SM, Ctx.getLangOpts());
2057 
2058   if (EndOfOperand)
2059     return FixItList{{FixItHint::CreateInsertion(*EndOfOperand, ".data()")}};
2060 
2061   return std::nullopt;
2062 }
2063 
2064 // Generates fix-its replacing an expression of the form UPC(DRE) with
2065 // `DRE.data()`
2066 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2067 UPCStandalonePointerGadget::getFixits(const FixitStrategy &S) const {
2068   const auto VD = cast<VarDecl>(Node->getDecl());
2069   switch (S.lookup(VD)) {
2070   case FixitStrategy::Kind::Array:
2071   case FixitStrategy::Kind::Span: {
2072     return createDataFixit(VD->getASTContext(), Node);
2073     // FIXME: Points inside a macro expansion.
2074     break;
2075   }
2076   case FixitStrategy::Kind::Wontfix:
2077   case FixitStrategy::Kind::Iterator:
2078     return std::nullopt;
2079   case FixitStrategy::Kind::Vector:
2080     llvm_unreachable("unsupported strategies for FixableGadgets");
2081   }
2082 
2083   return std::nullopt;
2084 }
2085 
2086 // Generates fix-its replacing an expression of the form `&DRE[e]` with
2087 // `&DRE.data()[e]`:
2088 static std::optional<FixItList>
fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator * Node)2089 fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator *Node) {
2090   const auto *ArraySub = cast<ArraySubscriptExpr>(Node->getSubExpr());
2091   const auto *DRE = cast<DeclRefExpr>(ArraySub->getBase()->IgnoreImpCasts());
2092   // FIXME: this `getASTContext` call is costly, we should pass the
2093   // ASTContext in:
2094   const ASTContext &Ctx = DRE->getDecl()->getASTContext();
2095   const Expr *Idx = ArraySub->getIdx();
2096   const SourceManager &SM = Ctx.getSourceManager();
2097   const LangOptions &LangOpts = Ctx.getLangOpts();
2098   std::stringstream SS;
2099   bool IdxIsLitZero = false;
2100 
2101   if (auto ICE = Idx->getIntegerConstantExpr(Ctx))
2102     if ((*ICE).isZero())
2103       IdxIsLitZero = true;
2104   std::optional<StringRef> DreString = getExprText(DRE, SM, LangOpts);
2105   if (!DreString)
2106     return std::nullopt;
2107 
2108   if (IdxIsLitZero) {
2109     // If the index is literal zero, we produce the most concise fix-it:
2110     SS << (*DreString).str() << ".data()";
2111   } else {
2112     std::optional<StringRef> IndexString = getExprText(Idx, SM, LangOpts);
2113     if (!IndexString)
2114       return std::nullopt;
2115 
2116     SS << "&" << (*DreString).str() << ".data()"
2117        << "[" << (*IndexString).str() << "]";
2118   }
2119   return FixItList{
2120       FixItHint::CreateReplacement(Node->getSourceRange(), SS.str())};
2121 }
2122 
2123 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2124 UUCAddAssignGadget::getFixits(const FixitStrategy &S) const {
2125   DeclUseList DREs = getClaimedVarUseSites();
2126 
2127   if (DREs.size() != 1)
2128     return std::nullopt; // In cases of `Ptr += n` where `Ptr` is not a DRE, we
2129                          // give up
2130   if (const VarDecl *VD = dyn_cast<VarDecl>(DREs.front()->getDecl())) {
2131     if (S.lookup(VD) == FixitStrategy::Kind::Span) {
2132       FixItList Fixes;
2133 
2134       const Stmt *AddAssignNode = Node;
2135       StringRef varName = VD->getName();
2136       const ASTContext &Ctx = VD->getASTContext();
2137 
2138       if (!isNonNegativeIntegerExpr(Offset, VD, Ctx))
2139         return std::nullopt;
2140 
2141       // To transform UUC(p += n) to UUC(p = p.subspan(..)):
2142       bool NotParenExpr =
2143           (Offset->IgnoreParens()->getBeginLoc() == Offset->getBeginLoc());
2144       std::string SS = varName.str() + " = " + varName.str() + ".subspan";
2145       if (NotParenExpr)
2146         SS += "(";
2147 
2148       std::optional<SourceLocation> AddAssignLocation = getEndCharLoc(
2149           AddAssignNode, Ctx.getSourceManager(), Ctx.getLangOpts());
2150       if (!AddAssignLocation)
2151         return std::nullopt;
2152 
2153       Fixes.push_back(FixItHint::CreateReplacement(
2154           SourceRange(AddAssignNode->getBeginLoc(), Node->getOperatorLoc()),
2155           SS));
2156       if (NotParenExpr)
2157         Fixes.push_back(FixItHint::CreateInsertion(
2158             Offset->getEndLoc().getLocWithOffset(1), ")"));
2159       return Fixes;
2160     }
2161   }
2162   return std::nullopt; // Not in the cases that we can handle for now, give up.
2163 }
2164 
2165 std::optional<FixItList>
getFixits(const FixitStrategy & S) const2166 UPCPreIncrementGadget::getFixits(const FixitStrategy &S) const {
2167   DeclUseList DREs = getClaimedVarUseSites();
2168 
2169   if (DREs.size() != 1)
2170     return std::nullopt; // In cases of `++Ptr` where `Ptr` is not a DRE, we
2171                          // give up
2172   if (const VarDecl *VD = dyn_cast<VarDecl>(DREs.front()->getDecl())) {
2173     if (S.lookup(VD) == FixitStrategy::Kind::Span) {
2174       FixItList Fixes;
2175       std::stringstream SS;
2176       StringRef varName = VD->getName();
2177       const ASTContext &Ctx = VD->getASTContext();
2178 
2179       // To transform UPC(++p) to UPC((p = p.subspan(1)).data()):
2180       SS << "(" << varName.data() << " = " << varName.data()
2181          << ".subspan(1)).data()";
2182       std::optional<SourceLocation> PreIncLocation =
2183           getEndCharLoc(Node, Ctx.getSourceManager(), Ctx.getLangOpts());
2184       if (!PreIncLocation)
2185         return std::nullopt;
2186 
2187       Fixes.push_back(FixItHint::CreateReplacement(
2188           SourceRange(Node->getBeginLoc(), *PreIncLocation), SS.str()));
2189       return Fixes;
2190     }
2191   }
2192   return std::nullopt; // Not in the cases that we can handle for now, give up.
2193 }
2194 
2195 // For a non-null initializer `Init` of `T *` type, this function returns
2196 // `FixItHint`s producing a list initializer `{Init,  S}` as a part of a fix-it
2197 // to output stream.
2198 // In many cases, this function cannot figure out the actual extent `S`.  It
2199 // then will use a place holder to replace `S` to ask users to fill `S` in.  The
2200 // initializer shall be used to initialize a variable of type `std::span<T>`.
2201 // In some cases (e. g. constant size array) the initializer should remain
2202 // unchanged and the function returns empty list. In case the function can't
2203 // provide the right fixit it will return nullopt.
2204 //
2205 // FIXME: Support multi-level pointers
2206 //
2207 // Parameters:
2208 //   `Init` a pointer to the initializer expression
2209 //   `Ctx` a reference to the ASTContext
2210 static std::optional<FixItList>
FixVarInitializerWithSpan(const Expr * Init,ASTContext & Ctx,const StringRef UserFillPlaceHolder)2211 FixVarInitializerWithSpan(const Expr *Init, ASTContext &Ctx,
2212                           const StringRef UserFillPlaceHolder) {
2213   const SourceManager &SM = Ctx.getSourceManager();
2214   const LangOptions &LangOpts = Ctx.getLangOpts();
2215 
2216   // If `Init` has a constant value that is (or equivalent to) a
2217   // NULL pointer, we use the default constructor to initialize the span
2218   // object, i.e., a `std:span` variable declaration with no initializer.
2219   // So the fix-it is just to remove the initializer.
2220   if (Init->isNullPointerConstant(
2221           Ctx,
2222           // FIXME: Why does this function not ask for `const ASTContext
2223           // &`? It should. Maybe worth an NFC patch later.
2224           Expr::NullPointerConstantValueDependence::
2225               NPC_ValueDependentIsNotNull)) {
2226     std::optional<SourceLocation> InitLocation =
2227         getEndCharLoc(Init, SM, LangOpts);
2228     if (!InitLocation)
2229       return std::nullopt;
2230 
2231     SourceRange SR(Init->getBeginLoc(), *InitLocation);
2232 
2233     return FixItList{FixItHint::CreateRemoval(SR)};
2234   }
2235 
2236   FixItList FixIts{};
2237   std::string ExtentText = UserFillPlaceHolder.data();
2238   StringRef One = "1";
2239 
2240   // Insert `{` before `Init`:
2241   FixIts.push_back(FixItHint::CreateInsertion(Init->getBeginLoc(), "{"));
2242   // Try to get the data extent. Break into different cases:
2243   if (auto CxxNew = dyn_cast<CXXNewExpr>(Init->IgnoreImpCasts())) {
2244     // In cases `Init` is `new T[n]` and there is no explicit cast over
2245     // `Init`, we know that `Init` must evaluates to a pointer to `n` objects
2246     // of `T`. So the extent is `n` unless `n` has side effects.  Similar but
2247     // simpler for the case where `Init` is `new T`.
2248     if (const Expr *Ext = CxxNew->getArraySize().value_or(nullptr)) {
2249       if (!Ext->HasSideEffects(Ctx)) {
2250         std::optional<StringRef> ExtentString = getExprText(Ext, SM, LangOpts);
2251         if (!ExtentString)
2252           return std::nullopt;
2253         ExtentText = *ExtentString;
2254       }
2255     } else if (!CxxNew->isArray())
2256       // Although the initializer is not allocating a buffer, the pointer
2257       // variable could still be used in buffer access operations.
2258       ExtentText = One;
2259   } else if (Ctx.getAsConstantArrayType(Init->IgnoreImpCasts()->getType())) {
2260     // std::span has a single parameter constructor for initialization with
2261     // constant size array. The size is auto-deduced as the constructor is a
2262     // function template. The correct fixit is empty - no changes should happen.
2263     return FixItList{};
2264   } else {
2265     // In cases `Init` is of the form `&Var` after stripping of implicit
2266     // casts, where `&` is the built-in operator, the extent is 1.
2267     if (auto AddrOfExpr = dyn_cast<UnaryOperator>(Init->IgnoreImpCasts()))
2268       if (AddrOfExpr->getOpcode() == UnaryOperatorKind::UO_AddrOf &&
2269           isa_and_present<DeclRefExpr>(AddrOfExpr->getSubExpr()))
2270         ExtentText = One;
2271     // TODO: we can handle more cases, e.g., `&a[0]`, `&a`, `std::addressof`,
2272     // and explicit casting, etc. etc.
2273   }
2274 
2275   SmallString<32> StrBuffer{};
2276   std::optional<SourceLocation> LocPassInit = getPastLoc(Init, SM, LangOpts);
2277 
2278   if (!LocPassInit)
2279     return std::nullopt;
2280 
2281   StrBuffer.append(", ");
2282   StrBuffer.append(ExtentText);
2283   StrBuffer.append("}");
2284   FixIts.push_back(FixItHint::CreateInsertion(*LocPassInit, StrBuffer.str()));
2285   return FixIts;
2286 }
2287 
2288 #ifndef NDEBUG
2289 #define DEBUG_NOTE_DECL_FAIL(D, Msg)                                           \
2290   Handler.addDebugNoteForVar((D), (D)->getBeginLoc(),                          \
2291                              "failed to produce fixit for declaration '" +     \
2292                                  (D)->getNameAsString() + "'" + (Msg))
2293 #else
2294 #define DEBUG_NOTE_DECL_FAIL(D, Msg)
2295 #endif
2296 
2297 // For the given variable declaration with a pointer-to-T type, returns the text
2298 // `std::span<T>`.  If it is unable to generate the text, returns
2299 // `std::nullopt`.
2300 static std::optional<std::string>
createSpanTypeForVarDecl(const VarDecl * VD,const ASTContext & Ctx)2301 createSpanTypeForVarDecl(const VarDecl *VD, const ASTContext &Ctx) {
2302   assert(VD->getType()->isPointerType());
2303 
2304   std::optional<Qualifiers> PteTyQualifiers = std::nullopt;
2305   std::optional<std::string> PteTyText = getPointeeTypeText(
2306       VD, Ctx.getSourceManager(), Ctx.getLangOpts(), &PteTyQualifiers);
2307 
2308   if (!PteTyText)
2309     return std::nullopt;
2310 
2311   std::string SpanTyText = "std::span<";
2312 
2313   SpanTyText.append(*PteTyText);
2314   // Append qualifiers to span element type if any:
2315   if (PteTyQualifiers) {
2316     SpanTyText.append(" ");
2317     SpanTyText.append(PteTyQualifiers->getAsString());
2318   }
2319   SpanTyText.append(">");
2320   return SpanTyText;
2321 }
2322 
2323 // For a `VarDecl` of the form `T  * var (= Init)?`, this
2324 // function generates fix-its that
2325 //  1) replace `T * var` with `std::span<T> var`; and
2326 //  2) change `Init` accordingly to a span constructor, if it exists.
2327 //
2328 // FIXME: support Multi-level pointers
2329 //
2330 // Parameters:
2331 //   `D` a pointer the variable declaration node
2332 //   `Ctx` a reference to the ASTContext
2333 //   `UserFillPlaceHolder` the user-input placeholder text
2334 // Returns:
2335 //    the non-empty fix-it list, if fix-its are successfuly generated; empty
2336 //    list otherwise.
fixLocalVarDeclWithSpan(const VarDecl * D,ASTContext & Ctx,const StringRef UserFillPlaceHolder,UnsafeBufferUsageHandler & Handler)2337 static FixItList fixLocalVarDeclWithSpan(const VarDecl *D, ASTContext &Ctx,
2338                                          const StringRef UserFillPlaceHolder,
2339                                          UnsafeBufferUsageHandler &Handler) {
2340   if (hasUnsupportedSpecifiers(D, Ctx.getSourceManager()))
2341     return {};
2342 
2343   FixItList FixIts{};
2344   std::optional<std::string> SpanTyText = createSpanTypeForVarDecl(D, Ctx);
2345 
2346   if (!SpanTyText) {
2347     DEBUG_NOTE_DECL_FAIL(D, " : failed to generate 'std::span' type");
2348     return {};
2349   }
2350 
2351   // Will hold the text for `std::span<T> Ident`:
2352   std::stringstream SS;
2353 
2354   SS << *SpanTyText;
2355   // Fix the initializer if it exists:
2356   if (const Expr *Init = D->getInit()) {
2357     std::optional<FixItList> InitFixIts =
2358         FixVarInitializerWithSpan(Init, Ctx, UserFillPlaceHolder);
2359     if (!InitFixIts)
2360       return {};
2361     FixIts.insert(FixIts.end(), std::make_move_iterator(InitFixIts->begin()),
2362                   std::make_move_iterator(InitFixIts->end()));
2363   }
2364   // For declaration of the form `T * ident = init;`, we want to replace
2365   // `T * ` with `std::span<T>`.
2366   // We ignore CV-qualifiers so for `T * const ident;` we also want to replace
2367   // just `T *` with `std::span<T>`.
2368   const SourceLocation EndLocForReplacement = D->getTypeSpecEndLoc();
2369   if (!EndLocForReplacement.isValid()) {
2370     DEBUG_NOTE_DECL_FAIL(D, " : failed to locate the end of the declaration");
2371     return {};
2372   }
2373   // The only exception is that for `T *ident` we'll add a single space between
2374   // "std::span<T>" and "ident".
2375   // FIXME: The condition is false for identifiers expended from macros.
2376   if (EndLocForReplacement.getLocWithOffset(1) == getVarDeclIdentifierLoc(D))
2377     SS << " ";
2378 
2379   FixIts.push_back(FixItHint::CreateReplacement(
2380       SourceRange(D->getBeginLoc(), EndLocForReplacement), SS.str()));
2381   return FixIts;
2382 }
2383 
hasConflictingOverload(const FunctionDecl * FD)2384 static bool hasConflictingOverload(const FunctionDecl *FD) {
2385   return !FD->getDeclContext()->lookup(FD->getDeclName()).isSingleResult();
2386 }
2387 
2388 // For a `FunctionDecl`, whose `ParmVarDecl`s are being changed to have new
2389 // types, this function produces fix-its to make the change self-contained.  Let
2390 // 'F' be the entity defined by the original `FunctionDecl` and "NewF" be the
2391 // entity defined by the `FunctionDecl` after the change to the parameters.
2392 // Fix-its produced by this function are
2393 //   1. Add the `[[clang::unsafe_buffer_usage]]` attribute to each declaration
2394 //   of 'F';
2395 //   2. Create a declaration of "NewF" next to each declaration of `F`;
2396 //   3. Create a definition of "F" (as its' original definition is now belongs
2397 //      to "NewF") next to its original definition.  The body of the creating
2398 //      definition calls to "NewF".
2399 //
2400 // Example:
2401 //
2402 // void f(int *p);  // original declaration
2403 // void f(int *p) { // original definition
2404 //    p[5];
2405 // }
2406 //
2407 // To change the parameter `p` to be of `std::span<int>` type, we
2408 // also add overloads:
2409 //
2410 // [[clang::unsafe_buffer_usage]] void f(int *p); // original decl
2411 // void f(std::span<int> p);                      // added overload decl
2412 // void f(std::span<int> p) {     // original def where param is changed
2413 //    p[5];
2414 // }
2415 // [[clang::unsafe_buffer_usage]] void f(int *p) {  // added def
2416 //   return f(std::span(p, <# size #>));
2417 // }
2418 //
2419 static std::optional<FixItList>
createOverloadsForFixedParams(const FixitStrategy & S,const FunctionDecl * FD,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2420 createOverloadsForFixedParams(const FixitStrategy &S, const FunctionDecl *FD,
2421                               const ASTContext &Ctx,
2422                               UnsafeBufferUsageHandler &Handler) {
2423   // FIXME: need to make this conflict checking better:
2424   if (hasConflictingOverload(FD))
2425     return std::nullopt;
2426 
2427   const SourceManager &SM = Ctx.getSourceManager();
2428   const LangOptions &LangOpts = Ctx.getLangOpts();
2429   const unsigned NumParms = FD->getNumParams();
2430   std::vector<std::string> NewTysTexts(NumParms);
2431   std::vector<bool> ParmsMask(NumParms, false);
2432   bool AtLeastOneParmToFix = false;
2433 
2434   for (unsigned i = 0; i < NumParms; i++) {
2435     const ParmVarDecl *PVD = FD->getParamDecl(i);
2436 
2437     if (S.lookup(PVD) == FixitStrategy::Kind::Wontfix)
2438       continue;
2439     if (S.lookup(PVD) != FixitStrategy::Kind::Span)
2440       // Not supported, not suppose to happen:
2441       return std::nullopt;
2442 
2443     std::optional<Qualifiers> PteTyQuals = std::nullopt;
2444     std::optional<std::string> PteTyText =
2445         getPointeeTypeText(PVD, SM, LangOpts, &PteTyQuals);
2446 
2447     if (!PteTyText)
2448       // something wrong in obtaining the text of the pointee type, give up
2449       return std::nullopt;
2450     // FIXME: whether we should create std::span type depends on the
2451     // FixitStrategy.
2452     NewTysTexts[i] = getSpanTypeText(*PteTyText, PteTyQuals);
2453     ParmsMask[i] = true;
2454     AtLeastOneParmToFix = true;
2455   }
2456   if (!AtLeastOneParmToFix)
2457     // No need to create function overloads:
2458     return {};
2459   // FIXME Respect indentation of the original code.
2460 
2461   // A lambda that creates the text representation of a function declaration
2462   // with the new type signatures:
2463   const auto NewOverloadSignatureCreator =
2464       [&SM, &LangOpts, &NewTysTexts,
2465        &ParmsMask](const FunctionDecl *FD) -> std::optional<std::string> {
2466     std::stringstream SS;
2467 
2468     SS << ";";
2469     SS << getEndOfLine().str();
2470     // Append: ret-type func-name "("
2471     if (auto Prefix = getRangeText(
2472             SourceRange(FD->getBeginLoc(), (*FD->param_begin())->getBeginLoc()),
2473             SM, LangOpts))
2474       SS << Prefix->str();
2475     else
2476       return std::nullopt; // give up
2477     // Append: parameter-type-list
2478     const unsigned NumParms = FD->getNumParams();
2479 
2480     for (unsigned i = 0; i < NumParms; i++) {
2481       const ParmVarDecl *Parm = FD->getParamDecl(i);
2482 
2483       if (Parm->isImplicit())
2484         continue;
2485       if (ParmsMask[i]) {
2486         // This `i`-th parameter will be fixed with `NewTysTexts[i]` being its
2487         // new type:
2488         SS << NewTysTexts[i];
2489         // print parameter name if provided:
2490         if (IdentifierInfo *II = Parm->getIdentifier())
2491           SS << ' ' << II->getName().str();
2492       } else if (auto ParmTypeText =
2493                      getRangeText(getSourceRangeToTokenEnd(Parm, SM, LangOpts),
2494                                   SM, LangOpts)) {
2495         // print the whole `Parm` without modification:
2496         SS << ParmTypeText->str();
2497       } else
2498         return std::nullopt; // something wrong, give up
2499       if (i != NumParms - 1)
2500         SS << ", ";
2501     }
2502     SS << ")";
2503     return SS.str();
2504   };
2505 
2506   // A lambda that creates the text representation of a function definition with
2507   // the original signature:
2508   const auto OldOverloadDefCreator =
2509       [&Handler, &SM, &LangOpts, &NewTysTexts,
2510        &ParmsMask](const FunctionDecl *FD) -> std::optional<std::string> {
2511     std::stringstream SS;
2512 
2513     SS << getEndOfLine().str();
2514     // Append: attr-name ret-type func-name "(" param-list ")" "{"
2515     if (auto FDPrefix = getRangeText(
2516             SourceRange(FD->getBeginLoc(), FD->getBody()->getBeginLoc()), SM,
2517             LangOpts))
2518       SS << Handler.getUnsafeBufferUsageAttributeTextAt(FD->getBeginLoc(), " ")
2519          << FDPrefix->str() << "{";
2520     else
2521       return std::nullopt;
2522     // Append: "return" func-name "("
2523     if (auto FunQualName = getFunNameText(FD, SM, LangOpts))
2524       SS << "return " << FunQualName->str() << "(";
2525     else
2526       return std::nullopt;
2527 
2528     // Append: arg-list
2529     const unsigned NumParms = FD->getNumParams();
2530     for (unsigned i = 0; i < NumParms; i++) {
2531       const ParmVarDecl *Parm = FD->getParamDecl(i);
2532 
2533       if (Parm->isImplicit())
2534         continue;
2535       // FIXME: If a parameter has no name, it is unused in the
2536       // definition. So we could just leave it as it is.
2537       if (!Parm->getIdentifier())
2538         // If a parameter of a function definition has no name:
2539         return std::nullopt;
2540       if (ParmsMask[i])
2541         // This is our spanified paramter!
2542         SS << NewTysTexts[i] << "(" << Parm->getIdentifier()->getName().str()
2543            << ", " << getUserFillPlaceHolder("size") << ")";
2544       else
2545         SS << Parm->getIdentifier()->getName().str();
2546       if (i != NumParms - 1)
2547         SS << ", ";
2548     }
2549     // finish call and the body
2550     SS << ");}" << getEndOfLine().str();
2551     // FIXME: 80-char line formatting?
2552     return SS.str();
2553   };
2554 
2555   FixItList FixIts{};
2556   for (FunctionDecl *FReDecl : FD->redecls()) {
2557     std::optional<SourceLocation> Loc = getPastLoc(FReDecl, SM, LangOpts);
2558 
2559     if (!Loc)
2560       return {};
2561     if (FReDecl->isThisDeclarationADefinition()) {
2562       assert(FReDecl == FD && "inconsistent function definition");
2563       // Inserts a definition with the old signature to the end of
2564       // `FReDecl`:
2565       if (auto OldOverloadDef = OldOverloadDefCreator(FReDecl))
2566         FixIts.emplace_back(FixItHint::CreateInsertion(*Loc, *OldOverloadDef));
2567       else
2568         return {}; // give up
2569     } else {
2570       // Adds the unsafe-buffer attribute (if not already there) to `FReDecl`:
2571       if (!FReDecl->hasAttr<UnsafeBufferUsageAttr>()) {
2572         FixIts.emplace_back(FixItHint::CreateInsertion(
2573             FReDecl->getBeginLoc(), Handler.getUnsafeBufferUsageAttributeTextAt(
2574                                         FReDecl->getBeginLoc(), " ")));
2575       }
2576       // Inserts a declaration with the new signature to the end of `FReDecl`:
2577       if (auto NewOverloadDecl = NewOverloadSignatureCreator(FReDecl))
2578         FixIts.emplace_back(FixItHint::CreateInsertion(*Loc, *NewOverloadDecl));
2579       else
2580         return {};
2581     }
2582   }
2583   return FixIts;
2584 }
2585 
2586 // To fix a `ParmVarDecl` to be of `std::span` type.
fixParamWithSpan(const ParmVarDecl * PVD,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2587 static FixItList fixParamWithSpan(const ParmVarDecl *PVD, const ASTContext &Ctx,
2588                                   UnsafeBufferUsageHandler &Handler) {
2589   if (hasUnsupportedSpecifiers(PVD, Ctx.getSourceManager())) {
2590     DEBUG_NOTE_DECL_FAIL(PVD, " : has unsupport specifier(s)");
2591     return {};
2592   }
2593   if (PVD->hasDefaultArg()) {
2594     // FIXME: generate fix-its for default values:
2595     DEBUG_NOTE_DECL_FAIL(PVD, " : has default arg");
2596     return {};
2597   }
2598 
2599   std::optional<Qualifiers> PteTyQualifiers = std::nullopt;
2600   std::optional<std::string> PteTyText = getPointeeTypeText(
2601       PVD, Ctx.getSourceManager(), Ctx.getLangOpts(), &PteTyQualifiers);
2602 
2603   if (!PteTyText) {
2604     DEBUG_NOTE_DECL_FAIL(PVD, " : invalid pointee type");
2605     return {};
2606   }
2607 
2608   std::optional<StringRef> PVDNameText = PVD->getIdentifier()->getName();
2609 
2610   if (!PVDNameText) {
2611     DEBUG_NOTE_DECL_FAIL(PVD, " : invalid identifier name");
2612     return {};
2613   }
2614 
2615   std::stringstream SS;
2616   std::optional<std::string> SpanTyText = createSpanTypeForVarDecl(PVD, Ctx);
2617 
2618   if (PteTyQualifiers)
2619     // Append qualifiers if they exist:
2620     SS << getSpanTypeText(*PteTyText, PteTyQualifiers);
2621   else
2622     SS << getSpanTypeText(*PteTyText);
2623   // Append qualifiers to the type of the parameter:
2624   if (PVD->getType().hasQualifiers())
2625     SS << ' ' << PVD->getType().getQualifiers().getAsString();
2626   // Append parameter's name:
2627   SS << ' ' << PVDNameText->str();
2628   // Add replacement fix-it:
2629   return {FixItHint::CreateReplacement(PVD->getSourceRange(), SS.str())};
2630 }
2631 
fixVariableWithSpan(const VarDecl * VD,const DeclUseTracker & Tracker,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2632 static FixItList fixVariableWithSpan(const VarDecl *VD,
2633                                      const DeclUseTracker &Tracker,
2634                                      ASTContext &Ctx,
2635                                      UnsafeBufferUsageHandler &Handler) {
2636   const DeclStmt *DS = Tracker.lookupDecl(VD);
2637   if (!DS) {
2638     DEBUG_NOTE_DECL_FAIL(VD,
2639                          " : variables declared this way not implemented yet");
2640     return {};
2641   }
2642   if (!DS->isSingleDecl()) {
2643     // FIXME: to support handling multiple `VarDecl`s in a single `DeclStmt`
2644     DEBUG_NOTE_DECL_FAIL(VD, " : multiple VarDecls");
2645     return {};
2646   }
2647   // Currently DS is an unused variable but we'll need it when
2648   // non-single decls are implemented, where the pointee type name
2649   // and the '*' are spread around the place.
2650   (void)DS;
2651 
2652   // FIXME: handle cases where DS has multiple declarations
2653   return fixLocalVarDeclWithSpan(VD, Ctx, getUserFillPlaceHolder(), Handler);
2654 }
2655 
fixVarDeclWithArray(const VarDecl * D,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2656 static FixItList fixVarDeclWithArray(const VarDecl *D, const ASTContext &Ctx,
2657                                      UnsafeBufferUsageHandler &Handler) {
2658   FixItList FixIts{};
2659 
2660   // Note: the code below expects the declaration to not use any type sugar like
2661   // typedef.
2662   if (auto CAT = dyn_cast<clang::ConstantArrayType>(D->getType())) {
2663     const QualType &ArrayEltT = CAT->getElementType();
2664     assert(!ArrayEltT.isNull() && "Trying to fix a non-array type variable!");
2665     // FIXME: support multi-dimensional arrays
2666     if (isa<clang::ArrayType>(ArrayEltT.getCanonicalType()))
2667       return {};
2668 
2669     const SourceLocation IdentifierLoc = getVarDeclIdentifierLoc(D);
2670 
2671     // Get the spelling of the element type as written in the source file
2672     // (including macros, etc.).
2673     auto MaybeElemTypeTxt =
2674         getRangeText({D->getBeginLoc(), IdentifierLoc}, Ctx.getSourceManager(),
2675                      Ctx.getLangOpts());
2676     if (!MaybeElemTypeTxt)
2677       return {};
2678     const llvm::StringRef ElemTypeTxt = MaybeElemTypeTxt->trim();
2679 
2680     // Find the '[' token.
2681     std::optional<Token> NextTok = Lexer::findNextToken(
2682         IdentifierLoc, Ctx.getSourceManager(), Ctx.getLangOpts());
2683     while (NextTok && !NextTok->is(tok::l_square) &&
2684            NextTok->getLocation() <= D->getSourceRange().getEnd())
2685       NextTok = Lexer::findNextToken(NextTok->getLocation(),
2686                                      Ctx.getSourceManager(), Ctx.getLangOpts());
2687     if (!NextTok)
2688       return {};
2689     const SourceLocation LSqBracketLoc = NextTok->getLocation();
2690 
2691     // Get the spelling of the array size as written in the source file
2692     // (including macros, etc.).
2693     auto MaybeArraySizeTxt = getRangeText(
2694         {LSqBracketLoc.getLocWithOffset(1), D->getTypeSpecEndLoc()},
2695         Ctx.getSourceManager(), Ctx.getLangOpts());
2696     if (!MaybeArraySizeTxt)
2697       return {};
2698     const llvm::StringRef ArraySizeTxt = MaybeArraySizeTxt->trim();
2699     if (ArraySizeTxt.empty()) {
2700       // FIXME: Support array size getting determined from the initializer.
2701       // Examples:
2702       //    int arr1[] = {0, 1, 2};
2703       //    int arr2{3, 4, 5};
2704       // We might be able to preserve the non-specified size with `auto` and
2705       // `std::to_array`:
2706       //    auto arr1 = std::to_array<int>({0, 1, 2});
2707       return {};
2708     }
2709 
2710     std::optional<StringRef> IdentText =
2711         getVarDeclIdentifierText(D, Ctx.getSourceManager(), Ctx.getLangOpts());
2712 
2713     if (!IdentText) {
2714       DEBUG_NOTE_DECL_FAIL(D, " : failed to locate the identifier");
2715       return {};
2716     }
2717 
2718     SmallString<32> Replacement;
2719     raw_svector_ostream OS(Replacement);
2720     OS << "std::array<" << ElemTypeTxt << ", " << ArraySizeTxt << "> "
2721        << IdentText->str();
2722 
2723     FixIts.push_back(FixItHint::CreateReplacement(
2724         SourceRange{D->getBeginLoc(), D->getTypeSpecEndLoc()}, OS.str()));
2725   }
2726 
2727   return FixIts;
2728 }
2729 
fixVariableWithArray(const VarDecl * VD,const DeclUseTracker & Tracker,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2730 static FixItList fixVariableWithArray(const VarDecl *VD,
2731                                       const DeclUseTracker &Tracker,
2732                                       const ASTContext &Ctx,
2733                                       UnsafeBufferUsageHandler &Handler) {
2734   const DeclStmt *DS = Tracker.lookupDecl(VD);
2735   assert(DS && "Fixing non-local variables not implemented yet!");
2736   if (!DS->isSingleDecl()) {
2737     // FIXME: to support handling multiple `VarDecl`s in a single `DeclStmt`
2738     return {};
2739   }
2740   // Currently DS is an unused variable but we'll need it when
2741   // non-single decls are implemented, where the pointee type name
2742   // and the '*' are spread around the place.
2743   (void)DS;
2744 
2745   // FIXME: handle cases where DS has multiple declarations
2746   return fixVarDeclWithArray(VD, Ctx, Handler);
2747 }
2748 
2749 // TODO: we should be consistent to use `std::nullopt` to represent no-fix due
2750 // to any unexpected problem.
2751 static FixItList
fixVariable(const VarDecl * VD,FixitStrategy::Kind K,const Decl * D,const DeclUseTracker & Tracker,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2752 fixVariable(const VarDecl *VD, FixitStrategy::Kind K,
2753             /* The function decl under analysis */ const Decl *D,
2754             const DeclUseTracker &Tracker, ASTContext &Ctx,
2755             UnsafeBufferUsageHandler &Handler) {
2756   if (const auto *PVD = dyn_cast<ParmVarDecl>(VD)) {
2757     auto *FD = dyn_cast<clang::FunctionDecl>(PVD->getDeclContext());
2758     if (!FD || FD != D) {
2759       // `FD != D` means that `PVD` belongs to a function that is not being
2760       // analyzed currently.  Thus `FD` may not be complete.
2761       DEBUG_NOTE_DECL_FAIL(VD, " : function not currently analyzed");
2762       return {};
2763     }
2764 
2765     // TODO If function has a try block we can't change params unless we check
2766     // also its catch block for their use.
2767     // FIXME We might support static class methods, some select methods,
2768     // operators and possibly lamdas.
2769     if (FD->isMain() || FD->isConstexpr() ||
2770         FD->getTemplatedKind() != FunctionDecl::TemplatedKind::TK_NonTemplate ||
2771         FD->isVariadic() ||
2772         // also covers call-operator of lamdas
2773         isa<CXXMethodDecl>(FD) ||
2774         // skip when the function body is a try-block
2775         (FD->hasBody() && isa<CXXTryStmt>(FD->getBody())) ||
2776         FD->isOverloadedOperator()) {
2777       DEBUG_NOTE_DECL_FAIL(VD, " : unsupported function decl");
2778       return {}; // TODO test all these cases
2779     }
2780   }
2781 
2782   switch (K) {
2783   case FixitStrategy::Kind::Span: {
2784     if (VD->getType()->isPointerType()) {
2785       if (const auto *PVD = dyn_cast<ParmVarDecl>(VD))
2786         return fixParamWithSpan(PVD, Ctx, Handler);
2787 
2788       if (VD->isLocalVarDecl())
2789         return fixVariableWithSpan(VD, Tracker, Ctx, Handler);
2790     }
2791     DEBUG_NOTE_DECL_FAIL(VD, " : not a pointer");
2792     return {};
2793   }
2794   case FixitStrategy::Kind::Array: {
2795     if (VD->isLocalVarDecl() &&
2796         isa<clang::ConstantArrayType>(VD->getType().getCanonicalType()))
2797       return fixVariableWithArray(VD, Tracker, Ctx, Handler);
2798 
2799     DEBUG_NOTE_DECL_FAIL(VD, " : not a local const-size array");
2800     return {};
2801   }
2802   case FixitStrategy::Kind::Iterator:
2803   case FixitStrategy::Kind::Vector:
2804     llvm_unreachable("FixitStrategy not implemented yet!");
2805   case FixitStrategy::Kind::Wontfix:
2806     llvm_unreachable("Invalid strategy!");
2807   }
2808   llvm_unreachable("Unknown strategy!");
2809 }
2810 
2811 // Returns true iff there exists a `FixItHint` 'h' in `FixIts` such that the
2812 // `RemoveRange` of 'h' overlaps with a macro use.
overlapWithMacro(const FixItList & FixIts)2813 static bool overlapWithMacro(const FixItList &FixIts) {
2814   // FIXME: For now we only check if the range (or the first token) is (part of)
2815   // a macro expansion.  Ideally, we want to check for all tokens in the range.
2816   return llvm::any_of(FixIts, [](const FixItHint &Hint) {
2817     auto Range = Hint.RemoveRange;
2818     if (Range.getBegin().isMacroID() || Range.getEnd().isMacroID())
2819       // If the range (or the first token) is (part of) a macro expansion:
2820       return true;
2821     return false;
2822   });
2823 }
2824 
2825 // Returns true iff `VD` is a parameter of the declaration `D`:
isParameterOf(const VarDecl * VD,const Decl * D)2826 static bool isParameterOf(const VarDecl *VD, const Decl *D) {
2827   return isa<ParmVarDecl>(VD) &&
2828          VD->getDeclContext() == dyn_cast<DeclContext>(D);
2829 }
2830 
2831 // Erases variables in `FixItsForVariable`, if such a variable has an unfixable
2832 // group mate.  A variable `v` is unfixable iff `FixItsForVariable` does not
2833 // contain `v`.
eraseVarsForUnfixableGroupMates(std::map<const VarDecl *,FixItList> & FixItsForVariable,const VariableGroupsManager & VarGrpMgr)2834 static void eraseVarsForUnfixableGroupMates(
2835     std::map<const VarDecl *, FixItList> &FixItsForVariable,
2836     const VariableGroupsManager &VarGrpMgr) {
2837   // Variables will be removed from `FixItsForVariable`:
2838   SmallVector<const VarDecl *, 8> ToErase;
2839 
2840   for (const auto &[VD, Ignore] : FixItsForVariable) {
2841     VarGrpRef Grp = VarGrpMgr.getGroupOfVar(VD);
2842     if (llvm::any_of(Grp,
2843                      [&FixItsForVariable](const VarDecl *GrpMember) -> bool {
2844                        return !FixItsForVariable.count(GrpMember);
2845                      })) {
2846       // At least one group member cannot be fixed, so we have to erase the
2847       // whole group:
2848       for (const VarDecl *Member : Grp)
2849         ToErase.push_back(Member);
2850     }
2851   }
2852   for (auto *VarToErase : ToErase)
2853     FixItsForVariable.erase(VarToErase);
2854 }
2855 
2856 // Returns the fix-its that create bounds-safe function overloads for the
2857 // function `D`, if `D`'s parameters will be changed to safe-types through
2858 // fix-its in `FixItsForVariable`.
2859 //
2860 // NOTE: In case `D`'s parameters will be changed but bounds-safe function
2861 // overloads cannot created, the whole group that contains the parameters will
2862 // be erased from `FixItsForVariable`.
createFunctionOverloadsForParms(std::map<const VarDecl *,FixItList> & FixItsForVariable,const VariableGroupsManager & VarGrpMgr,const FunctionDecl * FD,const FixitStrategy & S,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2863 static FixItList createFunctionOverloadsForParms(
2864     std::map<const VarDecl *, FixItList> &FixItsForVariable /* mutable */,
2865     const VariableGroupsManager &VarGrpMgr, const FunctionDecl *FD,
2866     const FixitStrategy &S, ASTContext &Ctx,
2867     UnsafeBufferUsageHandler &Handler) {
2868   FixItList FixItsSharedByParms{};
2869 
2870   std::optional<FixItList> OverloadFixes =
2871       createOverloadsForFixedParams(S, FD, Ctx, Handler);
2872 
2873   if (OverloadFixes) {
2874     FixItsSharedByParms.append(*OverloadFixes);
2875   } else {
2876     // Something wrong in generating `OverloadFixes`, need to remove the
2877     // whole group, where parameters are in, from `FixItsForVariable` (Note
2878     // that all parameters should be in the same group):
2879     for (auto *Member : VarGrpMgr.getGroupOfParms())
2880       FixItsForVariable.erase(Member);
2881   }
2882   return FixItsSharedByParms;
2883 }
2884 
2885 // Constructs self-contained fix-its for each variable in `FixablesForAllVars`.
2886 static std::map<const VarDecl *, FixItList>
getFixIts(FixableGadgetSets & FixablesForAllVars,const FixitStrategy & S,ASTContext & Ctx,const Decl * D,const DeclUseTracker & Tracker,UnsafeBufferUsageHandler & Handler,const VariableGroupsManager & VarGrpMgr)2887 getFixIts(FixableGadgetSets &FixablesForAllVars, const FixitStrategy &S,
2888           ASTContext &Ctx,
2889           /* The function decl under analysis */ const Decl *D,
2890           const DeclUseTracker &Tracker, UnsafeBufferUsageHandler &Handler,
2891           const VariableGroupsManager &VarGrpMgr) {
2892   // `FixItsForVariable` will map each variable to a set of fix-its directly
2893   // associated to the variable itself.  Fix-its of distinct variables in
2894   // `FixItsForVariable` are disjoint.
2895   std::map<const VarDecl *, FixItList> FixItsForVariable;
2896 
2897   // Populate `FixItsForVariable` with fix-its directly associated with each
2898   // variable.  Fix-its directly associated to a variable 'v' are the ones
2899   // produced by the `FixableGadget`s whose claimed variable is 'v'.
2900   for (const auto &[VD, Fixables] : FixablesForAllVars.byVar) {
2901     FixItsForVariable[VD] =
2902         fixVariable(VD, S.lookup(VD), D, Tracker, Ctx, Handler);
2903     // If we fail to produce Fix-It for the declaration we have to skip the
2904     // variable entirely.
2905     if (FixItsForVariable[VD].empty()) {
2906       FixItsForVariable.erase(VD);
2907       continue;
2908     }
2909     for (const auto &F : Fixables) {
2910       std::optional<FixItList> Fixits = F->getFixits(S);
2911 
2912       if (Fixits) {
2913         FixItsForVariable[VD].insert(FixItsForVariable[VD].end(),
2914                                      Fixits->begin(), Fixits->end());
2915         continue;
2916       }
2917 #ifndef NDEBUG
2918       Handler.addDebugNoteForVar(
2919           VD, F->getSourceLoc(),
2920           ("gadget '" + F->getDebugName() + "' refused to produce a fix")
2921               .str());
2922 #endif
2923       FixItsForVariable.erase(VD);
2924       break;
2925     }
2926   }
2927 
2928   // `FixItsForVariable` now contains only variables that can be
2929   // fixed. A variable can be fixed if its' declaration and all Fixables
2930   // associated to it can all be fixed.
2931 
2932   // To further remove from `FixItsForVariable` variables whose group mates
2933   // cannot be fixed...
2934   eraseVarsForUnfixableGroupMates(FixItsForVariable, VarGrpMgr);
2935   // Now `FixItsForVariable` gets further reduced: a variable is in
2936   // `FixItsForVariable` iff it can be fixed and all its group mates can be
2937   // fixed.
2938 
2939   // Fix-its of bounds-safe overloads of `D` are shared by parameters of `D`.
2940   // That is,  when fixing multiple parameters in one step,  these fix-its will
2941   // be applied only once (instead of being applied per parameter).
2942   FixItList FixItsSharedByParms{};
2943 
2944   if (auto *FD = dyn_cast<FunctionDecl>(D))
2945     FixItsSharedByParms = createFunctionOverloadsForParms(
2946         FixItsForVariable, VarGrpMgr, FD, S, Ctx, Handler);
2947 
2948   // The map that maps each variable `v` to fix-its for the whole group where
2949   // `v` is in:
2950   std::map<const VarDecl *, FixItList> FinalFixItsForVariable{
2951       FixItsForVariable};
2952 
2953   for (auto &[Var, Ignore] : FixItsForVariable) {
2954     bool AnyParm = false;
2955     const auto VarGroupForVD = VarGrpMgr.getGroupOfVar(Var, &AnyParm);
2956 
2957     for (const VarDecl *GrpMate : VarGroupForVD) {
2958       if (Var == GrpMate)
2959         continue;
2960       if (FixItsForVariable.count(GrpMate))
2961         FinalFixItsForVariable[Var].append(FixItsForVariable[GrpMate]);
2962     }
2963     if (AnyParm) {
2964       // This assertion should never fail.  Otherwise we have a bug.
2965       assert(!FixItsSharedByParms.empty() &&
2966              "Should not try to fix a parameter that does not belong to a "
2967              "FunctionDecl");
2968       FinalFixItsForVariable[Var].append(FixItsSharedByParms);
2969     }
2970   }
2971   // Fix-its that will be applied in one step shall NOT:
2972   // 1. overlap with macros or/and templates; or
2973   // 2. conflict with each other.
2974   // Otherwise, the fix-its will be dropped.
2975   for (auto Iter = FinalFixItsForVariable.begin();
2976        Iter != FinalFixItsForVariable.end();)
2977     if (overlapWithMacro(Iter->second) ||
2978         clang::internal::anyConflict(Iter->second, Ctx.getSourceManager())) {
2979       Iter = FinalFixItsForVariable.erase(Iter);
2980     } else
2981       Iter++;
2982   return FinalFixItsForVariable;
2983 }
2984 
2985 template <typename VarDeclIterTy>
2986 static FixitStrategy
getNaiveStrategy(llvm::iterator_range<VarDeclIterTy> UnsafeVars)2987 getNaiveStrategy(llvm::iterator_range<VarDeclIterTy> UnsafeVars) {
2988   FixitStrategy S;
2989   for (const VarDecl *VD : UnsafeVars) {
2990     if (isa<ConstantArrayType>(VD->getType().getCanonicalType()))
2991       S.set(VD, FixitStrategy::Kind::Array);
2992     else
2993       S.set(VD, FixitStrategy::Kind::Span);
2994   }
2995   return S;
2996 }
2997 
2998 //  Manages variable groups:
2999 class VariableGroupsManagerImpl : public VariableGroupsManager {
3000   const std::vector<VarGrpTy> Groups;
3001   const std::map<const VarDecl *, unsigned> &VarGrpMap;
3002   const llvm::SetVector<const VarDecl *> &GrpsUnionForParms;
3003 
3004 public:
VariableGroupsManagerImpl(const std::vector<VarGrpTy> & Groups,const std::map<const VarDecl *,unsigned> & VarGrpMap,const llvm::SetVector<const VarDecl * > & GrpsUnionForParms)3005   VariableGroupsManagerImpl(
3006       const std::vector<VarGrpTy> &Groups,
3007       const std::map<const VarDecl *, unsigned> &VarGrpMap,
3008       const llvm::SetVector<const VarDecl *> &GrpsUnionForParms)
3009       : Groups(Groups), VarGrpMap(VarGrpMap),
3010         GrpsUnionForParms(GrpsUnionForParms) {}
3011 
getGroupOfVar(const VarDecl * Var,bool * HasParm) const3012   VarGrpRef getGroupOfVar(const VarDecl *Var, bool *HasParm) const override {
3013     if (GrpsUnionForParms.contains(Var)) {
3014       if (HasParm)
3015         *HasParm = true;
3016       return GrpsUnionForParms.getArrayRef();
3017     }
3018     if (HasParm)
3019       *HasParm = false;
3020 
3021     auto It = VarGrpMap.find(Var);
3022 
3023     if (It == VarGrpMap.end())
3024       return std::nullopt;
3025     return Groups[It->second];
3026   }
3027 
getGroupOfParms() const3028   VarGrpRef getGroupOfParms() const override {
3029     return GrpsUnionForParms.getArrayRef();
3030   }
3031 };
3032 
checkUnsafeBufferUsage(const Decl * D,UnsafeBufferUsageHandler & Handler,bool EmitSuggestions)3033 void clang::checkUnsafeBufferUsage(const Decl *D,
3034                                    UnsafeBufferUsageHandler &Handler,
3035                                    bool EmitSuggestions) {
3036 #ifndef NDEBUG
3037   Handler.clearDebugNotes();
3038 #endif
3039 
3040   assert(D && D->getBody());
3041   // We do not want to visit a Lambda expression defined inside a method
3042   // independently. Instead, it should be visited along with the outer method.
3043   // FIXME: do we want to do the same thing for `BlockDecl`s?
3044   if (const auto *fd = dyn_cast<CXXMethodDecl>(D)) {
3045     if (fd->getParent()->isLambda() && fd->getParent()->isLocalClass())
3046       return;
3047   }
3048 
3049   // Do not emit fixit suggestions for functions declared in an
3050   // extern "C" block.
3051   if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
3052     for (FunctionDecl *FReDecl : FD->redecls()) {
3053       if (FReDecl->isExternC()) {
3054         EmitSuggestions = false;
3055         break;
3056       }
3057     }
3058   }
3059 
3060   WarningGadgetSets UnsafeOps;
3061   FixableGadgetSets FixablesForAllVars;
3062 
3063   auto [FixableGadgets, WarningGadgets, Tracker] =
3064       findGadgets(D, Handler, EmitSuggestions);
3065 
3066   if (!EmitSuggestions) {
3067     // Our job is very easy without suggestions. Just warn about
3068     // every problematic operation and consider it done. No need to deal
3069     // with fixable gadgets, no need to group operations by variable.
3070     for (const auto &G : WarningGadgets) {
3071       G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/false,
3072                                D->getASTContext());
3073     }
3074 
3075     // This return guarantees that most of the machine doesn't run when
3076     // suggestions aren't requested.
3077     assert(FixableGadgets.size() == 0 &&
3078            "Fixable gadgets found but suggestions not requested!");
3079     return;
3080   }
3081 
3082   // If no `WarningGadget`s ever matched, there is no unsafe operations in the
3083   //  function under the analysis. No need to fix any Fixables.
3084   if (!WarningGadgets.empty()) {
3085     // Gadgets "claim" variables they're responsible for. Once this loop
3086     // finishes, the tracker will only track DREs that weren't claimed by any
3087     // gadgets, i.e. not understood by the analysis.
3088     for (const auto &G : FixableGadgets) {
3089       for (const auto *DRE : G->getClaimedVarUseSites()) {
3090         Tracker.claimUse(DRE);
3091       }
3092     }
3093   }
3094 
3095   // If no `WarningGadget`s ever matched, there is no unsafe operations in the
3096   // function under the analysis.  Thus, it early returns here as there is
3097   // nothing needs to be fixed.
3098   //
3099   // Note this claim is based on the assumption that there is no unsafe
3100   // variable whose declaration is invisible from the analyzing function.
3101   // Otherwise, we need to consider if the uses of those unsafe varuables needs
3102   // fix.
3103   // So far, we are not fixing any global variables or class members. And,
3104   // lambdas will be analyzed along with the enclosing function. So this early
3105   // return is correct for now.
3106   if (WarningGadgets.empty())
3107     return;
3108 
3109   UnsafeOps = groupWarningGadgetsByVar(std::move(WarningGadgets));
3110   FixablesForAllVars = groupFixablesByVar(std::move(FixableGadgets));
3111 
3112   std::map<const VarDecl *, FixItList> FixItsForVariableGroup;
3113 
3114   // Filter out non-local vars and vars with unclaimed DeclRefExpr-s.
3115   for (auto it = FixablesForAllVars.byVar.cbegin();
3116        it != FixablesForAllVars.byVar.cend();) {
3117     // FIXME: need to deal with global variables later
3118     if ((!it->first->isLocalVarDecl() && !isa<ParmVarDecl>(it->first))) {
3119 #ifndef NDEBUG
3120       Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
3121                                  ("failed to produce fixit for '" +
3122                                   it->first->getNameAsString() +
3123                                   "' : neither local nor a parameter"));
3124 #endif
3125       it = FixablesForAllVars.byVar.erase(it);
3126     } else if (it->first->getType().getCanonicalType()->isReferenceType()) {
3127 #ifndef NDEBUG
3128       Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
3129                                  ("failed to produce fixit for '" +
3130                                   it->first->getNameAsString() +
3131                                   "' : has a reference type"));
3132 #endif
3133       it = FixablesForAllVars.byVar.erase(it);
3134     } else if (Tracker.hasUnclaimedUses(it->first)) {
3135       it = FixablesForAllVars.byVar.erase(it);
3136     } else if (it->first->isInitCapture()) {
3137 #ifndef NDEBUG
3138       Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
3139                                  ("failed to produce fixit for '" +
3140                                   it->first->getNameAsString() +
3141                                   "' : init capture"));
3142 #endif
3143       it = FixablesForAllVars.byVar.erase(it);
3144     } else {
3145       ++it;
3146     }
3147   }
3148 
3149 #ifndef NDEBUG
3150   for (const auto &it : UnsafeOps.byVar) {
3151     const VarDecl *const UnsafeVD = it.first;
3152     auto UnclaimedDREs = Tracker.getUnclaimedUses(UnsafeVD);
3153     if (UnclaimedDREs.empty())
3154       continue;
3155     const auto UnfixedVDName = UnsafeVD->getNameAsString();
3156     for (const clang::DeclRefExpr *UnclaimedDRE : UnclaimedDREs) {
3157       std::string UnclaimedUseTrace =
3158           getDREAncestorString(UnclaimedDRE, D->getASTContext());
3159 
3160       Handler.addDebugNoteForVar(
3161           UnsafeVD, UnclaimedDRE->getBeginLoc(),
3162           ("failed to produce fixit for '" + UnfixedVDName +
3163            "' : has an unclaimed use\nThe unclaimed DRE trace: " +
3164            UnclaimedUseTrace));
3165     }
3166   }
3167 #endif
3168 
3169   // Fixpoint iteration for pointer assignments
3170   using DepMapTy = DenseMap<const VarDecl *, llvm::SetVector<const VarDecl *>>;
3171   DepMapTy DependenciesMap{};
3172   DepMapTy PtrAssignmentGraph{};
3173 
3174   for (auto it : FixablesForAllVars.byVar) {
3175     for (const FixableGadget *fixable : it.second) {
3176       std::optional<std::pair<const VarDecl *, const VarDecl *>> ImplPair =
3177           fixable->getStrategyImplications();
3178       if (ImplPair) {
3179         std::pair<const VarDecl *, const VarDecl *> Impl = std::move(*ImplPair);
3180         PtrAssignmentGraph[Impl.first].insert(Impl.second);
3181       }
3182     }
3183   }
3184 
3185   /*
3186    The following code does a BFS traversal of the `PtrAssignmentGraph`
3187    considering all unsafe vars as starting nodes and constructs an undirected
3188    graph `DependenciesMap`. Constructing the `DependenciesMap` in this manner
3189    elimiates all variables that are unreachable from any unsafe var. In other
3190    words, this removes all dependencies that don't include any unsafe variable
3191    and consequently don't need any fixit generation.
3192    Note: A careful reader would observe that the code traverses
3193    `PtrAssignmentGraph` using `CurrentVar` but adds edges between `Var` and
3194    `Adj` and not between `CurrentVar` and `Adj`. Both approaches would
3195    achieve the same result but the one used here dramatically cuts the
3196    amount of hoops the second part of the algorithm needs to jump, given that
3197    a lot of these connections become "direct". The reader is advised not to
3198    imagine how the graph is transformed because of using `Var` instead of
3199    `CurrentVar`. The reader can continue reading as if `CurrentVar` was used,
3200    and think about why it's equivalent later.
3201    */
3202   std::set<const VarDecl *> VisitedVarsDirected{};
3203   for (const auto &[Var, ignore] : UnsafeOps.byVar) {
3204     if (VisitedVarsDirected.find(Var) == VisitedVarsDirected.end()) {
3205 
3206       std::queue<const VarDecl *> QueueDirected{};
3207       QueueDirected.push(Var);
3208       while (!QueueDirected.empty()) {
3209         const VarDecl *CurrentVar = QueueDirected.front();
3210         QueueDirected.pop();
3211         VisitedVarsDirected.insert(CurrentVar);
3212         auto AdjacentNodes = PtrAssignmentGraph[CurrentVar];
3213         for (const VarDecl *Adj : AdjacentNodes) {
3214           if (VisitedVarsDirected.find(Adj) == VisitedVarsDirected.end()) {
3215             QueueDirected.push(Adj);
3216           }
3217           DependenciesMap[Var].insert(Adj);
3218           DependenciesMap[Adj].insert(Var);
3219         }
3220       }
3221     }
3222   }
3223 
3224   // `Groups` stores the set of Connected Components in the graph.
3225   std::vector<VarGrpTy> Groups;
3226   // `VarGrpMap` maps variables that need fix to the groups (indexes) that the
3227   // variables belong to.  Group indexes refer to the elements in `Groups`.
3228   // `VarGrpMap` is complete in that every variable that needs fix is in it.
3229   std::map<const VarDecl *, unsigned> VarGrpMap;
3230   // The union group over the ones in "Groups" that contain parameters of `D`:
3231   llvm::SetVector<const VarDecl *>
3232       GrpsUnionForParms; // these variables need to be fixed in one step
3233 
3234   // Group Connected Components for Unsafe Vars
3235   // (Dependencies based on pointer assignments)
3236   std::set<const VarDecl *> VisitedVars{};
3237   for (const auto &[Var, ignore] : UnsafeOps.byVar) {
3238     if (VisitedVars.find(Var) == VisitedVars.end()) {
3239       VarGrpTy &VarGroup = Groups.emplace_back();
3240       std::queue<const VarDecl *> Queue{};
3241 
3242       Queue.push(Var);
3243       while (!Queue.empty()) {
3244         const VarDecl *CurrentVar = Queue.front();
3245         Queue.pop();
3246         VisitedVars.insert(CurrentVar);
3247         VarGroup.push_back(CurrentVar);
3248         auto AdjacentNodes = DependenciesMap[CurrentVar];
3249         for (const VarDecl *Adj : AdjacentNodes) {
3250           if (VisitedVars.find(Adj) == VisitedVars.end()) {
3251             Queue.push(Adj);
3252           }
3253         }
3254       }
3255 
3256       bool HasParm = false;
3257       unsigned GrpIdx = Groups.size() - 1;
3258 
3259       for (const VarDecl *V : VarGroup) {
3260         VarGrpMap[V] = GrpIdx;
3261         if (!HasParm && isParameterOf(V, D))
3262           HasParm = true;
3263       }
3264       if (HasParm)
3265         GrpsUnionForParms.insert(VarGroup.begin(), VarGroup.end());
3266     }
3267   }
3268 
3269   // Remove a `FixableGadget` if the associated variable is not in the graph
3270   // computed above.  We do not want to generate fix-its for such variables,
3271   // since they are neither warned nor reachable from a warned one.
3272   //
3273   // Note a variable is not warned if it is not directly used in any unsafe
3274   // operation. A variable `v` is NOT reachable from an unsafe variable, if it
3275   // does not exist another variable `u` such that `u` is warned and fixing `u`
3276   // (transitively) implicates fixing `v`.
3277   //
3278   // For example,
3279   // ```
3280   // void f(int * p) {
3281   //   int * a = p; *p = 0;
3282   // }
3283   // ```
3284   // `*p = 0` is a fixable gadget associated with a variable `p` that is neither
3285   // warned nor reachable from a warned one.  If we add `a[5] = 0` to the end of
3286   // the function above, `p` becomes reachable from a warned variable.
3287   for (auto I = FixablesForAllVars.byVar.begin();
3288        I != FixablesForAllVars.byVar.end();) {
3289     // Note `VisitedVars` contain all the variables in the graph:
3290     if (!VisitedVars.count((*I).first)) {
3291       // no such var in graph:
3292       I = FixablesForAllVars.byVar.erase(I);
3293     } else
3294       ++I;
3295   }
3296 
3297   // We assign strategies to variables that are 1) in the graph and 2) can be
3298   // fixed. Other variables have the default "Won't fix" strategy.
3299   FixitStrategy NaiveStrategy = getNaiveStrategy(llvm::make_filter_range(
3300       VisitedVars, [&FixablesForAllVars](const VarDecl *V) {
3301         // If a warned variable has no "Fixable", it is considered unfixable:
3302         return FixablesForAllVars.byVar.count(V);
3303       }));
3304   VariableGroupsManagerImpl VarGrpMgr(Groups, VarGrpMap, GrpsUnionForParms);
3305 
3306   if (isa<NamedDecl>(D))
3307     // The only case where `D` is not a `NamedDecl` is when `D` is a
3308     // `BlockDecl`. Let's not fix variables in blocks for now
3309     FixItsForVariableGroup =
3310         getFixIts(FixablesForAllVars, NaiveStrategy, D->getASTContext(), D,
3311                   Tracker, Handler, VarGrpMgr);
3312 
3313   for (const auto &G : UnsafeOps.noVar) {
3314     G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/false,
3315                              D->getASTContext());
3316   }
3317 
3318   for (const auto &[VD, WarningGadgets] : UnsafeOps.byVar) {
3319     auto FixItsIt = FixItsForVariableGroup.find(VD);
3320     Handler.handleUnsafeVariableGroup(VD, VarGrpMgr,
3321                                       FixItsIt != FixItsForVariableGroup.end()
3322                                           ? std::move(FixItsIt->second)
3323                                           : FixItList{},
3324                                       D, NaiveStrategy);
3325     for (const auto &G : WarningGadgets) {
3326       G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/true,
3327                                D->getASTContext());
3328     }
3329   }
3330 }
3331