xref: /freebsd/contrib/llvm-project/clang/lib/StaticAnalyzer/Core/LoopUnrolling.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file contains functions which are used to decide if a loop worth to be
10 /// unrolled. Moreover, these functions manages the stack of loop which is
11 /// tracked by the ProgramState.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/ASTMatchers/ASTMatchers.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20 #include <optional>
21 
22 using namespace clang;
23 using namespace ento;
24 using namespace clang::ast_matchers;
25 
26 static const int MAXIMUM_STEP_UNROLLED = 128;
27 
28 namespace {
29 struct LoopState {
30 private:
31   enum Kind { Normal, Unrolled } K;
32   const Stmt *LoopStmt;
33   const LocationContext *LCtx;
34   unsigned maxStep;
LoopState__anon87ccfcc30111::LoopState35   LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
36       : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
37 
38 public:
getNormal__anon87ccfcc30111::LoopState39   static LoopState getNormal(const Stmt *S, const LocationContext *L,
40                              unsigned N) {
41     return LoopState(Normal, S, L, N);
42   }
getUnrolled__anon87ccfcc30111::LoopState43   static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
44                                unsigned N) {
45     return LoopState(Unrolled, S, L, N);
46   }
isUnrolled__anon87ccfcc30111::LoopState47   bool isUnrolled() const { return K == Unrolled; }
getMaxStep__anon87ccfcc30111::LoopState48   unsigned getMaxStep() const { return maxStep; }
getLoopStmt__anon87ccfcc30111::LoopState49   const Stmt *getLoopStmt() const { return LoopStmt; }
getLocationContext__anon87ccfcc30111::LoopState50   const LocationContext *getLocationContext() const { return LCtx; }
operator ==__anon87ccfcc30111::LoopState51   bool operator==(const LoopState &X) const {
52     return K == X.K && LoopStmt == X.LoopStmt;
53   }
Profile__anon87ccfcc30111::LoopState54   void Profile(llvm::FoldingSetNodeID &ID) const {
55     ID.AddInteger(K);
56     ID.AddPointer(LoopStmt);
57     ID.AddPointer(LCtx);
58     ID.AddInteger(maxStep);
59   }
60 };
61 } // namespace
62 
63 // The tracked stack of loops. The stack indicates that which loops the
64 // simulated element contained by. The loops are marked depending if we decided
65 // to unroll them.
66 // TODO: The loop stack should not need to be in the program state since it is
67 // lexical in nature. Instead, the stack of loops should be tracked in the
68 // LocationContext.
69 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
70 
71 namespace clang {
72 namespace ento {
73 
isLoopStmt(const Stmt * S)74 static bool isLoopStmt(const Stmt *S) {
75   return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(S);
76 }
77 
processLoopEnd(const Stmt * LoopStmt,ProgramStateRef State)78 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
79   auto LS = State->get<LoopStack>();
80   if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
81     State = State->set<LoopStack>(LS.getTail());
82   return State;
83 }
84 
simpleCondition(StringRef BindName,StringRef RefName)85 static internal::Matcher<Stmt> simpleCondition(StringRef BindName,
86                                                StringRef RefName) {
87   return binaryOperator(
88              anyOf(hasOperatorName("<"), hasOperatorName(">"),
89                    hasOperatorName("<="), hasOperatorName(">="),
90                    hasOperatorName("!=")),
91              hasEitherOperand(ignoringParenImpCasts(
92                  declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName)))
93                      .bind(RefName))),
94              hasEitherOperand(
95                  ignoringParenImpCasts(integerLiteral().bind("boundNum"))))
96       .bind("conditionOperator");
97 }
98 
99 static internal::Matcher<Stmt>
changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher)100 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
101   return anyOf(
102       unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
103                     hasUnaryOperand(ignoringParenImpCasts(
104                         declRefExpr(to(varDecl(VarNodeMatcher)))))),
105       binaryOperator(isAssignmentOperator(),
106                      hasLHS(ignoringParenImpCasts(
107                          declRefExpr(to(varDecl(VarNodeMatcher)))))));
108 }
109 
110 static internal::Matcher<Stmt>
callByRef(internal::Matcher<Decl> VarNodeMatcher)111 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
112   return callExpr(forEachArgumentWithParam(
113       declRefExpr(to(varDecl(VarNodeMatcher))),
114       parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
115 }
116 
117 static internal::Matcher<Stmt>
assignedToRef(internal::Matcher<Decl> VarNodeMatcher)118 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
119   return declStmt(hasDescendant(varDecl(
120       allOf(hasType(referenceType()),
121             hasInitializer(anyOf(
122                 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
123                 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
124 }
125 
126 static internal::Matcher<Stmt>
getAddrTo(internal::Matcher<Decl> VarNodeMatcher)127 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
128   return unaryOperator(
129       hasOperatorName("&"),
130       hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
131 }
132 
hasSuspiciousStmt(StringRef NodeName)133 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
134   return hasDescendant(stmt(
135       anyOf(gotoStmt(), switchStmt(), returnStmt(),
136             // Escaping and not known mutation of the loop counter is handled
137             // by exclusion of assigning and address-of operators and
138             // pass-by-ref function calls on the loop counter from the body.
139             changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
140             callByRef(equalsBoundNode(std::string(NodeName))),
141             getAddrTo(equalsBoundNode(std::string(NodeName))),
142             assignedToRef(equalsBoundNode(std::string(NodeName))))));
143 }
144 
forLoopMatcher()145 static internal::Matcher<Stmt> forLoopMatcher() {
146   return forStmt(
147              hasCondition(simpleCondition("initVarName", "initVarRef")),
148              // Initialization should match the form: 'int i = 6' or 'i = 42'.
149              hasLoopInit(
150                  anyOf(declStmt(hasSingleDecl(
151                            varDecl(allOf(hasInitializer(ignoringParenImpCasts(
152                                              integerLiteral().bind("initNum"))),
153                                          equalsBoundNode("initVarName"))))),
154                        binaryOperator(hasLHS(declRefExpr(to(varDecl(
155                                           equalsBoundNode("initVarName"))))),
156                                       hasRHS(ignoringParenImpCasts(
157                                           integerLiteral().bind("initNum")))))),
158              // Incrementation should be a simple increment or decrement
159              // operator call.
160              hasIncrement(unaryOperator(
161                  anyOf(hasOperatorName("++"), hasOperatorName("--")),
162                  hasUnaryOperand(declRefExpr(
163                      to(varDecl(allOf(equalsBoundNode("initVarName"),
164                                       hasType(isInteger())))))))),
165              unless(hasBody(hasSuspiciousStmt("initVarName"))))
166       .bind("forLoop");
167 }
168 
isCapturedByReference(ExplodedNode * N,const DeclRefExpr * DR)169 static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) {
170 
171   // Get the lambda CXXRecordDecl
172   assert(DR->refersToEnclosingVariableOrCapture());
173   const LocationContext *LocCtxt = N->getLocationContext();
174   const Decl *D = LocCtxt->getDecl();
175   const auto *MD = cast<CXXMethodDecl>(D);
176   assert(MD && MD->getParent()->isLambda() &&
177          "Captured variable should only be seen while evaluating a lambda");
178   const CXXRecordDecl *LambdaCXXRec = MD->getParent();
179 
180   // Lookup the fields of the lambda
181   llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields;
182   FieldDecl *LambdaThisCaptureField;
183   LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField);
184 
185   // Check if the counter is captured by reference
186   const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
187   assert(VD);
188   const FieldDecl *FD = LambdaCaptureFields[VD];
189   assert(FD && "Captured variable without a corresponding field");
190   return FD->getType()->isReferenceType();
191 }
192 
isFoundInStmt(const Stmt * S,const VarDecl * VD)193 static bool isFoundInStmt(const Stmt *S, const VarDecl *VD) {
194   if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
195     for (const Decl *D : DS->decls()) {
196       // Once we reach the declaration of the VD we can return.
197       if (D->getCanonicalDecl() == VD)
198         return true;
199     }
200   }
201   return false;
202 }
203 
204 // A loop counter is considered escaped if:
205 // case 1: It is a global variable.
206 // case 2: It is a reference parameter or a reference capture.
207 // case 3: It is assigned to a non-const reference variable or parameter.
208 // case 4: Has its address taken.
isPossiblyEscaped(ExplodedNode * N,const DeclRefExpr * DR)209 static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
210   const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
211   assert(VD);
212   // Case 1:
213   if (VD->hasGlobalStorage())
214     return true;
215 
216   const bool IsRefParamOrCapture =
217       isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture();
218   // Case 2:
219   if ((DR->refersToEnclosingVariableOrCapture() &&
220        isCapturedByReference(N, DR)) ||
221       (IsRefParamOrCapture && VD->getType()->isReferenceType()))
222     return true;
223 
224   while (!N->pred_empty()) {
225     // FIXME: getStmtForDiagnostics() does nasty things in order to provide
226     // a valid statement for body farms, do we need this behavior here?
227     const Stmt *S = N->getStmtForDiagnostics();
228     if (!S) {
229       N = N->getFirstPred();
230       continue;
231     }
232 
233     if (isFoundInStmt(S, VD)) {
234       return false;
235     }
236 
237     if (const auto *SS = dyn_cast<SwitchStmt>(S)) {
238       if (const auto *CST = dyn_cast<CompoundStmt>(SS->getBody())) {
239         for (const Stmt *CB : CST->body()) {
240           if (isFoundInStmt(CB, VD))
241             return false;
242         }
243       }
244     }
245 
246     // Check the usage of the pass-by-ref function calls and adress-of operator
247     // on VD and reference initialized by VD.
248     ASTContext &ASTCtx =
249         N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
250     // Case 3 and 4:
251     auto Match =
252         match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
253                          assignedToRef(equalsNode(VD)))),
254               *S, ASTCtx);
255     if (!Match.empty())
256       return true;
257 
258     N = N->getFirstPred();
259   }
260 
261   // Reference parameter and reference capture will not be found.
262   if (IsRefParamOrCapture)
263     return false;
264 
265   llvm_unreachable("Reached root without finding the declaration of VD");
266 }
267 
shouldCompletelyUnroll(const Stmt * LoopStmt,ASTContext & ASTCtx,ExplodedNode * Pred,unsigned & maxStep)268 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
269                             ExplodedNode *Pred, unsigned &maxStep) {
270 
271   if (!isLoopStmt(LoopStmt))
272     return false;
273 
274   // TODO: Match the cases where the bound is not a concrete literal but an
275   // integer with known value
276   auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
277   if (Matches.empty())
278     return false;
279 
280   const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef");
281   llvm::APInt BoundNum =
282       Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
283   llvm::APInt InitNum =
284       Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
285   auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
286   if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
287     InitNum = InitNum.zext(BoundNum.getBitWidth());
288     BoundNum = BoundNum.zext(InitNum.getBitWidth());
289   }
290 
291   if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
292     maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
293   else
294     maxStep = (BoundNum - InitNum).abs().getZExtValue();
295 
296   // Check if the counter of the loop is not escaped before.
297   return !isPossiblyEscaped(Pred, CounterVarRef);
298 }
299 
madeNewBranch(ExplodedNode * N,const Stmt * LoopStmt)300 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
301   const Stmt *S = nullptr;
302   while (!N->pred_empty()) {
303     if (N->succ_size() > 1)
304       return true;
305 
306     ProgramPoint P = N->getLocation();
307     if (std::optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
308       S = BE->getBlock()->getTerminatorStmt();
309 
310     if (S == LoopStmt)
311       return false;
312 
313     N = N->getFirstPred();
314   }
315 
316   llvm_unreachable("Reached root without encountering the previous step");
317 }
318 
319 // updateLoopStack is called on every basic block, therefore it needs to be fast
updateLoopStack(const Stmt * LoopStmt,ASTContext & ASTCtx,ExplodedNode * Pred,unsigned maxVisitOnPath)320 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
321                                 ExplodedNode *Pred, unsigned maxVisitOnPath) {
322   auto State = Pred->getState();
323   auto LCtx = Pred->getLocationContext();
324 
325   if (!isLoopStmt(LoopStmt))
326     return State;
327 
328   auto LS = State->get<LoopStack>();
329   if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
330       LCtx == LS.getHead().getLocationContext()) {
331     if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
332       State = State->set<LoopStack>(LS.getTail());
333       State = State->add<LoopStack>(
334           LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
335     }
336     return State;
337   }
338   unsigned maxStep;
339   if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
340     State = State->add<LoopStack>(
341         LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
342     return State;
343   }
344 
345   unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
346 
347   unsigned innerMaxStep = maxStep * outerStep;
348   if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
349     State = State->add<LoopStack>(
350         LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
351   else
352     State = State->add<LoopStack>(
353         LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
354   return State;
355 }
356 
isUnrolledState(ProgramStateRef State)357 bool isUnrolledState(ProgramStateRef State) {
358   auto LS = State->get<LoopStack>();
359   if (LS.isEmpty() || !LS.getHead().isUnrolled())
360     return false;
361   return true;
362 }
363 }
364 }
365