1 //===- Consumed.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A intra-procedural analysis for checking consumed properties. This is based,
10 // in part, on research on linear types.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "clang/Analysis/Analyses/Consumed.h"
15 #include "clang/AST/Attr.h"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/Expr.h"
19 #include "clang/AST/ExprCXX.h"
20 #include "clang/AST/Stmt.h"
21 #include "clang/AST/StmtVisitor.h"
22 #include "clang/AST/Type.h"
23 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
24 #include "clang/Analysis/AnalysisDeclContext.h"
25 #include "clang/Analysis/CFG.h"
26 #include "clang/Basic/LLVM.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include <cassert>
32 #include <memory>
33 #include <optional>
34 #include <utility>
35
36 // TODO: Adjust states of args to constructors in the same way that arguments to
37 // function calls are handled.
38 // TODO: Use information from tests in for- and while-loop conditional.
39 // TODO: Add notes about the actual and expected state for
40 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
41 // TODO: Adjust the parser and AttributesList class to support lists of
42 // identifiers.
43 // TODO: Warn about unreachable code.
44 // TODO: Switch to using a bitmap to track unreachable blocks.
45 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
46 // if (valid) ...; (Deferred)
47 // TODO: Take notes on state transitions to provide better warning messages.
48 // (Deferred)
49 // TODO: Test nested conditionals: A) Checking the same value multiple times,
50 // and 2) Checking different values. (Deferred)
51
52 using namespace clang;
53 using namespace consumed;
54
55 // Key method definition
56 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() = default;
57
getFirstStmtLoc(const CFGBlock * Block)58 static SourceLocation getFirstStmtLoc(const CFGBlock *Block) {
59 // Find the source location of the first statement in the block, if the block
60 // is not empty.
61 for (const auto &B : *Block)
62 if (std::optional<CFGStmt> CS = B.getAs<CFGStmt>())
63 return CS->getStmt()->getBeginLoc();
64
65 // Block is empty.
66 // If we have one successor, return the first statement in that block
67 if (Block->succ_size() == 1 && *Block->succ_begin())
68 return getFirstStmtLoc(*Block->succ_begin());
69
70 return {};
71 }
72
getLastStmtLoc(const CFGBlock * Block)73 static SourceLocation getLastStmtLoc(const CFGBlock *Block) {
74 // Find the source location of the last statement in the block, if the block
75 // is not empty.
76 if (const Stmt *StmtNode = Block->getTerminatorStmt()) {
77 return StmtNode->getBeginLoc();
78 } else {
79 for (CFGBlock::const_reverse_iterator BI = Block->rbegin(),
80 BE = Block->rend(); BI != BE; ++BI) {
81 if (std::optional<CFGStmt> CS = BI->getAs<CFGStmt>())
82 return CS->getStmt()->getBeginLoc();
83 }
84 }
85
86 // If we have one successor, return the first statement in that block
87 SourceLocation Loc;
88 if (Block->succ_size() == 1 && *Block->succ_begin())
89 Loc = getFirstStmtLoc(*Block->succ_begin());
90 if (Loc.isValid())
91 return Loc;
92
93 // If we have one predecessor, return the last statement in that block
94 if (Block->pred_size() == 1 && *Block->pred_begin())
95 return getLastStmtLoc(*Block->pred_begin());
96
97 return Loc;
98 }
99
invertConsumedUnconsumed(ConsumedState State)100 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
101 switch (State) {
102 case CS_Unconsumed:
103 return CS_Consumed;
104 case CS_Consumed:
105 return CS_Unconsumed;
106 case CS_None:
107 return CS_None;
108 case CS_Unknown:
109 return CS_Unknown;
110 }
111 llvm_unreachable("invalid enum");
112 }
113
isCallableInState(const CallableWhenAttr * CWAttr,ConsumedState State)114 static bool isCallableInState(const CallableWhenAttr *CWAttr,
115 ConsumedState State) {
116 for (const auto &S : CWAttr->callableStates()) {
117 ConsumedState MappedAttrState = CS_None;
118
119 switch (S) {
120 case CallableWhenAttr::Unknown:
121 MappedAttrState = CS_Unknown;
122 break;
123
124 case CallableWhenAttr::Unconsumed:
125 MappedAttrState = CS_Unconsumed;
126 break;
127
128 case CallableWhenAttr::Consumed:
129 MappedAttrState = CS_Consumed;
130 break;
131 }
132
133 if (MappedAttrState == State)
134 return true;
135 }
136
137 return false;
138 }
139
isConsumableType(const QualType & QT)140 static bool isConsumableType(const QualType &QT) {
141 if (QT->isPointerOrReferenceType())
142 return false;
143
144 if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
145 return RD->hasAttr<ConsumableAttr>();
146
147 return false;
148 }
149
isAutoCastType(const QualType & QT)150 static bool isAutoCastType(const QualType &QT) {
151 if (QT->isPointerOrReferenceType())
152 return false;
153
154 if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
155 return RD->hasAttr<ConsumableAutoCastAttr>();
156
157 return false;
158 }
159
isSetOnReadPtrType(const QualType & QT)160 static bool isSetOnReadPtrType(const QualType &QT) {
161 if (const CXXRecordDecl *RD = QT->getPointeeCXXRecordDecl())
162 return RD->hasAttr<ConsumableSetOnReadAttr>();
163 return false;
164 }
165
isKnownState(ConsumedState State)166 static bool isKnownState(ConsumedState State) {
167 switch (State) {
168 case CS_Unconsumed:
169 case CS_Consumed:
170 return true;
171 case CS_None:
172 case CS_Unknown:
173 return false;
174 }
175 llvm_unreachable("invalid enum");
176 }
177
isRValueRef(QualType ParamType)178 static bool isRValueRef(QualType ParamType) {
179 return ParamType->isRValueReferenceType();
180 }
181
isTestingFunction(const FunctionDecl * FunDecl)182 static bool isTestingFunction(const FunctionDecl *FunDecl) {
183 return FunDecl->hasAttr<TestTypestateAttr>();
184 }
185
mapConsumableAttrState(const QualType QT)186 static ConsumedState mapConsumableAttrState(const QualType QT) {
187 assert(isConsumableType(QT));
188
189 const ConsumableAttr *CAttr =
190 QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
191
192 switch (CAttr->getDefaultState()) {
193 case ConsumableAttr::Unknown:
194 return CS_Unknown;
195 case ConsumableAttr::Unconsumed:
196 return CS_Unconsumed;
197 case ConsumableAttr::Consumed:
198 return CS_Consumed;
199 }
200 llvm_unreachable("invalid enum");
201 }
202
203 static ConsumedState
mapParamTypestateAttrState(const ParamTypestateAttr * PTAttr)204 mapParamTypestateAttrState(const ParamTypestateAttr *PTAttr) {
205 switch (PTAttr->getParamState()) {
206 case ParamTypestateAttr::Unknown:
207 return CS_Unknown;
208 case ParamTypestateAttr::Unconsumed:
209 return CS_Unconsumed;
210 case ParamTypestateAttr::Consumed:
211 return CS_Consumed;
212 }
213 llvm_unreachable("invalid_enum");
214 }
215
216 static ConsumedState
mapReturnTypestateAttrState(const ReturnTypestateAttr * RTSAttr)217 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
218 switch (RTSAttr->getState()) {
219 case ReturnTypestateAttr::Unknown:
220 return CS_Unknown;
221 case ReturnTypestateAttr::Unconsumed:
222 return CS_Unconsumed;
223 case ReturnTypestateAttr::Consumed:
224 return CS_Consumed;
225 }
226 llvm_unreachable("invalid enum");
227 }
228
mapSetTypestateAttrState(const SetTypestateAttr * STAttr)229 static ConsumedState mapSetTypestateAttrState(const SetTypestateAttr *STAttr) {
230 switch (STAttr->getNewState()) {
231 case SetTypestateAttr::Unknown:
232 return CS_Unknown;
233 case SetTypestateAttr::Unconsumed:
234 return CS_Unconsumed;
235 case SetTypestateAttr::Consumed:
236 return CS_Consumed;
237 }
238 llvm_unreachable("invalid_enum");
239 }
240
stateToString(ConsumedState State)241 static StringRef stateToString(ConsumedState State) {
242 switch (State) {
243 case consumed::CS_None:
244 return "none";
245
246 case consumed::CS_Unknown:
247 return "unknown";
248
249 case consumed::CS_Unconsumed:
250 return "unconsumed";
251
252 case consumed::CS_Consumed:
253 return "consumed";
254 }
255 llvm_unreachable("invalid enum");
256 }
257
testsFor(const FunctionDecl * FunDecl)258 static ConsumedState testsFor(const FunctionDecl *FunDecl) {
259 assert(isTestingFunction(FunDecl));
260 switch (FunDecl->getAttr<TestTypestateAttr>()->getTestState()) {
261 case TestTypestateAttr::Unconsumed:
262 return CS_Unconsumed;
263 case TestTypestateAttr::Consumed:
264 return CS_Consumed;
265 }
266 llvm_unreachable("invalid enum");
267 }
268
269 namespace {
270
271 struct VarTestResult {
272 const VarDecl *Var;
273 ConsumedState TestsFor;
274 };
275
276 } // namespace
277
278 namespace clang {
279 namespace consumed {
280
281 enum EffectiveOp {
282 EO_And,
283 EO_Or
284 };
285
286 class PropagationInfo {
287 enum {
288 IT_None,
289 IT_State,
290 IT_VarTest,
291 IT_BinTest,
292 IT_Var,
293 IT_Tmp
294 } InfoType = IT_None;
295
296 struct BinTestTy {
297 const BinaryOperator *Source;
298 EffectiveOp EOp;
299 VarTestResult LTest;
300 VarTestResult RTest;
301 };
302
303 union {
304 ConsumedState State;
305 VarTestResult VarTest;
306 const VarDecl *Var;
307 const CXXBindTemporaryExpr *Tmp;
308 BinTestTy BinTest;
309 };
310
311 public:
312 PropagationInfo() = default;
PropagationInfo(const VarTestResult & VarTest)313 PropagationInfo(const VarTestResult &VarTest)
314 : InfoType(IT_VarTest), VarTest(VarTest) {}
315
PropagationInfo(const VarDecl * Var,ConsumedState TestsFor)316 PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
317 : InfoType(IT_VarTest) {
318 VarTest.Var = Var;
319 VarTest.TestsFor = TestsFor;
320 }
321
PropagationInfo(const BinaryOperator * Source,EffectiveOp EOp,const VarTestResult & LTest,const VarTestResult & RTest)322 PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
323 const VarTestResult <est, const VarTestResult &RTest)
324 : InfoType(IT_BinTest) {
325 BinTest.Source = Source;
326 BinTest.EOp = EOp;
327 BinTest.LTest = LTest;
328 BinTest.RTest = RTest;
329 }
330
PropagationInfo(const BinaryOperator * Source,EffectiveOp EOp,const VarDecl * LVar,ConsumedState LTestsFor,const VarDecl * RVar,ConsumedState RTestsFor)331 PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
332 const VarDecl *LVar, ConsumedState LTestsFor,
333 const VarDecl *RVar, ConsumedState RTestsFor)
334 : InfoType(IT_BinTest) {
335 BinTest.Source = Source;
336 BinTest.EOp = EOp;
337 BinTest.LTest.Var = LVar;
338 BinTest.LTest.TestsFor = LTestsFor;
339 BinTest.RTest.Var = RVar;
340 BinTest.RTest.TestsFor = RTestsFor;
341 }
342
PropagationInfo(ConsumedState State)343 PropagationInfo(ConsumedState State)
344 : InfoType(IT_State), State(State) {}
PropagationInfo(const VarDecl * Var)345 PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
PropagationInfo(const CXXBindTemporaryExpr * Tmp)346 PropagationInfo(const CXXBindTemporaryExpr *Tmp)
347 : InfoType(IT_Tmp), Tmp(Tmp) {}
348
getState() const349 const ConsumedState &getState() const {
350 assert(InfoType == IT_State);
351 return State;
352 }
353
getVarTest() const354 const VarTestResult &getVarTest() const {
355 assert(InfoType == IT_VarTest);
356 return VarTest;
357 }
358
getLTest() const359 const VarTestResult &getLTest() const {
360 assert(InfoType == IT_BinTest);
361 return BinTest.LTest;
362 }
363
getRTest() const364 const VarTestResult &getRTest() const {
365 assert(InfoType == IT_BinTest);
366 return BinTest.RTest;
367 }
368
getVar() const369 const VarDecl *getVar() const {
370 assert(InfoType == IT_Var);
371 return Var;
372 }
373
getTmp() const374 const CXXBindTemporaryExpr *getTmp() const {
375 assert(InfoType == IT_Tmp);
376 return Tmp;
377 }
378
getAsState(const ConsumedStateMap * StateMap) const379 ConsumedState getAsState(const ConsumedStateMap *StateMap) const {
380 assert(isVar() || isTmp() || isState());
381
382 if (isVar())
383 return StateMap->getState(Var);
384 else if (isTmp())
385 return StateMap->getState(Tmp);
386 else if (isState())
387 return State;
388 else
389 return CS_None;
390 }
391
testEffectiveOp() const392 EffectiveOp testEffectiveOp() const {
393 assert(InfoType == IT_BinTest);
394 return BinTest.EOp;
395 }
396
testSourceNode() const397 const BinaryOperator * testSourceNode() const {
398 assert(InfoType == IT_BinTest);
399 return BinTest.Source;
400 }
401
isValid() const402 bool isValid() const { return InfoType != IT_None; }
isState() const403 bool isState() const { return InfoType == IT_State; }
isVarTest() const404 bool isVarTest() const { return InfoType == IT_VarTest; }
isBinTest() const405 bool isBinTest() const { return InfoType == IT_BinTest; }
isVar() const406 bool isVar() const { return InfoType == IT_Var; }
isTmp() const407 bool isTmp() const { return InfoType == IT_Tmp; }
408
isTest() const409 bool isTest() const {
410 return InfoType == IT_VarTest || InfoType == IT_BinTest;
411 }
412
isPointerToValue() const413 bool isPointerToValue() const {
414 return InfoType == IT_Var || InfoType == IT_Tmp;
415 }
416
invertTest() const417 PropagationInfo invertTest() const {
418 assert(InfoType == IT_VarTest || InfoType == IT_BinTest);
419
420 if (InfoType == IT_VarTest) {
421 return PropagationInfo(VarTest.Var,
422 invertConsumedUnconsumed(VarTest.TestsFor));
423
424 } else if (InfoType == IT_BinTest) {
425 return PropagationInfo(BinTest.Source,
426 BinTest.EOp == EO_And ? EO_Or : EO_And,
427 BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
428 BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
429 } else {
430 return {};
431 }
432 }
433 };
434
435 } // namespace consumed
436 } // namespace clang
437
438 static void
setStateForVarOrTmp(ConsumedStateMap * StateMap,const PropagationInfo & PInfo,ConsumedState State)439 setStateForVarOrTmp(ConsumedStateMap *StateMap, const PropagationInfo &PInfo,
440 ConsumedState State) {
441 assert(PInfo.isVar() || PInfo.isTmp());
442
443 if (PInfo.isVar())
444 StateMap->setState(PInfo.getVar(), State);
445 else
446 StateMap->setState(PInfo.getTmp(), State);
447 }
448
449 namespace clang {
450 namespace consumed {
451
452 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
453 using MapType = llvm::DenseMap<const Stmt *, PropagationInfo>;
454 using PairType= std::pair<const Stmt *, PropagationInfo>;
455 using InfoEntry = MapType::iterator;
456 using ConstInfoEntry = MapType::const_iterator;
457
458 ConsumedAnalyzer &Analyzer;
459 ConsumedStateMap *StateMap;
460 MapType PropagationMap;
461
findInfo(const Expr * E)462 InfoEntry findInfo(const Expr *E) {
463 if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
464 if (!Cleanups->cleanupsHaveSideEffects())
465 E = Cleanups->getSubExpr();
466 return PropagationMap.find(E->IgnoreParens());
467 }
468
findInfo(const Expr * E) const469 ConstInfoEntry findInfo(const Expr *E) const {
470 if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
471 if (!Cleanups->cleanupsHaveSideEffects())
472 E = Cleanups->getSubExpr();
473 return PropagationMap.find(E->IgnoreParens());
474 }
475
insertInfo(const Expr * E,const PropagationInfo & PI)476 void insertInfo(const Expr *E, const PropagationInfo &PI) {
477 PropagationMap.insert(PairType(E->IgnoreParens(), PI));
478 }
479
480 void forwardInfo(const Expr *From, const Expr *To);
481 void copyInfo(const Expr *From, const Expr *To, ConsumedState CS);
482 ConsumedState getInfo(const Expr *From);
483 void setInfo(const Expr *To, ConsumedState NS);
484 void propagateReturnType(const Expr *Call, const FunctionDecl *Fun);
485
486 public:
487 void checkCallability(const PropagationInfo &PInfo,
488 const FunctionDecl *FunDecl,
489 SourceLocation BlameLoc);
490 bool handleCall(const CallExpr *Call, const Expr *ObjArg,
491 const FunctionDecl *FunD);
492
493 void VisitBinaryOperator(const BinaryOperator *BinOp);
494 void VisitCallExpr(const CallExpr *Call);
495 void VisitCastExpr(const CastExpr *Cast);
496 void VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *Temp);
497 void VisitCXXConstructExpr(const CXXConstructExpr *Call);
498 void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
499 void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
500 void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
501 void VisitDeclStmt(const DeclStmt *DelcS);
502 void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
503 void VisitMemberExpr(const MemberExpr *MExpr);
504 void VisitParmVarDecl(const ParmVarDecl *Param);
505 void VisitReturnStmt(const ReturnStmt *Ret);
506 void VisitUnaryOperator(const UnaryOperator *UOp);
507 void VisitVarDecl(const VarDecl *Var);
508
ConsumedStmtVisitor(ConsumedAnalyzer & Analyzer,ConsumedStateMap * StateMap)509 ConsumedStmtVisitor(ConsumedAnalyzer &Analyzer, ConsumedStateMap *StateMap)
510 : Analyzer(Analyzer), StateMap(StateMap) {}
511
getInfo(const Expr * StmtNode) const512 PropagationInfo getInfo(const Expr *StmtNode) const {
513 ConstInfoEntry Entry = findInfo(StmtNode);
514
515 if (Entry != PropagationMap.end())
516 return Entry->second;
517 else
518 return {};
519 }
520
reset(ConsumedStateMap * NewStateMap)521 void reset(ConsumedStateMap *NewStateMap) {
522 StateMap = NewStateMap;
523 }
524 };
525
526 } // namespace consumed
527 } // namespace clang
528
forwardInfo(const Expr * From,const Expr * To)529 void ConsumedStmtVisitor::forwardInfo(const Expr *From, const Expr *To) {
530 InfoEntry Entry = findInfo(From);
531 if (Entry != PropagationMap.end())
532 insertInfo(To, Entry->second);
533 }
534
535 // Create a new state for To, which is initialized to the state of From.
536 // If NS is not CS_None, sets the state of From to NS.
copyInfo(const Expr * From,const Expr * To,ConsumedState NS)537 void ConsumedStmtVisitor::copyInfo(const Expr *From, const Expr *To,
538 ConsumedState NS) {
539 InfoEntry Entry = findInfo(From);
540 if (Entry != PropagationMap.end()) {
541 PropagationInfo& PInfo = Entry->second;
542 ConsumedState CS = PInfo.getAsState(StateMap);
543 if (CS != CS_None)
544 insertInfo(To, PropagationInfo(CS));
545 if (NS != CS_None && PInfo.isPointerToValue())
546 setStateForVarOrTmp(StateMap, PInfo, NS);
547 }
548 }
549
550 // Get the ConsumedState for From
getInfo(const Expr * From)551 ConsumedState ConsumedStmtVisitor::getInfo(const Expr *From) {
552 InfoEntry Entry = findInfo(From);
553 if (Entry != PropagationMap.end()) {
554 PropagationInfo& PInfo = Entry->second;
555 return PInfo.getAsState(StateMap);
556 }
557 return CS_None;
558 }
559
560 // If we already have info for To then update it, otherwise create a new entry.
setInfo(const Expr * To,ConsumedState NS)561 void ConsumedStmtVisitor::setInfo(const Expr *To, ConsumedState NS) {
562 InfoEntry Entry = findInfo(To);
563 if (Entry != PropagationMap.end()) {
564 PropagationInfo& PInfo = Entry->second;
565 if (PInfo.isPointerToValue())
566 setStateForVarOrTmp(StateMap, PInfo, NS);
567 } else if (NS != CS_None) {
568 insertInfo(To, PropagationInfo(NS));
569 }
570 }
571
checkCallability(const PropagationInfo & PInfo,const FunctionDecl * FunDecl,SourceLocation BlameLoc)572 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
573 const FunctionDecl *FunDecl,
574 SourceLocation BlameLoc) {
575 assert(!PInfo.isTest());
576
577 const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
578 if (!CWAttr)
579 return;
580
581 if (PInfo.isVar()) {
582 ConsumedState VarState = StateMap->getState(PInfo.getVar());
583
584 if (VarState == CS_None || isCallableInState(CWAttr, VarState))
585 return;
586
587 Analyzer.WarningsHandler.warnUseInInvalidState(
588 FunDecl->getNameAsString(), PInfo.getVar()->getNameAsString(),
589 stateToString(VarState), BlameLoc);
590 } else {
591 ConsumedState TmpState = PInfo.getAsState(StateMap);
592
593 if (TmpState == CS_None || isCallableInState(CWAttr, TmpState))
594 return;
595
596 Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
597 FunDecl->getNameAsString(), stateToString(TmpState), BlameLoc);
598 }
599 }
600
601 // Factors out common behavior for function, method, and operator calls.
602 // Check parameters and set parameter state if necessary.
603 // Returns true if the state of ObjArg is set, or false otherwise.
handleCall(const CallExpr * Call,const Expr * ObjArg,const FunctionDecl * FunD)604 bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,
605 const FunctionDecl *FunD) {
606 unsigned Offset = 0;
607 if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))
608 Offset = 1; // first argument is 'this'
609
610 // check explicit parameters
611 for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
612 // Skip variable argument lists.
613 if (Index - Offset >= FunD->getNumParams())
614 break;
615
616 const ParmVarDecl *Param = FunD->getParamDecl(Index - Offset);
617 QualType ParamType = Param->getType();
618
619 InfoEntry Entry = findInfo(Call->getArg(Index));
620
621 if (Entry == PropagationMap.end() || Entry->second.isTest())
622 continue;
623 PropagationInfo PInfo = Entry->second;
624
625 // Check that the parameter is in the correct state.
626 if (ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>()) {
627 ConsumedState ParamState = PInfo.getAsState(StateMap);
628 ConsumedState ExpectedState = mapParamTypestateAttrState(PTA);
629
630 if (ParamState != ExpectedState)
631 Analyzer.WarningsHandler.warnParamTypestateMismatch(
632 Call->getArg(Index)->getExprLoc(),
633 stateToString(ExpectedState), stateToString(ParamState));
634 }
635
636 if (!(Entry->second.isVar() || Entry->second.isTmp()))
637 continue;
638
639 // Adjust state on the caller side.
640 if (ReturnTypestateAttr *RT = Param->getAttr<ReturnTypestateAttr>())
641 setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));
642 else if (isRValueRef(ParamType) || isConsumableType(ParamType))
643 setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);
644 else if (ParamType->isPointerOrReferenceType() &&
645 (!ParamType->getPointeeType().isConstQualified() ||
646 isSetOnReadPtrType(ParamType)))
647 setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);
648 }
649
650 if (!ObjArg)
651 return false;
652
653 // check implicit 'self' parameter, if present
654 InfoEntry Entry = findInfo(ObjArg);
655 if (Entry != PropagationMap.end()) {
656 PropagationInfo PInfo = Entry->second;
657 checkCallability(PInfo, FunD, Call->getExprLoc());
658
659 if (SetTypestateAttr *STA = FunD->getAttr<SetTypestateAttr>()) {
660 if (PInfo.isVar()) {
661 StateMap->setState(PInfo.getVar(), mapSetTypestateAttrState(STA));
662 return true;
663 }
664 else if (PInfo.isTmp()) {
665 StateMap->setState(PInfo.getTmp(), mapSetTypestateAttrState(STA));
666 return true;
667 }
668 }
669 else if (isTestingFunction(FunD) && PInfo.isVar()) {
670 PropagationMap.insert(PairType(Call,
671 PropagationInfo(PInfo.getVar(), testsFor(FunD))));
672 }
673 }
674 return false;
675 }
676
propagateReturnType(const Expr * Call,const FunctionDecl * Fun)677 void ConsumedStmtVisitor::propagateReturnType(const Expr *Call,
678 const FunctionDecl *Fun) {
679 QualType RetType = Fun->getCallResultType();
680 if (RetType->isReferenceType())
681 RetType = RetType->getPointeeType();
682
683 if (isConsumableType(RetType)) {
684 ConsumedState ReturnState;
685 if (ReturnTypestateAttr *RTA = Fun->getAttr<ReturnTypestateAttr>())
686 ReturnState = mapReturnTypestateAttrState(RTA);
687 else
688 ReturnState = mapConsumableAttrState(RetType);
689
690 PropagationMap.insert(PairType(Call, PropagationInfo(ReturnState)));
691 }
692 }
693
VisitBinaryOperator(const BinaryOperator * BinOp)694 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
695 switch (BinOp->getOpcode()) {
696 case BO_LAnd:
697 case BO_LOr : {
698 InfoEntry LEntry = findInfo(BinOp->getLHS()),
699 REntry = findInfo(BinOp->getRHS());
700
701 VarTestResult LTest, RTest;
702
703 if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {
704 LTest = LEntry->second.getVarTest();
705 } else {
706 LTest.Var = nullptr;
707 LTest.TestsFor = CS_None;
708 }
709
710 if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {
711 RTest = REntry->second.getVarTest();
712 } else {
713 RTest.Var = nullptr;
714 RTest.TestsFor = CS_None;
715 }
716
717 if (!(LTest.Var == nullptr && RTest.Var == nullptr))
718 PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
719 static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
720 break;
721 }
722
723 case BO_PtrMemD:
724 case BO_PtrMemI:
725 forwardInfo(BinOp->getLHS(), BinOp);
726 break;
727
728 default:
729 break;
730 }
731 }
732
VisitCallExpr(const CallExpr * Call)733 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
734 const FunctionDecl *FunDecl = Call->getDirectCallee();
735 if (!FunDecl)
736 return;
737
738 // Special case for the std::move function.
739 // TODO: Make this more specific. (Deferred)
740 if (Call->isCallToStdMove()) {
741 copyInfo(Call->getArg(0), Call, CS_Consumed);
742 return;
743 }
744
745 handleCall(Call, nullptr, FunDecl);
746 propagateReturnType(Call, FunDecl);
747 }
748
VisitCastExpr(const CastExpr * Cast)749 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
750 forwardInfo(Cast->getSubExpr(), Cast);
751 }
752
VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr * Temp)753 void ConsumedStmtVisitor::VisitCXXBindTemporaryExpr(
754 const CXXBindTemporaryExpr *Temp) {
755
756 InfoEntry Entry = findInfo(Temp->getSubExpr());
757
758 if (Entry != PropagationMap.end() && !Entry->second.isTest()) {
759 StateMap->setState(Temp, Entry->second.getAsState(StateMap));
760 PropagationMap.insert(PairType(Temp, PropagationInfo(Temp)));
761 }
762 }
763
VisitCXXConstructExpr(const CXXConstructExpr * Call)764 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
765 CXXConstructorDecl *Constructor = Call->getConstructor();
766
767 QualType ThisType = Constructor->getFunctionObjectParameterType();
768
769 if (!isConsumableType(ThisType))
770 return;
771
772 // FIXME: What should happen if someone annotates the move constructor?
773 if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {
774 // TODO: Adjust state of args appropriately.
775 ConsumedState RetState = mapReturnTypestateAttrState(RTA);
776 PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
777 } else if (Constructor->isDefaultConstructor()) {
778 PropagationMap.insert(PairType(Call,
779 PropagationInfo(consumed::CS_Consumed)));
780 } else if (Constructor->isMoveConstructor()) {
781 copyInfo(Call->getArg(0), Call, CS_Consumed);
782 } else if (Constructor->isCopyConstructor()) {
783 // Copy state from arg. If setStateOnRead then set arg to CS_Unknown.
784 ConsumedState NS =
785 isSetOnReadPtrType(Constructor->getThisType()) ?
786 CS_Unknown : CS_None;
787 copyInfo(Call->getArg(0), Call, NS);
788 } else {
789 // TODO: Adjust state of args appropriately.
790 ConsumedState RetState = mapConsumableAttrState(ThisType);
791 PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
792 }
793 }
794
VisitCXXMemberCallExpr(const CXXMemberCallExpr * Call)795 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
796 const CXXMemberCallExpr *Call) {
797 CXXMethodDecl* MD = Call->getMethodDecl();
798 if (!MD)
799 return;
800
801 handleCall(Call, Call->getImplicitObjectArgument(), MD);
802 propagateReturnType(Call, MD);
803 }
804
VisitCXXOperatorCallExpr(const CXXOperatorCallExpr * Call)805 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
806 const CXXOperatorCallExpr *Call) {
807 const auto *FunDecl = dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
808 if (!FunDecl) return;
809
810 if (Call->getOperator() == OO_Equal) {
811 ConsumedState CS = getInfo(Call->getArg(1));
812 if (!handleCall(Call, Call->getArg(0), FunDecl))
813 setInfo(Call->getArg(0), CS);
814 return;
815 }
816
817 if (const auto *MCall = dyn_cast<CXXMemberCallExpr>(Call))
818 handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);
819 else
820 handleCall(Call, Call->getArg(0), FunDecl);
821
822 propagateReturnType(Call, FunDecl);
823 }
824
VisitDeclRefExpr(const DeclRefExpr * DeclRef)825 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
826 if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
827 if (StateMap->getState(Var) != consumed::CS_None)
828 PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
829 }
830
VisitDeclStmt(const DeclStmt * DeclS)831 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
832 for (const auto *DI : DeclS->decls())
833 if (isa<VarDecl>(DI))
834 VisitVarDecl(cast<VarDecl>(DI));
835
836 if (DeclS->isSingleDecl())
837 if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
838 PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
839 }
840
VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr * Temp)841 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
842 const MaterializeTemporaryExpr *Temp) {
843 forwardInfo(Temp->getSubExpr(), Temp);
844 }
845
VisitMemberExpr(const MemberExpr * MExpr)846 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
847 forwardInfo(MExpr->getBase(), MExpr);
848 }
849
VisitParmVarDecl(const ParmVarDecl * Param)850 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
851 QualType ParamType = Param->getType();
852 ConsumedState ParamState = consumed::CS_None;
853
854 if (const ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>())
855 ParamState = mapParamTypestateAttrState(PTA);
856 else if (isConsumableType(ParamType))
857 ParamState = mapConsumableAttrState(ParamType);
858 else if (isRValueRef(ParamType) &&
859 isConsumableType(ParamType->getPointeeType()))
860 ParamState = mapConsumableAttrState(ParamType->getPointeeType());
861 else if (ParamType->isReferenceType() &&
862 isConsumableType(ParamType->getPointeeType()))
863 ParamState = consumed::CS_Unknown;
864
865 if (ParamState != CS_None)
866 StateMap->setState(Param, ParamState);
867 }
868
VisitReturnStmt(const ReturnStmt * Ret)869 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
870 ConsumedState ExpectedState = Analyzer.getExpectedReturnState();
871
872 if (ExpectedState != CS_None) {
873 InfoEntry Entry = findInfo(Ret->getRetValue());
874
875 if (Entry != PropagationMap.end()) {
876 ConsumedState RetState = Entry->second.getAsState(StateMap);
877
878 if (RetState != ExpectedState)
879 Analyzer.WarningsHandler.warnReturnTypestateMismatch(
880 Ret->getReturnLoc(), stateToString(ExpectedState),
881 stateToString(RetState));
882 }
883 }
884
885 StateMap->checkParamsForReturnTypestate(Ret->getBeginLoc(),
886 Analyzer.WarningsHandler);
887 }
888
VisitUnaryOperator(const UnaryOperator * UOp)889 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
890 InfoEntry Entry = findInfo(UOp->getSubExpr());
891 if (Entry == PropagationMap.end()) return;
892
893 switch (UOp->getOpcode()) {
894 case UO_AddrOf:
895 PropagationMap.insert(PairType(UOp, Entry->second));
896 break;
897
898 case UO_LNot:
899 if (Entry->second.isTest())
900 PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
901 break;
902
903 default:
904 break;
905 }
906 }
907
908 // TODO: See if I need to check for reference types here.
VisitVarDecl(const VarDecl * Var)909 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
910 if (isConsumableType(Var->getType())) {
911 if (Var->hasInit()) {
912 MapType::iterator VIT = findInfo(Var->getInit()->IgnoreImplicit());
913 if (VIT != PropagationMap.end()) {
914 PropagationInfo PInfo = VIT->second;
915 ConsumedState St = PInfo.getAsState(StateMap);
916
917 if (St != consumed::CS_None) {
918 StateMap->setState(Var, St);
919 return;
920 }
921 }
922 }
923 // Otherwise
924 StateMap->setState(Var, consumed::CS_Unknown);
925 }
926 }
927
splitVarStateForIf(const IfStmt * IfNode,const VarTestResult & Test,ConsumedStateMap * ThenStates,ConsumedStateMap * ElseStates)928 static void splitVarStateForIf(const IfStmt *IfNode, const VarTestResult &Test,
929 ConsumedStateMap *ThenStates,
930 ConsumedStateMap *ElseStates) {
931 ConsumedState VarState = ThenStates->getState(Test.Var);
932
933 if (VarState == CS_Unknown) {
934 ThenStates->setState(Test.Var, Test.TestsFor);
935 ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
936 } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
937 ThenStates->markUnreachable();
938 } else if (VarState == Test.TestsFor) {
939 ElseStates->markUnreachable();
940 }
941 }
942
splitVarStateForIfBinOp(const PropagationInfo & PInfo,ConsumedStateMap * ThenStates,ConsumedStateMap * ElseStates)943 static void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
944 ConsumedStateMap *ThenStates,
945 ConsumedStateMap *ElseStates) {
946 const VarTestResult <est = PInfo.getLTest(),
947 &RTest = PInfo.getRTest();
948
949 ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
950 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
951
952 if (LTest.Var) {
953 if (PInfo.testEffectiveOp() == EO_And) {
954 if (LState == CS_Unknown) {
955 ThenStates->setState(LTest.Var, LTest.TestsFor);
956 } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
957 ThenStates->markUnreachable();
958 } else if (LState == LTest.TestsFor && isKnownState(RState)) {
959 if (RState == RTest.TestsFor)
960 ElseStates->markUnreachable();
961 else
962 ThenStates->markUnreachable();
963 }
964 } else {
965 if (LState == CS_Unknown) {
966 ElseStates->setState(LTest.Var,
967 invertConsumedUnconsumed(LTest.TestsFor));
968 } else if (LState == LTest.TestsFor) {
969 ElseStates->markUnreachable();
970 } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
971 isKnownState(RState)) {
972 if (RState == RTest.TestsFor)
973 ElseStates->markUnreachable();
974 else
975 ThenStates->markUnreachable();
976 }
977 }
978 }
979
980 if (RTest.Var) {
981 if (PInfo.testEffectiveOp() == EO_And) {
982 if (RState == CS_Unknown)
983 ThenStates->setState(RTest.Var, RTest.TestsFor);
984 else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
985 ThenStates->markUnreachable();
986 } else {
987 if (RState == CS_Unknown)
988 ElseStates->setState(RTest.Var,
989 invertConsumedUnconsumed(RTest.TestsFor));
990 else if (RState == RTest.TestsFor)
991 ElseStates->markUnreachable();
992 }
993 }
994 }
995
allBackEdgesVisited(const CFGBlock * CurrBlock,const CFGBlock * TargetBlock)996 bool ConsumedBlockInfo::allBackEdgesVisited(const CFGBlock *CurrBlock,
997 const CFGBlock *TargetBlock) {
998 assert(CurrBlock && "Block pointer must not be NULL");
999 assert(TargetBlock && "TargetBlock pointer must not be NULL");
1000
1001 unsigned int CurrBlockOrder = VisitOrder[CurrBlock->getBlockID()];
1002 for (CFGBlock::const_pred_iterator PI = TargetBlock->pred_begin(),
1003 PE = TargetBlock->pred_end(); PI != PE; ++PI) {
1004 if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )
1005 return false;
1006 }
1007 return true;
1008 }
1009
addInfo(const CFGBlock * Block,ConsumedStateMap * StateMap,std::unique_ptr<ConsumedStateMap> & OwnedStateMap)1010 void ConsumedBlockInfo::addInfo(
1011 const CFGBlock *Block, ConsumedStateMap *StateMap,
1012 std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {
1013 assert(Block && "Block pointer must not be NULL");
1014
1015 auto &Entry = StateMapsArray[Block->getBlockID()];
1016
1017 if (Entry) {
1018 Entry->intersect(*StateMap);
1019 } else if (OwnedStateMap)
1020 Entry = std::move(OwnedStateMap);
1021 else
1022 Entry = std::make_unique<ConsumedStateMap>(*StateMap);
1023 }
1024
addInfo(const CFGBlock * Block,std::unique_ptr<ConsumedStateMap> StateMap)1025 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
1026 std::unique_ptr<ConsumedStateMap> StateMap) {
1027 assert(Block && "Block pointer must not be NULL");
1028
1029 auto &Entry = StateMapsArray[Block->getBlockID()];
1030
1031 if (Entry) {
1032 Entry->intersect(*StateMap);
1033 } else {
1034 Entry = std::move(StateMap);
1035 }
1036 }
1037
borrowInfo(const CFGBlock * Block)1038 ConsumedStateMap* ConsumedBlockInfo::borrowInfo(const CFGBlock *Block) {
1039 assert(Block && "Block pointer must not be NULL");
1040 assert(StateMapsArray[Block->getBlockID()] && "Block has no block info");
1041
1042 return StateMapsArray[Block->getBlockID()].get();
1043 }
1044
discardInfo(const CFGBlock * Block)1045 void ConsumedBlockInfo::discardInfo(const CFGBlock *Block) {
1046 StateMapsArray[Block->getBlockID()] = nullptr;
1047 }
1048
1049 std::unique_ptr<ConsumedStateMap>
getInfo(const CFGBlock * Block)1050 ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
1051 assert(Block && "Block pointer must not be NULL");
1052
1053 auto &Entry = StateMapsArray[Block->getBlockID()];
1054 return isBackEdgeTarget(Block) ? std::make_unique<ConsumedStateMap>(*Entry)
1055 : std::move(Entry);
1056 }
1057
isBackEdge(const CFGBlock * From,const CFGBlock * To)1058 bool ConsumedBlockInfo::isBackEdge(const CFGBlock *From, const CFGBlock *To) {
1059 assert(From && "From block must not be NULL");
1060 assert(To && "From block must not be NULL");
1061
1062 return VisitOrder[From->getBlockID()] > VisitOrder[To->getBlockID()];
1063 }
1064
isBackEdgeTarget(const CFGBlock * Block)1065 bool ConsumedBlockInfo::isBackEdgeTarget(const CFGBlock *Block) {
1066 assert(Block && "Block pointer must not be NULL");
1067
1068 // Anything with less than two predecessors can't be the target of a back
1069 // edge.
1070 if (Block->pred_size() < 2)
1071 return false;
1072
1073 unsigned int BlockVisitOrder = VisitOrder[Block->getBlockID()];
1074 for (CFGBlock::const_pred_iterator PI = Block->pred_begin(),
1075 PE = Block->pred_end(); PI != PE; ++PI) {
1076 if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])
1077 return true;
1078 }
1079 return false;
1080 }
1081
checkParamsForReturnTypestate(SourceLocation BlameLoc,ConsumedWarningsHandlerBase & WarningsHandler) const1082 void ConsumedStateMap::checkParamsForReturnTypestate(SourceLocation BlameLoc,
1083 ConsumedWarningsHandlerBase &WarningsHandler) const {
1084
1085 for (const auto &DM : VarMap) {
1086 if (isa<ParmVarDecl>(DM.first)) {
1087 const auto *Param = cast<ParmVarDecl>(DM.first);
1088 const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();
1089
1090 if (!RTA)
1091 continue;
1092
1093 ConsumedState ExpectedState = mapReturnTypestateAttrState(RTA);
1094 if (DM.second != ExpectedState)
1095 WarningsHandler.warnParamReturnTypestateMismatch(BlameLoc,
1096 Param->getNameAsString(), stateToString(ExpectedState),
1097 stateToString(DM.second));
1098 }
1099 }
1100 }
1101
clearTemporaries()1102 void ConsumedStateMap::clearTemporaries() {
1103 TmpMap.clear();
1104 }
1105
getState(const VarDecl * Var) const1106 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) const {
1107 VarMapType::const_iterator Entry = VarMap.find(Var);
1108
1109 if (Entry != VarMap.end())
1110 return Entry->second;
1111
1112 return CS_None;
1113 }
1114
1115 ConsumedState
getState(const CXXBindTemporaryExpr * Tmp) const1116 ConsumedStateMap::getState(const CXXBindTemporaryExpr *Tmp) const {
1117 TmpMapType::const_iterator Entry = TmpMap.find(Tmp);
1118
1119 if (Entry != TmpMap.end())
1120 return Entry->second;
1121
1122 return CS_None;
1123 }
1124
intersect(const ConsumedStateMap & Other)1125 void ConsumedStateMap::intersect(const ConsumedStateMap &Other) {
1126 ConsumedState LocalState;
1127
1128 if (this->From && this->From == Other.From && !Other.Reachable) {
1129 this->markUnreachable();
1130 return;
1131 }
1132
1133 for (const auto &DM : Other.VarMap) {
1134 LocalState = this->getState(DM.first);
1135
1136 if (LocalState == CS_None)
1137 continue;
1138
1139 if (LocalState != DM.second)
1140 VarMap[DM.first] = CS_Unknown;
1141 }
1142 }
1143
intersectAtLoopHead(const CFGBlock * LoopHead,const CFGBlock * LoopBack,const ConsumedStateMap * LoopBackStates,ConsumedWarningsHandlerBase & WarningsHandler)1144 void ConsumedStateMap::intersectAtLoopHead(const CFGBlock *LoopHead,
1145 const CFGBlock *LoopBack, const ConsumedStateMap *LoopBackStates,
1146 ConsumedWarningsHandlerBase &WarningsHandler) {
1147
1148 ConsumedState LocalState;
1149 SourceLocation BlameLoc = getLastStmtLoc(LoopBack);
1150
1151 for (const auto &DM : LoopBackStates->VarMap) {
1152 LocalState = this->getState(DM.first);
1153
1154 if (LocalState == CS_None)
1155 continue;
1156
1157 if (LocalState != DM.second) {
1158 VarMap[DM.first] = CS_Unknown;
1159 WarningsHandler.warnLoopStateMismatch(BlameLoc,
1160 DM.first->getNameAsString());
1161 }
1162 }
1163 }
1164
markUnreachable()1165 void ConsumedStateMap::markUnreachable() {
1166 this->Reachable = false;
1167 VarMap.clear();
1168 TmpMap.clear();
1169 }
1170
setState(const VarDecl * Var,ConsumedState State)1171 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
1172 VarMap[Var] = State;
1173 }
1174
setState(const CXXBindTemporaryExpr * Tmp,ConsumedState State)1175 void ConsumedStateMap::setState(const CXXBindTemporaryExpr *Tmp,
1176 ConsumedState State) {
1177 TmpMap[Tmp] = State;
1178 }
1179
remove(const CXXBindTemporaryExpr * Tmp)1180 void ConsumedStateMap::remove(const CXXBindTemporaryExpr *Tmp) {
1181 TmpMap.erase(Tmp);
1182 }
1183
operator !=(const ConsumedStateMap * Other) const1184 bool ConsumedStateMap::operator!=(const ConsumedStateMap *Other) const {
1185 for (const auto &DM : Other->VarMap)
1186 if (this->getState(DM.first) != DM.second)
1187 return true;
1188 return false;
1189 }
1190
determineExpectedReturnState(AnalysisDeclContext & AC,const FunctionDecl * D)1191 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
1192 const FunctionDecl *D) {
1193 QualType ReturnType;
1194 if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1195 ReturnType = Constructor->getFunctionObjectParameterType();
1196 } else
1197 ReturnType = D->getCallResultType();
1198
1199 if (const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>()) {
1200 const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1201 if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1202 // FIXME: This should be removed when template instantiation propagates
1203 // attributes at template specialization definition, not
1204 // declaration. When it is removed the test needs to be enabled
1205 // in SemaDeclAttr.cpp.
1206 WarningsHandler.warnReturnTypestateForUnconsumableType(
1207 RTSAttr->getLocation(), ReturnType.getAsString());
1208 ExpectedReturnState = CS_None;
1209 } else
1210 ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1211 } else if (isConsumableType(ReturnType)) {
1212 if (isAutoCastType(ReturnType)) // We can auto-cast the state to the
1213 ExpectedReturnState = CS_None; // expected state.
1214 else
1215 ExpectedReturnState = mapConsumableAttrState(ReturnType);
1216 }
1217 else
1218 ExpectedReturnState = CS_None;
1219 }
1220
splitState(const CFGBlock * CurrBlock,const ConsumedStmtVisitor & Visitor)1221 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1222 const ConsumedStmtVisitor &Visitor) {
1223 std::unique_ptr<ConsumedStateMap> FalseStates(
1224 new ConsumedStateMap(*CurrStates));
1225 PropagationInfo PInfo;
1226
1227 if (const auto *IfNode =
1228 dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1229 if (IfNode->isConsteval())
1230 return false;
1231
1232 const Expr *Cond = IfNode->getCond();
1233
1234 PInfo = Visitor.getInfo(Cond);
1235 if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1236 PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1237
1238 if (PInfo.isVarTest()) {
1239 CurrStates->setSource(Cond);
1240 FalseStates->setSource(Cond);
1241 splitVarStateForIf(IfNode, PInfo.getVarTest(), CurrStates.get(),
1242 FalseStates.get());
1243 } else if (PInfo.isBinTest()) {
1244 CurrStates->setSource(PInfo.testSourceNode());
1245 FalseStates->setSource(PInfo.testSourceNode());
1246 splitVarStateForIfBinOp(PInfo, CurrStates.get(), FalseStates.get());
1247 } else {
1248 return false;
1249 }
1250 } else if (const auto *BinOp =
1251 dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1252 PInfo = Visitor.getInfo(BinOp->getLHS());
1253 if (!PInfo.isVarTest()) {
1254 if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1255 PInfo = Visitor.getInfo(BinOp->getRHS());
1256
1257 if (!PInfo.isVarTest())
1258 return false;
1259 } else {
1260 return false;
1261 }
1262 }
1263
1264 CurrStates->setSource(BinOp);
1265 FalseStates->setSource(BinOp);
1266
1267 const VarTestResult &Test = PInfo.getVarTest();
1268 ConsumedState VarState = CurrStates->getState(Test.Var);
1269
1270 if (BinOp->getOpcode() == BO_LAnd) {
1271 if (VarState == CS_Unknown)
1272 CurrStates->setState(Test.Var, Test.TestsFor);
1273 else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1274 CurrStates->markUnreachable();
1275
1276 } else if (BinOp->getOpcode() == BO_LOr) {
1277 if (VarState == CS_Unknown)
1278 FalseStates->setState(Test.Var,
1279 invertConsumedUnconsumed(Test.TestsFor));
1280 else if (VarState == Test.TestsFor)
1281 FalseStates->markUnreachable();
1282 }
1283 } else {
1284 return false;
1285 }
1286
1287 CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1288
1289 if (*SI)
1290 BlockInfo.addInfo(*SI, std::move(CurrStates));
1291 else
1292 CurrStates = nullptr;
1293
1294 if (*++SI)
1295 BlockInfo.addInfo(*SI, std::move(FalseStates));
1296
1297 return true;
1298 }
1299
run(AnalysisDeclContext & AC)1300 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1301 const auto *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1302 if (!D)
1303 return;
1304
1305 CFG *CFGraph = AC.getCFG();
1306 if (!CFGraph)
1307 return;
1308
1309 determineExpectedReturnState(AC, D);
1310
1311 PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1312 // AC.getCFG()->viewCFG(LangOptions());
1313
1314 BlockInfo = ConsumedBlockInfo(CFGraph->getNumBlockIDs(), SortedGraph);
1315
1316 CurrStates = std::make_unique<ConsumedStateMap>();
1317 ConsumedStmtVisitor Visitor(*this, CurrStates.get());
1318
1319 // Add all trackable parameters to the state map.
1320 for (const auto *PI : D->parameters())
1321 Visitor.VisitParmVarDecl(PI);
1322
1323 // Visit all of the function's basic blocks.
1324 for (const auto *CurrBlock : *SortedGraph) {
1325 if (!CurrStates)
1326 CurrStates = BlockInfo.getInfo(CurrBlock);
1327
1328 if (!CurrStates) {
1329 continue;
1330 } else if (!CurrStates->isReachable()) {
1331 CurrStates = nullptr;
1332 continue;
1333 }
1334
1335 Visitor.reset(CurrStates.get());
1336
1337 // Visit all of the basic block's statements.
1338 for (const auto &B : *CurrBlock) {
1339 switch (B.getKind()) {
1340 case CFGElement::Statement:
1341 Visitor.Visit(B.castAs<CFGStmt>().getStmt());
1342 break;
1343
1344 case CFGElement::TemporaryDtor: {
1345 const CFGTemporaryDtor &DTor = B.castAs<CFGTemporaryDtor>();
1346 const CXXBindTemporaryExpr *BTE = DTor.getBindTemporaryExpr();
1347
1348 Visitor.checkCallability(PropagationInfo(BTE),
1349 DTor.getDestructorDecl(AC.getASTContext()),
1350 BTE->getExprLoc());
1351 CurrStates->remove(BTE);
1352 break;
1353 }
1354
1355 case CFGElement::AutomaticObjectDtor: {
1356 const CFGAutomaticObjDtor &DTor = B.castAs<CFGAutomaticObjDtor>();
1357 SourceLocation Loc = DTor.getTriggerStmt()->getEndLoc();
1358 const VarDecl *Var = DTor.getVarDecl();
1359
1360 Visitor.checkCallability(PropagationInfo(Var),
1361 DTor.getDestructorDecl(AC.getASTContext()),
1362 Loc);
1363 break;
1364 }
1365
1366 default:
1367 break;
1368 }
1369 }
1370
1371 // TODO: Handle other forms of branching with precision, including while-
1372 // and for-loops. (Deferred)
1373 if (!splitState(CurrBlock, Visitor)) {
1374 CurrStates->setSource(nullptr);
1375
1376 if (CurrBlock->succ_size() > 1 ||
1377 (CurrBlock->succ_size() == 1 &&
1378 (*CurrBlock->succ_begin())->pred_size() > 1)) {
1379
1380 auto *RawState = CurrStates.get();
1381
1382 for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1383 SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1384 if (*SI == nullptr) continue;
1385
1386 if (BlockInfo.isBackEdge(CurrBlock, *SI)) {
1387 BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(
1388 *SI, CurrBlock, RawState, WarningsHandler);
1389
1390 if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))
1391 BlockInfo.discardInfo(*SI);
1392 } else {
1393 BlockInfo.addInfo(*SI, RawState, CurrStates);
1394 }
1395 }
1396
1397 CurrStates = nullptr;
1398 }
1399 }
1400
1401 if (CurrBlock == &AC.getCFG()->getExit() &&
1402 D->getCallResultType()->isVoidType())
1403 CurrStates->checkParamsForReturnTypestate(D->getLocation(),
1404 WarningsHandler);
1405 } // End of block iterator.
1406
1407 // Delete the last existing state map.
1408 CurrStates = nullptr;
1409
1410 WarningsHandler.emitDiagnostics();
1411 }
1412