xref: /freebsd/contrib/llvm-project/clang/lib/Analysis/Consumed.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- Consumed.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A intra-procedural analysis for checking consumed properties.  This is based,
10 // in part, on research on linear types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/Analysis/Analyses/Consumed.h"
15 #include "clang/AST/Attr.h"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/Expr.h"
19 #include "clang/AST/ExprCXX.h"
20 #include "clang/AST/Stmt.h"
21 #include "clang/AST/StmtVisitor.h"
22 #include "clang/AST/Type.h"
23 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
24 #include "clang/Analysis/AnalysisDeclContext.h"
25 #include "clang/Analysis/CFG.h"
26 #include "clang/Basic/LLVM.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include <cassert>
32 #include <memory>
33 #include <optional>
34 #include <utility>
35 
36 // TODO: Adjust states of args to constructors in the same way that arguments to
37 //       function calls are handled.
38 // TODO: Use information from tests in for- and while-loop conditional.
39 // TODO: Add notes about the actual and expected state for
40 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
41 // TODO: Adjust the parser and AttributesList class to support lists of
42 //       identifiers.
43 // TODO: Warn about unreachable code.
44 // TODO: Switch to using a bitmap to track unreachable blocks.
45 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
46 //       if (valid) ...; (Deferred)
47 // TODO: Take notes on state transitions to provide better warning messages.
48 //       (Deferred)
49 // TODO: Test nested conditionals: A) Checking the same value multiple times,
50 //       and 2) Checking different values. (Deferred)
51 
52 using namespace clang;
53 using namespace consumed;
54 
55 // Key method definition
56 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() = default;
57 
getFirstStmtLoc(const CFGBlock * Block)58 static SourceLocation getFirstStmtLoc(const CFGBlock *Block) {
59   // Find the source location of the first statement in the block, if the block
60   // is not empty.
61   for (const auto &B : *Block)
62     if (std::optional<CFGStmt> CS = B.getAs<CFGStmt>())
63       return CS->getStmt()->getBeginLoc();
64 
65   // Block is empty.
66   // If we have one successor, return the first statement in that block
67   if (Block->succ_size() == 1 && *Block->succ_begin())
68     return getFirstStmtLoc(*Block->succ_begin());
69 
70   return {};
71 }
72 
getLastStmtLoc(const CFGBlock * Block)73 static SourceLocation getLastStmtLoc(const CFGBlock *Block) {
74   // Find the source location of the last statement in the block, if the block
75   // is not empty.
76   if (const Stmt *StmtNode = Block->getTerminatorStmt()) {
77     return StmtNode->getBeginLoc();
78   } else {
79     for (CFGBlock::const_reverse_iterator BI = Block->rbegin(),
80          BE = Block->rend(); BI != BE; ++BI) {
81       if (std::optional<CFGStmt> CS = BI->getAs<CFGStmt>())
82         return CS->getStmt()->getBeginLoc();
83     }
84   }
85 
86   // If we have one successor, return the first statement in that block
87   SourceLocation Loc;
88   if (Block->succ_size() == 1 && *Block->succ_begin())
89     Loc = getFirstStmtLoc(*Block->succ_begin());
90   if (Loc.isValid())
91     return Loc;
92 
93   // If we have one predecessor, return the last statement in that block
94   if (Block->pred_size() == 1 && *Block->pred_begin())
95     return getLastStmtLoc(*Block->pred_begin());
96 
97   return Loc;
98 }
99 
invertConsumedUnconsumed(ConsumedState State)100 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
101   switch (State) {
102   case CS_Unconsumed:
103     return CS_Consumed;
104   case CS_Consumed:
105     return CS_Unconsumed;
106   case CS_None:
107     return CS_None;
108   case CS_Unknown:
109     return CS_Unknown;
110   }
111   llvm_unreachable("invalid enum");
112 }
113 
isCallableInState(const CallableWhenAttr * CWAttr,ConsumedState State)114 static bool isCallableInState(const CallableWhenAttr *CWAttr,
115                               ConsumedState State) {
116   for (const auto &S : CWAttr->callableStates()) {
117     ConsumedState MappedAttrState = CS_None;
118 
119     switch (S) {
120     case CallableWhenAttr::Unknown:
121       MappedAttrState = CS_Unknown;
122       break;
123 
124     case CallableWhenAttr::Unconsumed:
125       MappedAttrState = CS_Unconsumed;
126       break;
127 
128     case CallableWhenAttr::Consumed:
129       MappedAttrState = CS_Consumed;
130       break;
131     }
132 
133     if (MappedAttrState == State)
134       return true;
135   }
136 
137   return false;
138 }
139 
isConsumableType(const QualType & QT)140 static bool isConsumableType(const QualType &QT) {
141   if (QT->isPointerOrReferenceType())
142     return false;
143 
144   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
145     return RD->hasAttr<ConsumableAttr>();
146 
147   return false;
148 }
149 
isAutoCastType(const QualType & QT)150 static bool isAutoCastType(const QualType &QT) {
151   if (QT->isPointerOrReferenceType())
152     return false;
153 
154   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
155     return RD->hasAttr<ConsumableAutoCastAttr>();
156 
157   return false;
158 }
159 
isSetOnReadPtrType(const QualType & QT)160 static bool isSetOnReadPtrType(const QualType &QT) {
161   if (const CXXRecordDecl *RD = QT->getPointeeCXXRecordDecl())
162     return RD->hasAttr<ConsumableSetOnReadAttr>();
163   return false;
164 }
165 
isKnownState(ConsumedState State)166 static bool isKnownState(ConsumedState State) {
167   switch (State) {
168   case CS_Unconsumed:
169   case CS_Consumed:
170     return true;
171   case CS_None:
172   case CS_Unknown:
173     return false;
174   }
175   llvm_unreachable("invalid enum");
176 }
177 
isRValueRef(QualType ParamType)178 static bool isRValueRef(QualType ParamType) {
179   return ParamType->isRValueReferenceType();
180 }
181 
isTestingFunction(const FunctionDecl * FunDecl)182 static bool isTestingFunction(const FunctionDecl *FunDecl) {
183   return FunDecl->hasAttr<TestTypestateAttr>();
184 }
185 
mapConsumableAttrState(const QualType QT)186 static ConsumedState mapConsumableAttrState(const QualType QT) {
187   assert(isConsumableType(QT));
188 
189   const ConsumableAttr *CAttr =
190       QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
191 
192   switch (CAttr->getDefaultState()) {
193   case ConsumableAttr::Unknown:
194     return CS_Unknown;
195   case ConsumableAttr::Unconsumed:
196     return CS_Unconsumed;
197   case ConsumableAttr::Consumed:
198     return CS_Consumed;
199   }
200   llvm_unreachable("invalid enum");
201 }
202 
203 static ConsumedState
mapParamTypestateAttrState(const ParamTypestateAttr * PTAttr)204 mapParamTypestateAttrState(const ParamTypestateAttr *PTAttr) {
205   switch (PTAttr->getParamState()) {
206   case ParamTypestateAttr::Unknown:
207     return CS_Unknown;
208   case ParamTypestateAttr::Unconsumed:
209     return CS_Unconsumed;
210   case ParamTypestateAttr::Consumed:
211     return CS_Consumed;
212   }
213   llvm_unreachable("invalid_enum");
214 }
215 
216 static ConsumedState
mapReturnTypestateAttrState(const ReturnTypestateAttr * RTSAttr)217 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
218   switch (RTSAttr->getState()) {
219   case ReturnTypestateAttr::Unknown:
220     return CS_Unknown;
221   case ReturnTypestateAttr::Unconsumed:
222     return CS_Unconsumed;
223   case ReturnTypestateAttr::Consumed:
224     return CS_Consumed;
225   }
226   llvm_unreachable("invalid enum");
227 }
228 
mapSetTypestateAttrState(const SetTypestateAttr * STAttr)229 static ConsumedState mapSetTypestateAttrState(const SetTypestateAttr *STAttr) {
230   switch (STAttr->getNewState()) {
231   case SetTypestateAttr::Unknown:
232     return CS_Unknown;
233   case SetTypestateAttr::Unconsumed:
234     return CS_Unconsumed;
235   case SetTypestateAttr::Consumed:
236     return CS_Consumed;
237   }
238   llvm_unreachable("invalid_enum");
239 }
240 
stateToString(ConsumedState State)241 static StringRef stateToString(ConsumedState State) {
242   switch (State) {
243   case consumed::CS_None:
244     return "none";
245 
246   case consumed::CS_Unknown:
247     return "unknown";
248 
249   case consumed::CS_Unconsumed:
250     return "unconsumed";
251 
252   case consumed::CS_Consumed:
253     return "consumed";
254   }
255   llvm_unreachable("invalid enum");
256 }
257 
testsFor(const FunctionDecl * FunDecl)258 static ConsumedState testsFor(const FunctionDecl *FunDecl) {
259   assert(isTestingFunction(FunDecl));
260   switch (FunDecl->getAttr<TestTypestateAttr>()->getTestState()) {
261   case TestTypestateAttr::Unconsumed:
262     return CS_Unconsumed;
263   case TestTypestateAttr::Consumed:
264     return CS_Consumed;
265   }
266   llvm_unreachable("invalid enum");
267 }
268 
269 namespace {
270 
271 struct VarTestResult {
272   const VarDecl *Var;
273   ConsumedState TestsFor;
274 };
275 
276 } // namespace
277 
278 namespace clang {
279 namespace consumed {
280 
281 enum EffectiveOp {
282   EO_And,
283   EO_Or
284 };
285 
286 class PropagationInfo {
287   enum {
288     IT_None,
289     IT_State,
290     IT_VarTest,
291     IT_BinTest,
292     IT_Var,
293     IT_Tmp
294   } InfoType = IT_None;
295 
296   struct BinTestTy {
297     const BinaryOperator *Source;
298     EffectiveOp EOp;
299     VarTestResult LTest;
300     VarTestResult RTest;
301   };
302 
303   union {
304     ConsumedState State;
305     VarTestResult VarTest;
306     const VarDecl *Var;
307     const CXXBindTemporaryExpr *Tmp;
308     BinTestTy BinTest;
309   };
310 
311 public:
312   PropagationInfo() = default;
PropagationInfo(const VarTestResult & VarTest)313   PropagationInfo(const VarTestResult &VarTest)
314       : InfoType(IT_VarTest), VarTest(VarTest) {}
315 
PropagationInfo(const VarDecl * Var,ConsumedState TestsFor)316   PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
317       : InfoType(IT_VarTest) {
318     VarTest.Var      = Var;
319     VarTest.TestsFor = TestsFor;
320   }
321 
PropagationInfo(const BinaryOperator * Source,EffectiveOp EOp,const VarTestResult & LTest,const VarTestResult & RTest)322   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
323                   const VarTestResult &LTest, const VarTestResult &RTest)
324       : InfoType(IT_BinTest) {
325     BinTest.Source  = Source;
326     BinTest.EOp     = EOp;
327     BinTest.LTest   = LTest;
328     BinTest.RTest   = RTest;
329   }
330 
PropagationInfo(const BinaryOperator * Source,EffectiveOp EOp,const VarDecl * LVar,ConsumedState LTestsFor,const VarDecl * RVar,ConsumedState RTestsFor)331   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
332                   const VarDecl *LVar, ConsumedState LTestsFor,
333                   const VarDecl *RVar, ConsumedState RTestsFor)
334       : InfoType(IT_BinTest) {
335     BinTest.Source         = Source;
336     BinTest.EOp            = EOp;
337     BinTest.LTest.Var      = LVar;
338     BinTest.LTest.TestsFor = LTestsFor;
339     BinTest.RTest.Var      = RVar;
340     BinTest.RTest.TestsFor = RTestsFor;
341   }
342 
PropagationInfo(ConsumedState State)343   PropagationInfo(ConsumedState State)
344       : InfoType(IT_State), State(State) {}
PropagationInfo(const VarDecl * Var)345   PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
PropagationInfo(const CXXBindTemporaryExpr * Tmp)346   PropagationInfo(const CXXBindTemporaryExpr *Tmp)
347       : InfoType(IT_Tmp), Tmp(Tmp) {}
348 
getState() const349   const ConsumedState &getState() const {
350     assert(InfoType == IT_State);
351     return State;
352   }
353 
getVarTest() const354   const VarTestResult &getVarTest() const {
355     assert(InfoType == IT_VarTest);
356     return VarTest;
357   }
358 
getLTest() const359   const VarTestResult &getLTest() const {
360     assert(InfoType == IT_BinTest);
361     return BinTest.LTest;
362   }
363 
getRTest() const364   const VarTestResult &getRTest() const {
365     assert(InfoType == IT_BinTest);
366     return BinTest.RTest;
367   }
368 
getVar() const369   const VarDecl *getVar() const {
370     assert(InfoType == IT_Var);
371     return Var;
372   }
373 
getTmp() const374   const CXXBindTemporaryExpr *getTmp() const {
375     assert(InfoType == IT_Tmp);
376     return Tmp;
377   }
378 
getAsState(const ConsumedStateMap * StateMap) const379   ConsumedState getAsState(const ConsumedStateMap *StateMap) const {
380     assert(isVar() || isTmp() || isState());
381 
382     if (isVar())
383       return StateMap->getState(Var);
384     else if (isTmp())
385       return StateMap->getState(Tmp);
386     else if (isState())
387       return State;
388     else
389       return CS_None;
390   }
391 
testEffectiveOp() const392   EffectiveOp testEffectiveOp() const {
393     assert(InfoType == IT_BinTest);
394     return BinTest.EOp;
395   }
396 
testSourceNode() const397   const BinaryOperator * testSourceNode() const {
398     assert(InfoType == IT_BinTest);
399     return BinTest.Source;
400   }
401 
isValid() const402   bool isValid() const { return InfoType != IT_None; }
isState() const403   bool isState() const { return InfoType == IT_State; }
isVarTest() const404   bool isVarTest() const { return InfoType == IT_VarTest; }
isBinTest() const405   bool isBinTest() const { return InfoType == IT_BinTest; }
isVar() const406   bool isVar() const { return InfoType == IT_Var; }
isTmp() const407   bool isTmp() const { return InfoType == IT_Tmp; }
408 
isTest() const409   bool isTest() const {
410     return InfoType == IT_VarTest || InfoType == IT_BinTest;
411   }
412 
isPointerToValue() const413   bool isPointerToValue() const {
414     return InfoType == IT_Var || InfoType == IT_Tmp;
415   }
416 
invertTest() const417   PropagationInfo invertTest() const {
418     assert(InfoType == IT_VarTest || InfoType == IT_BinTest);
419 
420     if (InfoType == IT_VarTest) {
421       return PropagationInfo(VarTest.Var,
422                              invertConsumedUnconsumed(VarTest.TestsFor));
423 
424     } else if (InfoType == IT_BinTest) {
425       return PropagationInfo(BinTest.Source,
426         BinTest.EOp == EO_And ? EO_Or : EO_And,
427         BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
428         BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
429     } else {
430       return {};
431     }
432   }
433 };
434 
435 } // namespace consumed
436 } // namespace clang
437 
438 static void
setStateForVarOrTmp(ConsumedStateMap * StateMap,const PropagationInfo & PInfo,ConsumedState State)439 setStateForVarOrTmp(ConsumedStateMap *StateMap, const PropagationInfo &PInfo,
440                     ConsumedState State) {
441   assert(PInfo.isVar() || PInfo.isTmp());
442 
443   if (PInfo.isVar())
444     StateMap->setState(PInfo.getVar(), State);
445   else
446     StateMap->setState(PInfo.getTmp(), State);
447 }
448 
449 namespace clang {
450 namespace consumed {
451 
452 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
453   using MapType = llvm::DenseMap<const Stmt *, PropagationInfo>;
454   using PairType= std::pair<const Stmt *, PropagationInfo>;
455   using InfoEntry = MapType::iterator;
456   using ConstInfoEntry = MapType::const_iterator;
457 
458   ConsumedAnalyzer &Analyzer;
459   ConsumedStateMap *StateMap;
460   MapType PropagationMap;
461 
findInfo(const Expr * E)462   InfoEntry findInfo(const Expr *E) {
463     if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
464       if (!Cleanups->cleanupsHaveSideEffects())
465         E = Cleanups->getSubExpr();
466     return PropagationMap.find(E->IgnoreParens());
467   }
468 
findInfo(const Expr * E) const469   ConstInfoEntry findInfo(const Expr *E) const {
470     if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
471       if (!Cleanups->cleanupsHaveSideEffects())
472         E = Cleanups->getSubExpr();
473     return PropagationMap.find(E->IgnoreParens());
474   }
475 
insertInfo(const Expr * E,const PropagationInfo & PI)476   void insertInfo(const Expr *E, const PropagationInfo &PI) {
477     PropagationMap.insert(PairType(E->IgnoreParens(), PI));
478   }
479 
480   void forwardInfo(const Expr *From, const Expr *To);
481   void copyInfo(const Expr *From, const Expr *To, ConsumedState CS);
482   ConsumedState getInfo(const Expr *From);
483   void setInfo(const Expr *To, ConsumedState NS);
484   void propagateReturnType(const Expr *Call, const FunctionDecl *Fun);
485 
486 public:
487   void checkCallability(const PropagationInfo &PInfo,
488                         const FunctionDecl *FunDecl,
489                         SourceLocation BlameLoc);
490   bool handleCall(const CallExpr *Call, const Expr *ObjArg,
491                   const FunctionDecl *FunD);
492 
493   void VisitBinaryOperator(const BinaryOperator *BinOp);
494   void VisitCallExpr(const CallExpr *Call);
495   void VisitCastExpr(const CastExpr *Cast);
496   void VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *Temp);
497   void VisitCXXConstructExpr(const CXXConstructExpr *Call);
498   void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
499   void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
500   void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
501   void VisitDeclStmt(const DeclStmt *DelcS);
502   void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
503   void VisitMemberExpr(const MemberExpr *MExpr);
504   void VisitParmVarDecl(const ParmVarDecl *Param);
505   void VisitReturnStmt(const ReturnStmt *Ret);
506   void VisitUnaryOperator(const UnaryOperator *UOp);
507   void VisitVarDecl(const VarDecl *Var);
508 
ConsumedStmtVisitor(ConsumedAnalyzer & Analyzer,ConsumedStateMap * StateMap)509   ConsumedStmtVisitor(ConsumedAnalyzer &Analyzer, ConsumedStateMap *StateMap)
510       : Analyzer(Analyzer), StateMap(StateMap) {}
511 
getInfo(const Expr * StmtNode) const512   PropagationInfo getInfo(const Expr *StmtNode) const {
513     ConstInfoEntry Entry = findInfo(StmtNode);
514 
515     if (Entry != PropagationMap.end())
516       return Entry->second;
517     else
518       return {};
519   }
520 
reset(ConsumedStateMap * NewStateMap)521   void reset(ConsumedStateMap *NewStateMap) {
522     StateMap = NewStateMap;
523   }
524 };
525 
526 } // namespace consumed
527 } // namespace clang
528 
forwardInfo(const Expr * From,const Expr * To)529 void ConsumedStmtVisitor::forwardInfo(const Expr *From, const Expr *To) {
530   InfoEntry Entry = findInfo(From);
531   if (Entry != PropagationMap.end())
532     insertInfo(To, Entry->second);
533 }
534 
535 // Create a new state for To, which is initialized to the state of From.
536 // If NS is not CS_None, sets the state of From to NS.
copyInfo(const Expr * From,const Expr * To,ConsumedState NS)537 void ConsumedStmtVisitor::copyInfo(const Expr *From, const Expr *To,
538                                    ConsumedState NS) {
539   InfoEntry Entry = findInfo(From);
540   if (Entry != PropagationMap.end()) {
541     PropagationInfo& PInfo = Entry->second;
542     ConsumedState CS = PInfo.getAsState(StateMap);
543     if (CS != CS_None)
544       insertInfo(To, PropagationInfo(CS));
545     if (NS != CS_None && PInfo.isPointerToValue())
546       setStateForVarOrTmp(StateMap, PInfo, NS);
547   }
548 }
549 
550 // Get the ConsumedState for From
getInfo(const Expr * From)551 ConsumedState ConsumedStmtVisitor::getInfo(const Expr *From) {
552   InfoEntry Entry = findInfo(From);
553   if (Entry != PropagationMap.end()) {
554     PropagationInfo& PInfo = Entry->second;
555     return PInfo.getAsState(StateMap);
556   }
557   return CS_None;
558 }
559 
560 // If we already have info for To then update it, otherwise create a new entry.
setInfo(const Expr * To,ConsumedState NS)561 void ConsumedStmtVisitor::setInfo(const Expr *To, ConsumedState NS) {
562   InfoEntry Entry = findInfo(To);
563   if (Entry != PropagationMap.end()) {
564     PropagationInfo& PInfo = Entry->second;
565     if (PInfo.isPointerToValue())
566       setStateForVarOrTmp(StateMap, PInfo, NS);
567   } else if (NS != CS_None) {
568      insertInfo(To, PropagationInfo(NS));
569   }
570 }
571 
checkCallability(const PropagationInfo & PInfo,const FunctionDecl * FunDecl,SourceLocation BlameLoc)572 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
573                                            const FunctionDecl *FunDecl,
574                                            SourceLocation BlameLoc) {
575   assert(!PInfo.isTest());
576 
577   const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
578   if (!CWAttr)
579     return;
580 
581   if (PInfo.isVar()) {
582     ConsumedState VarState = StateMap->getState(PInfo.getVar());
583 
584     if (VarState == CS_None || isCallableInState(CWAttr, VarState))
585       return;
586 
587     Analyzer.WarningsHandler.warnUseInInvalidState(
588       FunDecl->getNameAsString(), PInfo.getVar()->getNameAsString(),
589       stateToString(VarState), BlameLoc);
590   } else {
591     ConsumedState TmpState = PInfo.getAsState(StateMap);
592 
593     if (TmpState == CS_None || isCallableInState(CWAttr, TmpState))
594       return;
595 
596     Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
597       FunDecl->getNameAsString(), stateToString(TmpState), BlameLoc);
598   }
599 }
600 
601 // Factors out common behavior for function, method, and operator calls.
602 // Check parameters and set parameter state if necessary.
603 // Returns true if the state of ObjArg is set, or false otherwise.
handleCall(const CallExpr * Call,const Expr * ObjArg,const FunctionDecl * FunD)604 bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,
605                                      const FunctionDecl *FunD) {
606   unsigned Offset = 0;
607   if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))
608     Offset = 1;  // first argument is 'this'
609 
610   // check explicit parameters
611   for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
612     // Skip variable argument lists.
613     if (Index - Offset >= FunD->getNumParams())
614       break;
615 
616     const ParmVarDecl *Param = FunD->getParamDecl(Index - Offset);
617     QualType ParamType = Param->getType();
618 
619     InfoEntry Entry = findInfo(Call->getArg(Index));
620 
621     if (Entry == PropagationMap.end() || Entry->second.isTest())
622       continue;
623     PropagationInfo PInfo = Entry->second;
624 
625     // Check that the parameter is in the correct state.
626     if (ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>()) {
627       ConsumedState ParamState = PInfo.getAsState(StateMap);
628       ConsumedState ExpectedState = mapParamTypestateAttrState(PTA);
629 
630       if (ParamState != ExpectedState)
631         Analyzer.WarningsHandler.warnParamTypestateMismatch(
632           Call->getArg(Index)->getExprLoc(),
633           stateToString(ExpectedState), stateToString(ParamState));
634     }
635 
636     if (!(Entry->second.isVar() || Entry->second.isTmp()))
637       continue;
638 
639     // Adjust state on the caller side.
640     if (ReturnTypestateAttr *RT = Param->getAttr<ReturnTypestateAttr>())
641       setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));
642     else if (isRValueRef(ParamType) || isConsumableType(ParamType))
643       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);
644     else if (ParamType->isPointerOrReferenceType() &&
645              (!ParamType->getPointeeType().isConstQualified() ||
646               isSetOnReadPtrType(ParamType)))
647       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);
648   }
649 
650   if (!ObjArg)
651     return false;
652 
653   // check implicit 'self' parameter, if present
654   InfoEntry Entry = findInfo(ObjArg);
655   if (Entry != PropagationMap.end()) {
656     PropagationInfo PInfo = Entry->second;
657     checkCallability(PInfo, FunD, Call->getExprLoc());
658 
659     if (SetTypestateAttr *STA = FunD->getAttr<SetTypestateAttr>()) {
660       if (PInfo.isVar()) {
661         StateMap->setState(PInfo.getVar(), mapSetTypestateAttrState(STA));
662         return true;
663       }
664       else if (PInfo.isTmp()) {
665         StateMap->setState(PInfo.getTmp(), mapSetTypestateAttrState(STA));
666         return true;
667       }
668     }
669     else if (isTestingFunction(FunD) && PInfo.isVar()) {
670       PropagationMap.insert(PairType(Call,
671         PropagationInfo(PInfo.getVar(), testsFor(FunD))));
672     }
673   }
674   return false;
675 }
676 
propagateReturnType(const Expr * Call,const FunctionDecl * Fun)677 void ConsumedStmtVisitor::propagateReturnType(const Expr *Call,
678                                               const FunctionDecl *Fun) {
679   QualType RetType = Fun->getCallResultType();
680   if (RetType->isReferenceType())
681     RetType = RetType->getPointeeType();
682 
683   if (isConsumableType(RetType)) {
684     ConsumedState ReturnState;
685     if (ReturnTypestateAttr *RTA = Fun->getAttr<ReturnTypestateAttr>())
686       ReturnState = mapReturnTypestateAttrState(RTA);
687     else
688       ReturnState = mapConsumableAttrState(RetType);
689 
690     PropagationMap.insert(PairType(Call, PropagationInfo(ReturnState)));
691   }
692 }
693 
VisitBinaryOperator(const BinaryOperator * BinOp)694 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
695   switch (BinOp->getOpcode()) {
696   case BO_LAnd:
697   case BO_LOr : {
698     InfoEntry LEntry = findInfo(BinOp->getLHS()),
699               REntry = findInfo(BinOp->getRHS());
700 
701     VarTestResult LTest, RTest;
702 
703     if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {
704       LTest = LEntry->second.getVarTest();
705     } else {
706       LTest.Var      = nullptr;
707       LTest.TestsFor = CS_None;
708     }
709 
710     if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {
711       RTest = REntry->second.getVarTest();
712     } else {
713       RTest.Var      = nullptr;
714       RTest.TestsFor = CS_None;
715     }
716 
717     if (!(LTest.Var == nullptr && RTest.Var == nullptr))
718       PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
719         static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
720     break;
721   }
722 
723   case BO_PtrMemD:
724   case BO_PtrMemI:
725     forwardInfo(BinOp->getLHS(), BinOp);
726     break;
727 
728   default:
729     break;
730   }
731 }
732 
VisitCallExpr(const CallExpr * Call)733 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
734   const FunctionDecl *FunDecl = Call->getDirectCallee();
735   if (!FunDecl)
736     return;
737 
738   // Special case for the std::move function.
739   // TODO: Make this more specific. (Deferred)
740   if (Call->isCallToStdMove()) {
741     copyInfo(Call->getArg(0), Call, CS_Consumed);
742     return;
743   }
744 
745   handleCall(Call, nullptr, FunDecl);
746   propagateReturnType(Call, FunDecl);
747 }
748 
VisitCastExpr(const CastExpr * Cast)749 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
750   forwardInfo(Cast->getSubExpr(), Cast);
751 }
752 
VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr * Temp)753 void ConsumedStmtVisitor::VisitCXXBindTemporaryExpr(
754   const CXXBindTemporaryExpr *Temp) {
755 
756   InfoEntry Entry = findInfo(Temp->getSubExpr());
757 
758   if (Entry != PropagationMap.end() && !Entry->second.isTest()) {
759     StateMap->setState(Temp, Entry->second.getAsState(StateMap));
760     PropagationMap.insert(PairType(Temp, PropagationInfo(Temp)));
761   }
762 }
763 
VisitCXXConstructExpr(const CXXConstructExpr * Call)764 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
765   CXXConstructorDecl *Constructor = Call->getConstructor();
766 
767   QualType ThisType = Constructor->getFunctionObjectParameterType();
768 
769   if (!isConsumableType(ThisType))
770     return;
771 
772   // FIXME: What should happen if someone annotates the move constructor?
773   if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {
774     // TODO: Adjust state of args appropriately.
775     ConsumedState RetState = mapReturnTypestateAttrState(RTA);
776     PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
777   } else if (Constructor->isDefaultConstructor()) {
778     PropagationMap.insert(PairType(Call,
779       PropagationInfo(consumed::CS_Consumed)));
780   } else if (Constructor->isMoveConstructor()) {
781     copyInfo(Call->getArg(0), Call, CS_Consumed);
782   } else if (Constructor->isCopyConstructor()) {
783     // Copy state from arg.  If setStateOnRead then set arg to CS_Unknown.
784     ConsumedState NS =
785       isSetOnReadPtrType(Constructor->getThisType()) ?
786       CS_Unknown : CS_None;
787     copyInfo(Call->getArg(0), Call, NS);
788   } else {
789     // TODO: Adjust state of args appropriately.
790     ConsumedState RetState = mapConsumableAttrState(ThisType);
791     PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
792   }
793 }
794 
VisitCXXMemberCallExpr(const CXXMemberCallExpr * Call)795 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
796     const CXXMemberCallExpr *Call) {
797   CXXMethodDecl* MD = Call->getMethodDecl();
798   if (!MD)
799     return;
800 
801   handleCall(Call, Call->getImplicitObjectArgument(), MD);
802   propagateReturnType(Call, MD);
803 }
804 
VisitCXXOperatorCallExpr(const CXXOperatorCallExpr * Call)805 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
806     const CXXOperatorCallExpr *Call) {
807   const auto *FunDecl = dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
808   if (!FunDecl) return;
809 
810   if (Call->getOperator() == OO_Equal) {
811     ConsumedState CS = getInfo(Call->getArg(1));
812     if (!handleCall(Call, Call->getArg(0), FunDecl))
813       setInfo(Call->getArg(0), CS);
814     return;
815   }
816 
817   if (const auto *MCall = dyn_cast<CXXMemberCallExpr>(Call))
818     handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);
819   else
820     handleCall(Call, Call->getArg(0), FunDecl);
821 
822   propagateReturnType(Call, FunDecl);
823 }
824 
VisitDeclRefExpr(const DeclRefExpr * DeclRef)825 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
826   if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
827     if (StateMap->getState(Var) != consumed::CS_None)
828       PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
829 }
830 
VisitDeclStmt(const DeclStmt * DeclS)831 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
832   for (const auto *DI : DeclS->decls())
833     if (isa<VarDecl>(DI))
834       VisitVarDecl(cast<VarDecl>(DI));
835 
836   if (DeclS->isSingleDecl())
837     if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
838       PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
839 }
840 
VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr * Temp)841 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
842   const MaterializeTemporaryExpr *Temp) {
843   forwardInfo(Temp->getSubExpr(), Temp);
844 }
845 
VisitMemberExpr(const MemberExpr * MExpr)846 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
847   forwardInfo(MExpr->getBase(), MExpr);
848 }
849 
VisitParmVarDecl(const ParmVarDecl * Param)850 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
851   QualType ParamType = Param->getType();
852   ConsumedState ParamState = consumed::CS_None;
853 
854   if (const ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>())
855     ParamState = mapParamTypestateAttrState(PTA);
856   else if (isConsumableType(ParamType))
857     ParamState = mapConsumableAttrState(ParamType);
858   else if (isRValueRef(ParamType) &&
859            isConsumableType(ParamType->getPointeeType()))
860     ParamState = mapConsumableAttrState(ParamType->getPointeeType());
861   else if (ParamType->isReferenceType() &&
862            isConsumableType(ParamType->getPointeeType()))
863     ParamState = consumed::CS_Unknown;
864 
865   if (ParamState != CS_None)
866     StateMap->setState(Param, ParamState);
867 }
868 
VisitReturnStmt(const ReturnStmt * Ret)869 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
870   ConsumedState ExpectedState = Analyzer.getExpectedReturnState();
871 
872   if (ExpectedState != CS_None) {
873     InfoEntry Entry = findInfo(Ret->getRetValue());
874 
875     if (Entry != PropagationMap.end()) {
876       ConsumedState RetState = Entry->second.getAsState(StateMap);
877 
878       if (RetState != ExpectedState)
879         Analyzer.WarningsHandler.warnReturnTypestateMismatch(
880           Ret->getReturnLoc(), stateToString(ExpectedState),
881           stateToString(RetState));
882     }
883   }
884 
885   StateMap->checkParamsForReturnTypestate(Ret->getBeginLoc(),
886                                           Analyzer.WarningsHandler);
887 }
888 
VisitUnaryOperator(const UnaryOperator * UOp)889 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
890   InfoEntry Entry = findInfo(UOp->getSubExpr());
891   if (Entry == PropagationMap.end()) return;
892 
893   switch (UOp->getOpcode()) {
894   case UO_AddrOf:
895     PropagationMap.insert(PairType(UOp, Entry->second));
896     break;
897 
898   case UO_LNot:
899     if (Entry->second.isTest())
900       PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
901     break;
902 
903   default:
904     break;
905   }
906 }
907 
908 // TODO: See if I need to check for reference types here.
VisitVarDecl(const VarDecl * Var)909 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
910   if (isConsumableType(Var->getType())) {
911     if (Var->hasInit()) {
912       MapType::iterator VIT = findInfo(Var->getInit()->IgnoreImplicit());
913       if (VIT != PropagationMap.end()) {
914         PropagationInfo PInfo = VIT->second;
915         ConsumedState St = PInfo.getAsState(StateMap);
916 
917         if (St != consumed::CS_None) {
918           StateMap->setState(Var, St);
919           return;
920         }
921       }
922     }
923     // Otherwise
924     StateMap->setState(Var, consumed::CS_Unknown);
925   }
926 }
927 
splitVarStateForIf(const IfStmt * IfNode,const VarTestResult & Test,ConsumedStateMap * ThenStates,ConsumedStateMap * ElseStates)928 static void splitVarStateForIf(const IfStmt *IfNode, const VarTestResult &Test,
929                                ConsumedStateMap *ThenStates,
930                                ConsumedStateMap *ElseStates) {
931   ConsumedState VarState = ThenStates->getState(Test.Var);
932 
933   if (VarState == CS_Unknown) {
934     ThenStates->setState(Test.Var, Test.TestsFor);
935     ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
936   } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
937     ThenStates->markUnreachable();
938   } else if (VarState == Test.TestsFor) {
939     ElseStates->markUnreachable();
940   }
941 }
942 
splitVarStateForIfBinOp(const PropagationInfo & PInfo,ConsumedStateMap * ThenStates,ConsumedStateMap * ElseStates)943 static void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
944                                     ConsumedStateMap *ThenStates,
945                                     ConsumedStateMap *ElseStates) {
946   const VarTestResult &LTest = PInfo.getLTest(),
947                       &RTest = PInfo.getRTest();
948 
949   ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
950                 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
951 
952   if (LTest.Var) {
953     if (PInfo.testEffectiveOp() == EO_And) {
954       if (LState == CS_Unknown) {
955         ThenStates->setState(LTest.Var, LTest.TestsFor);
956       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
957         ThenStates->markUnreachable();
958       } else if (LState == LTest.TestsFor && isKnownState(RState)) {
959         if (RState == RTest.TestsFor)
960           ElseStates->markUnreachable();
961         else
962           ThenStates->markUnreachable();
963       }
964     } else {
965       if (LState == CS_Unknown) {
966         ElseStates->setState(LTest.Var,
967                              invertConsumedUnconsumed(LTest.TestsFor));
968       } else if (LState == LTest.TestsFor) {
969         ElseStates->markUnreachable();
970       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
971                  isKnownState(RState)) {
972         if (RState == RTest.TestsFor)
973           ElseStates->markUnreachable();
974         else
975           ThenStates->markUnreachable();
976       }
977     }
978   }
979 
980   if (RTest.Var) {
981     if (PInfo.testEffectiveOp() == EO_And) {
982       if (RState == CS_Unknown)
983         ThenStates->setState(RTest.Var, RTest.TestsFor);
984       else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
985         ThenStates->markUnreachable();
986     } else {
987       if (RState == CS_Unknown)
988         ElseStates->setState(RTest.Var,
989                              invertConsumedUnconsumed(RTest.TestsFor));
990       else if (RState == RTest.TestsFor)
991         ElseStates->markUnreachable();
992     }
993   }
994 }
995 
allBackEdgesVisited(const CFGBlock * CurrBlock,const CFGBlock * TargetBlock)996 bool ConsumedBlockInfo::allBackEdgesVisited(const CFGBlock *CurrBlock,
997                                             const CFGBlock *TargetBlock) {
998   assert(CurrBlock && "Block pointer must not be NULL");
999   assert(TargetBlock && "TargetBlock pointer must not be NULL");
1000 
1001   unsigned int CurrBlockOrder = VisitOrder[CurrBlock->getBlockID()];
1002   for (CFGBlock::const_pred_iterator PI = TargetBlock->pred_begin(),
1003        PE = TargetBlock->pred_end(); PI != PE; ++PI) {
1004     if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )
1005       return false;
1006   }
1007   return true;
1008 }
1009 
addInfo(const CFGBlock * Block,ConsumedStateMap * StateMap,std::unique_ptr<ConsumedStateMap> & OwnedStateMap)1010 void ConsumedBlockInfo::addInfo(
1011     const CFGBlock *Block, ConsumedStateMap *StateMap,
1012     std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {
1013   assert(Block && "Block pointer must not be NULL");
1014 
1015   auto &Entry = StateMapsArray[Block->getBlockID()];
1016 
1017   if (Entry) {
1018     Entry->intersect(*StateMap);
1019   } else if (OwnedStateMap)
1020     Entry = std::move(OwnedStateMap);
1021   else
1022     Entry = std::make_unique<ConsumedStateMap>(*StateMap);
1023 }
1024 
addInfo(const CFGBlock * Block,std::unique_ptr<ConsumedStateMap> StateMap)1025 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
1026                                 std::unique_ptr<ConsumedStateMap> StateMap) {
1027   assert(Block && "Block pointer must not be NULL");
1028 
1029   auto &Entry = StateMapsArray[Block->getBlockID()];
1030 
1031   if (Entry) {
1032     Entry->intersect(*StateMap);
1033   } else {
1034     Entry = std::move(StateMap);
1035   }
1036 }
1037 
borrowInfo(const CFGBlock * Block)1038 ConsumedStateMap* ConsumedBlockInfo::borrowInfo(const CFGBlock *Block) {
1039   assert(Block && "Block pointer must not be NULL");
1040   assert(StateMapsArray[Block->getBlockID()] && "Block has no block info");
1041 
1042   return StateMapsArray[Block->getBlockID()].get();
1043 }
1044 
discardInfo(const CFGBlock * Block)1045 void ConsumedBlockInfo::discardInfo(const CFGBlock *Block) {
1046   StateMapsArray[Block->getBlockID()] = nullptr;
1047 }
1048 
1049 std::unique_ptr<ConsumedStateMap>
getInfo(const CFGBlock * Block)1050 ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
1051   assert(Block && "Block pointer must not be NULL");
1052 
1053   auto &Entry = StateMapsArray[Block->getBlockID()];
1054   return isBackEdgeTarget(Block) ? std::make_unique<ConsumedStateMap>(*Entry)
1055                                  : std::move(Entry);
1056 }
1057 
isBackEdge(const CFGBlock * From,const CFGBlock * To)1058 bool ConsumedBlockInfo::isBackEdge(const CFGBlock *From, const CFGBlock *To) {
1059   assert(From && "From block must not be NULL");
1060   assert(To   && "From block must not be NULL");
1061 
1062   return VisitOrder[From->getBlockID()] > VisitOrder[To->getBlockID()];
1063 }
1064 
isBackEdgeTarget(const CFGBlock * Block)1065 bool ConsumedBlockInfo::isBackEdgeTarget(const CFGBlock *Block) {
1066   assert(Block && "Block pointer must not be NULL");
1067 
1068   // Anything with less than two predecessors can't be the target of a back
1069   // edge.
1070   if (Block->pred_size() < 2)
1071     return false;
1072 
1073   unsigned int BlockVisitOrder = VisitOrder[Block->getBlockID()];
1074   for (CFGBlock::const_pred_iterator PI = Block->pred_begin(),
1075        PE = Block->pred_end(); PI != PE; ++PI) {
1076     if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])
1077       return true;
1078   }
1079   return false;
1080 }
1081 
checkParamsForReturnTypestate(SourceLocation BlameLoc,ConsumedWarningsHandlerBase & WarningsHandler) const1082 void ConsumedStateMap::checkParamsForReturnTypestate(SourceLocation BlameLoc,
1083   ConsumedWarningsHandlerBase &WarningsHandler) const {
1084 
1085   for (const auto &DM : VarMap) {
1086     if (isa<ParmVarDecl>(DM.first)) {
1087       const auto *Param = cast<ParmVarDecl>(DM.first);
1088       const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();
1089 
1090       if (!RTA)
1091         continue;
1092 
1093       ConsumedState ExpectedState = mapReturnTypestateAttrState(RTA);
1094       if (DM.second != ExpectedState)
1095         WarningsHandler.warnParamReturnTypestateMismatch(BlameLoc,
1096           Param->getNameAsString(), stateToString(ExpectedState),
1097           stateToString(DM.second));
1098     }
1099   }
1100 }
1101 
clearTemporaries()1102 void ConsumedStateMap::clearTemporaries() {
1103   TmpMap.clear();
1104 }
1105 
getState(const VarDecl * Var) const1106 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) const {
1107   VarMapType::const_iterator Entry = VarMap.find(Var);
1108 
1109   if (Entry != VarMap.end())
1110     return Entry->second;
1111 
1112   return CS_None;
1113 }
1114 
1115 ConsumedState
getState(const CXXBindTemporaryExpr * Tmp) const1116 ConsumedStateMap::getState(const CXXBindTemporaryExpr *Tmp) const {
1117   TmpMapType::const_iterator Entry = TmpMap.find(Tmp);
1118 
1119   if (Entry != TmpMap.end())
1120     return Entry->second;
1121 
1122   return CS_None;
1123 }
1124 
intersect(const ConsumedStateMap & Other)1125 void ConsumedStateMap::intersect(const ConsumedStateMap &Other) {
1126   ConsumedState LocalState;
1127 
1128   if (this->From && this->From == Other.From && !Other.Reachable) {
1129     this->markUnreachable();
1130     return;
1131   }
1132 
1133   for (const auto &DM : Other.VarMap) {
1134     LocalState = this->getState(DM.first);
1135 
1136     if (LocalState == CS_None)
1137       continue;
1138 
1139     if (LocalState != DM.second)
1140      VarMap[DM.first] = CS_Unknown;
1141   }
1142 }
1143 
intersectAtLoopHead(const CFGBlock * LoopHead,const CFGBlock * LoopBack,const ConsumedStateMap * LoopBackStates,ConsumedWarningsHandlerBase & WarningsHandler)1144 void ConsumedStateMap::intersectAtLoopHead(const CFGBlock *LoopHead,
1145   const CFGBlock *LoopBack, const ConsumedStateMap *LoopBackStates,
1146   ConsumedWarningsHandlerBase &WarningsHandler) {
1147 
1148   ConsumedState LocalState;
1149   SourceLocation BlameLoc = getLastStmtLoc(LoopBack);
1150 
1151   for (const auto &DM : LoopBackStates->VarMap) {
1152     LocalState = this->getState(DM.first);
1153 
1154     if (LocalState == CS_None)
1155       continue;
1156 
1157     if (LocalState != DM.second) {
1158       VarMap[DM.first] = CS_Unknown;
1159       WarningsHandler.warnLoopStateMismatch(BlameLoc,
1160                                             DM.first->getNameAsString());
1161     }
1162   }
1163 }
1164 
markUnreachable()1165 void ConsumedStateMap::markUnreachable() {
1166   this->Reachable = false;
1167   VarMap.clear();
1168   TmpMap.clear();
1169 }
1170 
setState(const VarDecl * Var,ConsumedState State)1171 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
1172   VarMap[Var] = State;
1173 }
1174 
setState(const CXXBindTemporaryExpr * Tmp,ConsumedState State)1175 void ConsumedStateMap::setState(const CXXBindTemporaryExpr *Tmp,
1176                                 ConsumedState State) {
1177   TmpMap[Tmp] = State;
1178 }
1179 
remove(const CXXBindTemporaryExpr * Tmp)1180 void ConsumedStateMap::remove(const CXXBindTemporaryExpr *Tmp) {
1181   TmpMap.erase(Tmp);
1182 }
1183 
operator !=(const ConsumedStateMap * Other) const1184 bool ConsumedStateMap::operator!=(const ConsumedStateMap *Other) const {
1185   for (const auto &DM : Other->VarMap)
1186     if (this->getState(DM.first) != DM.second)
1187       return true;
1188   return false;
1189 }
1190 
determineExpectedReturnState(AnalysisDeclContext & AC,const FunctionDecl * D)1191 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
1192                                                     const FunctionDecl *D) {
1193   QualType ReturnType;
1194   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1195     ReturnType = Constructor->getFunctionObjectParameterType();
1196   } else
1197     ReturnType = D->getCallResultType();
1198 
1199   if (const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>()) {
1200     const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1201     if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1202       // FIXME: This should be removed when template instantiation propagates
1203       //        attributes at template specialization definition, not
1204       //        declaration. When it is removed the test needs to be enabled
1205       //        in SemaDeclAttr.cpp.
1206       WarningsHandler.warnReturnTypestateForUnconsumableType(
1207           RTSAttr->getLocation(), ReturnType.getAsString());
1208       ExpectedReturnState = CS_None;
1209     } else
1210       ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1211   } else if (isConsumableType(ReturnType)) {
1212     if (isAutoCastType(ReturnType))   // We can auto-cast the state to the
1213       ExpectedReturnState = CS_None;  // expected state.
1214     else
1215       ExpectedReturnState = mapConsumableAttrState(ReturnType);
1216   }
1217   else
1218     ExpectedReturnState = CS_None;
1219 }
1220 
splitState(const CFGBlock * CurrBlock,const ConsumedStmtVisitor & Visitor)1221 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1222                                   const ConsumedStmtVisitor &Visitor) {
1223   std::unique_ptr<ConsumedStateMap> FalseStates(
1224       new ConsumedStateMap(*CurrStates));
1225   PropagationInfo PInfo;
1226 
1227   if (const auto *IfNode =
1228           dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1229     if (IfNode->isConsteval())
1230       return false;
1231 
1232     const Expr *Cond = IfNode->getCond();
1233 
1234     PInfo = Visitor.getInfo(Cond);
1235     if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1236       PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1237 
1238     if (PInfo.isVarTest()) {
1239       CurrStates->setSource(Cond);
1240       FalseStates->setSource(Cond);
1241       splitVarStateForIf(IfNode, PInfo.getVarTest(), CurrStates.get(),
1242                          FalseStates.get());
1243     } else if (PInfo.isBinTest()) {
1244       CurrStates->setSource(PInfo.testSourceNode());
1245       FalseStates->setSource(PInfo.testSourceNode());
1246       splitVarStateForIfBinOp(PInfo, CurrStates.get(), FalseStates.get());
1247     } else {
1248       return false;
1249     }
1250   } else if (const auto *BinOp =
1251        dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1252     PInfo = Visitor.getInfo(BinOp->getLHS());
1253     if (!PInfo.isVarTest()) {
1254       if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1255         PInfo = Visitor.getInfo(BinOp->getRHS());
1256 
1257         if (!PInfo.isVarTest())
1258           return false;
1259       } else {
1260         return false;
1261       }
1262     }
1263 
1264     CurrStates->setSource(BinOp);
1265     FalseStates->setSource(BinOp);
1266 
1267     const VarTestResult &Test = PInfo.getVarTest();
1268     ConsumedState VarState = CurrStates->getState(Test.Var);
1269 
1270     if (BinOp->getOpcode() == BO_LAnd) {
1271       if (VarState == CS_Unknown)
1272         CurrStates->setState(Test.Var, Test.TestsFor);
1273       else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1274         CurrStates->markUnreachable();
1275 
1276     } else if (BinOp->getOpcode() == BO_LOr) {
1277       if (VarState == CS_Unknown)
1278         FalseStates->setState(Test.Var,
1279                               invertConsumedUnconsumed(Test.TestsFor));
1280       else if (VarState == Test.TestsFor)
1281         FalseStates->markUnreachable();
1282     }
1283   } else {
1284     return false;
1285   }
1286 
1287   CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1288 
1289   if (*SI)
1290     BlockInfo.addInfo(*SI, std::move(CurrStates));
1291   else
1292     CurrStates = nullptr;
1293 
1294   if (*++SI)
1295     BlockInfo.addInfo(*SI, std::move(FalseStates));
1296 
1297   return true;
1298 }
1299 
run(AnalysisDeclContext & AC)1300 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1301   const auto *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1302   if (!D)
1303     return;
1304 
1305   CFG *CFGraph = AC.getCFG();
1306   if (!CFGraph)
1307     return;
1308 
1309   determineExpectedReturnState(AC, D);
1310 
1311   PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1312   // AC.getCFG()->viewCFG(LangOptions());
1313 
1314   BlockInfo = ConsumedBlockInfo(CFGraph->getNumBlockIDs(), SortedGraph);
1315 
1316   CurrStates = std::make_unique<ConsumedStateMap>();
1317   ConsumedStmtVisitor Visitor(*this, CurrStates.get());
1318 
1319   // Add all trackable parameters to the state map.
1320   for (const auto *PI : D->parameters())
1321     Visitor.VisitParmVarDecl(PI);
1322 
1323   // Visit all of the function's basic blocks.
1324   for (const auto *CurrBlock : *SortedGraph) {
1325     if (!CurrStates)
1326       CurrStates = BlockInfo.getInfo(CurrBlock);
1327 
1328     if (!CurrStates) {
1329       continue;
1330     } else if (!CurrStates->isReachable()) {
1331       CurrStates = nullptr;
1332       continue;
1333     }
1334 
1335     Visitor.reset(CurrStates.get());
1336 
1337     // Visit all of the basic block's statements.
1338     for (const auto &B : *CurrBlock) {
1339       switch (B.getKind()) {
1340       case CFGElement::Statement:
1341         Visitor.Visit(B.castAs<CFGStmt>().getStmt());
1342         break;
1343 
1344       case CFGElement::TemporaryDtor: {
1345         const CFGTemporaryDtor &DTor = B.castAs<CFGTemporaryDtor>();
1346         const CXXBindTemporaryExpr *BTE = DTor.getBindTemporaryExpr();
1347 
1348         Visitor.checkCallability(PropagationInfo(BTE),
1349                                  DTor.getDestructorDecl(AC.getASTContext()),
1350                                  BTE->getExprLoc());
1351         CurrStates->remove(BTE);
1352         break;
1353       }
1354 
1355       case CFGElement::AutomaticObjectDtor: {
1356         const CFGAutomaticObjDtor &DTor = B.castAs<CFGAutomaticObjDtor>();
1357         SourceLocation Loc = DTor.getTriggerStmt()->getEndLoc();
1358         const VarDecl *Var = DTor.getVarDecl();
1359 
1360         Visitor.checkCallability(PropagationInfo(Var),
1361                                  DTor.getDestructorDecl(AC.getASTContext()),
1362                                  Loc);
1363         break;
1364       }
1365 
1366       default:
1367         break;
1368       }
1369     }
1370 
1371     // TODO: Handle other forms of branching with precision, including while-
1372     //       and for-loops. (Deferred)
1373     if (!splitState(CurrBlock, Visitor)) {
1374       CurrStates->setSource(nullptr);
1375 
1376       if (CurrBlock->succ_size() > 1 ||
1377           (CurrBlock->succ_size() == 1 &&
1378            (*CurrBlock->succ_begin())->pred_size() > 1)) {
1379 
1380         auto *RawState = CurrStates.get();
1381 
1382         for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1383              SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1384           if (*SI == nullptr) continue;
1385 
1386           if (BlockInfo.isBackEdge(CurrBlock, *SI)) {
1387             BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(
1388                 *SI, CurrBlock, RawState, WarningsHandler);
1389 
1390             if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))
1391               BlockInfo.discardInfo(*SI);
1392           } else {
1393             BlockInfo.addInfo(*SI, RawState, CurrStates);
1394           }
1395         }
1396 
1397         CurrStates = nullptr;
1398       }
1399     }
1400 
1401     if (CurrBlock == &AC.getCFG()->getExit() &&
1402         D->getCallResultType()->isVoidType())
1403       CurrStates->checkParamsForReturnTypestate(D->getLocation(),
1404                                                 WarningsHandler);
1405   } // End of block iterator.
1406 
1407   // Delete the last existing state map.
1408   CurrStates = nullptr;
1409 
1410   WarningsHandler.emitDiagnostics();
1411 }
1412