xref: /freebsd/contrib/llvm-project/clang/lib/ARCMigrate/TransProtectedScope.cpp (revision b4af4f93c682e445bf159f0d1ec90b636296c946)
1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Adds brackets in case statements that "contain" initialization of retaining
10 // variable, thus emitting the "switch case is in protected scope" error.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "Transforms.h"
15 #include "Internals.h"
16 #include "clang/AST/ASTContext.h"
17 #include "clang/Sema/SemaDiagnostic.h"
18 
19 using namespace clang;
20 using namespace arcmt;
21 using namespace trans;
22 
23 namespace {
24 
25 class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
26   SmallVectorImpl<DeclRefExpr *> &Refs;
27 
28 public:
29   LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
30     : Refs(refs) { }
31 
32   bool VisitDeclRefExpr(DeclRefExpr *E) {
33     if (ValueDecl *D = E->getDecl())
34       if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
35         Refs.push_back(E);
36     return true;
37   }
38 };
39 
40 struct CaseInfo {
41   SwitchCase *SC;
42   SourceRange Range;
43   enum {
44     St_Unchecked,
45     St_CannotFix,
46     St_Fixed
47   } State;
48 
49   CaseInfo() : SC(nullptr), State(St_Unchecked) {}
50   CaseInfo(SwitchCase *S, SourceRange Range)
51     : SC(S), Range(Range), State(St_Unchecked) {}
52 };
53 
54 class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
55   ParentMap &PMap;
56   SmallVectorImpl<CaseInfo> &Cases;
57 
58 public:
59   CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
60     : PMap(PMap), Cases(Cases) { }
61 
62   bool VisitSwitchStmt(SwitchStmt *S) {
63     SwitchCase *Curr = S->getSwitchCaseList();
64     if (!Curr)
65       return true;
66     Stmt *Parent = getCaseParent(Curr);
67     Curr = Curr->getNextSwitchCase();
68     // Make sure all case statements are in the same scope.
69     while (Curr) {
70       if (getCaseParent(Curr) != Parent)
71         return true;
72       Curr = Curr->getNextSwitchCase();
73     }
74 
75     SourceLocation NextLoc = S->getEndLoc();
76     Curr = S->getSwitchCaseList();
77     // We iterate over case statements in reverse source-order.
78     while (Curr) {
79       Cases.push_back(
80           CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc)));
81       NextLoc = Curr->getBeginLoc();
82       Curr = Curr->getNextSwitchCase();
83     }
84     return true;
85   }
86 
87   Stmt *getCaseParent(SwitchCase *S) {
88     Stmt *Parent = PMap.getParent(S);
89     while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
90       Parent = PMap.getParent(Parent);
91     return Parent;
92   }
93 };
94 
95 class ProtectedScopeFixer {
96   MigrationPass &Pass;
97   SourceManager &SM;
98   SmallVector<CaseInfo, 16> Cases;
99   SmallVector<DeclRefExpr *, 16> LocalRefs;
100 
101 public:
102   ProtectedScopeFixer(BodyContext &BodyCtx)
103     : Pass(BodyCtx.getMigrationContext().Pass),
104       SM(Pass.Ctx.getSourceManager()) {
105 
106     CaseCollector(BodyCtx.getParentMap(), Cases)
107         .TraverseStmt(BodyCtx.getTopStmt());
108     LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
109 
110     SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
111     const CapturedDiagList &DiagList = Pass.getDiags();
112     // Copy the diagnostics so we don't have to worry about invaliding iterators
113     // from the diagnostic list.
114     SmallVector<StoredDiagnostic, 16> StoredDiags;
115     StoredDiags.append(DiagList.begin(), DiagList.end());
116     SmallVectorImpl<StoredDiagnostic>::iterator
117         I = StoredDiags.begin(), E = StoredDiags.end();
118     while (I != E) {
119       if (I->getID() == diag::err_switch_into_protected_scope &&
120           isInRange(I->getLocation(), BodyRange)) {
121         handleProtectedScopeError(I, E);
122         continue;
123       }
124       ++I;
125     }
126   }
127 
128   void handleProtectedScopeError(
129                              SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
130                              SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
131     Transaction Trans(Pass.TA);
132     assert(DiagI->getID() == diag::err_switch_into_protected_scope);
133     SourceLocation ErrLoc = DiagI->getLocation();
134     bool handledAllNotes = true;
135     ++DiagI;
136     for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
137          ++DiagI) {
138       if (!handleProtectedNote(*DiagI))
139         handledAllNotes = false;
140     }
141 
142     if (handledAllNotes)
143       Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
144   }
145 
146   bool handleProtectedNote(const StoredDiagnostic &Diag) {
147     assert(Diag.getLevel() == DiagnosticsEngine::Note);
148 
149     for (unsigned i = 0; i != Cases.size(); i++) {
150       CaseInfo &info = Cases[i];
151       if (isInRange(Diag.getLocation(), info.Range)) {
152 
153         if (info.State == CaseInfo::St_Unchecked)
154           tryFixing(info);
155         assert(info.State != CaseInfo::St_Unchecked);
156 
157         if (info.State == CaseInfo::St_Fixed) {
158           Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
159           return true;
160         }
161         return false;
162       }
163     }
164 
165     return false;
166   }
167 
168   void tryFixing(CaseInfo &info) {
169     assert(info.State == CaseInfo::St_Unchecked);
170     if (hasVarReferencedOutside(info)) {
171       info.State = CaseInfo::St_CannotFix;
172       return;
173     }
174 
175     Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
176     Pass.TA.insert(info.Range.getEnd(), "}\n");
177     info.State = CaseInfo::St_Fixed;
178   }
179 
180   bool hasVarReferencedOutside(CaseInfo &info) {
181     for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
182       DeclRefExpr *DRE = LocalRefs[i];
183       if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
184           !isInRange(DRE->getLocation(), info.Range))
185         return true;
186     }
187     return false;
188   }
189 
190   bool isInRange(SourceLocation Loc, SourceRange R) {
191     if (Loc.isInvalid())
192       return false;
193     return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
194             SM.isBeforeInTranslationUnit(Loc, R.getEnd());
195   }
196 };
197 
198 } // anonymous namespace
199 
200 void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
201   ProtectedScopeFixer Fix(BodyCtx);
202 }
203