xref: /freebsd/contrib/llvm-project/clang/lib/Sema/HLSLExternalSemaSource.cpp (revision 7ef62cebc2f965b0f640263e179276928885e33d)
1 //===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "clang/Sema/HLSLExternalSemaSource.h"
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Attr.h"
15 #include "clang/AST/DeclCXX.h"
16 #include "clang/Basic/AttrKinds.h"
17 #include "clang/Basic/HLSLRuntime.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/Sema.h"
20 #include "llvm/Frontend/HLSL/HLSLResource.h"
21 
22 #include <functional>
23 
24 using namespace clang;
25 using namespace llvm::hlsl;
26 
27 namespace {
28 
29 struct TemplateParameterListBuilder;
30 
31 struct BuiltinTypeDeclBuilder {
32   CXXRecordDecl *Record = nullptr;
33   ClassTemplateDecl *Template = nullptr;
34   ClassTemplateDecl *PrevTemplate = nullptr;
35   NamespaceDecl *HLSLNamespace = nullptr;
36   llvm::StringMap<FieldDecl *> Fields;
37 
38   BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) {
39     Record->startDefinition();
40     Template = Record->getDescribedClassTemplate();
41   }
42 
43   BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name)
44       : HLSLNamespace(Namespace) {
45     ASTContext &AST = S.getASTContext();
46     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
47 
48     LookupResult Result(S, &II, SourceLocation(), Sema::LookupTagName);
49     CXXRecordDecl *PrevDecl = nullptr;
50     if (S.LookupQualifiedName(Result, HLSLNamespace)) {
51       NamedDecl *Found = Result.getFoundDecl();
52       if (auto *TD = dyn_cast<ClassTemplateDecl>(Found)) {
53         PrevDecl = TD->getTemplatedDecl();
54         PrevTemplate = TD;
55       } else
56         PrevDecl = dyn_cast<CXXRecordDecl>(Found);
57       assert(PrevDecl && "Unexpected lookup result type.");
58     }
59 
60     if (PrevDecl && PrevDecl->isCompleteDefinition()) {
61       Record = PrevDecl;
62       return;
63     }
64 
65     Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class,
66                                    HLSLNamespace, SourceLocation(),
67                                    SourceLocation(), &II, PrevDecl, true);
68     Record->setImplicit(true);
69     Record->setLexicalDeclContext(HLSLNamespace);
70     Record->setHasExternalLexicalStorage();
71 
72     // Don't let anyone derive from built-in types.
73     Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(),
74                                               AttributeCommonInfo::AS_Keyword,
75                                               FinalAttr::Keyword_final));
76   }
77 
78   ~BuiltinTypeDeclBuilder() {
79     if (HLSLNamespace && !Template && Record->getDeclContext() == HLSLNamespace)
80       HLSLNamespace->addDecl(Record);
81   }
82 
83   BuiltinTypeDeclBuilder &
84   addMemberVariable(StringRef Name, QualType Type,
85                     AccessSpecifier Access = AccessSpecifier::AS_private) {
86     if (Record->isCompleteDefinition())
87       return *this;
88     assert(Record->isBeingDefined() &&
89            "Definition must be started before adding members!");
90     ASTContext &AST = Record->getASTContext();
91 
92     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
93     TypeSourceInfo *MemTySource =
94         AST.getTrivialTypeSourceInfo(Type, SourceLocation());
95     auto *Field = FieldDecl::Create(
96         AST, Record, SourceLocation(), SourceLocation(), &II, Type, MemTySource,
97         nullptr, false, InClassInitStyle::ICIS_NoInit);
98     Field->setAccess(Access);
99     Field->setImplicit(true);
100     Record->addDecl(Field);
101     Fields[Name] = Field;
102     return *this;
103   }
104 
105   BuiltinTypeDeclBuilder &
106   addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) {
107     if (Record->isCompleteDefinition())
108       return *this;
109     QualType Ty = Record->getASTContext().VoidPtrTy;
110     if (Template) {
111       if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
112               Template->getTemplateParameters()->getParam(0)))
113         Ty = Record->getASTContext().getPointerType(
114             QualType(TTD->getTypeForDecl(), 0));
115     }
116     return addMemberVariable("h", Ty, Access);
117   }
118 
119   BuiltinTypeDeclBuilder &
120   annotateResourceClass(HLSLResourceAttr::ResourceClass RC,
121                         HLSLResourceAttr::ResourceKind RK) {
122     if (Record->isCompleteDefinition())
123       return *this;
124     Record->addAttr(
125         HLSLResourceAttr::CreateImplicit(Record->getASTContext(), RC, RK));
126     return *this;
127   }
128 
129   static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
130                                             StringRef Name) {
131     CXXScopeSpec SS;
132     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
133     DeclarationNameInfo NameInfo =
134         DeclarationNameInfo(DeclarationName(&II), SourceLocation());
135     LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
136     S.LookupParsedName(R, S.getCurScope(), &SS, false);
137     assert(R.isSingleResult() &&
138            "Since this is a builtin it should always resolve!");
139     auto *VD = cast<ValueDecl>(R.getFoundDecl());
140     QualType Ty = VD->getType();
141     return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
142                                VD, false, NameInfo, Ty, VK_PRValue);
143   }
144 
145   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
146     return IntegerLiteral::Create(
147         AST,
148         llvm::APInt(AST.getIntWidth(AST.UnsignedCharTy),
149                     static_cast<uint8_t>(RC)),
150         AST.UnsignedCharTy, SourceLocation());
151   }
152 
153   BuiltinTypeDeclBuilder &addDefaultHandleConstructor(Sema &S,
154                                                       ResourceClass RC) {
155     if (Record->isCompleteDefinition())
156       return *this;
157     ASTContext &AST = Record->getASTContext();
158 
159     QualType ConstructorType =
160         AST.getFunctionType(AST.VoidTy, {}, FunctionProtoType::ExtProtoInfo());
161 
162     CanQualType CanTy = Record->getTypeForDecl()->getCanonicalTypeUnqualified();
163     DeclarationName Name = AST.DeclarationNames.getCXXConstructorName(CanTy);
164     CXXConstructorDecl *Constructor = CXXConstructorDecl::Create(
165         AST, Record, SourceLocation(),
166         DeclarationNameInfo(Name, SourceLocation()), ConstructorType,
167         AST.getTrivialTypeSourceInfo(ConstructorType, SourceLocation()),
168         ExplicitSpecifier(), false, true, false,
169         ConstexprSpecKind::Unspecified);
170 
171     DeclRefExpr *Fn =
172         lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle");
173 
174     Expr *RCExpr = emitResourceClassExpr(AST, RC);
175     Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
176                                   SourceLocation(), FPOptionsOverride());
177 
178     CXXThisExpr *This = new (AST) CXXThisExpr(
179         SourceLocation(),
180         Constructor->getThisType().getTypePtr()->getPointeeType(), true);
181     This->setValueKind(ExprValueKind::VK_LValue);
182     Expr *Handle = MemberExpr::CreateImplicit(AST, This, false, Fields["h"],
183                                               Fields["h"]->getType(), VK_LValue,
184                                               OK_Ordinary);
185 
186     // If the handle isn't a void pointer, cast the builtin result to the
187     // correct type.
188     if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) {
189       Call = CXXStaticCastExpr::Create(
190           AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr,
191           AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()),
192           FPOptionsOverride(), SourceLocation(), SourceLocation(),
193           SourceRange());
194     }
195 
196     BinaryOperator *Assign = BinaryOperator::Create(
197         AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary,
198         SourceLocation(), FPOptionsOverride());
199 
200     Constructor->setBody(
201         CompoundStmt::Create(AST, {Assign}, FPOptionsOverride(),
202                              SourceLocation(), SourceLocation()));
203     Constructor->setAccess(AccessSpecifier::AS_public);
204     Record->addDecl(Constructor);
205     return *this;
206   }
207 
208   BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
209     if (Record->isCompleteDefinition())
210       return *this;
211     addArraySubscriptOperator(true);
212     addArraySubscriptOperator(false);
213     return *this;
214   }
215 
216   BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
217     if (Record->isCompleteDefinition())
218       return *this;
219     assert(Fields.count("h") > 0 &&
220            "Subscript operator must be added after the handle.");
221 
222     FieldDecl *Handle = Fields["h"];
223     ASTContext &AST = Record->getASTContext();
224 
225     assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy &&
226            "Not yet supported for void pointer handles.");
227 
228     QualType ElemTy =
229         QualType(Handle->getType()->getPointeeOrArrayElementType(), 0);
230     QualType ReturnTy = ElemTy;
231 
232     FunctionProtoType::ExtProtoInfo ExtInfo;
233 
234     // Subscript operators return references to elements, const makes the
235     // reference and method const so that the underlying data is not mutable.
236     ReturnTy = AST.getLValueReferenceType(ReturnTy);
237     if (IsConst) {
238       ExtInfo.TypeQuals.addConst();
239       ReturnTy.addConst();
240     }
241 
242     QualType MethodTy =
243         AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
244     auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
245     auto *MethodDecl = CXXMethodDecl::Create(
246         AST, Record, SourceLocation(),
247         DeclarationNameInfo(
248             AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
249             SourceLocation()),
250         MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
251         SourceLocation());
252 
253     IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
254     auto *IdxParam = ParmVarDecl::Create(
255         AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
256         &II, AST.UnsignedIntTy,
257         AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
258         SC_None, nullptr);
259     MethodDecl->setParams({IdxParam});
260 
261     // Also add the parameter to the function prototype.
262     auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
263     FnProtoLoc.setParam(0, IdxParam);
264 
265     auto *This = new (AST) CXXThisExpr(
266         SourceLocation(),
267         MethodDecl->getThisType().getTypePtr()->getPointeeType(), true);
268     This->setValueKind(ExprValueKind::VK_LValue);
269     auto *HandleAccess = MemberExpr::CreateImplicit(
270         AST, This, false, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
271 
272     auto *IndexExpr = DeclRefExpr::Create(
273         AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
274         DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
275         AST.UnsignedIntTy, VK_PRValue);
276 
277     auto *Array =
278         new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue,
279                                      OK_Ordinary, SourceLocation());
280 
281     auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr);
282 
283     MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
284                                              SourceLocation(),
285                                              SourceLocation()));
286     MethodDecl->setLexicalDeclContext(Record);
287     MethodDecl->setAccess(AccessSpecifier::AS_public);
288     MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
289         AST, SourceRange(), AttributeCommonInfo::AS_Keyword,
290         AlwaysInlineAttr::CXX11_clang_always_inline));
291     Record->addDecl(MethodDecl);
292 
293     return *this;
294   }
295 
296   BuiltinTypeDeclBuilder &startDefinition() {
297     if (Record->isCompleteDefinition())
298       return *this;
299     Record->startDefinition();
300     return *this;
301   }
302 
303   BuiltinTypeDeclBuilder &completeDefinition() {
304     if (Record->isCompleteDefinition())
305       return *this;
306     assert(Record->isBeingDefined() &&
307            "Definition must be started before completing it.");
308 
309     Record->completeDefinition();
310     return *this;
311   }
312 
313   TemplateParameterListBuilder addTemplateArgumentList();
314 };
315 
316 struct TemplateParameterListBuilder {
317   BuiltinTypeDeclBuilder &Builder;
318   ASTContext &AST;
319   llvm::SmallVector<NamedDecl *> Params;
320 
321   TemplateParameterListBuilder(BuiltinTypeDeclBuilder &RB)
322       : Builder(RB), AST(RB.Record->getASTContext()) {}
323 
324   ~TemplateParameterListBuilder() { finalizeTemplateArgs(); }
325 
326   TemplateParameterListBuilder &
327   addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) {
328     if (Builder.Record->isCompleteDefinition())
329       return *this;
330     unsigned Position = static_cast<unsigned>(Params.size());
331     auto *Decl = TemplateTypeParmDecl::Create(
332         AST, Builder.Record->getDeclContext(), SourceLocation(),
333         SourceLocation(), /* TemplateDepth */ 0, Position,
334         &AST.Idents.get(Name, tok::TokenKind::identifier), /* Typename */ false,
335         /* ParameterPack */ false);
336     if (!DefaultValue.isNull())
337       Decl->setDefaultArgument(AST.getTrivialTypeSourceInfo(DefaultValue));
338 
339     Params.emplace_back(Decl);
340     return *this;
341   }
342 
343   BuiltinTypeDeclBuilder &finalizeTemplateArgs() {
344     if (Params.empty())
345       return Builder;
346     auto *ParamList =
347         TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
348                                       Params, SourceLocation(), nullptr);
349     Builder.Template = ClassTemplateDecl::Create(
350         AST, Builder.Record->getDeclContext(), SourceLocation(),
351         DeclarationName(Builder.Record->getIdentifier()), ParamList,
352         Builder.Record);
353     Builder.Record->setDescribedClassTemplate(Builder.Template);
354     Builder.Template->setImplicit(true);
355     Builder.Template->setLexicalDeclContext(Builder.Record->getDeclContext());
356     // NOTE: setPreviousDecl before addDecl so new decl replace old decl when
357     // make visible.
358     Builder.Template->setPreviousDecl(Builder.PrevTemplate);
359     Builder.Record->getDeclContext()->addDecl(Builder.Template);
360     Params.clear();
361 
362     QualType T = Builder.Template->getInjectedClassNameSpecialization();
363     T = AST.getInjectedClassNameType(Builder.Record, T);
364 
365     return Builder;
366   }
367 };
368 
369 TemplateParameterListBuilder BuiltinTypeDeclBuilder::addTemplateArgumentList() {
370   return TemplateParameterListBuilder(*this);
371 }
372 } // namespace
373 
374 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
375 
376 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
377   SemaPtr = &S;
378   ASTContext &AST = SemaPtr->getASTContext();
379   // If the translation unit has external storage force external decls to load.
380   if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
381     (void)AST.getTranslationUnitDecl()->decls_begin();
382 
383   IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
384   LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
385   NamespaceDecl *PrevDecl = nullptr;
386   if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl()))
387     PrevDecl = Result.getAsSingle<NamespaceDecl>();
388   HLSLNamespace = NamespaceDecl::Create(
389       AST, AST.getTranslationUnitDecl(), /*Inline=*/false, SourceLocation(),
390       SourceLocation(), &HLSL, PrevDecl, /*Nested=*/false);
391   HLSLNamespace->setImplicit(true);
392   HLSLNamespace->setHasExternalLexicalStorage();
393   AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
394 
395   // Force external decls in the HLSL namespace to load from the PCH.
396   (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
397   defineTrivialHLSLTypes();
398   forwardDeclareHLSLTypes();
399 
400   // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
401   // built in types inside a namespace, but we are planning to change that in
402   // the near future. In order to be source compatible older versions of HLSL
403   // will need to implicitly use the hlsl namespace. For now in clang everything
404   // will get added to the namespace, and we can remove the using directive for
405   // future language versions to match HLSL's evolution.
406   auto *UsingDecl = UsingDirectiveDecl::Create(
407       AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
408       NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
409       AST.getTranslationUnitDecl());
410 
411   AST.getTranslationUnitDecl()->addDecl(UsingDecl);
412 }
413 
414 void HLSLExternalSemaSource::defineHLSLVectorAlias() {
415   ASTContext &AST = SemaPtr->getASTContext();
416 
417   llvm::SmallVector<NamedDecl *> TemplateParams;
418 
419   auto *TypeParam = TemplateTypeParmDecl::Create(
420       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
421       &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
422   TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
423 
424   TemplateParams.emplace_back(TypeParam);
425 
426   auto *SizeParam = NonTypeTemplateParmDecl::Create(
427       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
428       &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
429       false, AST.getTrivialTypeSourceInfo(AST.IntTy));
430   Expr *LiteralExpr =
431       IntegerLiteral::Create(AST, llvm::APInt(AST.getIntWidth(AST.IntTy), 4),
432                              AST.IntTy, SourceLocation());
433   SizeParam->setDefaultArgument(LiteralExpr);
434   TemplateParams.emplace_back(SizeParam);
435 
436   auto *ParamList =
437       TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
438                                     TemplateParams, SourceLocation(), nullptr);
439 
440   IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
441 
442   QualType AliasType = AST.getDependentSizedExtVectorType(
443       AST.getTemplateTypeParmType(0, 0, false, TypeParam),
444       DeclRefExpr::Create(
445           AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
446           DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
447           AST.IntTy, VK_LValue),
448       SourceLocation());
449 
450   auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
451                                        SourceLocation(), &II,
452                                        AST.getTrivialTypeSourceInfo(AliasType));
453   Record->setImplicit(true);
454 
455   auto *Template =
456       TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
457                                     Record->getIdentifier(), ParamList, Record);
458 
459   Record->setDescribedAliasTemplate(Template);
460   Template->setImplicit(true);
461   Template->setLexicalDeclContext(Record->getDeclContext());
462   HLSLNamespace->addDecl(Template);
463 }
464 
465 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
466   defineHLSLVectorAlias();
467 
468   ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource")
469                      .startDefinition()
470                      .addHandleMember(AccessSpecifier::AS_public)
471                      .completeDefinition()
472                      .Record;
473 }
474 
475 void HLSLExternalSemaSource::forwardDeclareHLSLTypes() {
476   CXXRecordDecl *Decl;
477   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
478              .addTemplateArgumentList()
479              .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy)
480              .finalizeTemplateArgs()
481              .Record;
482   if (!Decl->isCompleteDefinition())
483     Completions.insert(
484         std::make_pair(Decl->getCanonicalDecl(),
485                        std::bind(&HLSLExternalSemaSource::completeBufferType,
486                                  this, std::placeholders::_1)));
487 }
488 
489 void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
490   if (!isa<CXXRecordDecl>(Tag))
491     return;
492   auto Record = cast<CXXRecordDecl>(Tag);
493 
494   // If this is a specialization, we need to get the underlying templated
495   // declaration and complete that.
496   if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
497     Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
498   Record = Record->getCanonicalDecl();
499   auto It = Completions.find(Record);
500   if (It == Completions.end())
501     return;
502   It->second(Record);
503 }
504 
505 void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) {
506   BuiltinTypeDeclBuilder(Record)
507       .addHandleMember()
508       .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
509       .addArraySubscriptOperators()
510       .annotateResourceClass(HLSLResourceAttr::UAV,
511                              HLSLResourceAttr::TypedBuffer)
512       .completeDefinition();
513 }
514