xref: /freebsd/contrib/llvm-project/clang/lib/Analysis/UnsafeBufferUsage.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1  //===- UnsafeBufferUsage.cpp - Replace pointers with modern C++ -----------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  
9  #include "clang/Analysis/Analyses/UnsafeBufferUsage.h"
10  #include "clang/AST/ASTContext.h"
11  #include "clang/AST/Decl.h"
12  #include "clang/AST/Expr.h"
13  #include "clang/AST/RecursiveASTVisitor.h"
14  #include "clang/AST/Stmt.h"
15  #include "clang/AST/StmtVisitor.h"
16  #include "clang/ASTMatchers/ASTMatchFinder.h"
17  #include "clang/ASTMatchers/ASTMatchers.h"
18  #include "clang/Basic/CharInfo.h"
19  #include "clang/Basic/SourceLocation.h"
20  #include "clang/Lex/Lexer.h"
21  #include "clang/Lex/Preprocessor.h"
22  #include "llvm/ADT/APSInt.h"
23  #include "llvm/ADT/SmallVector.h"
24  #include "llvm/ADT/StringRef.h"
25  #include "llvm/Support/Casting.h"
26  #include <memory>
27  #include <optional>
28  #include <queue>
29  #include <sstream>
30  
31  using namespace llvm;
32  using namespace clang;
33  using namespace ast_matchers;
34  
35  #ifndef NDEBUG
36  namespace {
37  class StmtDebugPrinter
38      : public ConstStmtVisitor<StmtDebugPrinter, std::string> {
39  public:
VisitStmt(const Stmt * S)40    std::string VisitStmt(const Stmt *S) { return S->getStmtClassName(); }
41  
VisitBinaryOperator(const BinaryOperator * BO)42    std::string VisitBinaryOperator(const BinaryOperator *BO) {
43      return "BinaryOperator(" + BO->getOpcodeStr().str() + ")";
44    }
45  
VisitUnaryOperator(const UnaryOperator * UO)46    std::string VisitUnaryOperator(const UnaryOperator *UO) {
47      return "UnaryOperator(" + UO->getOpcodeStr(UO->getOpcode()).str() + ")";
48    }
49  
VisitImplicitCastExpr(const ImplicitCastExpr * ICE)50    std::string VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
51      return "ImplicitCastExpr(" + std::string(ICE->getCastKindName()) + ")";
52    }
53  };
54  
55  // Returns a string of ancestor `Stmt`s of the given `DRE` in such a form:
56  // "DRE ==> parent-of-DRE ==> grandparent-of-DRE ==> ...".
getDREAncestorString(const DeclRefExpr * DRE,ASTContext & Ctx)57  static std::string getDREAncestorString(const DeclRefExpr *DRE,
58                                          ASTContext &Ctx) {
59    std::stringstream SS;
60    const Stmt *St = DRE;
61    StmtDebugPrinter StmtPriner;
62  
63    do {
64      SS << StmtPriner.Visit(St);
65  
66      DynTypedNodeList StParents = Ctx.getParents(*St);
67  
68      if (StParents.size() > 1)
69        return "unavailable due to multiple parents";
70      if (StParents.size() == 0)
71        break;
72      St = StParents.begin()->get<Stmt>();
73      if (St)
74        SS << " ==> ";
75    } while (St);
76    return SS.str();
77  }
78  } // namespace
79  #endif /* NDEBUG */
80  
81  namespace clang::ast_matchers {
82  // A `RecursiveASTVisitor` that traverses all descendants of a given node "n"
83  // except for those belonging to a different callable of "n".
84  class MatchDescendantVisitor
85      : public RecursiveASTVisitor<MatchDescendantVisitor> {
86  public:
87    typedef RecursiveASTVisitor<MatchDescendantVisitor> VisitorBase;
88  
89    // Creates an AST visitor that matches `Matcher` on all
90    // descendants of a given node "n" except for the ones
91    // belonging to a different callable of "n".
MatchDescendantVisitor(const internal::DynTypedMatcher * Matcher,internal::ASTMatchFinder * Finder,internal::BoundNodesTreeBuilder * Builder,internal::ASTMatchFinder::BindKind Bind,const bool ignoreUnevaluatedContext)92    MatchDescendantVisitor(const internal::DynTypedMatcher *Matcher,
93                           internal::ASTMatchFinder *Finder,
94                           internal::BoundNodesTreeBuilder *Builder,
95                           internal::ASTMatchFinder::BindKind Bind,
96                           const bool ignoreUnevaluatedContext)
97        : Matcher(Matcher), Finder(Finder), Builder(Builder), Bind(Bind),
98          Matches(false), ignoreUnevaluatedContext(ignoreUnevaluatedContext) {}
99  
100    // Returns true if a match is found in a subtree of `DynNode`, which belongs
101    // to the same callable of `DynNode`.
findMatch(const DynTypedNode & DynNode)102    bool findMatch(const DynTypedNode &DynNode) {
103      Matches = false;
104      if (const Stmt *StmtNode = DynNode.get<Stmt>()) {
105        TraverseStmt(const_cast<Stmt *>(StmtNode));
106        *Builder = ResultBindings;
107        return Matches;
108      }
109      return false;
110    }
111  
112    // The following are overriding methods from the base visitor class.
113    // They are public only to allow CRTP to work. They are *not *part
114    // of the public API of this class.
115  
116    // For the matchers so far used in safe buffers, we only need to match
117    // `Stmt`s.  To override more as needed.
118  
TraverseDecl(Decl * Node)119    bool TraverseDecl(Decl *Node) {
120      if (!Node)
121        return true;
122      if (!match(*Node))
123        return false;
124      // To skip callables:
125      if (isa<FunctionDecl, BlockDecl, ObjCMethodDecl>(Node))
126        return true;
127      // Traverse descendants
128      return VisitorBase::TraverseDecl(Node);
129    }
130  
TraverseGenericSelectionExpr(GenericSelectionExpr * Node)131    bool TraverseGenericSelectionExpr(GenericSelectionExpr *Node) {
132      // These are unevaluated, except the result expression.
133      if (ignoreUnevaluatedContext)
134        return TraverseStmt(Node->getResultExpr());
135      return VisitorBase::TraverseGenericSelectionExpr(Node);
136    }
137  
TraverseUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr * Node)138    bool TraverseUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr *Node) {
139      // Unevaluated context.
140      if (ignoreUnevaluatedContext)
141        return true;
142      return VisitorBase::TraverseUnaryExprOrTypeTraitExpr(Node);
143    }
144  
TraverseTypeOfExprTypeLoc(TypeOfExprTypeLoc Node)145    bool TraverseTypeOfExprTypeLoc(TypeOfExprTypeLoc Node) {
146      // Unevaluated context.
147      if (ignoreUnevaluatedContext)
148        return true;
149      return VisitorBase::TraverseTypeOfExprTypeLoc(Node);
150    }
151  
TraverseDecltypeTypeLoc(DecltypeTypeLoc Node)152    bool TraverseDecltypeTypeLoc(DecltypeTypeLoc Node) {
153      // Unevaluated context.
154      if (ignoreUnevaluatedContext)
155        return true;
156      return VisitorBase::TraverseDecltypeTypeLoc(Node);
157    }
158  
TraverseCXXNoexceptExpr(CXXNoexceptExpr * Node)159    bool TraverseCXXNoexceptExpr(CXXNoexceptExpr *Node) {
160      // Unevaluated context.
161      if (ignoreUnevaluatedContext)
162        return true;
163      return VisitorBase::TraverseCXXNoexceptExpr(Node);
164    }
165  
TraverseCXXTypeidExpr(CXXTypeidExpr * Node)166    bool TraverseCXXTypeidExpr(CXXTypeidExpr *Node) {
167      // Unevaluated context.
168      if (ignoreUnevaluatedContext)
169        return true;
170      return VisitorBase::TraverseCXXTypeidExpr(Node);
171    }
172  
TraverseStmt(Stmt * Node,DataRecursionQueue * Queue=nullptr)173    bool TraverseStmt(Stmt *Node, DataRecursionQueue *Queue = nullptr) {
174      if (!Node)
175        return true;
176      if (!match(*Node))
177        return false;
178      return VisitorBase::TraverseStmt(Node);
179    }
180  
shouldVisitTemplateInstantiations() const181    bool shouldVisitTemplateInstantiations() const { return true; }
shouldVisitImplicitCode() const182    bool shouldVisitImplicitCode() const {
183      // TODO: let's ignore implicit code for now
184      return false;
185    }
186  
187  private:
188    // Sets 'Matched' to true if 'Matcher' matches 'Node'
189    //
190    // Returns 'true' if traversal should continue after this function
191    // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
match(const T & Node)192    template <typename T> bool match(const T &Node) {
193      internal::BoundNodesTreeBuilder RecursiveBuilder(*Builder);
194  
195      if (Matcher->matches(DynTypedNode::create(Node), Finder,
196                           &RecursiveBuilder)) {
197        ResultBindings.addMatch(RecursiveBuilder);
198        Matches = true;
199        if (Bind != internal::ASTMatchFinder::BK_All)
200          return false; // Abort as soon as a match is found.
201      }
202      return true;
203    }
204  
205    const internal::DynTypedMatcher *const Matcher;
206    internal::ASTMatchFinder *const Finder;
207    internal::BoundNodesTreeBuilder *const Builder;
208    internal::BoundNodesTreeBuilder ResultBindings;
209    const internal::ASTMatchFinder::BindKind Bind;
210    bool Matches;
211    bool ignoreUnevaluatedContext;
212  };
213  
214  // Because we're dealing with raw pointers, let's define what we mean by that.
hasPointerType()215  static auto hasPointerType() {
216    return hasType(hasCanonicalType(pointerType()));
217  }
218  
hasArrayType()219  static auto hasArrayType() { return hasType(hasCanonicalType(arrayType())); }
220  
AST_MATCHER_P(Stmt,forEachDescendantEvaluatedStmt,internal::Matcher<Stmt>,innerMatcher)221  AST_MATCHER_P(Stmt, forEachDescendantEvaluatedStmt, internal::Matcher<Stmt>,
222                innerMatcher) {
223    const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher);
224  
225    MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All,
226                                   true);
227    return Visitor.findMatch(DynTypedNode::create(Node));
228  }
229  
AST_MATCHER_P(Stmt,forEachDescendantStmt,internal::Matcher<Stmt>,innerMatcher)230  AST_MATCHER_P(Stmt, forEachDescendantStmt, internal::Matcher<Stmt>,
231                innerMatcher) {
232    const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher);
233  
234    MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All,
235                                   false);
236    return Visitor.findMatch(DynTypedNode::create(Node));
237  }
238  
239  // Matches a `Stmt` node iff the node is in a safe-buffer opt-out region
AST_MATCHER_P(Stmt,notInSafeBufferOptOut,const UnsafeBufferUsageHandler *,Handler)240  AST_MATCHER_P(Stmt, notInSafeBufferOptOut, const UnsafeBufferUsageHandler *,
241                Handler) {
242    return !Handler->isSafeBufferOptOut(Node.getBeginLoc());
243  }
244  
AST_MATCHER_P(Stmt,ignoreUnsafeBufferInContainer,const UnsafeBufferUsageHandler *,Handler)245  AST_MATCHER_P(Stmt, ignoreUnsafeBufferInContainer,
246                const UnsafeBufferUsageHandler *, Handler) {
247    return Handler->ignoreUnsafeBufferInContainer(Node.getBeginLoc());
248  }
249  
AST_MATCHER_P(CastExpr,castSubExpr,internal::Matcher<Expr>,innerMatcher)250  AST_MATCHER_P(CastExpr, castSubExpr, internal::Matcher<Expr>, innerMatcher) {
251    return innerMatcher.matches(*Node.getSubExpr(), Finder, Builder);
252  }
253  
254  // Matches a `UnaryOperator` whose operator is pre-increment:
AST_MATCHER(UnaryOperator,isPreInc)255  AST_MATCHER(UnaryOperator, isPreInc) {
256    return Node.getOpcode() == UnaryOperator::Opcode::UO_PreInc;
257  }
258  
259  // Returns a matcher that matches any expression 'e' such that `innerMatcher`
260  // matches 'e' and 'e' is in an Unspecified Lvalue Context.
isInUnspecifiedLvalueContext(internal::Matcher<Expr> innerMatcher)261  static auto isInUnspecifiedLvalueContext(internal::Matcher<Expr> innerMatcher) {
262    // clang-format off
263    return
264      expr(anyOf(
265        implicitCastExpr(
266          hasCastKind(CastKind::CK_LValueToRValue),
267          castSubExpr(innerMatcher)),
268        binaryOperator(
269          hasAnyOperatorName("="),
270          hasLHS(innerMatcher)
271        )
272      ));
273    // clang-format on
274  }
275  
276  // Returns a matcher that matches any expression `e` such that `InnerMatcher`
277  // matches `e` and `e` is in an Unspecified Pointer Context (UPC).
278  static internal::Matcher<Stmt>
isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher)279  isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) {
280    // A UPC can be
281    // 1. an argument of a function call (except the callee has [[unsafe_...]]
282    //    attribute), or
283    // 2. the operand of a pointer-to-(integer or bool) cast operation; or
284    // 3. the operand of a comparator operation; or
285    // 4. the operand of a pointer subtraction operation
286    //    (i.e., computing the distance between two pointers); or ...
287  
288    // clang-format off
289    auto CallArgMatcher = callExpr(
290      forEachArgumentWithParamType(
291        InnerMatcher,
292        isAnyPointer() /* array also decays to pointer type*/),
293      unless(callee(
294        functionDecl(hasAttr(attr::UnsafeBufferUsage)))));
295  
296    auto CastOperandMatcher =
297        castExpr(anyOf(hasCastKind(CastKind::CK_PointerToIntegral),
298  		     hasCastKind(CastKind::CK_PointerToBoolean)),
299  	       castSubExpr(allOf(hasPointerType(), InnerMatcher)));
300  
301    auto CompOperandMatcher =
302        binaryOperator(hasAnyOperatorName("!=", "==", "<", "<=", ">", ">="),
303                       eachOf(hasLHS(allOf(hasPointerType(), InnerMatcher)),
304                              hasRHS(allOf(hasPointerType(), InnerMatcher))));
305  
306    // A matcher that matches pointer subtractions:
307    auto PtrSubtractionMatcher =
308        binaryOperator(hasOperatorName("-"),
309  		     // Note that here we need both LHS and RHS to be
310  		     // pointer. Then the inner matcher can match any of
311  		     // them:
312  		     allOf(hasLHS(hasPointerType()),
313  			   hasRHS(hasPointerType())),
314  		     eachOf(hasLHS(InnerMatcher),
315  			    hasRHS(InnerMatcher)));
316    // clang-format on
317  
318    return stmt(anyOf(CallArgMatcher, CastOperandMatcher, CompOperandMatcher,
319                      PtrSubtractionMatcher));
320    // FIXME: any more cases? (UPC excludes the RHS of an assignment.  For now we
321    // don't have to check that.)
322  }
323  
324  // Returns a matcher that matches any expression 'e' such that `innerMatcher`
325  // matches 'e' and 'e' is in an unspecified untyped context (i.e the expression
326  // 'e' isn't evaluated to an RValue). For example, consider the following code:
327  //    int *p = new int[4];
328  //    int *q = new int[4];
329  //    if ((p = q)) {}
330  //    p = q;
331  // The expression `p = q` in the conditional of the `if` statement
332  // `if ((p = q))` is evaluated as an RValue, whereas the expression `p = q;`
333  // in the assignment statement is in an untyped context.
334  static internal::Matcher<Stmt>
isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher)335  isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher) {
336    // An unspecified context can be
337    // 1. A compound statement,
338    // 2. The body of an if statement
339    // 3. Body of a loop
340    auto CompStmt = compoundStmt(forEach(InnerMatcher));
341    auto IfStmtThen = ifStmt(hasThen(InnerMatcher));
342    auto IfStmtElse = ifStmt(hasElse(InnerMatcher));
343    // FIXME: Handle loop bodies.
344    return stmt(anyOf(CompStmt, IfStmtThen, IfStmtElse));
345  }
346  
347  // Given a two-param std::span construct call, matches iff the call has the
348  // following forms:
349  //   1. `std::span<T>{new T[n], n}`, where `n` is a literal or a DRE
350  //   2. `std::span<T>{new T, 1}`
351  //   3. `std::span<T>{&var, 1}`
352  //   4. `std::span<T>{a, n}`, where `a` is of an array-of-T with constant size
353  //   `n`
354  //   5. `std::span<T>{any, 0}`
AST_MATCHER(CXXConstructExpr,isSafeSpanTwoParamConstruct)355  AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) {
356    assert(Node.getNumArgs() == 2 &&
357           "expecting a two-parameter std::span constructor");
358    const Expr *Arg0 = Node.getArg(0)->IgnoreImplicit();
359    const Expr *Arg1 = Node.getArg(1)->IgnoreImplicit();
360    auto HaveEqualConstantValues = [&Finder](const Expr *E0, const Expr *E1) {
361      if (auto E0CV = E0->getIntegerConstantExpr(Finder->getASTContext()))
362        if (auto E1CV = E1->getIntegerConstantExpr(Finder->getASTContext())) {
363          return APSInt::compareValues(*E0CV, *E1CV) == 0;
364        }
365      return false;
366    };
367    auto AreSameDRE = [](const Expr *E0, const Expr *E1) {
368      if (auto *DRE0 = dyn_cast<DeclRefExpr>(E0))
369        if (auto *DRE1 = dyn_cast<DeclRefExpr>(E1)) {
370          return DRE0->getDecl() == DRE1->getDecl();
371        }
372      return false;
373    };
374    std::optional<APSInt> Arg1CV =
375        Arg1->getIntegerConstantExpr(Finder->getASTContext());
376  
377    if (Arg1CV && Arg1CV->isZero())
378      // Check form 5:
379      return true;
380    switch (Arg0->IgnoreImplicit()->getStmtClass()) {
381    case Stmt::CXXNewExprClass:
382      if (auto Size = cast<CXXNewExpr>(Arg0)->getArraySize()) {
383        // Check form 1:
384        return AreSameDRE((*Size)->IgnoreImplicit(), Arg1) ||
385               HaveEqualConstantValues(*Size, Arg1);
386      }
387      // TODO: what's placeholder type? avoid it for now.
388      if (!cast<CXXNewExpr>(Arg0)->hasPlaceholderType()) {
389        // Check form 2:
390        return Arg1CV && Arg1CV->isOne();
391      }
392      break;
393    case Stmt::UnaryOperatorClass:
394      if (cast<UnaryOperator>(Arg0)->getOpcode() ==
395          UnaryOperator::Opcode::UO_AddrOf)
396        // Check form 3:
397        return Arg1CV && Arg1CV->isOne();
398      break;
399    default:
400      break;
401    }
402  
403    QualType Arg0Ty = Arg0->IgnoreImplicit()->getType();
404  
405    if (Arg0Ty->isConstantArrayType()) {
406      const APSInt ConstArrSize =
407          APSInt(cast<ConstantArrayType>(Arg0Ty)->getSize());
408  
409      // Check form 4:
410      return Arg1CV && APSInt::compareValues(ConstArrSize, *Arg1CV) == 0;
411    }
412    return false;
413  }
414  
AST_MATCHER(ArraySubscriptExpr,isSafeArraySubscript)415  AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) {
416    // FIXME: Proper solution:
417    //  - refactor Sema::CheckArrayAccess
418    //    - split safe/OOB/unknown decision logic from diagnostics emitting code
419    //    -  e. g. "Try harder to find a NamedDecl to point at in the note."
420    //    already duplicated
421    //  - call both from Sema and from here
422  
423    const auto *BaseDRE =
424        dyn_cast<DeclRefExpr>(Node.getBase()->IgnoreParenImpCasts());
425    if (!BaseDRE)
426      return false;
427    if (!BaseDRE->getDecl())
428      return false;
429    const auto *CATy = Finder->getASTContext().getAsConstantArrayType(
430        BaseDRE->getDecl()->getType());
431    if (!CATy)
432      return false;
433  
434    if (const auto *IdxLit = dyn_cast<IntegerLiteral>(Node.getIdx())) {
435      const APInt ArrIdx = IdxLit->getValue();
436      // FIXME: ArrIdx.isNegative() we could immediately emit an error as that's a
437      // bug
438      if (ArrIdx.isNonNegative() &&
439          ArrIdx.getLimitedValue() < CATy->getLimitedSize())
440        return true;
441    }
442  
443    return false;
444  }
445  
446  } // namespace clang::ast_matchers
447  
448  namespace {
449  // Because the analysis revolves around variables and their types, we'll need to
450  // track uses of variables (aka DeclRefExprs).
451  using DeclUseList = SmallVector<const DeclRefExpr *, 1>;
452  
453  // Convenience typedef.
454  using FixItList = SmallVector<FixItHint, 4>;
455  } // namespace
456  
457  namespace {
458  /// Gadget is an individual operation in the code that may be of interest to
459  /// this analysis. Each (non-abstract) subclass corresponds to a specific
460  /// rigid AST structure that constitutes an operation on a pointer-type object.
461  /// Discovery of a gadget in the code corresponds to claiming that we understand
462  /// what this part of code is doing well enough to potentially improve it.
463  /// Gadgets can be warning (immediately deserving a warning) or fixable (not
464  /// always deserving a warning per se, but requires our attention to identify
465  /// it warrants a fixit).
466  class Gadget {
467  public:
468    enum class Kind {
469  #define GADGET(x) x,
470  #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
471    };
472  
473    /// Common type of ASTMatchers used for discovering gadgets.
474    /// Useful for implementing the static matcher() methods
475    /// that are expected from all non-abstract subclasses.
476    using Matcher = decltype(stmt());
477  
Gadget(Kind K)478    Gadget(Kind K) : K(K) {}
479  
getKind() const480    Kind getKind() const { return K; }
481  
482  #ifndef NDEBUG
getDebugName() const483    StringRef getDebugName() const {
484      switch (K) {
485  #define GADGET(x)                                                              \
486    case Kind::x:                                                                \
487      return #x;
488  #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
489      }
490      llvm_unreachable("Unhandled Gadget::Kind enum");
491    }
492  #endif
493  
494    virtual bool isWarningGadget() const = 0;
495    // TODO remove this method from WarningGadget interface. It's only used for
496    // debug prints in FixableGadget.
497    virtual SourceLocation getSourceLoc() const = 0;
498  
499    /// Returns the list of pointer-type variables on which this gadget performs
500    /// its operation. Typically, there's only one variable. This isn't a list
501    /// of all DeclRefExprs in the gadget's AST!
502    virtual DeclUseList getClaimedVarUseSites() const = 0;
503  
504    virtual ~Gadget() = default;
505  
506  private:
507    Kind K;
508  };
509  
510  /// Warning gadgets correspond to unsafe code patterns that warrants
511  /// an immediate warning.
512  class WarningGadget : public Gadget {
513  public:
WarningGadget(Kind K)514    WarningGadget(Kind K) : Gadget(K) {}
515  
classof(const Gadget * G)516    static bool classof(const Gadget *G) { return G->isWarningGadget(); }
isWarningGadget() const517    bool isWarningGadget() const final { return true; }
518  
519    virtual void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
520                                       bool IsRelatedToDecl,
521                                       ASTContext &Ctx) const = 0;
522  };
523  
524  /// Fixable gadgets correspond to code patterns that aren't always unsafe but
525  /// need to be properly recognized in order to emit fixes. For example, if a raw
526  /// pointer-type variable is replaced by a safe C++ container, every use of such
527  /// variable must be carefully considered and possibly updated.
528  class FixableGadget : public Gadget {
529  public:
FixableGadget(Kind K)530    FixableGadget(Kind K) : Gadget(K) {}
531  
classof(const Gadget * G)532    static bool classof(const Gadget *G) { return !G->isWarningGadget(); }
isWarningGadget() const533    bool isWarningGadget() const final { return false; }
534  
535    /// Returns a fixit that would fix the current gadget according to
536    /// the current strategy. Returns std::nullopt if the fix cannot be produced;
537    /// returns an empty list if no fixes are necessary.
getFixits(const FixitStrategy &) const538    virtual std::optional<FixItList> getFixits(const FixitStrategy &) const {
539      return std::nullopt;
540    }
541  
542    /// Returns a list of two elements where the first element is the LHS of a
543    /// pointer assignment statement and the second element is the RHS. This
544    /// two-element list represents the fact that the LHS buffer gets its bounds
545    /// information from the RHS buffer. This information will be used later to
546    /// group all those variables whose types must be modified together to prevent
547    /// type mismatches.
548    virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const549    getStrategyImplications() const {
550      return std::nullopt;
551    }
552  };
553  
toSupportedVariable()554  static auto toSupportedVariable() { return to(varDecl()); }
555  
556  using FixableGadgetList = std::vector<std::unique_ptr<FixableGadget>>;
557  using WarningGadgetList = std::vector<std::unique_ptr<WarningGadget>>;
558  
559  /// An increment of a pointer-type value is unsafe as it may run the pointer
560  /// out of bounds.
561  class IncrementGadget : public WarningGadget {
562    static constexpr const char *const OpTag = "op";
563    const UnaryOperator *Op;
564  
565  public:
IncrementGadget(const MatchFinder::MatchResult & Result)566    IncrementGadget(const MatchFinder::MatchResult &Result)
567        : WarningGadget(Kind::Increment),
568          Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {}
569  
classof(const Gadget * G)570    static bool classof(const Gadget *G) {
571      return G->getKind() == Kind::Increment;
572    }
573  
matcher()574    static Matcher matcher() {
575      return stmt(
576          unaryOperator(hasOperatorName("++"),
577                        hasUnaryOperand(ignoringParenImpCasts(hasPointerType())))
578              .bind(OpTag));
579    }
580  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const581    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
582                               bool IsRelatedToDecl,
583                               ASTContext &Ctx) const override {
584      Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
585    }
getSourceLoc() const586    SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
587  
getClaimedVarUseSites() const588    DeclUseList getClaimedVarUseSites() const override {
589      SmallVector<const DeclRefExpr *, 2> Uses;
590      if (const auto *DRE =
591              dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
592        Uses.push_back(DRE);
593      }
594  
595      return std::move(Uses);
596    }
597  };
598  
599  /// A decrement of a pointer-type value is unsafe as it may run the pointer
600  /// out of bounds.
601  class DecrementGadget : public WarningGadget {
602    static constexpr const char *const OpTag = "op";
603    const UnaryOperator *Op;
604  
605  public:
DecrementGadget(const MatchFinder::MatchResult & Result)606    DecrementGadget(const MatchFinder::MatchResult &Result)
607        : WarningGadget(Kind::Decrement),
608          Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {}
609  
classof(const Gadget * G)610    static bool classof(const Gadget *G) {
611      return G->getKind() == Kind::Decrement;
612    }
613  
matcher()614    static Matcher matcher() {
615      return stmt(
616          unaryOperator(hasOperatorName("--"),
617                        hasUnaryOperand(ignoringParenImpCasts(hasPointerType())))
618              .bind(OpTag));
619    }
620  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const621    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
622                               bool IsRelatedToDecl,
623                               ASTContext &Ctx) const override {
624      Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
625    }
getSourceLoc() const626    SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
627  
getClaimedVarUseSites() const628    DeclUseList getClaimedVarUseSites() const override {
629      if (const auto *DRE =
630              dyn_cast<DeclRefExpr>(Op->getSubExpr()->IgnoreParenImpCasts())) {
631        return {DRE};
632      }
633  
634      return {};
635    }
636  };
637  
638  /// Array subscript expressions on raw pointers as if they're arrays. Unsafe as
639  /// it doesn't have any bounds checks for the array.
640  class ArraySubscriptGadget : public WarningGadget {
641    static constexpr const char *const ArraySubscrTag = "ArraySubscript";
642    const ArraySubscriptExpr *ASE;
643  
644  public:
ArraySubscriptGadget(const MatchFinder::MatchResult & Result)645    ArraySubscriptGadget(const MatchFinder::MatchResult &Result)
646        : WarningGadget(Kind::ArraySubscript),
647          ASE(Result.Nodes.getNodeAs<ArraySubscriptExpr>(ArraySubscrTag)) {}
648  
classof(const Gadget * G)649    static bool classof(const Gadget *G) {
650      return G->getKind() == Kind::ArraySubscript;
651    }
652  
matcher()653    static Matcher matcher() {
654      // clang-format off
655        return stmt(arraySubscriptExpr(
656              hasBase(ignoringParenImpCasts(
657                anyOf(hasPointerType(), hasArrayType()))),
658              unless(anyOf(
659                isSafeArraySubscript(),
660                hasIndex(
661                    anyOf(integerLiteral(equals(0)), arrayInitIndexExpr())
662                )
663              ))).bind(ArraySubscrTag));
664      // clang-format on
665    }
666  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const667    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
668                               bool IsRelatedToDecl,
669                               ASTContext &Ctx) const override {
670      Handler.handleUnsafeOperation(ASE, IsRelatedToDecl, Ctx);
671    }
getSourceLoc() const672    SourceLocation getSourceLoc() const override { return ASE->getBeginLoc(); }
673  
getClaimedVarUseSites() const674    DeclUseList getClaimedVarUseSites() const override {
675      if (const auto *DRE =
676              dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts())) {
677        return {DRE};
678      }
679  
680      return {};
681    }
682  };
683  
684  /// A pointer arithmetic expression of one of the forms:
685  ///  \code
686  ///  ptr + n | n + ptr | ptr - n | ptr += n | ptr -= n
687  ///  \endcode
688  class PointerArithmeticGadget : public WarningGadget {
689    static constexpr const char *const PointerArithmeticTag = "ptrAdd";
690    static constexpr const char *const PointerArithmeticPointerTag = "ptrAddPtr";
691    const BinaryOperator *PA; // pointer arithmetic expression
692    const Expr *Ptr;          // the pointer expression in `PA`
693  
694  public:
PointerArithmeticGadget(const MatchFinder::MatchResult & Result)695    PointerArithmeticGadget(const MatchFinder::MatchResult &Result)
696        : WarningGadget(Kind::PointerArithmetic),
697          PA(Result.Nodes.getNodeAs<BinaryOperator>(PointerArithmeticTag)),
698          Ptr(Result.Nodes.getNodeAs<Expr>(PointerArithmeticPointerTag)) {}
699  
classof(const Gadget * G)700    static bool classof(const Gadget *G) {
701      return G->getKind() == Kind::PointerArithmetic;
702    }
703  
matcher()704    static Matcher matcher() {
705      auto HasIntegerType = anyOf(hasType(isInteger()), hasType(enumType()));
706      auto PtrAtRight =
707          allOf(hasOperatorName("+"),
708                hasRHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)),
709                hasLHS(HasIntegerType));
710      auto PtrAtLeft =
711          allOf(anyOf(hasOperatorName("+"), hasOperatorName("-"),
712                      hasOperatorName("+="), hasOperatorName("-=")),
713                hasLHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)),
714                hasRHS(HasIntegerType));
715  
716      return stmt(binaryOperator(anyOf(PtrAtLeft, PtrAtRight))
717                      .bind(PointerArithmeticTag));
718    }
719  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const720    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
721                               bool IsRelatedToDecl,
722                               ASTContext &Ctx) const override {
723      Handler.handleUnsafeOperation(PA, IsRelatedToDecl, Ctx);
724    }
getSourceLoc() const725    SourceLocation getSourceLoc() const override { return PA->getBeginLoc(); }
726  
getClaimedVarUseSites() const727    DeclUseList getClaimedVarUseSites() const override {
728      if (const auto *DRE = dyn_cast<DeclRefExpr>(Ptr->IgnoreParenImpCasts())) {
729        return {DRE};
730      }
731  
732      return {};
733    }
734    // FIXME: pointer adding zero should be fine
735    // FIXME: this gadge will need a fix-it
736  };
737  
738  class SpanTwoParamConstructorGadget : public WarningGadget {
739    static constexpr const char *const SpanTwoParamConstructorTag =
740        "spanTwoParamConstructor";
741    const CXXConstructExpr *Ctor; // the span constructor expression
742  
743  public:
SpanTwoParamConstructorGadget(const MatchFinder::MatchResult & Result)744    SpanTwoParamConstructorGadget(const MatchFinder::MatchResult &Result)
745        : WarningGadget(Kind::SpanTwoParamConstructor),
746          Ctor(Result.Nodes.getNodeAs<CXXConstructExpr>(
747              SpanTwoParamConstructorTag)) {}
748  
classof(const Gadget * G)749    static bool classof(const Gadget *G) {
750      return G->getKind() == Kind::SpanTwoParamConstructor;
751    }
752  
matcher()753    static Matcher matcher() {
754      auto HasTwoParamSpanCtorDecl = hasDeclaration(
755          cxxConstructorDecl(hasDeclContext(isInStdNamespace()), hasName("span"),
756                             parameterCountIs(2)));
757  
758      return stmt(cxxConstructExpr(HasTwoParamSpanCtorDecl,
759                                   unless(isSafeSpanTwoParamConstruct()))
760                      .bind(SpanTwoParamConstructorTag));
761    }
762  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const763    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
764                               bool IsRelatedToDecl,
765                               ASTContext &Ctx) const override {
766      Handler.handleUnsafeOperationInContainer(Ctor, IsRelatedToDecl, Ctx);
767    }
getSourceLoc() const768    SourceLocation getSourceLoc() const override { return Ctor->getBeginLoc(); }
769  
getClaimedVarUseSites() const770    DeclUseList getClaimedVarUseSites() const override {
771      // If the constructor call is of the form `std::span{var, n}`, `var` is
772      // considered an unsafe variable.
773      if (auto *DRE = dyn_cast<DeclRefExpr>(Ctor->getArg(0))) {
774        if (isa<VarDecl>(DRE->getDecl()))
775          return {DRE};
776      }
777      return {};
778    }
779  };
780  
781  /// A pointer initialization expression of the form:
782  ///  \code
783  ///  int *p = q;
784  ///  \endcode
785  class PointerInitGadget : public FixableGadget {
786  private:
787    static constexpr const char *const PointerInitLHSTag = "ptrInitLHS";
788    static constexpr const char *const PointerInitRHSTag = "ptrInitRHS";
789    const VarDecl *PtrInitLHS;     // the LHS pointer expression in `PI`
790    const DeclRefExpr *PtrInitRHS; // the RHS pointer expression in `PI`
791  
792  public:
PointerInitGadget(const MatchFinder::MatchResult & Result)793    PointerInitGadget(const MatchFinder::MatchResult &Result)
794        : FixableGadget(Kind::PointerInit),
795          PtrInitLHS(Result.Nodes.getNodeAs<VarDecl>(PointerInitLHSTag)),
796          PtrInitRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerInitRHSTag)) {}
797  
classof(const Gadget * G)798    static bool classof(const Gadget *G) {
799      return G->getKind() == Kind::PointerInit;
800    }
801  
matcher()802    static Matcher matcher() {
803      auto PtrInitStmt = declStmt(hasSingleDecl(
804          varDecl(hasInitializer(ignoringImpCasts(
805                      declRefExpr(hasPointerType(), toSupportedVariable())
806                          .bind(PointerInitRHSTag))))
807              .bind(PointerInitLHSTag)));
808  
809      return stmt(PtrInitStmt);
810    }
811  
812    virtual std::optional<FixItList>
813    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const814    SourceLocation getSourceLoc() const override {
815      return PtrInitRHS->getBeginLoc();
816    }
817  
getClaimedVarUseSites() const818    virtual DeclUseList getClaimedVarUseSites() const override {
819      return DeclUseList{PtrInitRHS};
820    }
821  
822    virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const823    getStrategyImplications() const override {
824      return std::make_pair(PtrInitLHS, cast<VarDecl>(PtrInitRHS->getDecl()));
825    }
826  };
827  
828  /// A pointer assignment expression of the form:
829  ///  \code
830  ///  p = q;
831  ///  \endcode
832  /// where both `p` and `q` are pointers.
833  class PtrToPtrAssignmentGadget : public FixableGadget {
834  private:
835    static constexpr const char *const PointerAssignLHSTag = "ptrLHS";
836    static constexpr const char *const PointerAssignRHSTag = "ptrRHS";
837    const DeclRefExpr *PtrLHS; // the LHS pointer expression in `PA`
838    const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA`
839  
840  public:
PtrToPtrAssignmentGadget(const MatchFinder::MatchResult & Result)841    PtrToPtrAssignmentGadget(const MatchFinder::MatchResult &Result)
842        : FixableGadget(Kind::PtrToPtrAssignment),
843          PtrLHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)),
844          PtrRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {}
845  
classof(const Gadget * G)846    static bool classof(const Gadget *G) {
847      return G->getKind() == Kind::PtrToPtrAssignment;
848    }
849  
matcher()850    static Matcher matcher() {
851      auto PtrAssignExpr = binaryOperator(
852          allOf(hasOperatorName("="),
853                hasRHS(ignoringParenImpCasts(
854                    declRefExpr(hasPointerType(), toSupportedVariable())
855                        .bind(PointerAssignRHSTag))),
856                hasLHS(declRefExpr(hasPointerType(), toSupportedVariable())
857                           .bind(PointerAssignLHSTag))));
858  
859      return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr));
860    }
861  
862    virtual std::optional<FixItList>
863    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const864    SourceLocation getSourceLoc() const override { return PtrLHS->getBeginLoc(); }
865  
getClaimedVarUseSites() const866    virtual DeclUseList getClaimedVarUseSites() const override {
867      return DeclUseList{PtrLHS, PtrRHS};
868    }
869  
870    virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const871    getStrategyImplications() const override {
872      return std::make_pair(cast<VarDecl>(PtrLHS->getDecl()),
873                            cast<VarDecl>(PtrRHS->getDecl()));
874    }
875  };
876  
877  /// An assignment expression of the form:
878  ///  \code
879  ///  ptr = array;
880  ///  \endcode
881  /// where `p` is a pointer and `array` is a constant size array.
882  class CArrayToPtrAssignmentGadget : public FixableGadget {
883  private:
884    static constexpr const char *const PointerAssignLHSTag = "ptrLHS";
885    static constexpr const char *const PointerAssignRHSTag = "ptrRHS";
886    const DeclRefExpr *PtrLHS; // the LHS pointer expression in `PA`
887    const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA`
888  
889  public:
CArrayToPtrAssignmentGadget(const MatchFinder::MatchResult & Result)890    CArrayToPtrAssignmentGadget(const MatchFinder::MatchResult &Result)
891        : FixableGadget(Kind::CArrayToPtrAssignment),
892          PtrLHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)),
893          PtrRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {}
894  
classof(const Gadget * G)895    static bool classof(const Gadget *G) {
896      return G->getKind() == Kind::CArrayToPtrAssignment;
897    }
898  
matcher()899    static Matcher matcher() {
900      auto PtrAssignExpr = binaryOperator(
901          allOf(hasOperatorName("="),
902                hasRHS(ignoringParenImpCasts(
903                    declRefExpr(hasType(hasCanonicalType(constantArrayType())),
904                                toSupportedVariable())
905                        .bind(PointerAssignRHSTag))),
906                hasLHS(declRefExpr(hasPointerType(), toSupportedVariable())
907                           .bind(PointerAssignLHSTag))));
908  
909      return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr));
910    }
911  
912    virtual std::optional<FixItList>
913    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const914    SourceLocation getSourceLoc() const override { return PtrLHS->getBeginLoc(); }
915  
getClaimedVarUseSites() const916    virtual DeclUseList getClaimedVarUseSites() const override {
917      return DeclUseList{PtrLHS, PtrRHS};
918    }
919  
920    virtual std::optional<std::pair<const VarDecl *, const VarDecl *>>
getStrategyImplications() const921    getStrategyImplications() const override {
922      return {};
923    }
924  };
925  
926  /// A call of a function or method that performs unchecked buffer operations
927  /// over one of its pointer parameters.
928  class UnsafeBufferUsageAttrGadget : public WarningGadget {
929    constexpr static const char *const OpTag = "call_expr";
930    const CallExpr *Op;
931  
932  public:
UnsafeBufferUsageAttrGadget(const MatchFinder::MatchResult & Result)933    UnsafeBufferUsageAttrGadget(const MatchFinder::MatchResult &Result)
934        : WarningGadget(Kind::UnsafeBufferUsageAttr),
935          Op(Result.Nodes.getNodeAs<CallExpr>(OpTag)) {}
936  
classof(const Gadget * G)937    static bool classof(const Gadget *G) {
938      return G->getKind() == Kind::UnsafeBufferUsageAttr;
939    }
940  
matcher()941    static Matcher matcher() {
942      auto HasUnsafeFnDecl =
943          callee(functionDecl(hasAttr(attr::UnsafeBufferUsage)));
944      return stmt(callExpr(HasUnsafeFnDecl).bind(OpTag));
945    }
946  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const947    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
948                               bool IsRelatedToDecl,
949                               ASTContext &Ctx) const override {
950      Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
951    }
getSourceLoc() const952    SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
953  
getClaimedVarUseSites() const954    DeclUseList getClaimedVarUseSites() const override { return {}; }
955  };
956  
957  /// A call of a constructor that performs unchecked buffer operations
958  /// over one of its pointer parameters, or constructs a class object that will
959  /// perform buffer operations that depend on the correctness of the parameters.
960  class UnsafeBufferUsageCtorAttrGadget : public WarningGadget {
961    constexpr static const char *const OpTag = "cxx_construct_expr";
962    const CXXConstructExpr *Op;
963  
964  public:
UnsafeBufferUsageCtorAttrGadget(const MatchFinder::MatchResult & Result)965    UnsafeBufferUsageCtorAttrGadget(const MatchFinder::MatchResult &Result)
966        : WarningGadget(Kind::UnsafeBufferUsageCtorAttr),
967          Op(Result.Nodes.getNodeAs<CXXConstructExpr>(OpTag)) {}
968  
classof(const Gadget * G)969    static bool classof(const Gadget *G) {
970      return G->getKind() == Kind::UnsafeBufferUsageCtorAttr;
971    }
972  
matcher()973    static Matcher matcher() {
974      auto HasUnsafeCtorDecl =
975          hasDeclaration(cxxConstructorDecl(hasAttr(attr::UnsafeBufferUsage)));
976      // std::span(ptr, size) ctor is handled by SpanTwoParamConstructorGadget.
977      auto HasTwoParamSpanCtorDecl = SpanTwoParamConstructorGadget::matcher();
978      return stmt(
979          cxxConstructExpr(HasUnsafeCtorDecl, unless(HasTwoParamSpanCtorDecl))
980              .bind(OpTag));
981    }
982  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const983    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
984                               bool IsRelatedToDecl,
985                               ASTContext &Ctx) const override {
986      Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
987    }
getSourceLoc() const988    SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
989  
getClaimedVarUseSites() const990    DeclUseList getClaimedVarUseSites() const override { return {}; }
991  };
992  
993  // Warning gadget for unsafe invocation of span::data method.
994  // Triggers when the pointer returned by the invocation is immediately
995  // cast to a larger type.
996  
997  class DataInvocationGadget : public WarningGadget {
998    constexpr static const char *const OpTag = "data_invocation_expr";
999    const ExplicitCastExpr *Op;
1000  
1001  public:
DataInvocationGadget(const MatchFinder::MatchResult & Result)1002    DataInvocationGadget(const MatchFinder::MatchResult &Result)
1003        : WarningGadget(Kind::DataInvocation),
1004          Op(Result.Nodes.getNodeAs<ExplicitCastExpr>(OpTag)) {}
1005  
classof(const Gadget * G)1006    static bool classof(const Gadget *G) {
1007      return G->getKind() == Kind::DataInvocation;
1008    }
1009  
matcher()1010    static Matcher matcher() {
1011      Matcher callExpr = cxxMemberCallExpr(
1012          callee(cxxMethodDecl(hasName("data"), ofClass(hasName("std::span")))));
1013      return stmt(
1014          explicitCastExpr(anyOf(has(callExpr), has(parenExpr(has(callExpr)))))
1015              .bind(OpTag));
1016    }
1017  
handleUnsafeOperation(UnsafeBufferUsageHandler & Handler,bool IsRelatedToDecl,ASTContext & Ctx) const1018    void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
1019                               bool IsRelatedToDecl,
1020                               ASTContext &Ctx) const override {
1021      Handler.handleUnsafeOperation(Op, IsRelatedToDecl, Ctx);
1022    }
getSourceLoc() const1023    SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1024  
getClaimedVarUseSites() const1025    DeclUseList getClaimedVarUseSites() const override { return {}; }
1026  };
1027  
1028  // Represents expressions of the form `DRE[*]` in the Unspecified Lvalue
1029  // Context (see `isInUnspecifiedLvalueContext`).
1030  // Note here `[]` is the built-in subscript operator.
1031  class ULCArraySubscriptGadget : public FixableGadget {
1032  private:
1033    static constexpr const char *const ULCArraySubscriptTag =
1034        "ArraySubscriptUnderULC";
1035    const ArraySubscriptExpr *Node;
1036  
1037  public:
ULCArraySubscriptGadget(const MatchFinder::MatchResult & Result)1038    ULCArraySubscriptGadget(const MatchFinder::MatchResult &Result)
1039        : FixableGadget(Kind::ULCArraySubscript),
1040          Node(Result.Nodes.getNodeAs<ArraySubscriptExpr>(ULCArraySubscriptTag)) {
1041      assert(Node != nullptr && "Expecting a non-null matching result");
1042    }
1043  
classof(const Gadget * G)1044    static bool classof(const Gadget *G) {
1045      return G->getKind() == Kind::ULCArraySubscript;
1046    }
1047  
matcher()1048    static Matcher matcher() {
1049      auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType());
1050      auto BaseIsArrayOrPtrDRE = hasBase(
1051          ignoringParenImpCasts(declRefExpr(ArrayOrPtr, toSupportedVariable())));
1052      auto Target =
1053          arraySubscriptExpr(BaseIsArrayOrPtrDRE).bind(ULCArraySubscriptTag);
1054  
1055      return expr(isInUnspecifiedLvalueContext(Target));
1056    }
1057  
1058    virtual std::optional<FixItList>
1059    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1060    SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1061  
getClaimedVarUseSites() const1062    virtual DeclUseList getClaimedVarUseSites() const override {
1063      if (const auto *DRE =
1064              dyn_cast<DeclRefExpr>(Node->getBase()->IgnoreImpCasts())) {
1065        return {DRE};
1066      }
1067      return {};
1068    }
1069  };
1070  
1071  // Fixable gadget to handle stand alone pointers of the form `UPC(DRE)` in the
1072  // unspecified pointer context (isInUnspecifiedPointerContext). The gadget emits
1073  // fixit of the form `UPC(DRE.data())`.
1074  class UPCStandalonePointerGadget : public FixableGadget {
1075  private:
1076    static constexpr const char *const DeclRefExprTag = "StandalonePointer";
1077    const DeclRefExpr *Node;
1078  
1079  public:
UPCStandalonePointerGadget(const MatchFinder::MatchResult & Result)1080    UPCStandalonePointerGadget(const MatchFinder::MatchResult &Result)
1081        : FixableGadget(Kind::UPCStandalonePointer),
1082          Node(Result.Nodes.getNodeAs<DeclRefExpr>(DeclRefExprTag)) {
1083      assert(Node != nullptr && "Expecting a non-null matching result");
1084    }
1085  
classof(const Gadget * G)1086    static bool classof(const Gadget *G) {
1087      return G->getKind() == Kind::UPCStandalonePointer;
1088    }
1089  
matcher()1090    static Matcher matcher() {
1091      auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType());
1092      auto target = expr(ignoringParenImpCasts(
1093          declRefExpr(allOf(ArrayOrPtr, toSupportedVariable()))
1094              .bind(DeclRefExprTag)));
1095      return stmt(isInUnspecifiedPointerContext(target));
1096    }
1097  
1098    virtual std::optional<FixItList>
1099    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1100    SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1101  
getClaimedVarUseSites() const1102    virtual DeclUseList getClaimedVarUseSites() const override { return {Node}; }
1103  };
1104  
1105  class PointerDereferenceGadget : public FixableGadget {
1106    static constexpr const char *const BaseDeclRefExprTag = "BaseDRE";
1107    static constexpr const char *const OperatorTag = "op";
1108  
1109    const DeclRefExpr *BaseDeclRefExpr = nullptr;
1110    const UnaryOperator *Op = nullptr;
1111  
1112  public:
PointerDereferenceGadget(const MatchFinder::MatchResult & Result)1113    PointerDereferenceGadget(const MatchFinder::MatchResult &Result)
1114        : FixableGadget(Kind::PointerDereference),
1115          BaseDeclRefExpr(
1116              Result.Nodes.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)),
1117          Op(Result.Nodes.getNodeAs<UnaryOperator>(OperatorTag)) {}
1118  
classof(const Gadget * G)1119    static bool classof(const Gadget *G) {
1120      return G->getKind() == Kind::PointerDereference;
1121    }
1122  
matcher()1123    static Matcher matcher() {
1124      auto Target =
1125          unaryOperator(
1126              hasOperatorName("*"),
1127              has(expr(ignoringParenImpCasts(
1128                  declRefExpr(toSupportedVariable()).bind(BaseDeclRefExprTag)))))
1129              .bind(OperatorTag);
1130  
1131      return expr(isInUnspecifiedLvalueContext(Target));
1132    }
1133  
getClaimedVarUseSites() const1134    DeclUseList getClaimedVarUseSites() const override {
1135      return {BaseDeclRefExpr};
1136    }
1137  
1138    virtual std::optional<FixItList>
1139    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1140    SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); }
1141  };
1142  
1143  // Represents expressions of the form `&DRE[any]` in the Unspecified Pointer
1144  // Context (see `isInUnspecifiedPointerContext`).
1145  // Note here `[]` is the built-in subscript operator.
1146  class UPCAddressofArraySubscriptGadget : public FixableGadget {
1147  private:
1148    static constexpr const char *const UPCAddressofArraySubscriptTag =
1149        "AddressofArraySubscriptUnderUPC";
1150    const UnaryOperator *Node; // the `&DRE[any]` node
1151  
1152  public:
UPCAddressofArraySubscriptGadget(const MatchFinder::MatchResult & Result)1153    UPCAddressofArraySubscriptGadget(const MatchFinder::MatchResult &Result)
1154        : FixableGadget(Kind::ULCArraySubscript),
1155          Node(Result.Nodes.getNodeAs<UnaryOperator>(
1156              UPCAddressofArraySubscriptTag)) {
1157      assert(Node != nullptr && "Expecting a non-null matching result");
1158    }
1159  
classof(const Gadget * G)1160    static bool classof(const Gadget *G) {
1161      return G->getKind() == Kind::UPCAddressofArraySubscript;
1162    }
1163  
matcher()1164    static Matcher matcher() {
1165      return expr(isInUnspecifiedPointerContext(expr(ignoringImpCasts(
1166          unaryOperator(
1167              hasOperatorName("&"),
1168              hasUnaryOperand(arraySubscriptExpr(hasBase(
1169                  ignoringParenImpCasts(declRefExpr(toSupportedVariable()))))))
1170              .bind(UPCAddressofArraySubscriptTag)))));
1171    }
1172  
1173    virtual std::optional<FixItList>
1174    getFixits(const FixitStrategy &) const override;
getSourceLoc() const1175    SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1176  
getClaimedVarUseSites() const1177    virtual DeclUseList getClaimedVarUseSites() const override {
1178      const auto *ArraySubst = cast<ArraySubscriptExpr>(Node->getSubExpr());
1179      const auto *DRE =
1180          cast<DeclRefExpr>(ArraySubst->getBase()->IgnoreParenImpCasts());
1181      return {DRE};
1182    }
1183  };
1184  } // namespace
1185  
1186  namespace {
1187  // An auxiliary tracking facility for the fixit analysis. It helps connect
1188  // declarations to its uses and make sure we've covered all uses with our
1189  // analysis before we try to fix the declaration.
1190  class DeclUseTracker {
1191    using UseSetTy = SmallSet<const DeclRefExpr *, 16>;
1192    using DefMapTy = DenseMap<const VarDecl *, const DeclStmt *>;
1193  
1194    // Allocate on the heap for easier move.
1195    std::unique_ptr<UseSetTy> Uses{std::make_unique<UseSetTy>()};
1196    DefMapTy Defs{};
1197  
1198  public:
1199    DeclUseTracker() = default;
1200    DeclUseTracker(const DeclUseTracker &) = delete; // Let's avoid copies.
1201    DeclUseTracker &operator=(const DeclUseTracker &) = delete;
1202    DeclUseTracker(DeclUseTracker &&) = default;
1203    DeclUseTracker &operator=(DeclUseTracker &&) = default;
1204  
1205    // Start tracking a freshly discovered DRE.
discoverUse(const DeclRefExpr * DRE)1206    void discoverUse(const DeclRefExpr *DRE) { Uses->insert(DRE); }
1207  
1208    // Stop tracking the DRE as it's been fully figured out.
claimUse(const DeclRefExpr * DRE)1209    void claimUse(const DeclRefExpr *DRE) {
1210      assert(Uses->count(DRE) &&
1211             "DRE not found or claimed by multiple matchers!");
1212      Uses->erase(DRE);
1213    }
1214  
1215    // A variable is unclaimed if at least one use is unclaimed.
hasUnclaimedUses(const VarDecl * VD) const1216    bool hasUnclaimedUses(const VarDecl *VD) const {
1217      // FIXME: Can this be less linear? Maybe maintain a map from VDs to DREs?
1218      return any_of(*Uses, [VD](const DeclRefExpr *DRE) {
1219        return DRE->getDecl()->getCanonicalDecl() == VD->getCanonicalDecl();
1220      });
1221    }
1222  
getUnclaimedUses(const VarDecl * VD) const1223    UseSetTy getUnclaimedUses(const VarDecl *VD) const {
1224      UseSetTy ReturnSet;
1225      for (auto use : *Uses) {
1226        if (use->getDecl()->getCanonicalDecl() == VD->getCanonicalDecl()) {
1227          ReturnSet.insert(use);
1228        }
1229      }
1230      return ReturnSet;
1231    }
1232  
discoverDecl(const DeclStmt * DS)1233    void discoverDecl(const DeclStmt *DS) {
1234      for (const Decl *D : DS->decls()) {
1235        if (const auto *VD = dyn_cast<VarDecl>(D)) {
1236          // FIXME: Assertion temporarily disabled due to a bug in
1237          // ASTMatcher internal behavior in presence of GNU
1238          // statement-expressions. We need to properly investigate this
1239          // because it can screw up our algorithm in other ways.
1240          // assert(Defs.count(VD) == 0 && "Definition already discovered!");
1241          Defs[VD] = DS;
1242        }
1243      }
1244    }
1245  
lookupDecl(const VarDecl * VD) const1246    const DeclStmt *lookupDecl(const VarDecl *VD) const {
1247      return Defs.lookup(VD);
1248    }
1249  };
1250  } // namespace
1251  
1252  // Representing a pointer type expression of the form `++Ptr` in an Unspecified
1253  // Pointer Context (UPC):
1254  class UPCPreIncrementGadget : public FixableGadget {
1255  private:
1256    static constexpr const char *const UPCPreIncrementTag =
1257        "PointerPreIncrementUnderUPC";
1258    const UnaryOperator *Node; // the `++Ptr` node
1259  
1260  public:
UPCPreIncrementGadget(const MatchFinder::MatchResult & Result)1261    UPCPreIncrementGadget(const MatchFinder::MatchResult &Result)
1262        : FixableGadget(Kind::UPCPreIncrement),
1263          Node(Result.Nodes.getNodeAs<UnaryOperator>(UPCPreIncrementTag)) {
1264      assert(Node != nullptr && "Expecting a non-null matching result");
1265    }
1266  
classof(const Gadget * G)1267    static bool classof(const Gadget *G) {
1268      return G->getKind() == Kind::UPCPreIncrement;
1269    }
1270  
matcher()1271    static Matcher matcher() {
1272      // Note here we match `++Ptr` for any expression `Ptr` of pointer type.
1273      // Although currently we can only provide fix-its when `Ptr` is a DRE, we
1274      // can have the matcher be general, so long as `getClaimedVarUseSites` does
1275      // things right.
1276      return stmt(isInUnspecifiedPointerContext(expr(ignoringImpCasts(
1277          unaryOperator(isPreInc(),
1278                        hasUnaryOperand(declRefExpr(toSupportedVariable())))
1279              .bind(UPCPreIncrementTag)))));
1280    }
1281  
1282    virtual std::optional<FixItList>
1283    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1284    SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1285  
getClaimedVarUseSites() const1286    virtual DeclUseList getClaimedVarUseSites() const override {
1287      return {dyn_cast<DeclRefExpr>(Node->getSubExpr())};
1288    }
1289  };
1290  
1291  // Representing a pointer type expression of the form `Ptr += n` in an
1292  // Unspecified Untyped Context (UUC):
1293  class UUCAddAssignGadget : public FixableGadget {
1294  private:
1295    static constexpr const char *const UUCAddAssignTag =
1296        "PointerAddAssignUnderUUC";
1297    static constexpr const char *const OffsetTag = "Offset";
1298  
1299    const BinaryOperator *Node; // the `Ptr += n` node
1300    const Expr *Offset = nullptr;
1301  
1302  public:
UUCAddAssignGadget(const MatchFinder::MatchResult & Result)1303    UUCAddAssignGadget(const MatchFinder::MatchResult &Result)
1304        : FixableGadget(Kind::UUCAddAssign),
1305          Node(Result.Nodes.getNodeAs<BinaryOperator>(UUCAddAssignTag)),
1306          Offset(Result.Nodes.getNodeAs<Expr>(OffsetTag)) {
1307      assert(Node != nullptr && "Expecting a non-null matching result");
1308    }
1309  
classof(const Gadget * G)1310    static bool classof(const Gadget *G) {
1311      return G->getKind() == Kind::UUCAddAssign;
1312    }
1313  
matcher()1314    static Matcher matcher() {
1315      // clang-format off
1316      return stmt(isInUnspecifiedUntypedContext(expr(ignoringImpCasts(
1317          binaryOperator(hasOperatorName("+="),
1318                         hasLHS(
1319                          declRefExpr(
1320                            hasPointerType(),
1321                            toSupportedVariable())),
1322                         hasRHS(expr().bind(OffsetTag)))
1323              .bind(UUCAddAssignTag)))));
1324      // clang-format on
1325    }
1326  
1327    virtual std::optional<FixItList>
1328    getFixits(const FixitStrategy &S) const override;
getSourceLoc() const1329    SourceLocation getSourceLoc() const override { return Node->getBeginLoc(); }
1330  
getClaimedVarUseSites() const1331    virtual DeclUseList getClaimedVarUseSites() const override {
1332      return {dyn_cast<DeclRefExpr>(Node->getLHS())};
1333    }
1334  };
1335  
1336  // Representing a fixable expression of the form `*(ptr + 123)` or `*(123 +
1337  // ptr)`:
1338  class DerefSimplePtrArithFixableGadget : public FixableGadget {
1339    static constexpr const char *const BaseDeclRefExprTag = "BaseDRE";
1340    static constexpr const char *const DerefOpTag = "DerefOp";
1341    static constexpr const char *const AddOpTag = "AddOp";
1342    static constexpr const char *const OffsetTag = "Offset";
1343  
1344    const DeclRefExpr *BaseDeclRefExpr = nullptr;
1345    const UnaryOperator *DerefOp = nullptr;
1346    const BinaryOperator *AddOp = nullptr;
1347    const IntegerLiteral *Offset = nullptr;
1348  
1349  public:
DerefSimplePtrArithFixableGadget(const MatchFinder::MatchResult & Result)1350    DerefSimplePtrArithFixableGadget(const MatchFinder::MatchResult &Result)
1351        : FixableGadget(Kind::DerefSimplePtrArithFixable),
1352          BaseDeclRefExpr(
1353              Result.Nodes.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)),
1354          DerefOp(Result.Nodes.getNodeAs<UnaryOperator>(DerefOpTag)),
1355          AddOp(Result.Nodes.getNodeAs<BinaryOperator>(AddOpTag)),
1356          Offset(Result.Nodes.getNodeAs<IntegerLiteral>(OffsetTag)) {}
1357  
matcher()1358    static Matcher matcher() {
1359      // clang-format off
1360      auto ThePtr = expr(hasPointerType(),
1361                         ignoringImpCasts(declRefExpr(toSupportedVariable()).
1362                                          bind(BaseDeclRefExprTag)));
1363      auto PlusOverPtrAndInteger = expr(anyOf(
1364            binaryOperator(hasOperatorName("+"), hasLHS(ThePtr),
1365                           hasRHS(integerLiteral().bind(OffsetTag)))
1366                           .bind(AddOpTag),
1367            binaryOperator(hasOperatorName("+"), hasRHS(ThePtr),
1368                           hasLHS(integerLiteral().bind(OffsetTag)))
1369                           .bind(AddOpTag)));
1370      return isInUnspecifiedLvalueContext(unaryOperator(
1371          hasOperatorName("*"),
1372          hasUnaryOperand(ignoringParens(PlusOverPtrAndInteger)))
1373          .bind(DerefOpTag));
1374      // clang-format on
1375    }
1376  
1377    virtual std::optional<FixItList>
1378    getFixits(const FixitStrategy &s) const final;
getSourceLoc() const1379    SourceLocation getSourceLoc() const override {
1380      return DerefOp->getBeginLoc();
1381    }
1382  
getClaimedVarUseSites() const1383    virtual DeclUseList getClaimedVarUseSites() const final {
1384      return {BaseDeclRefExpr};
1385    }
1386  };
1387  
1388  /// Scan the function and return a list of gadgets found with provided kits.
1389  static std::tuple<FixableGadgetList, WarningGadgetList, DeclUseTracker>
findGadgets(const Decl * D,const UnsafeBufferUsageHandler & Handler,bool EmitSuggestions)1390  findGadgets(const Decl *D, const UnsafeBufferUsageHandler &Handler,
1391              bool EmitSuggestions) {
1392  
1393    struct GadgetFinderCallback : MatchFinder::MatchCallback {
1394      FixableGadgetList FixableGadgets;
1395      WarningGadgetList WarningGadgets;
1396      DeclUseTracker Tracker;
1397  
1398      void run(const MatchFinder::MatchResult &Result) override {
1399        // In debug mode, assert that we've found exactly one gadget.
1400        // This helps us avoid conflicts in .bind() tags.
1401  #if NDEBUG
1402  #define NEXT return
1403  #else
1404        [[maybe_unused]] int numFound = 0;
1405  #define NEXT ++numFound
1406  #endif
1407  
1408        if (const auto *DRE = Result.Nodes.getNodeAs<DeclRefExpr>("any_dre")) {
1409          Tracker.discoverUse(DRE);
1410          NEXT;
1411        }
1412  
1413        if (const auto *DS = Result.Nodes.getNodeAs<DeclStmt>("any_ds")) {
1414          Tracker.discoverDecl(DS);
1415          NEXT;
1416        }
1417  
1418        // Figure out which matcher we've found, and call the appropriate
1419        // subclass constructor.
1420        // FIXME: Can we do this more logarithmically?
1421  #define FIXABLE_GADGET(name)                                                   \
1422    if (Result.Nodes.getNodeAs<Stmt>(#name)) {                                   \
1423      FixableGadgets.push_back(std::make_unique<name##Gadget>(Result));          \
1424      NEXT;                                                                      \
1425    }
1426  #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1427  #define WARNING_GADGET(name)                                                   \
1428    if (Result.Nodes.getNodeAs<Stmt>(#name)) {                                   \
1429      WarningGadgets.push_back(std::make_unique<name##Gadget>(Result));          \
1430      NEXT;                                                                      \
1431    }
1432  #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1433  
1434        assert(numFound >= 1 && "Gadgets not found in match result!");
1435        assert(numFound <= 1 && "Conflicting bind tags in gadgets!");
1436      }
1437    };
1438  
1439    MatchFinder M;
1440    GadgetFinderCallback CB;
1441  
1442    // clang-format off
1443    M.addMatcher(
1444        stmt(
1445          forEachDescendantEvaluatedStmt(stmt(anyOf(
1446            // Add Gadget::matcher() for every gadget in the registry.
1447  #define WARNING_GADGET(x)                                                      \
1448            allOf(x ## Gadget::matcher().bind(#x),                               \
1449                  notInSafeBufferOptOut(&Handler)),
1450  #define WARNING_CONTAINER_GADGET(x)                                            \
1451            allOf(x ## Gadget::matcher().bind(#x),                               \
1452                  notInSafeBufferOptOut(&Handler),                               \
1453                  unless(ignoreUnsafeBufferInContainer(&Handler))),
1454  #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1455              // Avoid a hanging comma.
1456              unless(stmt())
1457          )))
1458      ),
1459      &CB
1460    );
1461    // clang-format on
1462  
1463    if (EmitSuggestions) {
1464      // clang-format off
1465      M.addMatcher(
1466          stmt(
1467            forEachDescendantStmt(stmt(eachOf(
1468  #define FIXABLE_GADGET(x)                                                      \
1469              x ## Gadget::matcher().bind(#x),
1470  #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def"
1471              // In parallel, match all DeclRefExprs so that to find out
1472              // whether there are any uncovered by gadgets.
1473              declRefExpr(anyOf(hasPointerType(), hasArrayType()),
1474                          to(anyOf(varDecl(), bindingDecl()))).bind("any_dre"),
1475              // Also match DeclStmts because we'll need them when fixing
1476              // their underlying VarDecls that otherwise don't have
1477              // any backreferences to DeclStmts.
1478              declStmt().bind("any_ds")
1479            )))
1480        ),
1481        &CB
1482      );
1483      // clang-format on
1484    }
1485  
1486    M.match(*D->getBody(), D->getASTContext());
1487    return {std::move(CB.FixableGadgets), std::move(CB.WarningGadgets),
1488            std::move(CB.Tracker)};
1489  }
1490  
1491  // Compares AST nodes by source locations.
1492  template <typename NodeTy> struct CompareNode {
operator ()CompareNode1493    bool operator()(const NodeTy *N1, const NodeTy *N2) const {
1494      return N1->getBeginLoc().getRawEncoding() <
1495             N2->getBeginLoc().getRawEncoding();
1496    }
1497  };
1498  
1499  struct WarningGadgetSets {
1500    std::map<const VarDecl *, std::set<const WarningGadget *>,
1501             // To keep keys sorted by their locations in the map so that the
1502             // order is deterministic:
1503             CompareNode<VarDecl>>
1504        byVar;
1505    // These Gadgets are not related to pointer variables (e. g. temporaries).
1506    llvm::SmallVector<const WarningGadget *, 16> noVar;
1507  };
1508  
1509  static WarningGadgetSets
groupWarningGadgetsByVar(const WarningGadgetList & AllUnsafeOperations)1510  groupWarningGadgetsByVar(const WarningGadgetList &AllUnsafeOperations) {
1511    WarningGadgetSets result;
1512    // If some gadgets cover more than one
1513    // variable, they'll appear more than once in the map.
1514    for (auto &G : AllUnsafeOperations) {
1515      DeclUseList ClaimedVarUseSites = G->getClaimedVarUseSites();
1516  
1517      bool AssociatedWithVarDecl = false;
1518      for (const DeclRefExpr *DRE : ClaimedVarUseSites) {
1519        if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1520          result.byVar[VD].insert(G.get());
1521          AssociatedWithVarDecl = true;
1522        }
1523      }
1524  
1525      if (!AssociatedWithVarDecl) {
1526        result.noVar.push_back(G.get());
1527        continue;
1528      }
1529    }
1530    return result;
1531  }
1532  
1533  struct FixableGadgetSets {
1534    std::map<const VarDecl *, std::set<const FixableGadget *>,
1535             // To keep keys sorted by their locations in the map so that the
1536             // order is deterministic:
1537             CompareNode<VarDecl>>
1538        byVar;
1539  };
1540  
1541  static FixableGadgetSets
groupFixablesByVar(FixableGadgetList && AllFixableOperations)1542  groupFixablesByVar(FixableGadgetList &&AllFixableOperations) {
1543    FixableGadgetSets FixablesForUnsafeVars;
1544    for (auto &F : AllFixableOperations) {
1545      DeclUseList DREs = F->getClaimedVarUseSites();
1546  
1547      for (const DeclRefExpr *DRE : DREs) {
1548        if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1549          FixablesForUnsafeVars.byVar[VD].insert(F.get());
1550        }
1551      }
1552    }
1553    return FixablesForUnsafeVars;
1554  }
1555  
anyConflict(const SmallVectorImpl<FixItHint> & FixIts,const SourceManager & SM)1556  bool clang::internal::anyConflict(const SmallVectorImpl<FixItHint> &FixIts,
1557                                    const SourceManager &SM) {
1558    // A simple interval overlap detection algorithm.  Sorts all ranges by their
1559    // begin location then finds the first overlap in one pass.
1560    std::vector<const FixItHint *> All; // a copy of `FixIts`
1561  
1562    for (const FixItHint &H : FixIts)
1563      All.push_back(&H);
1564    std::sort(All.begin(), All.end(),
1565              [&SM](const FixItHint *H1, const FixItHint *H2) {
1566                return SM.isBeforeInTranslationUnit(H1->RemoveRange.getBegin(),
1567                                                    H2->RemoveRange.getBegin());
1568              });
1569  
1570    const FixItHint *CurrHint = nullptr;
1571  
1572    for (const FixItHint *Hint : All) {
1573      if (!CurrHint ||
1574          SM.isBeforeInTranslationUnit(CurrHint->RemoveRange.getEnd(),
1575                                       Hint->RemoveRange.getBegin())) {
1576        // Either to initialize `CurrHint` or `CurrHint` does not
1577        // overlap with `Hint`:
1578        CurrHint = Hint;
1579      } else
1580        // In case `Hint` overlaps the `CurrHint`, we found at least one
1581        // conflict:
1582        return true;
1583    }
1584    return false;
1585  }
1586  
1587  std::optional<FixItList>
getFixits(const FixitStrategy & S) const1588  PtrToPtrAssignmentGadget::getFixits(const FixitStrategy &S) const {
1589    const auto *LeftVD = cast<VarDecl>(PtrLHS->getDecl());
1590    const auto *RightVD = cast<VarDecl>(PtrRHS->getDecl());
1591    switch (S.lookup(LeftVD)) {
1592    case FixitStrategy::Kind::Span:
1593      if (S.lookup(RightVD) == FixitStrategy::Kind::Span)
1594        return FixItList{};
1595      return std::nullopt;
1596    case FixitStrategy::Kind::Wontfix:
1597      return std::nullopt;
1598    case FixitStrategy::Kind::Iterator:
1599    case FixitStrategy::Kind::Array:
1600      return std::nullopt;
1601    case FixitStrategy::Kind::Vector:
1602      llvm_unreachable("unsupported strategies for FixableGadgets");
1603    }
1604    return std::nullopt;
1605  }
1606  
1607  /// \returns fixit that adds .data() call after \DRE.
1608  static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx,
1609                                                         const DeclRefExpr *DRE);
1610  
1611  std::optional<FixItList>
getFixits(const FixitStrategy & S) const1612  CArrayToPtrAssignmentGadget::getFixits(const FixitStrategy &S) const {
1613    const auto *LeftVD = cast<VarDecl>(PtrLHS->getDecl());
1614    const auto *RightVD = cast<VarDecl>(PtrRHS->getDecl());
1615    // TLDR: Implementing fixits for non-Wontfix strategy on both LHS and RHS is
1616    // non-trivial.
1617    //
1618    // CArrayToPtrAssignmentGadget doesn't have strategy implications because
1619    // constant size array propagates its bounds. Because of that LHS and RHS are
1620    // addressed by two different fixits.
1621    //
1622    // At the same time FixitStrategy S doesn't reflect what group a fixit belongs
1623    // to and can't be generally relied on in multi-variable Fixables!
1624    //
1625    // E. g. If an instance of this gadget is fixing variable on LHS then the
1626    // variable on RHS is fixed by a different fixit and its strategy for LHS
1627    // fixit is as if Wontfix.
1628    //
1629    // The only exception is Wontfix strategy for a given variable as that is
1630    // valid for any fixit produced for the given input source code.
1631    if (S.lookup(LeftVD) == FixitStrategy::Kind::Span) {
1632      if (S.lookup(RightVD) == FixitStrategy::Kind::Wontfix) {
1633        return FixItList{};
1634      }
1635    } else if (S.lookup(LeftVD) == FixitStrategy::Kind::Wontfix) {
1636      if (S.lookup(RightVD) == FixitStrategy::Kind::Array) {
1637        return createDataFixit(RightVD->getASTContext(), PtrRHS);
1638      }
1639    }
1640    return std::nullopt;
1641  }
1642  
1643  std::optional<FixItList>
getFixits(const FixitStrategy & S) const1644  PointerInitGadget::getFixits(const FixitStrategy &S) const {
1645    const auto *LeftVD = PtrInitLHS;
1646    const auto *RightVD = cast<VarDecl>(PtrInitRHS->getDecl());
1647    switch (S.lookup(LeftVD)) {
1648    case FixitStrategy::Kind::Span:
1649      if (S.lookup(RightVD) == FixitStrategy::Kind::Span)
1650        return FixItList{};
1651      return std::nullopt;
1652    case FixitStrategy::Kind::Wontfix:
1653      return std::nullopt;
1654    case FixitStrategy::Kind::Iterator:
1655    case FixitStrategy::Kind::Array:
1656      return std::nullopt;
1657    case FixitStrategy::Kind::Vector:
1658      llvm_unreachable("unsupported strategies for FixableGadgets");
1659    }
1660    return std::nullopt;
1661  }
1662  
isNonNegativeIntegerExpr(const Expr * Expr,const VarDecl * VD,const ASTContext & Ctx)1663  static bool isNonNegativeIntegerExpr(const Expr *Expr, const VarDecl *VD,
1664                                       const ASTContext &Ctx) {
1665    if (auto ConstVal = Expr->getIntegerConstantExpr(Ctx)) {
1666      if (ConstVal->isNegative())
1667        return false;
1668    } else if (!Expr->getType()->isUnsignedIntegerType())
1669      return false;
1670    return true;
1671  }
1672  
1673  std::optional<FixItList>
getFixits(const FixitStrategy & S) const1674  ULCArraySubscriptGadget::getFixits(const FixitStrategy &S) const {
1675    if (const auto *DRE =
1676            dyn_cast<DeclRefExpr>(Node->getBase()->IgnoreImpCasts()))
1677      if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1678        switch (S.lookup(VD)) {
1679        case FixitStrategy::Kind::Span: {
1680  
1681          // If the index has a negative constant value, we give up as no valid
1682          // fix-it can be generated:
1683          const ASTContext &Ctx = // FIXME: we need ASTContext to be passed in!
1684              VD->getASTContext();
1685          if (!isNonNegativeIntegerExpr(Node->getIdx(), VD, Ctx))
1686            return std::nullopt;
1687          // no-op is a good fix-it, otherwise
1688          return FixItList{};
1689        }
1690        case FixitStrategy::Kind::Array:
1691          return FixItList{};
1692        case FixitStrategy::Kind::Wontfix:
1693        case FixitStrategy::Kind::Iterator:
1694        case FixitStrategy::Kind::Vector:
1695          llvm_unreachable("unsupported strategies for FixableGadgets");
1696        }
1697      }
1698    return std::nullopt;
1699  }
1700  
1701  static std::optional<FixItList> // forward declaration
1702  fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator *Node);
1703  
1704  std::optional<FixItList>
getFixits(const FixitStrategy & S) const1705  UPCAddressofArraySubscriptGadget::getFixits(const FixitStrategy &S) const {
1706    auto DREs = getClaimedVarUseSites();
1707    const auto *VD = cast<VarDecl>(DREs.front()->getDecl());
1708  
1709    switch (S.lookup(VD)) {
1710    case FixitStrategy::Kind::Span:
1711      return fixUPCAddressofArraySubscriptWithSpan(Node);
1712    case FixitStrategy::Kind::Wontfix:
1713    case FixitStrategy::Kind::Iterator:
1714    case FixitStrategy::Kind::Array:
1715      return std::nullopt;
1716    case FixitStrategy::Kind::Vector:
1717      llvm_unreachable("unsupported strategies for FixableGadgets");
1718    }
1719    return std::nullopt; // something went wrong, no fix-it
1720  }
1721  
1722  // FIXME: this function should be customizable through format
getEndOfLine()1723  static StringRef getEndOfLine() {
1724    static const char *const EOL = "\n";
1725    return EOL;
1726  }
1727  
1728  // Returns the text indicating that the user needs to provide input there:
getUserFillPlaceHolder(StringRef HintTextToUser="placeholder")1729  std::string getUserFillPlaceHolder(StringRef HintTextToUser = "placeholder") {
1730    std::string s = std::string("<# ");
1731    s += HintTextToUser;
1732    s += " #>";
1733    return s;
1734  }
1735  
1736  // Return the source location of the last character of the AST `Node`.
1737  template <typename NodeTy>
1738  static std::optional<SourceLocation>
getEndCharLoc(const NodeTy * Node,const SourceManager & SM,const LangOptions & LangOpts)1739  getEndCharLoc(const NodeTy *Node, const SourceManager &SM,
1740                const LangOptions &LangOpts) {
1741    unsigned TkLen = Lexer::MeasureTokenLength(Node->getEndLoc(), SM, LangOpts);
1742    SourceLocation Loc = Node->getEndLoc().getLocWithOffset(TkLen - 1);
1743  
1744    if (Loc.isValid())
1745      return Loc;
1746  
1747    return std::nullopt;
1748  }
1749  
1750  // Return the source location just past the last character of the AST `Node`.
1751  template <typename NodeTy>
getPastLoc(const NodeTy * Node,const SourceManager & SM,const LangOptions & LangOpts)1752  static std::optional<SourceLocation> getPastLoc(const NodeTy *Node,
1753                                                  const SourceManager &SM,
1754                                                  const LangOptions &LangOpts) {
1755    SourceLocation Loc =
1756        Lexer::getLocForEndOfToken(Node->getEndLoc(), 0, SM, LangOpts);
1757    if (Loc.isValid())
1758      return Loc;
1759    return std::nullopt;
1760  }
1761  
1762  // Return text representation of an `Expr`.
getExprText(const Expr * E,const SourceManager & SM,const LangOptions & LangOpts)1763  static std::optional<StringRef> getExprText(const Expr *E,
1764                                              const SourceManager &SM,
1765                                              const LangOptions &LangOpts) {
1766    std::optional<SourceLocation> LastCharLoc = getPastLoc(E, SM, LangOpts);
1767  
1768    if (LastCharLoc)
1769      return Lexer::getSourceText(
1770          CharSourceRange::getCharRange(E->getBeginLoc(), *LastCharLoc), SM,
1771          LangOpts);
1772  
1773    return std::nullopt;
1774  }
1775  
1776  // Returns the literal text in `SourceRange SR`, if `SR` is a valid range.
getRangeText(SourceRange SR,const SourceManager & SM,const LangOptions & LangOpts)1777  static std::optional<StringRef> getRangeText(SourceRange SR,
1778                                               const SourceManager &SM,
1779                                               const LangOptions &LangOpts) {
1780    bool Invalid = false;
1781    CharSourceRange CSR = CharSourceRange::getCharRange(SR);
1782    StringRef Text = Lexer::getSourceText(CSR, SM, LangOpts, &Invalid);
1783  
1784    if (!Invalid)
1785      return Text;
1786    return std::nullopt;
1787  }
1788  
1789  // Returns the begin location of the identifier of the given variable
1790  // declaration.
getVarDeclIdentifierLoc(const VarDecl * VD)1791  static SourceLocation getVarDeclIdentifierLoc(const VarDecl *VD) {
1792    // According to the implementation of `VarDecl`, `VD->getLocation()` actually
1793    // returns the begin location of the identifier of the declaration:
1794    return VD->getLocation();
1795  }
1796  
1797  // Returns the literal text of the identifier of the given variable declaration.
1798  static std::optional<StringRef>
getVarDeclIdentifierText(const VarDecl * VD,const SourceManager & SM,const LangOptions & LangOpts)1799  getVarDeclIdentifierText(const VarDecl *VD, const SourceManager &SM,
1800                           const LangOptions &LangOpts) {
1801    SourceLocation ParmIdentBeginLoc = getVarDeclIdentifierLoc(VD);
1802    SourceLocation ParmIdentEndLoc =
1803        Lexer::getLocForEndOfToken(ParmIdentBeginLoc, 0, SM, LangOpts);
1804  
1805    if (ParmIdentEndLoc.isMacroID() &&
1806        !Lexer::isAtEndOfMacroExpansion(ParmIdentEndLoc, SM, LangOpts))
1807      return std::nullopt;
1808    return getRangeText({ParmIdentBeginLoc, ParmIdentEndLoc}, SM, LangOpts);
1809  }
1810  
1811  // We cannot fix a variable declaration if it has some other specifiers than the
1812  // type specifier.  Because the source ranges of those specifiers could overlap
1813  // with the source range that is being replaced using fix-its.  Especially when
1814  // we often cannot obtain accurate source ranges of cv-qualified type
1815  // specifiers.
1816  // FIXME: also deal with type attributes
hasUnsupportedSpecifiers(const VarDecl * VD,const SourceManager & SM)1817  static bool hasUnsupportedSpecifiers(const VarDecl *VD,
1818                                       const SourceManager &SM) {
1819    // AttrRangeOverlapping: true if at least one attribute of `VD` overlaps the
1820    // source range of `VD`:
1821    bool AttrRangeOverlapping = llvm::any_of(VD->attrs(), [&](Attr *At) -> bool {
1822      return !(SM.isBeforeInTranslationUnit(At->getRange().getEnd(),
1823                                            VD->getBeginLoc())) &&
1824             !(SM.isBeforeInTranslationUnit(VD->getEndLoc(),
1825                                            At->getRange().getBegin()));
1826    });
1827    return VD->isInlineSpecified() || VD->isConstexpr() ||
1828           VD->hasConstantInitialization() || !VD->hasLocalStorage() ||
1829           AttrRangeOverlapping;
1830  }
1831  
1832  // Returns the `SourceRange` of `D`.  The reason why this function exists is
1833  // that `D->getSourceRange()` may return a range where the end location is the
1834  // starting location of the last token.  The end location of the source range
1835  // returned by this function is the last location of the last token.
getSourceRangeToTokenEnd(const Decl * D,const SourceManager & SM,const LangOptions & LangOpts)1836  static SourceRange getSourceRangeToTokenEnd(const Decl *D,
1837                                              const SourceManager &SM,
1838                                              const LangOptions &LangOpts) {
1839    SourceLocation Begin = D->getBeginLoc();
1840    SourceLocation
1841        End = // `D->getEndLoc` should always return the starting location of the
1842        // last token, so we should get the end of the token
1843        Lexer::getLocForEndOfToken(D->getEndLoc(), 0, SM, LangOpts);
1844  
1845    return SourceRange(Begin, End);
1846  }
1847  
1848  // Returns the text of the pointee type of `T` from a `VarDecl` of a pointer
1849  // type. The text is obtained through from `TypeLoc`s.  Since `TypeLoc` does not
1850  // have source ranges of qualifiers ( The `QualifiedTypeLoc` looks hacky too me
1851  // :( ), `Qualifiers` of the pointee type is returned separately through the
1852  // output parameter `QualifiersToAppend`.
1853  static std::optional<std::string>
getPointeeTypeText(const VarDecl * VD,const SourceManager & SM,const LangOptions & LangOpts,std::optional<Qualifiers> * QualifiersToAppend)1854  getPointeeTypeText(const VarDecl *VD, const SourceManager &SM,
1855                     const LangOptions &LangOpts,
1856                     std::optional<Qualifiers> *QualifiersToAppend) {
1857    QualType Ty = VD->getType();
1858    QualType PteTy;
1859  
1860    assert(Ty->isPointerType() && !Ty->isFunctionPointerType() &&
1861           "Expecting a VarDecl of type of pointer to object type");
1862    PteTy = Ty->getPointeeType();
1863  
1864    TypeLoc TyLoc = VD->getTypeSourceInfo()->getTypeLoc().getUnqualifiedLoc();
1865    TypeLoc PteTyLoc;
1866  
1867    // We only deal with the cases that we know `TypeLoc::getNextTypeLoc` returns
1868    // the `TypeLoc` of the pointee type:
1869    switch (TyLoc.getTypeLocClass()) {
1870    case TypeLoc::ConstantArray:
1871    case TypeLoc::IncompleteArray:
1872    case TypeLoc::VariableArray:
1873    case TypeLoc::DependentSizedArray:
1874    case TypeLoc::Decayed:
1875      assert(isa<ParmVarDecl>(VD) && "An array type shall not be treated as a "
1876                                     "pointer type unless it decays.");
1877      PteTyLoc = TyLoc.getNextTypeLoc();
1878      break;
1879    case TypeLoc::Pointer:
1880      PteTyLoc = TyLoc.castAs<PointerTypeLoc>().getPointeeLoc();
1881      break;
1882    default:
1883      return std::nullopt;
1884    }
1885    if (PteTyLoc.isNull())
1886      // Sometimes we cannot get a useful `TypeLoc` for the pointee type, e.g.,
1887      // when the pointer type is `auto`.
1888      return std::nullopt;
1889  
1890    SourceLocation IdentLoc = getVarDeclIdentifierLoc(VD);
1891  
1892    if (!(IdentLoc.isValid() && PteTyLoc.getSourceRange().isValid())) {
1893      // We are expecting these locations to be valid. But in some cases, they are
1894      // not all valid. It is a Clang bug to me and we are not responsible for
1895      // fixing it.  So we will just give up for now when it happens.
1896      return std::nullopt;
1897    }
1898  
1899    // Note that TypeLoc.getEndLoc() returns the begin location of the last token:
1900    SourceLocation PteEndOfTokenLoc =
1901        Lexer::getLocForEndOfToken(PteTyLoc.getEndLoc(), 0, SM, LangOpts);
1902  
1903    if (!PteEndOfTokenLoc.isValid())
1904      // Sometimes we cannot get the end location of the pointee type, e.g., when
1905      // there are macros involved.
1906      return std::nullopt;
1907    if (!SM.isBeforeInTranslationUnit(PteEndOfTokenLoc, IdentLoc)) {
1908      // We only deal with the cases where the source text of the pointee type
1909      // appears on the left-hand side of the variable identifier completely,
1910      // including the following forms:
1911      // `T ident`,
1912      // `T ident[]`, where `T` is any type.
1913      // Examples of excluded cases are `T (*ident)[]` or `T ident[][n]`.
1914      return std::nullopt;
1915    }
1916    if (PteTy.hasQualifiers()) {
1917      // TypeLoc does not provide source ranges for qualifiers (it says it's
1918      // intentional but seems fishy to me), so we cannot get the full text
1919      // `PteTy` via source ranges.
1920      *QualifiersToAppend = PteTy.getQualifiers();
1921    }
1922    return getRangeText({PteTyLoc.getBeginLoc(), PteEndOfTokenLoc}, SM, LangOpts)
1923        ->str();
1924  }
1925  
1926  // Returns the text of the name (with qualifiers) of a `FunctionDecl`.
getFunNameText(const FunctionDecl * FD,const SourceManager & SM,const LangOptions & LangOpts)1927  static std::optional<StringRef> getFunNameText(const FunctionDecl *FD,
1928                                                 const SourceManager &SM,
1929                                                 const LangOptions &LangOpts) {
1930    SourceLocation BeginLoc = FD->getQualifier()
1931                                  ? FD->getQualifierLoc().getBeginLoc()
1932                                  : FD->getNameInfo().getBeginLoc();
1933    // Note that `FD->getNameInfo().getEndLoc()` returns the begin location of the
1934    // last token:
1935    SourceLocation EndLoc = Lexer::getLocForEndOfToken(
1936        FD->getNameInfo().getEndLoc(), 0, SM, LangOpts);
1937    SourceRange NameRange{BeginLoc, EndLoc};
1938  
1939    return getRangeText(NameRange, SM, LangOpts);
1940  }
1941  
1942  // Returns the text representing a `std::span` type where the element type is
1943  // represented by `EltTyText`.
1944  //
1945  // Note the optional parameter `Qualifiers`: one needs to pass qualifiers
1946  // explicitly if the element type needs to be qualified.
1947  static std::string
getSpanTypeText(StringRef EltTyText,std::optional<Qualifiers> Quals=std::nullopt)1948  getSpanTypeText(StringRef EltTyText,
1949                  std::optional<Qualifiers> Quals = std::nullopt) {
1950    const char *const SpanOpen = "std::span<";
1951  
1952    if (Quals)
1953      return SpanOpen + EltTyText.str() + ' ' + Quals->getAsString() + '>';
1954    return SpanOpen + EltTyText.str() + '>';
1955  }
1956  
1957  std::optional<FixItList>
getFixits(const FixitStrategy & s) const1958  DerefSimplePtrArithFixableGadget::getFixits(const FixitStrategy &s) const {
1959    const VarDecl *VD = dyn_cast<VarDecl>(BaseDeclRefExpr->getDecl());
1960  
1961    if (VD && s.lookup(VD) == FixitStrategy::Kind::Span) {
1962      ASTContext &Ctx = VD->getASTContext();
1963      // std::span can't represent elements before its begin()
1964      if (auto ConstVal = Offset->getIntegerConstantExpr(Ctx))
1965        if (ConstVal->isNegative())
1966          return std::nullopt;
1967  
1968      // note that the expr may (oddly) has multiple layers of parens
1969      // example:
1970      //   *((..(pointer + 123)..))
1971      // goal:
1972      //   pointer[123]
1973      // Fix-It:
1974      //   remove '*('
1975      //   replace ' + ' with '['
1976      //   replace ')' with ']'
1977  
1978      // example:
1979      //   *((..(123 + pointer)..))
1980      // goal:
1981      //   123[pointer]
1982      // Fix-It:
1983      //   remove '*('
1984      //   replace ' + ' with '['
1985      //   replace ')' with ']'
1986  
1987      const Expr *LHS = AddOp->getLHS(), *RHS = AddOp->getRHS();
1988      const SourceManager &SM = Ctx.getSourceManager();
1989      const LangOptions &LangOpts = Ctx.getLangOpts();
1990      CharSourceRange StarWithTrailWhitespace =
1991          clang::CharSourceRange::getCharRange(DerefOp->getOperatorLoc(),
1992                                               LHS->getBeginLoc());
1993  
1994      std::optional<SourceLocation> LHSLocation = getPastLoc(LHS, SM, LangOpts);
1995      if (!LHSLocation)
1996        return std::nullopt;
1997  
1998      CharSourceRange PlusWithSurroundingWhitespace =
1999          clang::CharSourceRange::getCharRange(*LHSLocation, RHS->getBeginLoc());
2000  
2001      std::optional<SourceLocation> AddOpLocation =
2002          getPastLoc(AddOp, SM, LangOpts);
2003      std::optional<SourceLocation> DerefOpLocation =
2004          getPastLoc(DerefOp, SM, LangOpts);
2005  
2006      if (!AddOpLocation || !DerefOpLocation)
2007        return std::nullopt;
2008  
2009      CharSourceRange ClosingParenWithPrecWhitespace =
2010          clang::CharSourceRange::getCharRange(*AddOpLocation, *DerefOpLocation);
2011  
2012      return FixItList{
2013          {FixItHint::CreateRemoval(StarWithTrailWhitespace),
2014           FixItHint::CreateReplacement(PlusWithSurroundingWhitespace, "["),
2015           FixItHint::CreateReplacement(ClosingParenWithPrecWhitespace, "]")}};
2016    }
2017    return std::nullopt; // something wrong or unsupported, give up
2018  }
2019  
2020  std::optional<FixItList>
getFixits(const FixitStrategy & S) const2021  PointerDereferenceGadget::getFixits(const FixitStrategy &S) const {
2022    const VarDecl *VD = cast<VarDecl>(BaseDeclRefExpr->getDecl());
2023    switch (S.lookup(VD)) {
2024    case FixitStrategy::Kind::Span: {
2025      ASTContext &Ctx = VD->getASTContext();
2026      SourceManager &SM = Ctx.getSourceManager();
2027      // Required changes: *(ptr); => (ptr[0]); and *ptr; => ptr[0]
2028      // Deletes the *operand
2029      CharSourceRange derefRange = clang::CharSourceRange::getCharRange(
2030          Op->getBeginLoc(), Op->getBeginLoc().getLocWithOffset(1));
2031      // Inserts the [0]
2032      if (auto LocPastOperand =
2033              getPastLoc(BaseDeclRefExpr, SM, Ctx.getLangOpts())) {
2034        return FixItList{{FixItHint::CreateRemoval(derefRange),
2035                          FixItHint::CreateInsertion(*LocPastOperand, "[0]")}};
2036      }
2037      break;
2038    }
2039    case FixitStrategy::Kind::Iterator:
2040    case FixitStrategy::Kind::Array:
2041      return std::nullopt;
2042    case FixitStrategy::Kind::Vector:
2043      llvm_unreachable("FixitStrategy not implemented yet!");
2044    case FixitStrategy::Kind::Wontfix:
2045      llvm_unreachable("Invalid strategy!");
2046    }
2047  
2048    return std::nullopt;
2049  }
2050  
createDataFixit(const ASTContext & Ctx,const DeclRefExpr * DRE)2051  static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx,
2052                                                         const DeclRefExpr *DRE) {
2053    const SourceManager &SM = Ctx.getSourceManager();
2054    // Inserts the .data() after the DRE
2055    std::optional<SourceLocation> EndOfOperand =
2056        getPastLoc(DRE, SM, Ctx.getLangOpts());
2057  
2058    if (EndOfOperand)
2059      return FixItList{{FixItHint::CreateInsertion(*EndOfOperand, ".data()")}};
2060  
2061    return std::nullopt;
2062  }
2063  
2064  // Generates fix-its replacing an expression of the form UPC(DRE) with
2065  // `DRE.data()`
2066  std::optional<FixItList>
getFixits(const FixitStrategy & S) const2067  UPCStandalonePointerGadget::getFixits(const FixitStrategy &S) const {
2068    const auto VD = cast<VarDecl>(Node->getDecl());
2069    switch (S.lookup(VD)) {
2070    case FixitStrategy::Kind::Array:
2071    case FixitStrategy::Kind::Span: {
2072      return createDataFixit(VD->getASTContext(), Node);
2073      // FIXME: Points inside a macro expansion.
2074      break;
2075    }
2076    case FixitStrategy::Kind::Wontfix:
2077    case FixitStrategy::Kind::Iterator:
2078      return std::nullopt;
2079    case FixitStrategy::Kind::Vector:
2080      llvm_unreachable("unsupported strategies for FixableGadgets");
2081    }
2082  
2083    return std::nullopt;
2084  }
2085  
2086  // Generates fix-its replacing an expression of the form `&DRE[e]` with
2087  // `&DRE.data()[e]`:
2088  static std::optional<FixItList>
fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator * Node)2089  fixUPCAddressofArraySubscriptWithSpan(const UnaryOperator *Node) {
2090    const auto *ArraySub = cast<ArraySubscriptExpr>(Node->getSubExpr());
2091    const auto *DRE = cast<DeclRefExpr>(ArraySub->getBase()->IgnoreImpCasts());
2092    // FIXME: this `getASTContext` call is costly, we should pass the
2093    // ASTContext in:
2094    const ASTContext &Ctx = DRE->getDecl()->getASTContext();
2095    const Expr *Idx = ArraySub->getIdx();
2096    const SourceManager &SM = Ctx.getSourceManager();
2097    const LangOptions &LangOpts = Ctx.getLangOpts();
2098    std::stringstream SS;
2099    bool IdxIsLitZero = false;
2100  
2101    if (auto ICE = Idx->getIntegerConstantExpr(Ctx))
2102      if ((*ICE).isZero())
2103        IdxIsLitZero = true;
2104    std::optional<StringRef> DreString = getExprText(DRE, SM, LangOpts);
2105    if (!DreString)
2106      return std::nullopt;
2107  
2108    if (IdxIsLitZero) {
2109      // If the index is literal zero, we produce the most concise fix-it:
2110      SS << (*DreString).str() << ".data()";
2111    } else {
2112      std::optional<StringRef> IndexString = getExprText(Idx, SM, LangOpts);
2113      if (!IndexString)
2114        return std::nullopt;
2115  
2116      SS << "&" << (*DreString).str() << ".data()"
2117         << "[" << (*IndexString).str() << "]";
2118    }
2119    return FixItList{
2120        FixItHint::CreateReplacement(Node->getSourceRange(), SS.str())};
2121  }
2122  
2123  std::optional<FixItList>
getFixits(const FixitStrategy & S) const2124  UUCAddAssignGadget::getFixits(const FixitStrategy &S) const {
2125    DeclUseList DREs = getClaimedVarUseSites();
2126  
2127    if (DREs.size() != 1)
2128      return std::nullopt; // In cases of `Ptr += n` where `Ptr` is not a DRE, we
2129                           // give up
2130    if (const VarDecl *VD = dyn_cast<VarDecl>(DREs.front()->getDecl())) {
2131      if (S.lookup(VD) == FixitStrategy::Kind::Span) {
2132        FixItList Fixes;
2133  
2134        const Stmt *AddAssignNode = Node;
2135        StringRef varName = VD->getName();
2136        const ASTContext &Ctx = VD->getASTContext();
2137  
2138        if (!isNonNegativeIntegerExpr(Offset, VD, Ctx))
2139          return std::nullopt;
2140  
2141        // To transform UUC(p += n) to UUC(p = p.subspan(..)):
2142        bool NotParenExpr =
2143            (Offset->IgnoreParens()->getBeginLoc() == Offset->getBeginLoc());
2144        std::string SS = varName.str() + " = " + varName.str() + ".subspan";
2145        if (NotParenExpr)
2146          SS += "(";
2147  
2148        std::optional<SourceLocation> AddAssignLocation = getEndCharLoc(
2149            AddAssignNode, Ctx.getSourceManager(), Ctx.getLangOpts());
2150        if (!AddAssignLocation)
2151          return std::nullopt;
2152  
2153        Fixes.push_back(FixItHint::CreateReplacement(
2154            SourceRange(AddAssignNode->getBeginLoc(), Node->getOperatorLoc()),
2155            SS));
2156        if (NotParenExpr)
2157          Fixes.push_back(FixItHint::CreateInsertion(
2158              Offset->getEndLoc().getLocWithOffset(1), ")"));
2159        return Fixes;
2160      }
2161    }
2162    return std::nullopt; // Not in the cases that we can handle for now, give up.
2163  }
2164  
2165  std::optional<FixItList>
getFixits(const FixitStrategy & S) const2166  UPCPreIncrementGadget::getFixits(const FixitStrategy &S) const {
2167    DeclUseList DREs = getClaimedVarUseSites();
2168  
2169    if (DREs.size() != 1)
2170      return std::nullopt; // In cases of `++Ptr` where `Ptr` is not a DRE, we
2171                           // give up
2172    if (const VarDecl *VD = dyn_cast<VarDecl>(DREs.front()->getDecl())) {
2173      if (S.lookup(VD) == FixitStrategy::Kind::Span) {
2174        FixItList Fixes;
2175        std::stringstream SS;
2176        StringRef varName = VD->getName();
2177        const ASTContext &Ctx = VD->getASTContext();
2178  
2179        // To transform UPC(++p) to UPC((p = p.subspan(1)).data()):
2180        SS << "(" << varName.data() << " = " << varName.data()
2181           << ".subspan(1)).data()";
2182        std::optional<SourceLocation> PreIncLocation =
2183            getEndCharLoc(Node, Ctx.getSourceManager(), Ctx.getLangOpts());
2184        if (!PreIncLocation)
2185          return std::nullopt;
2186  
2187        Fixes.push_back(FixItHint::CreateReplacement(
2188            SourceRange(Node->getBeginLoc(), *PreIncLocation), SS.str()));
2189        return Fixes;
2190      }
2191    }
2192    return std::nullopt; // Not in the cases that we can handle for now, give up.
2193  }
2194  
2195  // For a non-null initializer `Init` of `T *` type, this function returns
2196  // `FixItHint`s producing a list initializer `{Init,  S}` as a part of a fix-it
2197  // to output stream.
2198  // In many cases, this function cannot figure out the actual extent `S`.  It
2199  // then will use a place holder to replace `S` to ask users to fill `S` in.  The
2200  // initializer shall be used to initialize a variable of type `std::span<T>`.
2201  // In some cases (e. g. constant size array) the initializer should remain
2202  // unchanged and the function returns empty list. In case the function can't
2203  // provide the right fixit it will return nullopt.
2204  //
2205  // FIXME: Support multi-level pointers
2206  //
2207  // Parameters:
2208  //   `Init` a pointer to the initializer expression
2209  //   `Ctx` a reference to the ASTContext
2210  static std::optional<FixItList>
FixVarInitializerWithSpan(const Expr * Init,ASTContext & Ctx,const StringRef UserFillPlaceHolder)2211  FixVarInitializerWithSpan(const Expr *Init, ASTContext &Ctx,
2212                            const StringRef UserFillPlaceHolder) {
2213    const SourceManager &SM = Ctx.getSourceManager();
2214    const LangOptions &LangOpts = Ctx.getLangOpts();
2215  
2216    // If `Init` has a constant value that is (or equivalent to) a
2217    // NULL pointer, we use the default constructor to initialize the span
2218    // object, i.e., a `std:span` variable declaration with no initializer.
2219    // So the fix-it is just to remove the initializer.
2220    if (Init->isNullPointerConstant(
2221            Ctx,
2222            // FIXME: Why does this function not ask for `const ASTContext
2223            // &`? It should. Maybe worth an NFC patch later.
2224            Expr::NullPointerConstantValueDependence::
2225                NPC_ValueDependentIsNotNull)) {
2226      std::optional<SourceLocation> InitLocation =
2227          getEndCharLoc(Init, SM, LangOpts);
2228      if (!InitLocation)
2229        return std::nullopt;
2230  
2231      SourceRange SR(Init->getBeginLoc(), *InitLocation);
2232  
2233      return FixItList{FixItHint::CreateRemoval(SR)};
2234    }
2235  
2236    FixItList FixIts{};
2237    std::string ExtentText = UserFillPlaceHolder.data();
2238    StringRef One = "1";
2239  
2240    // Insert `{` before `Init`:
2241    FixIts.push_back(FixItHint::CreateInsertion(Init->getBeginLoc(), "{"));
2242    // Try to get the data extent. Break into different cases:
2243    if (auto CxxNew = dyn_cast<CXXNewExpr>(Init->IgnoreImpCasts())) {
2244      // In cases `Init` is `new T[n]` and there is no explicit cast over
2245      // `Init`, we know that `Init` must evaluates to a pointer to `n` objects
2246      // of `T`. So the extent is `n` unless `n` has side effects.  Similar but
2247      // simpler for the case where `Init` is `new T`.
2248      if (const Expr *Ext = CxxNew->getArraySize().value_or(nullptr)) {
2249        if (!Ext->HasSideEffects(Ctx)) {
2250          std::optional<StringRef> ExtentString = getExprText(Ext, SM, LangOpts);
2251          if (!ExtentString)
2252            return std::nullopt;
2253          ExtentText = *ExtentString;
2254        }
2255      } else if (!CxxNew->isArray())
2256        // Although the initializer is not allocating a buffer, the pointer
2257        // variable could still be used in buffer access operations.
2258        ExtentText = One;
2259    } else if (Ctx.getAsConstantArrayType(Init->IgnoreImpCasts()->getType())) {
2260      // std::span has a single parameter constructor for initialization with
2261      // constant size array. The size is auto-deduced as the constructor is a
2262      // function template. The correct fixit is empty - no changes should happen.
2263      return FixItList{};
2264    } else {
2265      // In cases `Init` is of the form `&Var` after stripping of implicit
2266      // casts, where `&` is the built-in operator, the extent is 1.
2267      if (auto AddrOfExpr = dyn_cast<UnaryOperator>(Init->IgnoreImpCasts()))
2268        if (AddrOfExpr->getOpcode() == UnaryOperatorKind::UO_AddrOf &&
2269            isa_and_present<DeclRefExpr>(AddrOfExpr->getSubExpr()))
2270          ExtentText = One;
2271      // TODO: we can handle more cases, e.g., `&a[0]`, `&a`, `std::addressof`,
2272      // and explicit casting, etc. etc.
2273    }
2274  
2275    SmallString<32> StrBuffer{};
2276    std::optional<SourceLocation> LocPassInit = getPastLoc(Init, SM, LangOpts);
2277  
2278    if (!LocPassInit)
2279      return std::nullopt;
2280  
2281    StrBuffer.append(", ");
2282    StrBuffer.append(ExtentText);
2283    StrBuffer.append("}");
2284    FixIts.push_back(FixItHint::CreateInsertion(*LocPassInit, StrBuffer.str()));
2285    return FixIts;
2286  }
2287  
2288  #ifndef NDEBUG
2289  #define DEBUG_NOTE_DECL_FAIL(D, Msg)                                           \
2290    Handler.addDebugNoteForVar((D), (D)->getBeginLoc(),                          \
2291                               "failed to produce fixit for declaration '" +     \
2292                                   (D)->getNameAsString() + "'" + (Msg))
2293  #else
2294  #define DEBUG_NOTE_DECL_FAIL(D, Msg)
2295  #endif
2296  
2297  // For the given variable declaration with a pointer-to-T type, returns the text
2298  // `std::span<T>`.  If it is unable to generate the text, returns
2299  // `std::nullopt`.
2300  static std::optional<std::string>
createSpanTypeForVarDecl(const VarDecl * VD,const ASTContext & Ctx)2301  createSpanTypeForVarDecl(const VarDecl *VD, const ASTContext &Ctx) {
2302    assert(VD->getType()->isPointerType());
2303  
2304    std::optional<Qualifiers> PteTyQualifiers = std::nullopt;
2305    std::optional<std::string> PteTyText = getPointeeTypeText(
2306        VD, Ctx.getSourceManager(), Ctx.getLangOpts(), &PteTyQualifiers);
2307  
2308    if (!PteTyText)
2309      return std::nullopt;
2310  
2311    std::string SpanTyText = "std::span<";
2312  
2313    SpanTyText.append(*PteTyText);
2314    // Append qualifiers to span element type if any:
2315    if (PteTyQualifiers) {
2316      SpanTyText.append(" ");
2317      SpanTyText.append(PteTyQualifiers->getAsString());
2318    }
2319    SpanTyText.append(">");
2320    return SpanTyText;
2321  }
2322  
2323  // For a `VarDecl` of the form `T  * var (= Init)?`, this
2324  // function generates fix-its that
2325  //  1) replace `T * var` with `std::span<T> var`; and
2326  //  2) change `Init` accordingly to a span constructor, if it exists.
2327  //
2328  // FIXME: support Multi-level pointers
2329  //
2330  // Parameters:
2331  //   `D` a pointer the variable declaration node
2332  //   `Ctx` a reference to the ASTContext
2333  //   `UserFillPlaceHolder` the user-input placeholder text
2334  // Returns:
2335  //    the non-empty fix-it list, if fix-its are successfuly generated; empty
2336  //    list otherwise.
fixLocalVarDeclWithSpan(const VarDecl * D,ASTContext & Ctx,const StringRef UserFillPlaceHolder,UnsafeBufferUsageHandler & Handler)2337  static FixItList fixLocalVarDeclWithSpan(const VarDecl *D, ASTContext &Ctx,
2338                                           const StringRef UserFillPlaceHolder,
2339                                           UnsafeBufferUsageHandler &Handler) {
2340    if (hasUnsupportedSpecifiers(D, Ctx.getSourceManager()))
2341      return {};
2342  
2343    FixItList FixIts{};
2344    std::optional<std::string> SpanTyText = createSpanTypeForVarDecl(D, Ctx);
2345  
2346    if (!SpanTyText) {
2347      DEBUG_NOTE_DECL_FAIL(D, " : failed to generate 'std::span' type");
2348      return {};
2349    }
2350  
2351    // Will hold the text for `std::span<T> Ident`:
2352    std::stringstream SS;
2353  
2354    SS << *SpanTyText;
2355    // Fix the initializer if it exists:
2356    if (const Expr *Init = D->getInit()) {
2357      std::optional<FixItList> InitFixIts =
2358          FixVarInitializerWithSpan(Init, Ctx, UserFillPlaceHolder);
2359      if (!InitFixIts)
2360        return {};
2361      FixIts.insert(FixIts.end(), std::make_move_iterator(InitFixIts->begin()),
2362                    std::make_move_iterator(InitFixIts->end()));
2363    }
2364    // For declaration of the form `T * ident = init;`, we want to replace
2365    // `T * ` with `std::span<T>`.
2366    // We ignore CV-qualifiers so for `T * const ident;` we also want to replace
2367    // just `T *` with `std::span<T>`.
2368    const SourceLocation EndLocForReplacement = D->getTypeSpecEndLoc();
2369    if (!EndLocForReplacement.isValid()) {
2370      DEBUG_NOTE_DECL_FAIL(D, " : failed to locate the end of the declaration");
2371      return {};
2372    }
2373    // The only exception is that for `T *ident` we'll add a single space between
2374    // "std::span<T>" and "ident".
2375    // FIXME: The condition is false for identifiers expended from macros.
2376    if (EndLocForReplacement.getLocWithOffset(1) == getVarDeclIdentifierLoc(D))
2377      SS << " ";
2378  
2379    FixIts.push_back(FixItHint::CreateReplacement(
2380        SourceRange(D->getBeginLoc(), EndLocForReplacement), SS.str()));
2381    return FixIts;
2382  }
2383  
hasConflictingOverload(const FunctionDecl * FD)2384  static bool hasConflictingOverload(const FunctionDecl *FD) {
2385    return !FD->getDeclContext()->lookup(FD->getDeclName()).isSingleResult();
2386  }
2387  
2388  // For a `FunctionDecl`, whose `ParmVarDecl`s are being changed to have new
2389  // types, this function produces fix-its to make the change self-contained.  Let
2390  // 'F' be the entity defined by the original `FunctionDecl` and "NewF" be the
2391  // entity defined by the `FunctionDecl` after the change to the parameters.
2392  // Fix-its produced by this function are
2393  //   1. Add the `[[clang::unsafe_buffer_usage]]` attribute to each declaration
2394  //   of 'F';
2395  //   2. Create a declaration of "NewF" next to each declaration of `F`;
2396  //   3. Create a definition of "F" (as its' original definition is now belongs
2397  //      to "NewF") next to its original definition.  The body of the creating
2398  //      definition calls to "NewF".
2399  //
2400  // Example:
2401  //
2402  // void f(int *p);  // original declaration
2403  // void f(int *p) { // original definition
2404  //    p[5];
2405  // }
2406  //
2407  // To change the parameter `p` to be of `std::span<int>` type, we
2408  // also add overloads:
2409  //
2410  // [[clang::unsafe_buffer_usage]] void f(int *p); // original decl
2411  // void f(std::span<int> p);                      // added overload decl
2412  // void f(std::span<int> p) {     // original def where param is changed
2413  //    p[5];
2414  // }
2415  // [[clang::unsafe_buffer_usage]] void f(int *p) {  // added def
2416  //   return f(std::span(p, <# size #>));
2417  // }
2418  //
2419  static std::optional<FixItList>
createOverloadsForFixedParams(const FixitStrategy & S,const FunctionDecl * FD,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2420  createOverloadsForFixedParams(const FixitStrategy &S, const FunctionDecl *FD,
2421                                const ASTContext &Ctx,
2422                                UnsafeBufferUsageHandler &Handler) {
2423    // FIXME: need to make this conflict checking better:
2424    if (hasConflictingOverload(FD))
2425      return std::nullopt;
2426  
2427    const SourceManager &SM = Ctx.getSourceManager();
2428    const LangOptions &LangOpts = Ctx.getLangOpts();
2429    const unsigned NumParms = FD->getNumParams();
2430    std::vector<std::string> NewTysTexts(NumParms);
2431    std::vector<bool> ParmsMask(NumParms, false);
2432    bool AtLeastOneParmToFix = false;
2433  
2434    for (unsigned i = 0; i < NumParms; i++) {
2435      const ParmVarDecl *PVD = FD->getParamDecl(i);
2436  
2437      if (S.lookup(PVD) == FixitStrategy::Kind::Wontfix)
2438        continue;
2439      if (S.lookup(PVD) != FixitStrategy::Kind::Span)
2440        // Not supported, not suppose to happen:
2441        return std::nullopt;
2442  
2443      std::optional<Qualifiers> PteTyQuals = std::nullopt;
2444      std::optional<std::string> PteTyText =
2445          getPointeeTypeText(PVD, SM, LangOpts, &PteTyQuals);
2446  
2447      if (!PteTyText)
2448        // something wrong in obtaining the text of the pointee type, give up
2449        return std::nullopt;
2450      // FIXME: whether we should create std::span type depends on the
2451      // FixitStrategy.
2452      NewTysTexts[i] = getSpanTypeText(*PteTyText, PteTyQuals);
2453      ParmsMask[i] = true;
2454      AtLeastOneParmToFix = true;
2455    }
2456    if (!AtLeastOneParmToFix)
2457      // No need to create function overloads:
2458      return {};
2459    // FIXME Respect indentation of the original code.
2460  
2461    // A lambda that creates the text representation of a function declaration
2462    // with the new type signatures:
2463    const auto NewOverloadSignatureCreator =
2464        [&SM, &LangOpts, &NewTysTexts,
2465         &ParmsMask](const FunctionDecl *FD) -> std::optional<std::string> {
2466      std::stringstream SS;
2467  
2468      SS << ";";
2469      SS << getEndOfLine().str();
2470      // Append: ret-type func-name "("
2471      if (auto Prefix = getRangeText(
2472              SourceRange(FD->getBeginLoc(), (*FD->param_begin())->getBeginLoc()),
2473              SM, LangOpts))
2474        SS << Prefix->str();
2475      else
2476        return std::nullopt; // give up
2477      // Append: parameter-type-list
2478      const unsigned NumParms = FD->getNumParams();
2479  
2480      for (unsigned i = 0; i < NumParms; i++) {
2481        const ParmVarDecl *Parm = FD->getParamDecl(i);
2482  
2483        if (Parm->isImplicit())
2484          continue;
2485        if (ParmsMask[i]) {
2486          // This `i`-th parameter will be fixed with `NewTysTexts[i]` being its
2487          // new type:
2488          SS << NewTysTexts[i];
2489          // print parameter name if provided:
2490          if (IdentifierInfo *II = Parm->getIdentifier())
2491            SS << ' ' << II->getName().str();
2492        } else if (auto ParmTypeText =
2493                       getRangeText(getSourceRangeToTokenEnd(Parm, SM, LangOpts),
2494                                    SM, LangOpts)) {
2495          // print the whole `Parm` without modification:
2496          SS << ParmTypeText->str();
2497        } else
2498          return std::nullopt; // something wrong, give up
2499        if (i != NumParms - 1)
2500          SS << ", ";
2501      }
2502      SS << ")";
2503      return SS.str();
2504    };
2505  
2506    // A lambda that creates the text representation of a function definition with
2507    // the original signature:
2508    const auto OldOverloadDefCreator =
2509        [&Handler, &SM, &LangOpts, &NewTysTexts,
2510         &ParmsMask](const FunctionDecl *FD) -> std::optional<std::string> {
2511      std::stringstream SS;
2512  
2513      SS << getEndOfLine().str();
2514      // Append: attr-name ret-type func-name "(" param-list ")" "{"
2515      if (auto FDPrefix = getRangeText(
2516              SourceRange(FD->getBeginLoc(), FD->getBody()->getBeginLoc()), SM,
2517              LangOpts))
2518        SS << Handler.getUnsafeBufferUsageAttributeTextAt(FD->getBeginLoc(), " ")
2519           << FDPrefix->str() << "{";
2520      else
2521        return std::nullopt;
2522      // Append: "return" func-name "("
2523      if (auto FunQualName = getFunNameText(FD, SM, LangOpts))
2524        SS << "return " << FunQualName->str() << "(";
2525      else
2526        return std::nullopt;
2527  
2528      // Append: arg-list
2529      const unsigned NumParms = FD->getNumParams();
2530      for (unsigned i = 0; i < NumParms; i++) {
2531        const ParmVarDecl *Parm = FD->getParamDecl(i);
2532  
2533        if (Parm->isImplicit())
2534          continue;
2535        // FIXME: If a parameter has no name, it is unused in the
2536        // definition. So we could just leave it as it is.
2537        if (!Parm->getIdentifier())
2538          // If a parameter of a function definition has no name:
2539          return std::nullopt;
2540        if (ParmsMask[i])
2541          // This is our spanified paramter!
2542          SS << NewTysTexts[i] << "(" << Parm->getIdentifier()->getName().str()
2543             << ", " << getUserFillPlaceHolder("size") << ")";
2544        else
2545          SS << Parm->getIdentifier()->getName().str();
2546        if (i != NumParms - 1)
2547          SS << ", ";
2548      }
2549      // finish call and the body
2550      SS << ");}" << getEndOfLine().str();
2551      // FIXME: 80-char line formatting?
2552      return SS.str();
2553    };
2554  
2555    FixItList FixIts{};
2556    for (FunctionDecl *FReDecl : FD->redecls()) {
2557      std::optional<SourceLocation> Loc = getPastLoc(FReDecl, SM, LangOpts);
2558  
2559      if (!Loc)
2560        return {};
2561      if (FReDecl->isThisDeclarationADefinition()) {
2562        assert(FReDecl == FD && "inconsistent function definition");
2563        // Inserts a definition with the old signature to the end of
2564        // `FReDecl`:
2565        if (auto OldOverloadDef = OldOverloadDefCreator(FReDecl))
2566          FixIts.emplace_back(FixItHint::CreateInsertion(*Loc, *OldOverloadDef));
2567        else
2568          return {}; // give up
2569      } else {
2570        // Adds the unsafe-buffer attribute (if not already there) to `FReDecl`:
2571        if (!FReDecl->hasAttr<UnsafeBufferUsageAttr>()) {
2572          FixIts.emplace_back(FixItHint::CreateInsertion(
2573              FReDecl->getBeginLoc(), Handler.getUnsafeBufferUsageAttributeTextAt(
2574                                          FReDecl->getBeginLoc(), " ")));
2575        }
2576        // Inserts a declaration with the new signature to the end of `FReDecl`:
2577        if (auto NewOverloadDecl = NewOverloadSignatureCreator(FReDecl))
2578          FixIts.emplace_back(FixItHint::CreateInsertion(*Loc, *NewOverloadDecl));
2579        else
2580          return {};
2581      }
2582    }
2583    return FixIts;
2584  }
2585  
2586  // To fix a `ParmVarDecl` to be of `std::span` type.
fixParamWithSpan(const ParmVarDecl * PVD,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2587  static FixItList fixParamWithSpan(const ParmVarDecl *PVD, const ASTContext &Ctx,
2588                                    UnsafeBufferUsageHandler &Handler) {
2589    if (hasUnsupportedSpecifiers(PVD, Ctx.getSourceManager())) {
2590      DEBUG_NOTE_DECL_FAIL(PVD, " : has unsupport specifier(s)");
2591      return {};
2592    }
2593    if (PVD->hasDefaultArg()) {
2594      // FIXME: generate fix-its for default values:
2595      DEBUG_NOTE_DECL_FAIL(PVD, " : has default arg");
2596      return {};
2597    }
2598  
2599    std::optional<Qualifiers> PteTyQualifiers = std::nullopt;
2600    std::optional<std::string> PteTyText = getPointeeTypeText(
2601        PVD, Ctx.getSourceManager(), Ctx.getLangOpts(), &PteTyQualifiers);
2602  
2603    if (!PteTyText) {
2604      DEBUG_NOTE_DECL_FAIL(PVD, " : invalid pointee type");
2605      return {};
2606    }
2607  
2608    std::optional<StringRef> PVDNameText = PVD->getIdentifier()->getName();
2609  
2610    if (!PVDNameText) {
2611      DEBUG_NOTE_DECL_FAIL(PVD, " : invalid identifier name");
2612      return {};
2613    }
2614  
2615    std::stringstream SS;
2616    std::optional<std::string> SpanTyText = createSpanTypeForVarDecl(PVD, Ctx);
2617  
2618    if (PteTyQualifiers)
2619      // Append qualifiers if they exist:
2620      SS << getSpanTypeText(*PteTyText, PteTyQualifiers);
2621    else
2622      SS << getSpanTypeText(*PteTyText);
2623    // Append qualifiers to the type of the parameter:
2624    if (PVD->getType().hasQualifiers())
2625      SS << ' ' << PVD->getType().getQualifiers().getAsString();
2626    // Append parameter's name:
2627    SS << ' ' << PVDNameText->str();
2628    // Add replacement fix-it:
2629    return {FixItHint::CreateReplacement(PVD->getSourceRange(), SS.str())};
2630  }
2631  
fixVariableWithSpan(const VarDecl * VD,const DeclUseTracker & Tracker,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2632  static FixItList fixVariableWithSpan(const VarDecl *VD,
2633                                       const DeclUseTracker &Tracker,
2634                                       ASTContext &Ctx,
2635                                       UnsafeBufferUsageHandler &Handler) {
2636    const DeclStmt *DS = Tracker.lookupDecl(VD);
2637    if (!DS) {
2638      DEBUG_NOTE_DECL_FAIL(VD,
2639                           " : variables declared this way not implemented yet");
2640      return {};
2641    }
2642    if (!DS->isSingleDecl()) {
2643      // FIXME: to support handling multiple `VarDecl`s in a single `DeclStmt`
2644      DEBUG_NOTE_DECL_FAIL(VD, " : multiple VarDecls");
2645      return {};
2646    }
2647    // Currently DS is an unused variable but we'll need it when
2648    // non-single decls are implemented, where the pointee type name
2649    // and the '*' are spread around the place.
2650    (void)DS;
2651  
2652    // FIXME: handle cases where DS has multiple declarations
2653    return fixLocalVarDeclWithSpan(VD, Ctx, getUserFillPlaceHolder(), Handler);
2654  }
2655  
fixVarDeclWithArray(const VarDecl * D,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2656  static FixItList fixVarDeclWithArray(const VarDecl *D, const ASTContext &Ctx,
2657                                       UnsafeBufferUsageHandler &Handler) {
2658    FixItList FixIts{};
2659  
2660    // Note: the code below expects the declaration to not use any type sugar like
2661    // typedef.
2662    if (auto CAT = dyn_cast<clang::ConstantArrayType>(D->getType())) {
2663      const QualType &ArrayEltT = CAT->getElementType();
2664      assert(!ArrayEltT.isNull() && "Trying to fix a non-array type variable!");
2665      // FIXME: support multi-dimensional arrays
2666      if (isa<clang::ArrayType>(ArrayEltT.getCanonicalType()))
2667        return {};
2668  
2669      const SourceLocation IdentifierLoc = getVarDeclIdentifierLoc(D);
2670  
2671      // Get the spelling of the element type as written in the source file
2672      // (including macros, etc.).
2673      auto MaybeElemTypeTxt =
2674          getRangeText({D->getBeginLoc(), IdentifierLoc}, Ctx.getSourceManager(),
2675                       Ctx.getLangOpts());
2676      if (!MaybeElemTypeTxt)
2677        return {};
2678      const llvm::StringRef ElemTypeTxt = MaybeElemTypeTxt->trim();
2679  
2680      // Find the '[' token.
2681      std::optional<Token> NextTok = Lexer::findNextToken(
2682          IdentifierLoc, Ctx.getSourceManager(), Ctx.getLangOpts());
2683      while (NextTok && !NextTok->is(tok::l_square) &&
2684             NextTok->getLocation() <= D->getSourceRange().getEnd())
2685        NextTok = Lexer::findNextToken(NextTok->getLocation(),
2686                                       Ctx.getSourceManager(), Ctx.getLangOpts());
2687      if (!NextTok)
2688        return {};
2689      const SourceLocation LSqBracketLoc = NextTok->getLocation();
2690  
2691      // Get the spelling of the array size as written in the source file
2692      // (including macros, etc.).
2693      auto MaybeArraySizeTxt = getRangeText(
2694          {LSqBracketLoc.getLocWithOffset(1), D->getTypeSpecEndLoc()},
2695          Ctx.getSourceManager(), Ctx.getLangOpts());
2696      if (!MaybeArraySizeTxt)
2697        return {};
2698      const llvm::StringRef ArraySizeTxt = MaybeArraySizeTxt->trim();
2699      if (ArraySizeTxt.empty()) {
2700        // FIXME: Support array size getting determined from the initializer.
2701        // Examples:
2702        //    int arr1[] = {0, 1, 2};
2703        //    int arr2{3, 4, 5};
2704        // We might be able to preserve the non-specified size with `auto` and
2705        // `std::to_array`:
2706        //    auto arr1 = std::to_array<int>({0, 1, 2});
2707        return {};
2708      }
2709  
2710      std::optional<StringRef> IdentText =
2711          getVarDeclIdentifierText(D, Ctx.getSourceManager(), Ctx.getLangOpts());
2712  
2713      if (!IdentText) {
2714        DEBUG_NOTE_DECL_FAIL(D, " : failed to locate the identifier");
2715        return {};
2716      }
2717  
2718      SmallString<32> Replacement;
2719      raw_svector_ostream OS(Replacement);
2720      OS << "std::array<" << ElemTypeTxt << ", " << ArraySizeTxt << "> "
2721         << IdentText->str();
2722  
2723      FixIts.push_back(FixItHint::CreateReplacement(
2724          SourceRange{D->getBeginLoc(), D->getTypeSpecEndLoc()}, OS.str()));
2725    }
2726  
2727    return FixIts;
2728  }
2729  
fixVariableWithArray(const VarDecl * VD,const DeclUseTracker & Tracker,const ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2730  static FixItList fixVariableWithArray(const VarDecl *VD,
2731                                        const DeclUseTracker &Tracker,
2732                                        const ASTContext &Ctx,
2733                                        UnsafeBufferUsageHandler &Handler) {
2734    const DeclStmt *DS = Tracker.lookupDecl(VD);
2735    assert(DS && "Fixing non-local variables not implemented yet!");
2736    if (!DS->isSingleDecl()) {
2737      // FIXME: to support handling multiple `VarDecl`s in a single `DeclStmt`
2738      return {};
2739    }
2740    // Currently DS is an unused variable but we'll need it when
2741    // non-single decls are implemented, where the pointee type name
2742    // and the '*' are spread around the place.
2743    (void)DS;
2744  
2745    // FIXME: handle cases where DS has multiple declarations
2746    return fixVarDeclWithArray(VD, Ctx, Handler);
2747  }
2748  
2749  // TODO: we should be consistent to use `std::nullopt` to represent no-fix due
2750  // to any unexpected problem.
2751  static FixItList
fixVariable(const VarDecl * VD,FixitStrategy::Kind K,const Decl * D,const DeclUseTracker & Tracker,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2752  fixVariable(const VarDecl *VD, FixitStrategy::Kind K,
2753              /* The function decl under analysis */ const Decl *D,
2754              const DeclUseTracker &Tracker, ASTContext &Ctx,
2755              UnsafeBufferUsageHandler &Handler) {
2756    if (const auto *PVD = dyn_cast<ParmVarDecl>(VD)) {
2757      auto *FD = dyn_cast<clang::FunctionDecl>(PVD->getDeclContext());
2758      if (!FD || FD != D) {
2759        // `FD != D` means that `PVD` belongs to a function that is not being
2760        // analyzed currently.  Thus `FD` may not be complete.
2761        DEBUG_NOTE_DECL_FAIL(VD, " : function not currently analyzed");
2762        return {};
2763      }
2764  
2765      // TODO If function has a try block we can't change params unless we check
2766      // also its catch block for their use.
2767      // FIXME We might support static class methods, some select methods,
2768      // operators and possibly lamdas.
2769      if (FD->isMain() || FD->isConstexpr() ||
2770          FD->getTemplatedKind() != FunctionDecl::TemplatedKind::TK_NonTemplate ||
2771          FD->isVariadic() ||
2772          // also covers call-operator of lamdas
2773          isa<CXXMethodDecl>(FD) ||
2774          // skip when the function body is a try-block
2775          (FD->hasBody() && isa<CXXTryStmt>(FD->getBody())) ||
2776          FD->isOverloadedOperator()) {
2777        DEBUG_NOTE_DECL_FAIL(VD, " : unsupported function decl");
2778        return {}; // TODO test all these cases
2779      }
2780    }
2781  
2782    switch (K) {
2783    case FixitStrategy::Kind::Span: {
2784      if (VD->getType()->isPointerType()) {
2785        if (const auto *PVD = dyn_cast<ParmVarDecl>(VD))
2786          return fixParamWithSpan(PVD, Ctx, Handler);
2787  
2788        if (VD->isLocalVarDecl())
2789          return fixVariableWithSpan(VD, Tracker, Ctx, Handler);
2790      }
2791      DEBUG_NOTE_DECL_FAIL(VD, " : not a pointer");
2792      return {};
2793    }
2794    case FixitStrategy::Kind::Array: {
2795      if (VD->isLocalVarDecl() &&
2796          isa<clang::ConstantArrayType>(VD->getType().getCanonicalType()))
2797        return fixVariableWithArray(VD, Tracker, Ctx, Handler);
2798  
2799      DEBUG_NOTE_DECL_FAIL(VD, " : not a local const-size array");
2800      return {};
2801    }
2802    case FixitStrategy::Kind::Iterator:
2803    case FixitStrategy::Kind::Vector:
2804      llvm_unreachable("FixitStrategy not implemented yet!");
2805    case FixitStrategy::Kind::Wontfix:
2806      llvm_unreachable("Invalid strategy!");
2807    }
2808    llvm_unreachable("Unknown strategy!");
2809  }
2810  
2811  // Returns true iff there exists a `FixItHint` 'h' in `FixIts` such that the
2812  // `RemoveRange` of 'h' overlaps with a macro use.
overlapWithMacro(const FixItList & FixIts)2813  static bool overlapWithMacro(const FixItList &FixIts) {
2814    // FIXME: For now we only check if the range (or the first token) is (part of)
2815    // a macro expansion.  Ideally, we want to check for all tokens in the range.
2816    return llvm::any_of(FixIts, [](const FixItHint &Hint) {
2817      auto Range = Hint.RemoveRange;
2818      if (Range.getBegin().isMacroID() || Range.getEnd().isMacroID())
2819        // If the range (or the first token) is (part of) a macro expansion:
2820        return true;
2821      return false;
2822    });
2823  }
2824  
2825  // Returns true iff `VD` is a parameter of the declaration `D`:
isParameterOf(const VarDecl * VD,const Decl * D)2826  static bool isParameterOf(const VarDecl *VD, const Decl *D) {
2827    return isa<ParmVarDecl>(VD) &&
2828           VD->getDeclContext() == dyn_cast<DeclContext>(D);
2829  }
2830  
2831  // Erases variables in `FixItsForVariable`, if such a variable has an unfixable
2832  // group mate.  A variable `v` is unfixable iff `FixItsForVariable` does not
2833  // contain `v`.
eraseVarsForUnfixableGroupMates(std::map<const VarDecl *,FixItList> & FixItsForVariable,const VariableGroupsManager & VarGrpMgr)2834  static void eraseVarsForUnfixableGroupMates(
2835      std::map<const VarDecl *, FixItList> &FixItsForVariable,
2836      const VariableGroupsManager &VarGrpMgr) {
2837    // Variables will be removed from `FixItsForVariable`:
2838    SmallVector<const VarDecl *, 8> ToErase;
2839  
2840    for (const auto &[VD, Ignore] : FixItsForVariable) {
2841      VarGrpRef Grp = VarGrpMgr.getGroupOfVar(VD);
2842      if (llvm::any_of(Grp,
2843                       [&FixItsForVariable](const VarDecl *GrpMember) -> bool {
2844                         return !FixItsForVariable.count(GrpMember);
2845                       })) {
2846        // At least one group member cannot be fixed, so we have to erase the
2847        // whole group:
2848        for (const VarDecl *Member : Grp)
2849          ToErase.push_back(Member);
2850      }
2851    }
2852    for (auto *VarToErase : ToErase)
2853      FixItsForVariable.erase(VarToErase);
2854  }
2855  
2856  // Returns the fix-its that create bounds-safe function overloads for the
2857  // function `D`, if `D`'s parameters will be changed to safe-types through
2858  // fix-its in `FixItsForVariable`.
2859  //
2860  // NOTE: In case `D`'s parameters will be changed but bounds-safe function
2861  // overloads cannot created, the whole group that contains the parameters will
2862  // be erased from `FixItsForVariable`.
createFunctionOverloadsForParms(std::map<const VarDecl *,FixItList> & FixItsForVariable,const VariableGroupsManager & VarGrpMgr,const FunctionDecl * FD,const FixitStrategy & S,ASTContext & Ctx,UnsafeBufferUsageHandler & Handler)2863  static FixItList createFunctionOverloadsForParms(
2864      std::map<const VarDecl *, FixItList> &FixItsForVariable /* mutable */,
2865      const VariableGroupsManager &VarGrpMgr, const FunctionDecl *FD,
2866      const FixitStrategy &S, ASTContext &Ctx,
2867      UnsafeBufferUsageHandler &Handler) {
2868    FixItList FixItsSharedByParms{};
2869  
2870    std::optional<FixItList> OverloadFixes =
2871        createOverloadsForFixedParams(S, FD, Ctx, Handler);
2872  
2873    if (OverloadFixes) {
2874      FixItsSharedByParms.append(*OverloadFixes);
2875    } else {
2876      // Something wrong in generating `OverloadFixes`, need to remove the
2877      // whole group, where parameters are in, from `FixItsForVariable` (Note
2878      // that all parameters should be in the same group):
2879      for (auto *Member : VarGrpMgr.getGroupOfParms())
2880        FixItsForVariable.erase(Member);
2881    }
2882    return FixItsSharedByParms;
2883  }
2884  
2885  // Constructs self-contained fix-its for each variable in `FixablesForAllVars`.
2886  static std::map<const VarDecl *, FixItList>
getFixIts(FixableGadgetSets & FixablesForAllVars,const FixitStrategy & S,ASTContext & Ctx,const Decl * D,const DeclUseTracker & Tracker,UnsafeBufferUsageHandler & Handler,const VariableGroupsManager & VarGrpMgr)2887  getFixIts(FixableGadgetSets &FixablesForAllVars, const FixitStrategy &S,
2888            ASTContext &Ctx,
2889            /* The function decl under analysis */ const Decl *D,
2890            const DeclUseTracker &Tracker, UnsafeBufferUsageHandler &Handler,
2891            const VariableGroupsManager &VarGrpMgr) {
2892    // `FixItsForVariable` will map each variable to a set of fix-its directly
2893    // associated to the variable itself.  Fix-its of distinct variables in
2894    // `FixItsForVariable` are disjoint.
2895    std::map<const VarDecl *, FixItList> FixItsForVariable;
2896  
2897    // Populate `FixItsForVariable` with fix-its directly associated with each
2898    // variable.  Fix-its directly associated to a variable 'v' are the ones
2899    // produced by the `FixableGadget`s whose claimed variable is 'v'.
2900    for (const auto &[VD, Fixables] : FixablesForAllVars.byVar) {
2901      FixItsForVariable[VD] =
2902          fixVariable(VD, S.lookup(VD), D, Tracker, Ctx, Handler);
2903      // If we fail to produce Fix-It for the declaration we have to skip the
2904      // variable entirely.
2905      if (FixItsForVariable[VD].empty()) {
2906        FixItsForVariable.erase(VD);
2907        continue;
2908      }
2909      for (const auto &F : Fixables) {
2910        std::optional<FixItList> Fixits = F->getFixits(S);
2911  
2912        if (Fixits) {
2913          FixItsForVariable[VD].insert(FixItsForVariable[VD].end(),
2914                                       Fixits->begin(), Fixits->end());
2915          continue;
2916        }
2917  #ifndef NDEBUG
2918        Handler.addDebugNoteForVar(
2919            VD, F->getSourceLoc(),
2920            ("gadget '" + F->getDebugName() + "' refused to produce a fix")
2921                .str());
2922  #endif
2923        FixItsForVariable.erase(VD);
2924        break;
2925      }
2926    }
2927  
2928    // `FixItsForVariable` now contains only variables that can be
2929    // fixed. A variable can be fixed if its' declaration and all Fixables
2930    // associated to it can all be fixed.
2931  
2932    // To further remove from `FixItsForVariable` variables whose group mates
2933    // cannot be fixed...
2934    eraseVarsForUnfixableGroupMates(FixItsForVariable, VarGrpMgr);
2935    // Now `FixItsForVariable` gets further reduced: a variable is in
2936    // `FixItsForVariable` iff it can be fixed and all its group mates can be
2937    // fixed.
2938  
2939    // Fix-its of bounds-safe overloads of `D` are shared by parameters of `D`.
2940    // That is,  when fixing multiple parameters in one step,  these fix-its will
2941    // be applied only once (instead of being applied per parameter).
2942    FixItList FixItsSharedByParms{};
2943  
2944    if (auto *FD = dyn_cast<FunctionDecl>(D))
2945      FixItsSharedByParms = createFunctionOverloadsForParms(
2946          FixItsForVariable, VarGrpMgr, FD, S, Ctx, Handler);
2947  
2948    // The map that maps each variable `v` to fix-its for the whole group where
2949    // `v` is in:
2950    std::map<const VarDecl *, FixItList> FinalFixItsForVariable{
2951        FixItsForVariable};
2952  
2953    for (auto &[Var, Ignore] : FixItsForVariable) {
2954      bool AnyParm = false;
2955      const auto VarGroupForVD = VarGrpMgr.getGroupOfVar(Var, &AnyParm);
2956  
2957      for (const VarDecl *GrpMate : VarGroupForVD) {
2958        if (Var == GrpMate)
2959          continue;
2960        if (FixItsForVariable.count(GrpMate))
2961          FinalFixItsForVariable[Var].append(FixItsForVariable[GrpMate]);
2962      }
2963      if (AnyParm) {
2964        // This assertion should never fail.  Otherwise we have a bug.
2965        assert(!FixItsSharedByParms.empty() &&
2966               "Should not try to fix a parameter that does not belong to a "
2967               "FunctionDecl");
2968        FinalFixItsForVariable[Var].append(FixItsSharedByParms);
2969      }
2970    }
2971    // Fix-its that will be applied in one step shall NOT:
2972    // 1. overlap with macros or/and templates; or
2973    // 2. conflict with each other.
2974    // Otherwise, the fix-its will be dropped.
2975    for (auto Iter = FinalFixItsForVariable.begin();
2976         Iter != FinalFixItsForVariable.end();)
2977      if (overlapWithMacro(Iter->second) ||
2978          clang::internal::anyConflict(Iter->second, Ctx.getSourceManager())) {
2979        Iter = FinalFixItsForVariable.erase(Iter);
2980      } else
2981        Iter++;
2982    return FinalFixItsForVariable;
2983  }
2984  
2985  template <typename VarDeclIterTy>
2986  static FixitStrategy
getNaiveStrategy(llvm::iterator_range<VarDeclIterTy> UnsafeVars)2987  getNaiveStrategy(llvm::iterator_range<VarDeclIterTy> UnsafeVars) {
2988    FixitStrategy S;
2989    for (const VarDecl *VD : UnsafeVars) {
2990      if (isa<ConstantArrayType>(VD->getType().getCanonicalType()))
2991        S.set(VD, FixitStrategy::Kind::Array);
2992      else
2993        S.set(VD, FixitStrategy::Kind::Span);
2994    }
2995    return S;
2996  }
2997  
2998  //  Manages variable groups:
2999  class VariableGroupsManagerImpl : public VariableGroupsManager {
3000    const std::vector<VarGrpTy> Groups;
3001    const std::map<const VarDecl *, unsigned> &VarGrpMap;
3002    const llvm::SetVector<const VarDecl *> &GrpsUnionForParms;
3003  
3004  public:
VariableGroupsManagerImpl(const std::vector<VarGrpTy> & Groups,const std::map<const VarDecl *,unsigned> & VarGrpMap,const llvm::SetVector<const VarDecl * > & GrpsUnionForParms)3005    VariableGroupsManagerImpl(
3006        const std::vector<VarGrpTy> &Groups,
3007        const std::map<const VarDecl *, unsigned> &VarGrpMap,
3008        const llvm::SetVector<const VarDecl *> &GrpsUnionForParms)
3009        : Groups(Groups), VarGrpMap(VarGrpMap),
3010          GrpsUnionForParms(GrpsUnionForParms) {}
3011  
getGroupOfVar(const VarDecl * Var,bool * HasParm) const3012    VarGrpRef getGroupOfVar(const VarDecl *Var, bool *HasParm) const override {
3013      if (GrpsUnionForParms.contains(Var)) {
3014        if (HasParm)
3015          *HasParm = true;
3016        return GrpsUnionForParms.getArrayRef();
3017      }
3018      if (HasParm)
3019        *HasParm = false;
3020  
3021      auto It = VarGrpMap.find(Var);
3022  
3023      if (It == VarGrpMap.end())
3024        return std::nullopt;
3025      return Groups[It->second];
3026    }
3027  
getGroupOfParms() const3028    VarGrpRef getGroupOfParms() const override {
3029      return GrpsUnionForParms.getArrayRef();
3030    }
3031  };
3032  
checkUnsafeBufferUsage(const Decl * D,UnsafeBufferUsageHandler & Handler,bool EmitSuggestions)3033  void clang::checkUnsafeBufferUsage(const Decl *D,
3034                                     UnsafeBufferUsageHandler &Handler,
3035                                     bool EmitSuggestions) {
3036  #ifndef NDEBUG
3037    Handler.clearDebugNotes();
3038  #endif
3039  
3040    assert(D && D->getBody());
3041    // We do not want to visit a Lambda expression defined inside a method
3042    // independently. Instead, it should be visited along with the outer method.
3043    // FIXME: do we want to do the same thing for `BlockDecl`s?
3044    if (const auto *fd = dyn_cast<CXXMethodDecl>(D)) {
3045      if (fd->getParent()->isLambda() && fd->getParent()->isLocalClass())
3046        return;
3047    }
3048  
3049    // Do not emit fixit suggestions for functions declared in an
3050    // extern "C" block.
3051    if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
3052      for (FunctionDecl *FReDecl : FD->redecls()) {
3053        if (FReDecl->isExternC()) {
3054          EmitSuggestions = false;
3055          break;
3056        }
3057      }
3058    }
3059  
3060    WarningGadgetSets UnsafeOps;
3061    FixableGadgetSets FixablesForAllVars;
3062  
3063    auto [FixableGadgets, WarningGadgets, Tracker] =
3064        findGadgets(D, Handler, EmitSuggestions);
3065  
3066    if (!EmitSuggestions) {
3067      // Our job is very easy without suggestions. Just warn about
3068      // every problematic operation and consider it done. No need to deal
3069      // with fixable gadgets, no need to group operations by variable.
3070      for (const auto &G : WarningGadgets) {
3071        G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/false,
3072                                 D->getASTContext());
3073      }
3074  
3075      // This return guarantees that most of the machine doesn't run when
3076      // suggestions aren't requested.
3077      assert(FixableGadgets.size() == 0 &&
3078             "Fixable gadgets found but suggestions not requested!");
3079      return;
3080    }
3081  
3082    // If no `WarningGadget`s ever matched, there is no unsafe operations in the
3083    //  function under the analysis. No need to fix any Fixables.
3084    if (!WarningGadgets.empty()) {
3085      // Gadgets "claim" variables they're responsible for. Once this loop
3086      // finishes, the tracker will only track DREs that weren't claimed by any
3087      // gadgets, i.e. not understood by the analysis.
3088      for (const auto &G : FixableGadgets) {
3089        for (const auto *DRE : G->getClaimedVarUseSites()) {
3090          Tracker.claimUse(DRE);
3091        }
3092      }
3093    }
3094  
3095    // If no `WarningGadget`s ever matched, there is no unsafe operations in the
3096    // function under the analysis.  Thus, it early returns here as there is
3097    // nothing needs to be fixed.
3098    //
3099    // Note this claim is based on the assumption that there is no unsafe
3100    // variable whose declaration is invisible from the analyzing function.
3101    // Otherwise, we need to consider if the uses of those unsafe varuables needs
3102    // fix.
3103    // So far, we are not fixing any global variables or class members. And,
3104    // lambdas will be analyzed along with the enclosing function. So this early
3105    // return is correct for now.
3106    if (WarningGadgets.empty())
3107      return;
3108  
3109    UnsafeOps = groupWarningGadgetsByVar(std::move(WarningGadgets));
3110    FixablesForAllVars = groupFixablesByVar(std::move(FixableGadgets));
3111  
3112    std::map<const VarDecl *, FixItList> FixItsForVariableGroup;
3113  
3114    // Filter out non-local vars and vars with unclaimed DeclRefExpr-s.
3115    for (auto it = FixablesForAllVars.byVar.cbegin();
3116         it != FixablesForAllVars.byVar.cend();) {
3117      // FIXME: need to deal with global variables later
3118      if ((!it->first->isLocalVarDecl() && !isa<ParmVarDecl>(it->first))) {
3119  #ifndef NDEBUG
3120        Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
3121                                   ("failed to produce fixit for '" +
3122                                    it->first->getNameAsString() +
3123                                    "' : neither local nor a parameter"));
3124  #endif
3125        it = FixablesForAllVars.byVar.erase(it);
3126      } else if (it->first->getType().getCanonicalType()->isReferenceType()) {
3127  #ifndef NDEBUG
3128        Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
3129                                   ("failed to produce fixit for '" +
3130                                    it->first->getNameAsString() +
3131                                    "' : has a reference type"));
3132  #endif
3133        it = FixablesForAllVars.byVar.erase(it);
3134      } else if (Tracker.hasUnclaimedUses(it->first)) {
3135        it = FixablesForAllVars.byVar.erase(it);
3136      } else if (it->first->isInitCapture()) {
3137  #ifndef NDEBUG
3138        Handler.addDebugNoteForVar(it->first, it->first->getBeginLoc(),
3139                                   ("failed to produce fixit for '" +
3140                                    it->first->getNameAsString() +
3141                                    "' : init capture"));
3142  #endif
3143        it = FixablesForAllVars.byVar.erase(it);
3144      } else {
3145        ++it;
3146      }
3147    }
3148  
3149  #ifndef NDEBUG
3150    for (const auto &it : UnsafeOps.byVar) {
3151      const VarDecl *const UnsafeVD = it.first;
3152      auto UnclaimedDREs = Tracker.getUnclaimedUses(UnsafeVD);
3153      if (UnclaimedDREs.empty())
3154        continue;
3155      const auto UnfixedVDName = UnsafeVD->getNameAsString();
3156      for (const clang::DeclRefExpr *UnclaimedDRE : UnclaimedDREs) {
3157        std::string UnclaimedUseTrace =
3158            getDREAncestorString(UnclaimedDRE, D->getASTContext());
3159  
3160        Handler.addDebugNoteForVar(
3161            UnsafeVD, UnclaimedDRE->getBeginLoc(),
3162            ("failed to produce fixit for '" + UnfixedVDName +
3163             "' : has an unclaimed use\nThe unclaimed DRE trace: " +
3164             UnclaimedUseTrace));
3165      }
3166    }
3167  #endif
3168  
3169    // Fixpoint iteration for pointer assignments
3170    using DepMapTy = DenseMap<const VarDecl *, llvm::SetVector<const VarDecl *>>;
3171    DepMapTy DependenciesMap{};
3172    DepMapTy PtrAssignmentGraph{};
3173  
3174    for (auto it : FixablesForAllVars.byVar) {
3175      for (const FixableGadget *fixable : it.second) {
3176        std::optional<std::pair<const VarDecl *, const VarDecl *>> ImplPair =
3177            fixable->getStrategyImplications();
3178        if (ImplPair) {
3179          std::pair<const VarDecl *, const VarDecl *> Impl = std::move(*ImplPair);
3180          PtrAssignmentGraph[Impl.first].insert(Impl.second);
3181        }
3182      }
3183    }
3184  
3185    /*
3186     The following code does a BFS traversal of the `PtrAssignmentGraph`
3187     considering all unsafe vars as starting nodes and constructs an undirected
3188     graph `DependenciesMap`. Constructing the `DependenciesMap` in this manner
3189     elimiates all variables that are unreachable from any unsafe var. In other
3190     words, this removes all dependencies that don't include any unsafe variable
3191     and consequently don't need any fixit generation.
3192     Note: A careful reader would observe that the code traverses
3193     `PtrAssignmentGraph` using `CurrentVar` but adds edges between `Var` and
3194     `Adj` and not between `CurrentVar` and `Adj`. Both approaches would
3195     achieve the same result but the one used here dramatically cuts the
3196     amount of hoops the second part of the algorithm needs to jump, given that
3197     a lot of these connections become "direct". The reader is advised not to
3198     imagine how the graph is transformed because of using `Var` instead of
3199     `CurrentVar`. The reader can continue reading as if `CurrentVar` was used,
3200     and think about why it's equivalent later.
3201     */
3202    std::set<const VarDecl *> VisitedVarsDirected{};
3203    for (const auto &[Var, ignore] : UnsafeOps.byVar) {
3204      if (VisitedVarsDirected.find(Var) == VisitedVarsDirected.end()) {
3205  
3206        std::queue<const VarDecl *> QueueDirected{};
3207        QueueDirected.push(Var);
3208        while (!QueueDirected.empty()) {
3209          const VarDecl *CurrentVar = QueueDirected.front();
3210          QueueDirected.pop();
3211          VisitedVarsDirected.insert(CurrentVar);
3212          auto AdjacentNodes = PtrAssignmentGraph[CurrentVar];
3213          for (const VarDecl *Adj : AdjacentNodes) {
3214            if (VisitedVarsDirected.find(Adj) == VisitedVarsDirected.end()) {
3215              QueueDirected.push(Adj);
3216            }
3217            DependenciesMap[Var].insert(Adj);
3218            DependenciesMap[Adj].insert(Var);
3219          }
3220        }
3221      }
3222    }
3223  
3224    // `Groups` stores the set of Connected Components in the graph.
3225    std::vector<VarGrpTy> Groups;
3226    // `VarGrpMap` maps variables that need fix to the groups (indexes) that the
3227    // variables belong to.  Group indexes refer to the elements in `Groups`.
3228    // `VarGrpMap` is complete in that every variable that needs fix is in it.
3229    std::map<const VarDecl *, unsigned> VarGrpMap;
3230    // The union group over the ones in "Groups" that contain parameters of `D`:
3231    llvm::SetVector<const VarDecl *>
3232        GrpsUnionForParms; // these variables need to be fixed in one step
3233  
3234    // Group Connected Components for Unsafe Vars
3235    // (Dependencies based on pointer assignments)
3236    std::set<const VarDecl *> VisitedVars{};
3237    for (const auto &[Var, ignore] : UnsafeOps.byVar) {
3238      if (VisitedVars.find(Var) == VisitedVars.end()) {
3239        VarGrpTy &VarGroup = Groups.emplace_back();
3240        std::queue<const VarDecl *> Queue{};
3241  
3242        Queue.push(Var);
3243        while (!Queue.empty()) {
3244          const VarDecl *CurrentVar = Queue.front();
3245          Queue.pop();
3246          VisitedVars.insert(CurrentVar);
3247          VarGroup.push_back(CurrentVar);
3248          auto AdjacentNodes = DependenciesMap[CurrentVar];
3249          for (const VarDecl *Adj : AdjacentNodes) {
3250            if (VisitedVars.find(Adj) == VisitedVars.end()) {
3251              Queue.push(Adj);
3252            }
3253          }
3254        }
3255  
3256        bool HasParm = false;
3257        unsigned GrpIdx = Groups.size() - 1;
3258  
3259        for (const VarDecl *V : VarGroup) {
3260          VarGrpMap[V] = GrpIdx;
3261          if (!HasParm && isParameterOf(V, D))
3262            HasParm = true;
3263        }
3264        if (HasParm)
3265          GrpsUnionForParms.insert(VarGroup.begin(), VarGroup.end());
3266      }
3267    }
3268  
3269    // Remove a `FixableGadget` if the associated variable is not in the graph
3270    // computed above.  We do not want to generate fix-its for such variables,
3271    // since they are neither warned nor reachable from a warned one.
3272    //
3273    // Note a variable is not warned if it is not directly used in any unsafe
3274    // operation. A variable `v` is NOT reachable from an unsafe variable, if it
3275    // does not exist another variable `u` such that `u` is warned and fixing `u`
3276    // (transitively) implicates fixing `v`.
3277    //
3278    // For example,
3279    // ```
3280    // void f(int * p) {
3281    //   int * a = p; *p = 0;
3282    // }
3283    // ```
3284    // `*p = 0` is a fixable gadget associated with a variable `p` that is neither
3285    // warned nor reachable from a warned one.  If we add `a[5] = 0` to the end of
3286    // the function above, `p` becomes reachable from a warned variable.
3287    for (auto I = FixablesForAllVars.byVar.begin();
3288         I != FixablesForAllVars.byVar.end();) {
3289      // Note `VisitedVars` contain all the variables in the graph:
3290      if (!VisitedVars.count((*I).first)) {
3291        // no such var in graph:
3292        I = FixablesForAllVars.byVar.erase(I);
3293      } else
3294        ++I;
3295    }
3296  
3297    // We assign strategies to variables that are 1) in the graph and 2) can be
3298    // fixed. Other variables have the default "Won't fix" strategy.
3299    FixitStrategy NaiveStrategy = getNaiveStrategy(llvm::make_filter_range(
3300        VisitedVars, [&FixablesForAllVars](const VarDecl *V) {
3301          // If a warned variable has no "Fixable", it is considered unfixable:
3302          return FixablesForAllVars.byVar.count(V);
3303        }));
3304    VariableGroupsManagerImpl VarGrpMgr(Groups, VarGrpMap, GrpsUnionForParms);
3305  
3306    if (isa<NamedDecl>(D))
3307      // The only case where `D` is not a `NamedDecl` is when `D` is a
3308      // `BlockDecl`. Let's not fix variables in blocks for now
3309      FixItsForVariableGroup =
3310          getFixIts(FixablesForAllVars, NaiveStrategy, D->getASTContext(), D,
3311                    Tracker, Handler, VarGrpMgr);
3312  
3313    for (const auto &G : UnsafeOps.noVar) {
3314      G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/false,
3315                               D->getASTContext());
3316    }
3317  
3318    for (const auto &[VD, WarningGadgets] : UnsafeOps.byVar) {
3319      auto FixItsIt = FixItsForVariableGroup.find(VD);
3320      Handler.handleUnsafeVariableGroup(VD, VarGrpMgr,
3321                                        FixItsIt != FixItsForVariableGroup.end()
3322                                            ? std::move(FixItsIt->second)
3323                                            : FixItList{},
3324                                        D, NaiveStrategy);
3325      for (const auto &G : WarningGadgets) {
3326        G->handleUnsafeOperation(Handler, /*IsRelatedToDecl=*/true,
3327                                 D->getASTContext());
3328      }
3329    }
3330  }
3331