xref: /freebsd/contrib/llvm-project/clang/lib/Sema/SemaSYCL.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This implements Semantic Analysis for SYCL constructs.
9 //===----------------------------------------------------------------------===//
10 
11 #include "clang/Sema/SemaSYCL.h"
12 #include "clang/AST/Mangle.h"
13 #include "clang/Sema/Attr.h"
14 #include "clang/Sema/ParsedAttr.h"
15 #include "clang/Sema/Sema.h"
16 #include "clang/Sema/SemaDiagnostic.h"
17 
18 using namespace clang;
19 
20 // -----------------------------------------------------------------------------
21 // SYCL device specific diagnostics implementation
22 // -----------------------------------------------------------------------------
23 
SemaSYCL(Sema & S)24 SemaSYCL::SemaSYCL(Sema &S) : SemaBase(S) {}
25 
DiagIfDeviceCode(SourceLocation Loc,unsigned DiagID)26 Sema::SemaDiagnosticBuilder SemaSYCL::DiagIfDeviceCode(SourceLocation Loc,
27                                                        unsigned DiagID) {
28   assert(getLangOpts().SYCLIsDevice &&
29          "Should only be called during SYCL compilation");
30   FunctionDecl *FD = dyn_cast<FunctionDecl>(SemaRef.getCurLexicalContext());
31   SemaDiagnosticBuilder::Kind DiagKind = [this, FD] {
32     if (!FD)
33       return SemaDiagnosticBuilder::K_Nop;
34     if (SemaRef.getEmissionStatus(FD) == Sema::FunctionEmissionStatus::Emitted)
35       return SemaDiagnosticBuilder::K_ImmediateWithCallStack;
36     return SemaDiagnosticBuilder::K_Deferred;
37   }();
38   return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, FD, SemaRef);
39 }
40 
isZeroSizedArray(SemaSYCL & S,QualType Ty)41 static bool isZeroSizedArray(SemaSYCL &S, QualType Ty) {
42   if (const auto *CAT = S.getASTContext().getAsConstantArrayType(Ty))
43     return CAT->isZeroSize();
44   return false;
45 }
46 
deepTypeCheckForDevice(SourceLocation UsedAt,llvm::DenseSet<QualType> Visited,ValueDecl * DeclToCheck)47 void SemaSYCL::deepTypeCheckForDevice(SourceLocation UsedAt,
48                                       llvm::DenseSet<QualType> Visited,
49                                       ValueDecl *DeclToCheck) {
50   assert(getLangOpts().SYCLIsDevice &&
51          "Should only be called during SYCL compilation");
52   // Emit notes only for the first discovered declaration of unsupported type
53   // to avoid mess of notes. This flag is to track that error already happened.
54   bool NeedToEmitNotes = true;
55 
56   auto Check = [&](QualType TypeToCheck, const ValueDecl *D) {
57     bool ErrorFound = false;
58     if (isZeroSizedArray(*this, TypeToCheck)) {
59       DiagIfDeviceCode(UsedAt, diag::err_typecheck_zero_array_size) << 1;
60       ErrorFound = true;
61     }
62     // Checks for other types can also be done here.
63     if (ErrorFound) {
64       if (NeedToEmitNotes) {
65         if (auto *FD = dyn_cast<FieldDecl>(D))
66           DiagIfDeviceCode(FD->getLocation(),
67                            diag::note_illegal_field_declared_here)
68               << FD->getType()->isPointerType() << FD->getType();
69         else
70           DiagIfDeviceCode(D->getLocation(), diag::note_declared_at);
71       }
72     }
73 
74     return ErrorFound;
75   };
76 
77   // In case we have a Record used do the DFS for a bad field.
78   SmallVector<const ValueDecl *, 4> StackForRecursion;
79   StackForRecursion.push_back(DeclToCheck);
80 
81   // While doing DFS save how we get there to emit a nice set of notes.
82   SmallVector<const FieldDecl *, 4> History;
83   History.push_back(nullptr);
84 
85   do {
86     const ValueDecl *Next = StackForRecursion.pop_back_val();
87     if (!Next) {
88       assert(!History.empty());
89       // Found a marker, we have gone up a level.
90       History.pop_back();
91       continue;
92     }
93     QualType NextTy = Next->getType();
94 
95     if (!Visited.insert(NextTy).second)
96       continue;
97 
98     auto EmitHistory = [&]() {
99       // The first element is always nullptr.
100       for (uint64_t Index = 1; Index < History.size(); ++Index) {
101         DiagIfDeviceCode(History[Index]->getLocation(),
102                          diag::note_within_field_of_type)
103             << History[Index]->getType();
104       }
105     };
106 
107     if (Check(NextTy, Next)) {
108       if (NeedToEmitNotes)
109         EmitHistory();
110       NeedToEmitNotes = false;
111     }
112 
113     // In case pointer/array/reference type is met get pointee type, then
114     // proceed with that type.
115     while (NextTy->isAnyPointerType() || NextTy->isArrayType() ||
116            NextTy->isReferenceType()) {
117       if (NextTy->isArrayType())
118         NextTy = QualType{NextTy->getArrayElementTypeNoTypeQual(), 0};
119       else
120         NextTy = NextTy->getPointeeType();
121       if (Check(NextTy, Next)) {
122         if (NeedToEmitNotes)
123           EmitHistory();
124         NeedToEmitNotes = false;
125       }
126     }
127 
128     if (const auto *RecDecl = NextTy->getAsRecordDecl()) {
129       if (auto *NextFD = dyn_cast<FieldDecl>(Next))
130         History.push_back(NextFD);
131       // When nullptr is discovered, this means we've gone back up a level, so
132       // the history should be cleaned.
133       StackForRecursion.push_back(nullptr);
134       llvm::copy(RecDecl->fields(), std::back_inserter(StackForRecursion));
135     }
136   } while (!StackForRecursion.empty());
137 }
138 
BuildUniqueStableNameExpr(SourceLocation OpLoc,SourceLocation LParen,SourceLocation RParen,TypeSourceInfo * TSI)139 ExprResult SemaSYCL::BuildUniqueStableNameExpr(SourceLocation OpLoc,
140                                                SourceLocation LParen,
141                                                SourceLocation RParen,
142                                                TypeSourceInfo *TSI) {
143   return SYCLUniqueStableNameExpr::Create(getASTContext(), OpLoc, LParen,
144                                           RParen, TSI);
145 }
146 
ActOnUniqueStableNameExpr(SourceLocation OpLoc,SourceLocation LParen,SourceLocation RParen,ParsedType ParsedTy)147 ExprResult SemaSYCL::ActOnUniqueStableNameExpr(SourceLocation OpLoc,
148                                                SourceLocation LParen,
149                                                SourceLocation RParen,
150                                                ParsedType ParsedTy) {
151   TypeSourceInfo *TSI = nullptr;
152   QualType Ty = SemaRef.GetTypeFromParser(ParsedTy, &TSI);
153 
154   if (Ty.isNull())
155     return ExprError();
156   if (!TSI)
157     TSI = getASTContext().getTrivialTypeSourceInfo(Ty, LParen);
158 
159   return BuildUniqueStableNameExpr(OpLoc, LParen, RParen, TSI);
160 }
161 
handleKernelAttr(Decl * D,const ParsedAttr & AL)162 void SemaSYCL::handleKernelAttr(Decl *D, const ParsedAttr &AL) {
163   // The 'sycl_kernel' attribute applies only to function templates.
164   const auto *FD = cast<FunctionDecl>(D);
165   const FunctionTemplateDecl *FT = FD->getDescribedFunctionTemplate();
166   assert(FT && "Function template is expected");
167 
168   // Function template must have at least two template parameters.
169   const TemplateParameterList *TL = FT->getTemplateParameters();
170   if (TL->size() < 2) {
171     Diag(FT->getLocation(), diag::warn_sycl_kernel_num_of_template_params);
172     return;
173   }
174 
175   // Template parameters must be typenames.
176   for (unsigned I = 0; I < 2; ++I) {
177     const NamedDecl *TParam = TL->getParam(I);
178     if (isa<NonTypeTemplateParmDecl>(TParam)) {
179       Diag(FT->getLocation(),
180            diag::warn_sycl_kernel_invalid_template_param_type);
181       return;
182     }
183   }
184 
185   // Function must have at least one argument.
186   if (getFunctionOrMethodNumParams(D) != 1) {
187     Diag(FT->getLocation(), diag::warn_sycl_kernel_num_of_function_params);
188     return;
189   }
190 
191   // Function must return void.
192   QualType RetTy = getFunctionOrMethodResultType(D);
193   if (!RetTy->isVoidType()) {
194     Diag(FT->getLocation(), diag::warn_sycl_kernel_return_type);
195     return;
196   }
197 
198   handleSimpleAttribute<SYCLKernelAttr>(*this, D, AL);
199 }
200