xref: /freebsd/contrib/llvm-project/clang/lib/Sema/HLSLExternalSemaSource.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "clang/Sema/HLSLExternalSemaSource.h"
13 #include "HLSLBuiltinTypeDeclBuilder.h"
14 #include "clang/AST/ASTContext.h"
15 #include "clang/AST/Attr.h"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/Expr.h"
19 #include "clang/AST/Type.h"
20 #include "clang/Basic/SourceLocation.h"
21 #include "clang/Sema/Lookup.h"
22 #include "clang/Sema/Sema.h"
23 #include "clang/Sema/SemaHLSL.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 using namespace clang;
27 using namespace llvm::hlsl;
28 
29 using clang::hlsl::BuiltinTypeDeclBuilder;
30 
InitializeSema(Sema & S)31 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
32   SemaPtr = &S;
33   ASTContext &AST = SemaPtr->getASTContext();
34   // If the translation unit has external storage force external decls to load.
35   if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
36     (void)AST.getTranslationUnitDecl()->decls_begin();
37 
38   IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
39   LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
40   NamespaceDecl *PrevDecl = nullptr;
41   if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl()))
42     PrevDecl = Result.getAsSingle<NamespaceDecl>();
43   HLSLNamespace = NamespaceDecl::Create(
44       AST, AST.getTranslationUnitDecl(), /*Inline=*/false, SourceLocation(),
45       SourceLocation(), &HLSL, PrevDecl, /*Nested=*/false);
46   HLSLNamespace->setImplicit(true);
47   HLSLNamespace->setHasExternalLexicalStorage();
48   AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
49 
50   // Force external decls in the HLSL namespace to load from the PCH.
51   (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
52   defineTrivialHLSLTypes();
53   defineHLSLTypesWithForwardDeclarations();
54 
55   // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
56   // built in types inside a namespace, but we are planning to change that in
57   // the near future. In order to be source compatible older versions of HLSL
58   // will need to implicitly use the hlsl namespace. For now in clang everything
59   // will get added to the namespace, and we can remove the using directive for
60   // future language versions to match HLSL's evolution.
61   auto *UsingDecl = UsingDirectiveDecl::Create(
62       AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
63       NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
64       AST.getTranslationUnitDecl());
65 
66   AST.getTranslationUnitDecl()->addDecl(UsingDecl);
67 }
68 
defineHLSLVectorAlias()69 void HLSLExternalSemaSource::defineHLSLVectorAlias() {
70   ASTContext &AST = SemaPtr->getASTContext();
71 
72   llvm::SmallVector<NamedDecl *> TemplateParams;
73 
74   auto *TypeParam = TemplateTypeParmDecl::Create(
75       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
76       &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
77   TypeParam->setDefaultArgument(
78       AST, SemaPtr->getTrivialTemplateArgumentLoc(
79                TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
80 
81   TemplateParams.emplace_back(TypeParam);
82 
83   auto *SizeParam = NonTypeTemplateParmDecl::Create(
84       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
85       &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
86       false, AST.getTrivialTypeSourceInfo(AST.IntTy));
87   llvm::APInt Val(AST.getIntWidth(AST.IntTy), 4);
88   TemplateArgument Default(AST, llvm::APSInt(std::move(Val)), AST.IntTy,
89                            /*IsDefaulted=*/true);
90   SizeParam->setDefaultArgument(
91       AST, SemaPtr->getTrivialTemplateArgumentLoc(Default, AST.IntTy,
92                                                   SourceLocation(), SizeParam));
93   TemplateParams.emplace_back(SizeParam);
94 
95   auto *ParamList =
96       TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
97                                     TemplateParams, SourceLocation(), nullptr);
98 
99   IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
100 
101   QualType AliasType = AST.getDependentSizedExtVectorType(
102       AST.getTemplateTypeParmType(0, 0, false, TypeParam),
103       DeclRefExpr::Create(
104           AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
105           DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
106           AST.IntTy, VK_LValue),
107       SourceLocation());
108 
109   auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
110                                        SourceLocation(), &II,
111                                        AST.getTrivialTypeSourceInfo(AliasType));
112   Record->setImplicit(true);
113 
114   auto *Template =
115       TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
116                                     Record->getIdentifier(), ParamList, Record);
117 
118   Record->setDescribedAliasTemplate(Template);
119   Template->setImplicit(true);
120   Template->setLexicalDeclContext(Record->getDeclContext());
121   HLSLNamespace->addDecl(Template);
122 }
123 
defineTrivialHLSLTypes()124 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
125   defineHLSLVectorAlias();
126 }
127 
128 /// Set up common members and attributes for buffer types
setupBufferType(CXXRecordDecl * Decl,Sema & S,ResourceClass RC,bool IsROV,bool RawBuffer)129 static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
130                                               ResourceClass RC, bool IsROV,
131                                               bool RawBuffer) {
132   return BuiltinTypeDeclBuilder(S, Decl)
133       .addHandleMember(RC, IsROV, RawBuffer)
134       .addDefaultHandleConstructor()
135       .addHandleConstructorFromBinding()
136       .addHandleConstructorFromImplicitBinding();
137 }
138 
139 // This function is responsible for constructing the constraint expression for
140 // this concept:
141 // template<typename T> concept is_typed_resource_element_compatible =
142 // __is_typed_resource_element_compatible<T>;
constructTypedBufferConstraintExpr(Sema & S,SourceLocation NameLoc,TemplateTypeParmDecl * T)143 static Expr *constructTypedBufferConstraintExpr(Sema &S, SourceLocation NameLoc,
144                                                 TemplateTypeParmDecl *T) {
145   ASTContext &Context = S.getASTContext();
146 
147   // Obtain the QualType for 'bool'
148   QualType BoolTy = Context.BoolTy;
149 
150   // Create a QualType that points to this TemplateTypeParmDecl
151   QualType TType = Context.getTypeDeclType(T);
152 
153   // Create a TypeSourceInfo for the template type parameter 'T'
154   TypeSourceInfo *TTypeSourceInfo =
155       Context.getTrivialTypeSourceInfo(TType, NameLoc);
156 
157   TypeTraitExpr *TypedResExpr = TypeTraitExpr::Create(
158       Context, BoolTy, NameLoc, UTT_IsTypedResourceElementCompatible,
159       {TTypeSourceInfo}, NameLoc, true);
160 
161   return TypedResExpr;
162 }
163 
164 // This function is responsible for constructing the constraint expression for
165 // this concept:
166 // template<typename T> concept is_structured_resource_element_compatible =
167 // !__is_intangible<T> && sizeof(T) >= 1;
constructStructuredBufferConstraintExpr(Sema & S,SourceLocation NameLoc,TemplateTypeParmDecl * T)168 static Expr *constructStructuredBufferConstraintExpr(Sema &S,
169                                                      SourceLocation NameLoc,
170                                                      TemplateTypeParmDecl *T) {
171   ASTContext &Context = S.getASTContext();
172 
173   // Obtain the QualType for 'bool'
174   QualType BoolTy = Context.BoolTy;
175 
176   // Create a QualType that points to this TemplateTypeParmDecl
177   QualType TType = Context.getTypeDeclType(T);
178 
179   // Create a TypeSourceInfo for the template type parameter 'T'
180   TypeSourceInfo *TTypeSourceInfo =
181       Context.getTrivialTypeSourceInfo(TType, NameLoc);
182 
183   TypeTraitExpr *IsIntangibleExpr =
184       TypeTraitExpr::Create(Context, BoolTy, NameLoc, UTT_IsIntangibleType,
185                             {TTypeSourceInfo}, NameLoc, true);
186 
187   // negate IsIntangibleExpr
188   UnaryOperator *NotIntangibleExpr = UnaryOperator::Create(
189       Context, IsIntangibleExpr, UO_LNot, BoolTy, VK_LValue, OK_Ordinary,
190       NameLoc, false, FPOptionsOverride());
191 
192   // element types also may not be of 0 size
193   UnaryExprOrTypeTraitExpr *SizeOfExpr = new (Context) UnaryExprOrTypeTraitExpr(
194       UETT_SizeOf, TTypeSourceInfo, BoolTy, NameLoc, NameLoc);
195 
196   // Create a BinaryOperator that checks if the size of the type is not equal to
197   // 1 Empty structs have a size of 1 in HLSL, so we need to check for that
198   IntegerLiteral *rhs = IntegerLiteral::Create(
199       Context, llvm::APInt(Context.getTypeSize(Context.getSizeType()), 1, true),
200       Context.getSizeType(), NameLoc);
201 
202   BinaryOperator *SizeGEQOneExpr =
203       BinaryOperator::Create(Context, SizeOfExpr, rhs, BO_GE, BoolTy, VK_LValue,
204                              OK_Ordinary, NameLoc, FPOptionsOverride());
205 
206   // Combine the two constraints
207   BinaryOperator *CombinedExpr = BinaryOperator::Create(
208       Context, NotIntangibleExpr, SizeGEQOneExpr, BO_LAnd, BoolTy, VK_LValue,
209       OK_Ordinary, NameLoc, FPOptionsOverride());
210 
211   return CombinedExpr;
212 }
213 
constructBufferConceptDecl(Sema & S,NamespaceDecl * NSD,bool isTypedBuffer)214 static ConceptDecl *constructBufferConceptDecl(Sema &S, NamespaceDecl *NSD,
215                                                bool isTypedBuffer) {
216   ASTContext &Context = S.getASTContext();
217   DeclContext *DC = NSD->getDeclContext();
218   SourceLocation DeclLoc = SourceLocation();
219 
220   IdentifierInfo &ElementTypeII = Context.Idents.get("element_type");
221   TemplateTypeParmDecl *T = TemplateTypeParmDecl::Create(
222       Context, NSD->getDeclContext(), DeclLoc, DeclLoc,
223       /*D=*/0,
224       /*P=*/0,
225       /*Id=*/&ElementTypeII,
226       /*Typename=*/true,
227       /*ParameterPack=*/false);
228 
229   T->setDeclContext(DC);
230   T->setReferenced();
231 
232   // Create and Attach Template Parameter List to ConceptDecl
233   TemplateParameterList *ConceptParams = TemplateParameterList::Create(
234       Context, DeclLoc, DeclLoc, {T}, DeclLoc, nullptr);
235 
236   DeclarationName DeclName;
237   Expr *ConstraintExpr = nullptr;
238 
239   if (isTypedBuffer) {
240     DeclName = DeclarationName(
241         &Context.Idents.get("__is_typed_resource_element_compatible"));
242     ConstraintExpr = constructTypedBufferConstraintExpr(S, DeclLoc, T);
243   } else {
244     DeclName = DeclarationName(
245         &Context.Idents.get("__is_structured_resource_element_compatible"));
246     ConstraintExpr = constructStructuredBufferConstraintExpr(S, DeclLoc, T);
247   }
248 
249   // Create a ConceptDecl
250   ConceptDecl *CD =
251       ConceptDecl::Create(Context, NSD->getDeclContext(), DeclLoc, DeclName,
252                           ConceptParams, ConstraintExpr);
253 
254   // Attach the template parameter list to the ConceptDecl
255   CD->setTemplateParameters(ConceptParams);
256 
257   // Add the concept declaration to the Translation Unit Decl
258   NSD->getDeclContext()->addDecl(CD);
259 
260   return CD;
261 }
262 
defineHLSLTypesWithForwardDeclarations()263 void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
264   CXXRecordDecl *Decl;
265   ConceptDecl *TypedBufferConcept = constructBufferConceptDecl(
266       *SemaPtr, HLSLNamespace, /*isTypedBuffer*/ true);
267   ConceptDecl *StructuredBufferConcept = constructBufferConceptDecl(
268       *SemaPtr, HLSLNamespace, /*isTypedBuffer*/ false);
269 
270   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Buffer")
271              .addSimpleTemplateParams({"element_type"}, TypedBufferConcept)
272              .finalizeForwardDeclaration();
273 
274   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
275     setupBufferType(Decl, *SemaPtr, ResourceClass::SRV, /*IsROV=*/false,
276                     /*RawBuffer=*/false)
277         .addArraySubscriptOperators()
278         .addLoadMethods()
279         .completeDefinition();
280   });
281 
282   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
283              .addSimpleTemplateParams({"element_type"}, TypedBufferConcept)
284              .finalizeForwardDeclaration();
285 
286   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
287     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/false,
288                     /*RawBuffer=*/false)
289         .addArraySubscriptOperators()
290         .addLoadMethods()
291         .completeDefinition();
292   });
293 
294   Decl =
295       BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RasterizerOrderedBuffer")
296           .addSimpleTemplateParams({"element_type"}, StructuredBufferConcept)
297           .finalizeForwardDeclaration();
298   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
299     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/true,
300                     /*RawBuffer=*/false)
301         .addArraySubscriptOperators()
302         .addLoadMethods()
303         .completeDefinition();
304   });
305 
306   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "StructuredBuffer")
307              .addSimpleTemplateParams({"element_type"}, StructuredBufferConcept)
308              .finalizeForwardDeclaration();
309   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
310     setupBufferType(Decl, *SemaPtr, ResourceClass::SRV, /*IsROV=*/false,
311                     /*RawBuffer=*/true)
312         .addArraySubscriptOperators()
313         .addLoadMethods()
314         .completeDefinition();
315   });
316 
317   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWStructuredBuffer")
318              .addSimpleTemplateParams({"element_type"}, StructuredBufferConcept)
319              .finalizeForwardDeclaration();
320   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
321     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/false,
322                     /*RawBuffer=*/true)
323         .addArraySubscriptOperators()
324         .addLoadMethods()
325         .addIncrementCounterMethod()
326         .addDecrementCounterMethod()
327         .completeDefinition();
328   });
329 
330   Decl =
331       BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "AppendStructuredBuffer")
332           .addSimpleTemplateParams({"element_type"}, StructuredBufferConcept)
333           .finalizeForwardDeclaration();
334   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
335     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/false,
336                     /*RawBuffer=*/true)
337         .addAppendMethod()
338         .completeDefinition();
339   });
340 
341   Decl =
342       BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "ConsumeStructuredBuffer")
343           .addSimpleTemplateParams({"element_type"}, StructuredBufferConcept)
344           .finalizeForwardDeclaration();
345   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
346     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/false,
347                     /*RawBuffer=*/true)
348         .addConsumeMethod()
349         .completeDefinition();
350   });
351 
352   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace,
353                                 "RasterizerOrderedStructuredBuffer")
354              .addSimpleTemplateParams({"element_type"}, StructuredBufferConcept)
355              .finalizeForwardDeclaration();
356   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
357     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/true,
358                     /*RawBuffer=*/true)
359         .addArraySubscriptOperators()
360         .addLoadMethods()
361         .addIncrementCounterMethod()
362         .addDecrementCounterMethod()
363         .completeDefinition();
364   });
365 
366   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "ByteAddressBuffer")
367              .finalizeForwardDeclaration();
368   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
369     setupBufferType(Decl, *SemaPtr, ResourceClass::SRV, /*IsROV=*/false,
370                     /*RawBuffer=*/true)
371         .completeDefinition();
372   });
373   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWByteAddressBuffer")
374              .finalizeForwardDeclaration();
375   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
376     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/false,
377                     /*RawBuffer=*/true)
378         .completeDefinition();
379   });
380   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace,
381                                 "RasterizerOrderedByteAddressBuffer")
382              .finalizeForwardDeclaration();
383   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
384     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, /*IsROV=*/true,
385                     /*RawBuffer=*/true)
386         .completeDefinition();
387   });
388 }
389 
onCompletion(CXXRecordDecl * Record,CompletionFunction Fn)390 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
391                                           CompletionFunction Fn) {
392   if (!Record->isCompleteDefinition())
393     Completions.insert(std::make_pair(Record->getCanonicalDecl(), Fn));
394 }
395 
CompleteType(TagDecl * Tag)396 void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
397   if (!isa<CXXRecordDecl>(Tag))
398     return;
399   auto Record = cast<CXXRecordDecl>(Tag);
400 
401   // If this is a specialization, we need to get the underlying templated
402   // declaration and complete that.
403   if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
404     Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
405   Record = Record->getCanonicalDecl();
406   auto It = Completions.find(Record);
407   if (It == Completions.end())
408     return;
409   It->second(Record);
410 }
411