xref: /freebsd/contrib/llvm-project/clang/lib/ARCMigrate/TransProtectedScope.cpp (revision b077aed33b7b6aefca7b17ddb250cf521f938613)
1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Adds brackets in case statements that "contain" initialization of retaining
10 // variable, thus emitting the "switch case is in protected scope" error.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "Internals.h"
15 #include "Transforms.h"
16 #include "clang/AST/ASTContext.h"
17 #include "clang/Basic/SourceManager.h"
18 #include "clang/Sema/SemaDiagnostic.h"
19 
20 using namespace clang;
21 using namespace arcmt;
22 using namespace trans;
23 
24 namespace {
25 
26 class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
27   SmallVectorImpl<DeclRefExpr *> &Refs;
28 
29 public:
30   LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
31     : Refs(refs) { }
32 
33   bool VisitDeclRefExpr(DeclRefExpr *E) {
34     if (ValueDecl *D = E->getDecl())
35       if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
36         Refs.push_back(E);
37     return true;
38   }
39 };
40 
41 struct CaseInfo {
42   SwitchCase *SC;
43   SourceRange Range;
44   enum {
45     St_Unchecked,
46     St_CannotFix,
47     St_Fixed
48   } State;
49 
50   CaseInfo() : SC(nullptr), State(St_Unchecked) {}
51   CaseInfo(SwitchCase *S, SourceRange Range)
52     : SC(S), Range(Range), State(St_Unchecked) {}
53 };
54 
55 class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
56   ParentMap &PMap;
57   SmallVectorImpl<CaseInfo> &Cases;
58 
59 public:
60   CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
61     : PMap(PMap), Cases(Cases) { }
62 
63   bool VisitSwitchStmt(SwitchStmt *S) {
64     SwitchCase *Curr = S->getSwitchCaseList();
65     if (!Curr)
66       return true;
67     Stmt *Parent = getCaseParent(Curr);
68     Curr = Curr->getNextSwitchCase();
69     // Make sure all case statements are in the same scope.
70     while (Curr) {
71       if (getCaseParent(Curr) != Parent)
72         return true;
73       Curr = Curr->getNextSwitchCase();
74     }
75 
76     SourceLocation NextLoc = S->getEndLoc();
77     Curr = S->getSwitchCaseList();
78     // We iterate over case statements in reverse source-order.
79     while (Curr) {
80       Cases.push_back(
81           CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc)));
82       NextLoc = Curr->getBeginLoc();
83       Curr = Curr->getNextSwitchCase();
84     }
85     return true;
86   }
87 
88   Stmt *getCaseParent(SwitchCase *S) {
89     Stmt *Parent = PMap.getParent(S);
90     while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
91       Parent = PMap.getParent(Parent);
92     return Parent;
93   }
94 };
95 
96 class ProtectedScopeFixer {
97   MigrationPass &Pass;
98   SourceManager &SM;
99   SmallVector<CaseInfo, 16> Cases;
100   SmallVector<DeclRefExpr *, 16> LocalRefs;
101 
102 public:
103   ProtectedScopeFixer(BodyContext &BodyCtx)
104     : Pass(BodyCtx.getMigrationContext().Pass),
105       SM(Pass.Ctx.getSourceManager()) {
106 
107     CaseCollector(BodyCtx.getParentMap(), Cases)
108         .TraverseStmt(BodyCtx.getTopStmt());
109     LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
110 
111     SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
112     const CapturedDiagList &DiagList = Pass.getDiags();
113     // Copy the diagnostics so we don't have to worry about invaliding iterators
114     // from the diagnostic list.
115     SmallVector<StoredDiagnostic, 16> StoredDiags;
116     StoredDiags.append(DiagList.begin(), DiagList.end());
117     SmallVectorImpl<StoredDiagnostic>::iterator
118         I = StoredDiags.begin(), E = StoredDiags.end();
119     while (I != E) {
120       if (I->getID() == diag::err_switch_into_protected_scope &&
121           isInRange(I->getLocation(), BodyRange)) {
122         handleProtectedScopeError(I, E);
123         continue;
124       }
125       ++I;
126     }
127   }
128 
129   void handleProtectedScopeError(
130                              SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
131                              SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
132     Transaction Trans(Pass.TA);
133     assert(DiagI->getID() == diag::err_switch_into_protected_scope);
134     SourceLocation ErrLoc = DiagI->getLocation();
135     bool handledAllNotes = true;
136     ++DiagI;
137     for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
138          ++DiagI) {
139       if (!handleProtectedNote(*DiagI))
140         handledAllNotes = false;
141     }
142 
143     if (handledAllNotes)
144       Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
145   }
146 
147   bool handleProtectedNote(const StoredDiagnostic &Diag) {
148     assert(Diag.getLevel() == DiagnosticsEngine::Note);
149 
150     for (unsigned i = 0; i != Cases.size(); i++) {
151       CaseInfo &info = Cases[i];
152       if (isInRange(Diag.getLocation(), info.Range)) {
153 
154         if (info.State == CaseInfo::St_Unchecked)
155           tryFixing(info);
156         assert(info.State != CaseInfo::St_Unchecked);
157 
158         if (info.State == CaseInfo::St_Fixed) {
159           Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
160           return true;
161         }
162         return false;
163       }
164     }
165 
166     return false;
167   }
168 
169   void tryFixing(CaseInfo &info) {
170     assert(info.State == CaseInfo::St_Unchecked);
171     if (hasVarReferencedOutside(info)) {
172       info.State = CaseInfo::St_CannotFix;
173       return;
174     }
175 
176     Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
177     Pass.TA.insert(info.Range.getEnd(), "}\n");
178     info.State = CaseInfo::St_Fixed;
179   }
180 
181   bool hasVarReferencedOutside(CaseInfo &info) {
182     for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
183       DeclRefExpr *DRE = LocalRefs[i];
184       if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
185           !isInRange(DRE->getLocation(), info.Range))
186         return true;
187     }
188     return false;
189   }
190 
191   bool isInRange(SourceLocation Loc, SourceRange R) {
192     if (Loc.isInvalid())
193       return false;
194     return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
195             SM.isBeforeInTranslationUnit(Loc, R.getEnd());
196   }
197 };
198 
199 } // anonymous namespace
200 
201 void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
202   ProtectedScopeFixer Fix(BodyCtx);
203 }
204