xref: /freebsd/contrib/llvm-project/clang/lib/Sema/HLSLExternalSemaSource.cpp (revision b1879975794772ee51f0b4865753364c7d7626c3)
1 //===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "clang/Sema/HLSLExternalSemaSource.h"
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Attr.h"
15 #include "clang/AST/DeclCXX.h"
16 #include "clang/Basic/AttrKinds.h"
17 #include "clang/Basic/HLSLRuntime.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/Sema.h"
20 #include "llvm/Frontend/HLSL/HLSLResource.h"
21 
22 #include <functional>
23 
24 using namespace clang;
25 using namespace llvm::hlsl;
26 
27 namespace {
28 
29 struct TemplateParameterListBuilder;
30 
31 struct BuiltinTypeDeclBuilder {
32   CXXRecordDecl *Record = nullptr;
33   ClassTemplateDecl *Template = nullptr;
34   ClassTemplateDecl *PrevTemplate = nullptr;
35   NamespaceDecl *HLSLNamespace = nullptr;
36   llvm::StringMap<FieldDecl *> Fields;
37 
38   BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) {
39     Record->startDefinition();
40     Template = Record->getDescribedClassTemplate();
41   }
42 
43   BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name)
44       : HLSLNamespace(Namespace) {
45     ASTContext &AST = S.getASTContext();
46     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
47 
48     LookupResult Result(S, &II, SourceLocation(), Sema::LookupTagName);
49     CXXRecordDecl *PrevDecl = nullptr;
50     if (S.LookupQualifiedName(Result, HLSLNamespace)) {
51       NamedDecl *Found = Result.getFoundDecl();
52       if (auto *TD = dyn_cast<ClassTemplateDecl>(Found)) {
53         PrevDecl = TD->getTemplatedDecl();
54         PrevTemplate = TD;
55       } else
56         PrevDecl = dyn_cast<CXXRecordDecl>(Found);
57       assert(PrevDecl && "Unexpected lookup result type.");
58     }
59 
60     if (PrevDecl && PrevDecl->isCompleteDefinition()) {
61       Record = PrevDecl;
62       return;
63     }
64 
65     Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::Class, HLSLNamespace,
66                                    SourceLocation(), SourceLocation(), &II,
67                                    PrevDecl, true);
68     Record->setImplicit(true);
69     Record->setLexicalDeclContext(HLSLNamespace);
70     Record->setHasExternalLexicalStorage();
71 
72     // Don't let anyone derive from built-in types.
73     Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(),
74                                               FinalAttr::Keyword_final));
75   }
76 
77   ~BuiltinTypeDeclBuilder() {
78     if (HLSLNamespace && !Template && Record->getDeclContext() == HLSLNamespace)
79       HLSLNamespace->addDecl(Record);
80   }
81 
82   BuiltinTypeDeclBuilder &
83   addMemberVariable(StringRef Name, QualType Type,
84                     AccessSpecifier Access = AccessSpecifier::AS_private) {
85     if (Record->isCompleteDefinition())
86       return *this;
87     assert(Record->isBeingDefined() &&
88            "Definition must be started before adding members!");
89     ASTContext &AST = Record->getASTContext();
90 
91     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
92     TypeSourceInfo *MemTySource =
93         AST.getTrivialTypeSourceInfo(Type, SourceLocation());
94     auto *Field = FieldDecl::Create(
95         AST, Record, SourceLocation(), SourceLocation(), &II, Type, MemTySource,
96         nullptr, false, InClassInitStyle::ICIS_NoInit);
97     Field->setAccess(Access);
98     Field->setImplicit(true);
99     Record->addDecl(Field);
100     Fields[Name] = Field;
101     return *this;
102   }
103 
104   BuiltinTypeDeclBuilder &
105   addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) {
106     if (Record->isCompleteDefinition())
107       return *this;
108     QualType Ty = Record->getASTContext().VoidPtrTy;
109     if (Template) {
110       if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
111               Template->getTemplateParameters()->getParam(0)))
112         Ty = Record->getASTContext().getPointerType(
113             QualType(TTD->getTypeForDecl(), 0));
114     }
115     return addMemberVariable("h", Ty, Access);
116   }
117 
118   BuiltinTypeDeclBuilder &annotateHLSLResource(ResourceClass RC,
119                                                ResourceKind RK, bool IsROV) {
120     if (Record->isCompleteDefinition())
121       return *this;
122     Record->addAttr(
123         HLSLResourceClassAttr::CreateImplicit(Record->getASTContext(), RC));
124     Record->addAttr(
125         HLSLResourceAttr::CreateImplicit(Record->getASTContext(), RK, IsROV));
126     return *this;
127   }
128 
129   static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
130                                             StringRef Name) {
131     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
132     DeclarationNameInfo NameInfo =
133         DeclarationNameInfo(DeclarationName(&II), SourceLocation());
134     LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
135     // AllowBuiltinCreation is false but LookupDirect will create
136     // the builtin when searching the global scope anyways...
137     S.LookupName(R, S.getCurScope());
138     // FIXME: If the builtin function was user-declared in global scope,
139     // this assert *will* fail. Should this call LookupBuiltin instead?
140     assert(R.isSingleResult() &&
141            "Since this is a builtin it should always resolve!");
142     auto *VD = cast<ValueDecl>(R.getFoundDecl());
143     QualType Ty = VD->getType();
144     return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
145                                VD, false, NameInfo, Ty, VK_PRValue);
146   }
147 
148   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
149     return IntegerLiteral::Create(
150         AST,
151         llvm::APInt(AST.getIntWidth(AST.UnsignedCharTy),
152                     static_cast<uint8_t>(RC)),
153         AST.UnsignedCharTy, SourceLocation());
154   }
155 
156   BuiltinTypeDeclBuilder &addDefaultHandleConstructor(Sema &S,
157                                                       ResourceClass RC) {
158     if (Record->isCompleteDefinition())
159       return *this;
160     ASTContext &AST = Record->getASTContext();
161 
162     QualType ConstructorType =
163         AST.getFunctionType(AST.VoidTy, {}, FunctionProtoType::ExtProtoInfo());
164 
165     CanQualType CanTy = Record->getTypeForDecl()->getCanonicalTypeUnqualified();
166     DeclarationName Name = AST.DeclarationNames.getCXXConstructorName(CanTy);
167     CXXConstructorDecl *Constructor = CXXConstructorDecl::Create(
168         AST, Record, SourceLocation(),
169         DeclarationNameInfo(Name, SourceLocation()), ConstructorType,
170         AST.getTrivialTypeSourceInfo(ConstructorType, SourceLocation()),
171         ExplicitSpecifier(), false, true, false,
172         ConstexprSpecKind::Unspecified);
173 
174     DeclRefExpr *Fn =
175         lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle");
176     Expr *RCExpr = emitResourceClassExpr(AST, RC);
177     Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
178                                   SourceLocation(), FPOptionsOverride());
179 
180     CXXThisExpr *This = CXXThisExpr::Create(
181         AST, SourceLocation(), Constructor->getFunctionObjectParameterType(),
182         true);
183     Expr *Handle = MemberExpr::CreateImplicit(AST, This, false, Fields["h"],
184                                               Fields["h"]->getType(), VK_LValue,
185                                               OK_Ordinary);
186 
187     // If the handle isn't a void pointer, cast the builtin result to the
188     // correct type.
189     if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) {
190       Call = CXXStaticCastExpr::Create(
191           AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr,
192           AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()),
193           FPOptionsOverride(), SourceLocation(), SourceLocation(),
194           SourceRange());
195     }
196 
197     BinaryOperator *Assign = BinaryOperator::Create(
198         AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary,
199         SourceLocation(), FPOptionsOverride());
200 
201     Constructor->setBody(
202         CompoundStmt::Create(AST, {Assign}, FPOptionsOverride(),
203                              SourceLocation(), SourceLocation()));
204     Constructor->setAccess(AccessSpecifier::AS_public);
205     Record->addDecl(Constructor);
206     return *this;
207   }
208 
209   BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
210     if (Record->isCompleteDefinition())
211       return *this;
212     addArraySubscriptOperator(true);
213     addArraySubscriptOperator(false);
214     return *this;
215   }
216 
217   BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
218     if (Record->isCompleteDefinition())
219       return *this;
220     assert(Fields.count("h") > 0 &&
221            "Subscript operator must be added after the handle.");
222 
223     FieldDecl *Handle = Fields["h"];
224     ASTContext &AST = Record->getASTContext();
225 
226     assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy &&
227            "Not yet supported for void pointer handles.");
228 
229     QualType ElemTy =
230         QualType(Handle->getType()->getPointeeOrArrayElementType(), 0);
231     QualType ReturnTy = ElemTy;
232 
233     FunctionProtoType::ExtProtoInfo ExtInfo;
234 
235     // Subscript operators return references to elements, const makes the
236     // reference and method const so that the underlying data is not mutable.
237     ReturnTy = AST.getLValueReferenceType(ReturnTy);
238     if (IsConst) {
239       ExtInfo.TypeQuals.addConst();
240       ReturnTy.addConst();
241     }
242 
243     QualType MethodTy =
244         AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
245     auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
246     auto *MethodDecl = CXXMethodDecl::Create(
247         AST, Record, SourceLocation(),
248         DeclarationNameInfo(
249             AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
250             SourceLocation()),
251         MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
252         SourceLocation());
253 
254     IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
255     auto *IdxParam = ParmVarDecl::Create(
256         AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
257         &II, AST.UnsignedIntTy,
258         AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
259         SC_None, nullptr);
260     MethodDecl->setParams({IdxParam});
261 
262     // Also add the parameter to the function prototype.
263     auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
264     FnProtoLoc.setParam(0, IdxParam);
265 
266     auto *This =
267         CXXThisExpr::Create(AST, SourceLocation(),
268                             MethodDecl->getFunctionObjectParameterType(), true);
269     auto *HandleAccess = MemberExpr::CreateImplicit(
270         AST, This, false, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
271 
272     auto *IndexExpr = DeclRefExpr::Create(
273         AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
274         DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
275         AST.UnsignedIntTy, VK_PRValue);
276 
277     auto *Array =
278         new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue,
279                                      OK_Ordinary, SourceLocation());
280 
281     auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr);
282 
283     MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
284                                              SourceLocation(),
285                                              SourceLocation()));
286     MethodDecl->setLexicalDeclContext(Record);
287     MethodDecl->setAccess(AccessSpecifier::AS_public);
288     MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
289         AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
290     Record->addDecl(MethodDecl);
291 
292     return *this;
293   }
294 
295   BuiltinTypeDeclBuilder &startDefinition() {
296     if (Record->isCompleteDefinition())
297       return *this;
298     Record->startDefinition();
299     return *this;
300   }
301 
302   BuiltinTypeDeclBuilder &completeDefinition() {
303     if (Record->isCompleteDefinition())
304       return *this;
305     assert(Record->isBeingDefined() &&
306            "Definition must be started before completing it.");
307 
308     Record->completeDefinition();
309     return *this;
310   }
311 
312   TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
313   BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
314                                                   ArrayRef<StringRef> Names);
315 };
316 
317 struct TemplateParameterListBuilder {
318   BuiltinTypeDeclBuilder &Builder;
319   Sema &S;
320   llvm::SmallVector<NamedDecl *> Params;
321 
322   TemplateParameterListBuilder(Sema &S, BuiltinTypeDeclBuilder &RB)
323       : Builder(RB), S(S) {}
324 
325   ~TemplateParameterListBuilder() { finalizeTemplateArgs(); }
326 
327   TemplateParameterListBuilder &
328   addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) {
329     if (Builder.Record->isCompleteDefinition())
330       return *this;
331     unsigned Position = static_cast<unsigned>(Params.size());
332     auto *Decl = TemplateTypeParmDecl::Create(
333         S.Context, Builder.Record->getDeclContext(), SourceLocation(),
334         SourceLocation(), /* TemplateDepth */ 0, Position,
335         &S.Context.Idents.get(Name, tok::TokenKind::identifier),
336         /* Typename */ false,
337         /* ParameterPack */ false);
338     if (!DefaultValue.isNull())
339       Decl->setDefaultArgument(
340           S.Context, S.getTrivialTemplateArgumentLoc(DefaultValue, QualType(),
341                                                      SourceLocation()));
342 
343     Params.emplace_back(Decl);
344     return *this;
345   }
346 
347   BuiltinTypeDeclBuilder &finalizeTemplateArgs() {
348     if (Params.empty())
349       return Builder;
350     auto *ParamList = TemplateParameterList::Create(S.Context, SourceLocation(),
351                                                     SourceLocation(), Params,
352                                                     SourceLocation(), nullptr);
353     Builder.Template = ClassTemplateDecl::Create(
354         S.Context, Builder.Record->getDeclContext(), SourceLocation(),
355         DeclarationName(Builder.Record->getIdentifier()), ParamList,
356         Builder.Record);
357     Builder.Record->setDescribedClassTemplate(Builder.Template);
358     Builder.Template->setImplicit(true);
359     Builder.Template->setLexicalDeclContext(Builder.Record->getDeclContext());
360     // NOTE: setPreviousDecl before addDecl so new decl replace old decl when
361     // make visible.
362     Builder.Template->setPreviousDecl(Builder.PrevTemplate);
363     Builder.Record->getDeclContext()->addDecl(Builder.Template);
364     Params.clear();
365 
366     QualType T = Builder.Template->getInjectedClassNameSpecialization();
367     T = S.Context.getInjectedClassNameType(Builder.Record, T);
368 
369     return Builder;
370   }
371 };
372 } // namespace
373 
374 TemplateParameterListBuilder
375 BuiltinTypeDeclBuilder::addTemplateArgumentList(Sema &S) {
376   return TemplateParameterListBuilder(S, *this);
377 }
378 
379 BuiltinTypeDeclBuilder &
380 BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
381                                                 ArrayRef<StringRef> Names) {
382   TemplateParameterListBuilder Builder = this->addTemplateArgumentList(S);
383   for (StringRef Name : Names)
384     Builder.addTypeParameter(Name);
385   return Builder.finalizeTemplateArgs();
386 }
387 
388 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
389 
390 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
391   SemaPtr = &S;
392   ASTContext &AST = SemaPtr->getASTContext();
393   // If the translation unit has external storage force external decls to load.
394   if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
395     (void)AST.getTranslationUnitDecl()->decls_begin();
396 
397   IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
398   LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
399   NamespaceDecl *PrevDecl = nullptr;
400   if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl()))
401     PrevDecl = Result.getAsSingle<NamespaceDecl>();
402   HLSLNamespace = NamespaceDecl::Create(
403       AST, AST.getTranslationUnitDecl(), /*Inline=*/false, SourceLocation(),
404       SourceLocation(), &HLSL, PrevDecl, /*Nested=*/false);
405   HLSLNamespace->setImplicit(true);
406   HLSLNamespace->setHasExternalLexicalStorage();
407   AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
408 
409   // Force external decls in the HLSL namespace to load from the PCH.
410   (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
411   defineTrivialHLSLTypes();
412   defineHLSLTypesWithForwardDeclarations();
413 
414   // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
415   // built in types inside a namespace, but we are planning to change that in
416   // the near future. In order to be source compatible older versions of HLSL
417   // will need to implicitly use the hlsl namespace. For now in clang everything
418   // will get added to the namespace, and we can remove the using directive for
419   // future language versions to match HLSL's evolution.
420   auto *UsingDecl = UsingDirectiveDecl::Create(
421       AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
422       NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
423       AST.getTranslationUnitDecl());
424 
425   AST.getTranslationUnitDecl()->addDecl(UsingDecl);
426 }
427 
428 void HLSLExternalSemaSource::defineHLSLVectorAlias() {
429   ASTContext &AST = SemaPtr->getASTContext();
430 
431   llvm::SmallVector<NamedDecl *> TemplateParams;
432 
433   auto *TypeParam = TemplateTypeParmDecl::Create(
434       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
435       &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
436   TypeParam->setDefaultArgument(
437       AST, SemaPtr->getTrivialTemplateArgumentLoc(
438                TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
439 
440   TemplateParams.emplace_back(TypeParam);
441 
442   auto *SizeParam = NonTypeTemplateParmDecl::Create(
443       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
444       &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
445       false, AST.getTrivialTypeSourceInfo(AST.IntTy));
446   llvm::APInt Val(AST.getIntWidth(AST.IntTy), 4);
447   TemplateArgument Default(AST, llvm::APSInt(std::move(Val)), AST.IntTy,
448                            /*IsDefaulted=*/true);
449   SizeParam->setDefaultArgument(
450       AST, SemaPtr->getTrivialTemplateArgumentLoc(Default, AST.IntTy,
451                                                   SourceLocation(), SizeParam));
452   TemplateParams.emplace_back(SizeParam);
453 
454   auto *ParamList =
455       TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
456                                     TemplateParams, SourceLocation(), nullptr);
457 
458   IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
459 
460   QualType AliasType = AST.getDependentSizedExtVectorType(
461       AST.getTemplateTypeParmType(0, 0, false, TypeParam),
462       DeclRefExpr::Create(
463           AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
464           DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
465           AST.IntTy, VK_LValue),
466       SourceLocation());
467 
468   auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
469                                        SourceLocation(), &II,
470                                        AST.getTrivialTypeSourceInfo(AliasType));
471   Record->setImplicit(true);
472 
473   auto *Template =
474       TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
475                                     Record->getIdentifier(), ParamList, Record);
476 
477   Record->setDescribedAliasTemplate(Template);
478   Template->setImplicit(true);
479   Template->setLexicalDeclContext(Record->getDeclContext());
480   HLSLNamespace->addDecl(Template);
481 }
482 
483 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
484   defineHLSLVectorAlias();
485 }
486 
487 /// Set up common members and attributes for buffer types
488 static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
489                                               ResourceClass RC, ResourceKind RK,
490                                               bool IsROV) {
491   return BuiltinTypeDeclBuilder(Decl)
492       .addHandleMember()
493       .addDefaultHandleConstructor(S, RC)
494       .annotateHLSLResource(RC, RK, IsROV);
495 }
496 
497 void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
498   CXXRecordDecl *Decl;
499   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
500              .addSimpleTemplateParams(*SemaPtr, {"element_type"})
501              .Record;
502   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
503     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
504                     ResourceKind::TypedBuffer, /*IsROV=*/false)
505         .addArraySubscriptOperators()
506         .completeDefinition();
507   });
508 
509   Decl =
510       BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RasterizerOrderedBuffer")
511           .addSimpleTemplateParams(*SemaPtr, {"element_type"})
512           .Record;
513   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
514     setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
515                     ResourceKind::TypedBuffer, /*IsROV=*/true)
516         .addArraySubscriptOperators()
517         .completeDefinition();
518   });
519 }
520 
521 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
522                                           CompletionFunction Fn) {
523   Completions.insert(std::make_pair(Record->getCanonicalDecl(), Fn));
524 }
525 
526 void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
527   if (!isa<CXXRecordDecl>(Tag))
528     return;
529   auto Record = cast<CXXRecordDecl>(Tag);
530 
531   // If this is a specialization, we need to get the underlying templated
532   // declaration and complete that.
533   if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
534     Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
535   Record = Record->getCanonicalDecl();
536   auto It = Completions.find(Record);
537   if (It == Completions.end())
538     return;
539   It->second(Record);
540 }
541