xref: /freebsd/contrib/llvm-project/clang/lib/Sema/SemaHLSL.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This implements Semantic Analysis for HLSL constructs.
9 //===----------------------------------------------------------------------===//
10 
11 #include "clang/Sema/SemaHLSL.h"
12 #include "clang/AST/ASTConsumer.h"
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Attr.h"
15 #include "clang/AST/Attrs.inc"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclBase.h"
18 #include "clang/AST/DeclCXX.h"
19 #include "clang/AST/DeclarationName.h"
20 #include "clang/AST/DynamicRecursiveASTVisitor.h"
21 #include "clang/AST/Expr.h"
22 #include "clang/AST/Type.h"
23 #include "clang/AST/TypeLoc.h"
24 #include "clang/Basic/Builtins.h"
25 #include "clang/Basic/DiagnosticSema.h"
26 #include "clang/Basic/IdentifierTable.h"
27 #include "clang/Basic/LLVM.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "clang/Basic/Specifiers.h"
30 #include "clang/Basic/TargetInfo.h"
31 #include "clang/Sema/Initialization.h"
32 #include "clang/Sema/Lookup.h"
33 #include "clang/Sema/ParsedAttr.h"
34 #include "clang/Sema/Sema.h"
35 #include "clang/Sema/Template.h"
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/ADT/Twine.h"
42 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/DXILABI.h"
45 #include "llvm/Support/ErrorHandling.h"
46 #include "llvm/Support/FormatVariadic.h"
47 #include "llvm/TargetParser/Triple.h"
48 #include <cmath>
49 #include <cstddef>
50 #include <iterator>
51 #include <utility>
52 
53 using namespace clang;
54 using RegisterType = HLSLResourceBindingAttr::RegisterType;
55 
56 static CXXRecordDecl *createHostLayoutStruct(Sema &S,
57                                              CXXRecordDecl *StructDecl);
58 
getRegisterType(ResourceClass RC)59 static RegisterType getRegisterType(ResourceClass RC) {
60   switch (RC) {
61   case ResourceClass::SRV:
62     return RegisterType::SRV;
63   case ResourceClass::UAV:
64     return RegisterType::UAV;
65   case ResourceClass::CBuffer:
66     return RegisterType::CBuffer;
67   case ResourceClass::Sampler:
68     return RegisterType::Sampler;
69   }
70   llvm_unreachable("unexpected ResourceClass value");
71 }
72 
73 // Converts the first letter of string Slot to RegisterType.
74 // Returns false if the letter does not correspond to a valid register type.
convertToRegisterType(StringRef Slot,RegisterType * RT)75 static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
76   assert(RT != nullptr);
77   switch (Slot[0]) {
78   case 't':
79   case 'T':
80     *RT = RegisterType::SRV;
81     return true;
82   case 'u':
83   case 'U':
84     *RT = RegisterType::UAV;
85     return true;
86   case 'b':
87   case 'B':
88     *RT = RegisterType::CBuffer;
89     return true;
90   case 's':
91   case 'S':
92     *RT = RegisterType::Sampler;
93     return true;
94   case 'c':
95   case 'C':
96     *RT = RegisterType::C;
97     return true;
98   case 'i':
99   case 'I':
100     *RT = RegisterType::I;
101     return true;
102   default:
103     return false;
104   }
105 }
106 
getResourceClass(RegisterType RT)107 static ResourceClass getResourceClass(RegisterType RT) {
108   switch (RT) {
109   case RegisterType::SRV:
110     return ResourceClass::SRV;
111   case RegisterType::UAV:
112     return ResourceClass::UAV;
113   case RegisterType::CBuffer:
114     return ResourceClass::CBuffer;
115   case RegisterType::Sampler:
116     return ResourceClass::Sampler;
117   case RegisterType::C:
118   case RegisterType::I:
119     // Deliberately falling through to the unreachable below.
120     break;
121   }
122   llvm_unreachable("unexpected RegisterType value");
123 }
124 
getSpecConstBuiltinId(const Type * Type)125 static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
126   const auto *BT = dyn_cast<BuiltinType>(Type);
127   if (!BT) {
128     if (!Type->isEnumeralType())
129       return Builtin::NotBuiltin;
130     return Builtin::BI__builtin_get_spirv_spec_constant_int;
131   }
132 
133   switch (BT->getKind()) {
134   case BuiltinType::Bool:
135     return Builtin::BI__builtin_get_spirv_spec_constant_bool;
136   case BuiltinType::Short:
137     return Builtin::BI__builtin_get_spirv_spec_constant_short;
138   case BuiltinType::Int:
139     return Builtin::BI__builtin_get_spirv_spec_constant_int;
140   case BuiltinType::LongLong:
141     return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
142   case BuiltinType::UShort:
143     return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
144   case BuiltinType::UInt:
145     return Builtin::BI__builtin_get_spirv_spec_constant_uint;
146   case BuiltinType::ULongLong:
147     return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
148   case BuiltinType::Half:
149     return Builtin::BI__builtin_get_spirv_spec_constant_half;
150   case BuiltinType::Float:
151     return Builtin::BI__builtin_get_spirv_spec_constant_float;
152   case BuiltinType::Double:
153     return Builtin::BI__builtin_get_spirv_spec_constant_double;
154   default:
155     return Builtin::NotBuiltin;
156   }
157 }
158 
addDeclBindingInfo(const VarDecl * VD,ResourceClass ResClass)159 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
160                                                       ResourceClass ResClass) {
161   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
162          "DeclBindingInfo already added");
163   assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
164   // VarDecl may have multiple entries for different resource classes.
165   // DeclToBindingListIndex stores the index of the first binding we saw
166   // for this decl. If there are any additional ones then that index
167   // shouldn't be updated.
168   DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
169   return &BindingsList.emplace_back(VD, ResClass);
170 }
171 
getDeclBindingInfo(const VarDecl * VD,ResourceClass ResClass)172 DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
173                                                       ResourceClass ResClass) {
174   auto Entry = DeclToBindingListIndex.find(VD);
175   if (Entry != DeclToBindingListIndex.end()) {
176     for (unsigned Index = Entry->getSecond();
177          Index < BindingsList.size() && BindingsList[Index].Decl == VD;
178          ++Index) {
179       if (BindingsList[Index].ResClass == ResClass)
180         return &BindingsList[Index];
181     }
182   }
183   return nullptr;
184 }
185 
hasBindingInfoForDecl(const VarDecl * VD) const186 bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
187   return DeclToBindingListIndex.contains(VD);
188 }
189 
SemaHLSL(Sema & S)190 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
191 
ActOnStartBuffer(Scope * BufferScope,bool CBuffer,SourceLocation KwLoc,IdentifierInfo * Ident,SourceLocation IdentLoc,SourceLocation LBrace)192 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
193                                  SourceLocation KwLoc, IdentifierInfo *Ident,
194                                  SourceLocation IdentLoc,
195                                  SourceLocation LBrace) {
196   // For anonymous namespace, take the location of the left brace.
197   DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
198   HLSLBufferDecl *Result = HLSLBufferDecl::Create(
199       getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
200 
201   // if CBuffer is false, then it's a TBuffer
202   auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
203                     : llvm::hlsl::ResourceClass::SRV;
204   Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
205 
206   SemaRef.PushOnScopeChains(Result, BufferScope);
207   SemaRef.PushDeclContext(BufferScope, Result);
208 
209   return Result;
210 }
211 
calculateLegacyCbufferFieldAlign(const ASTContext & Context,QualType T)212 static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
213                                                  QualType T) {
214   // Arrays and Structs are always aligned to new buffer rows
215   if (T->isArrayType() || T->isStructureType())
216     return 16;
217 
218   // Vectors are aligned to the type they contain
219   if (const VectorType *VT = T->getAs<VectorType>())
220     return calculateLegacyCbufferFieldAlign(Context, VT->getElementType());
221 
222   assert(Context.getTypeSize(T) <= 64 &&
223          "Scalar bit widths larger than 64 not supported");
224 
225   // Scalar types are aligned to their byte width
226   return Context.getTypeSize(T) / 8;
227 }
228 
229 // Calculate the size of a legacy cbuffer type in bytes based on
230 // https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
calculateLegacyCbufferSize(const ASTContext & Context,QualType T)231 static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
232                                            QualType T) {
233   constexpr unsigned CBufferAlign = 16;
234   if (const RecordType *RT = T->getAs<RecordType>()) {
235     unsigned Size = 0;
236     const RecordDecl *RD = RT->getDecl();
237     for (const FieldDecl *Field : RD->fields()) {
238       QualType Ty = Field->getType();
239       unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
240       unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty);
241 
242       // If the field crosses the row boundary after alignment it drops to the
243       // next row
244       unsigned AlignSize = llvm::alignTo(Size, FieldAlign);
245       if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
246         FieldAlign = CBufferAlign;
247       }
248 
249       Size = llvm::alignTo(Size, FieldAlign);
250       Size += FieldSize;
251     }
252     return Size;
253   }
254 
255   if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
256     unsigned ElementCount = AT->getSize().getZExtValue();
257     if (ElementCount == 0)
258       return 0;
259 
260     unsigned ElementSize =
261         calculateLegacyCbufferSize(Context, AT->getElementType());
262     unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
263     return AlignedElementSize * (ElementCount - 1) + ElementSize;
264   }
265 
266   if (const VectorType *VT = T->getAs<VectorType>()) {
267     unsigned ElementCount = VT->getNumElements();
268     unsigned ElementSize =
269         calculateLegacyCbufferSize(Context, VT->getElementType());
270     return ElementSize * ElementCount;
271   }
272 
273   return Context.getTypeSize(T) / 8;
274 }
275 
276 // Validate packoffset:
277 // - if packoffset it used it must be set on all declarations inside the buffer
278 // - packoffset ranges must not overlap
validatePackoffset(Sema & S,HLSLBufferDecl * BufDecl)279 static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
280   llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
281 
282   // Make sure the packoffset annotations are either on all declarations
283   // or on none.
284   bool HasPackOffset = false;
285   bool HasNonPackOffset = false;
286   for (auto *Field : BufDecl->buffer_decls()) {
287     VarDecl *Var = dyn_cast<VarDecl>(Field);
288     if (!Var)
289       continue;
290     if (Field->hasAttr<HLSLPackOffsetAttr>()) {
291       PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
292       HasPackOffset = true;
293     } else {
294       HasNonPackOffset = true;
295     }
296   }
297 
298   if (!HasPackOffset)
299     return;
300 
301   if (HasNonPackOffset)
302     S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
303 
304   // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
305   // and compare adjacent values.
306   bool IsValid = true;
307   ASTContext &Context = S.getASTContext();
308   std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
309             [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
310                const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
311               return LHS.second->getOffsetInBytes() <
312                      RHS.second->getOffsetInBytes();
313             });
314   for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
315     VarDecl *Var = PackOffsetVec[i].first;
316     HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
317     unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
318     unsigned Begin = Attr->getOffsetInBytes();
319     unsigned End = Begin + Size;
320     unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
321     if (End > NextBegin) {
322       VarDecl *NextVar = PackOffsetVec[i + 1].first;
323       S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
324           << NextVar << Var;
325       IsValid = false;
326     }
327   }
328   BufDecl->setHasValidPackoffset(IsValid);
329 }
330 
331 // Returns true if the array has a zero size = if any of the dimensions is 0
isZeroSizedArray(const ConstantArrayType * CAT)332 static bool isZeroSizedArray(const ConstantArrayType *CAT) {
333   while (CAT && !CAT->isZeroSize())
334     CAT = dyn_cast<ConstantArrayType>(
335         CAT->getElementType()->getUnqualifiedDesugaredType());
336   return CAT != nullptr;
337 }
338 
339 // Returns true if the record type is an HLSL resource class or an array of
340 // resource classes
isResourceRecordTypeOrArrayOf(const Type * Ty)341 static bool isResourceRecordTypeOrArrayOf(const Type *Ty) {
342   while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Ty))
343     Ty = CAT->getArrayElementTypeNoTypeQual();
344   return HLSLAttributedResourceType::findHandleTypeOnResource(Ty) != nullptr;
345 }
346 
isResourceRecordTypeOrArrayOf(VarDecl * VD)347 static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
348   return isResourceRecordTypeOrArrayOf(VD->getType().getTypePtr());
349 }
350 
351 // Returns true if the type is a leaf element type that is not valid to be
352 // included in HLSL Buffer, such as a resource class, empty struct, zero-sized
353 // array, or a builtin intangible type. Returns false it is a valid leaf element
354 // type or if it is a record type that needs to be inspected further.
isInvalidConstantBufferLeafElementType(const Type * Ty)355 static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
356   Ty = Ty->getUnqualifiedDesugaredType();
357   if (isResourceRecordTypeOrArrayOf(Ty))
358     return true;
359   if (Ty->isRecordType())
360     return Ty->getAsCXXRecordDecl()->isEmpty();
361   if (Ty->isConstantArrayType() &&
362       isZeroSizedArray(cast<ConstantArrayType>(Ty)))
363     return true;
364   if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
365     return true;
366   return false;
367 }
368 
369 // Returns true if the struct contains at least one element that prevents it
370 // from being included inside HLSL Buffer as is, such as an intangible type,
371 // empty struct, or zero-sized array. If it does, a new implicit layout struct
372 // needs to be created for HLSL Buffer use that will exclude these unwanted
373 // declarations (see createHostLayoutStruct function).
requiresImplicitBufferLayoutStructure(const CXXRecordDecl * RD)374 static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
375   if (RD->getTypeForDecl()->isHLSLIntangibleType() || RD->isEmpty())
376     return true;
377   // check fields
378   for (const FieldDecl *Field : RD->fields()) {
379     QualType Ty = Field->getType();
380     if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr()))
381       return true;
382     if (Ty->isRecordType() &&
383         requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl()))
384       return true;
385   }
386   // check bases
387   for (const CXXBaseSpecifier &Base : RD->bases())
388     if (requiresImplicitBufferLayoutStructure(
389             Base.getType()->getAsCXXRecordDecl()))
390       return true;
391   return false;
392 }
393 
findRecordDeclInContext(IdentifierInfo * II,DeclContext * DC)394 static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
395                                               DeclContext *DC) {
396   CXXRecordDecl *RD = nullptr;
397   for (NamedDecl *Decl :
398        DC->getNonTransparentContext()->lookup(DeclarationName(II))) {
399     if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
400       assert(RD == nullptr &&
401              "there should be at most 1 record by a given name in a scope");
402       RD = FoundRD;
403     }
404   }
405   return RD;
406 }
407 
408 // Creates a name for buffer layout struct using the provide name base.
409 // If the name must be unique (not previously defined), a suffix is added
410 // until a unique name is found.
getHostLayoutStructName(Sema & S,NamedDecl * BaseDecl,bool MustBeUnique)411 static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
412                                                bool MustBeUnique) {
413   ASTContext &AST = S.getASTContext();
414 
415   IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
416   llvm::SmallString<64> Name("__cblayout_");
417   if (NameBaseII) {
418     Name.append(NameBaseII->getName());
419   } else {
420     // anonymous struct
421     Name.append("anon");
422     MustBeUnique = true;
423   }
424 
425   size_t NameLength = Name.size();
426   IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
427   if (!MustBeUnique)
428     return II;
429 
430   unsigned suffix = 0;
431   while (true) {
432     if (suffix != 0) {
433       Name.append("_");
434       Name.append(llvm::Twine(suffix).str());
435       II = &AST.Idents.get(Name, tok::TokenKind::identifier);
436     }
437     if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
438       return II;
439     // declaration with that name already exists - increment suffix and try
440     // again until unique name is found
441     suffix++;
442     Name.truncate(NameLength);
443   };
444 }
445 
446 // Creates a field declaration of given name and type for HLSL buffer layout
447 // struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
createFieldForHostLayoutStruct(Sema & S,const Type * Ty,IdentifierInfo * II,CXXRecordDecl * LayoutStruct)448 static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
449                                                  IdentifierInfo *II,
450                                                  CXXRecordDecl *LayoutStruct) {
451   if (isInvalidConstantBufferLeafElementType(Ty))
452     return nullptr;
453 
454   if (Ty->isRecordType()) {
455     CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
456     if (requiresImplicitBufferLayoutStructure(RD)) {
457       RD = createHostLayoutStruct(S, RD);
458       if (!RD)
459         return nullptr;
460       Ty = RD->getTypeForDecl();
461     }
462   }
463 
464   QualType QT = QualType(Ty, 0);
465   ASTContext &AST = S.getASTContext();
466   TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(QT, SourceLocation());
467   auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
468                                   SourceLocation(), II, QT, TSI, nullptr, false,
469                                   InClassInitStyle::ICIS_NoInit);
470   Field->setAccess(AccessSpecifier::AS_public);
471   return Field;
472 }
473 
474 // Creates host layout struct for a struct included in HLSL Buffer.
475 // The layout struct will include only fields that are allowed in HLSL buffer.
476 // These fields will be filtered out:
477 // - resource classes
478 // - empty structs
479 // - zero-sized arrays
480 // Returns nullptr if the resulting layout struct would be empty.
createHostLayoutStruct(Sema & S,CXXRecordDecl * StructDecl)481 static CXXRecordDecl *createHostLayoutStruct(Sema &S,
482                                              CXXRecordDecl *StructDecl) {
483   assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
484          "struct is already HLSL buffer compatible");
485 
486   ASTContext &AST = S.getASTContext();
487   DeclContext *DC = StructDecl->getDeclContext();
488   IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
489 
490   // reuse existing if the layout struct if it already exists
491   if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
492     return RD;
493 
494   CXXRecordDecl *LS =
495       CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(),
496                             SourceLocation(), II);
497   LS->setImplicit(true);
498   LS->addAttr(PackedAttr::CreateImplicit(AST));
499   LS->startDefinition();
500 
501   // copy base struct, create HLSL Buffer compatible version if needed
502   if (unsigned NumBases = StructDecl->getNumBases()) {
503     assert(NumBases == 1 && "HLSL supports only one base type");
504     (void)NumBases;
505     CXXBaseSpecifier Base = *StructDecl->bases_begin();
506     CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
507     if (requiresImplicitBufferLayoutStructure(BaseDecl)) {
508       BaseDecl = createHostLayoutStruct(S, BaseDecl);
509       if (BaseDecl) {
510         TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(
511             QualType(BaseDecl->getTypeForDecl(), 0));
512         Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
513                                 AS_none, TSI, SourceLocation());
514       }
515     }
516     if (BaseDecl) {
517       const CXXBaseSpecifier *BasesArray[1] = {&Base};
518       LS->setBases(BasesArray, 1);
519     }
520   }
521 
522   // filter struct fields
523   for (const FieldDecl *FD : StructDecl->fields()) {
524     const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
525     if (FieldDecl *NewFD =
526             createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
527       LS->addDecl(NewFD);
528   }
529   LS->completeDefinition();
530 
531   if (LS->field_empty() && LS->getNumBases() == 0)
532     return nullptr;
533 
534   DC->addDecl(LS);
535   return LS;
536 }
537 
538 // Creates host layout struct for HLSL Buffer. The struct will include only
539 // fields of types that are allowed in HLSL buffer and it will filter out:
540 // - static or groupshared variable declarations
541 // - resource classes
542 // - empty structs
543 // - zero-sized arrays
544 // - non-variable declarations
545 // The layout struct will be added to the HLSLBufferDecl declarations.
createHostLayoutStructForBuffer(Sema & S,HLSLBufferDecl * BufDecl)546 void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
547   ASTContext &AST = S.getASTContext();
548   IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
549 
550   CXXRecordDecl *LS =
551       CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl,
552                             SourceLocation(), SourceLocation(), II);
553   LS->addAttr(PackedAttr::CreateImplicit(AST));
554   LS->setImplicit(true);
555   LS->startDefinition();
556 
557   for (Decl *D : BufDecl->buffer_decls()) {
558     VarDecl *VD = dyn_cast<VarDecl>(D);
559     if (!VD || VD->getStorageClass() == SC_Static ||
560         VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
561       continue;
562     const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
563     if (FieldDecl *FD =
564             createFieldForHostLayoutStruct(S, Ty, VD->getIdentifier(), LS)) {
565       // add the field decl to the layout struct
566       LS->addDecl(FD);
567       // update address space of the original decl to hlsl_constant
568       QualType NewTy =
569           AST.getAddrSpaceQualType(VD->getType(), LangAS::hlsl_constant);
570       VD->setType(NewTy);
571     }
572   }
573   LS->completeDefinition();
574   BufDecl->addLayoutStruct(LS);
575 }
576 
addImplicitBindingAttrToBuffer(Sema & S,HLSLBufferDecl * BufDecl,uint32_t ImplicitBindingOrderID)577 static void addImplicitBindingAttrToBuffer(Sema &S, HLSLBufferDecl *BufDecl,
578                                            uint32_t ImplicitBindingOrderID) {
579   RegisterType RT =
580       BufDecl->isCBuffer() ? RegisterType::CBuffer : RegisterType::SRV;
581   auto *Attr =
582       HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {});
583   std::optional<unsigned> RegSlot;
584   Attr->setBinding(RT, RegSlot, 0);
585   Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
586   BufDecl->addAttr(Attr);
587 }
588 
589 // Handle end of cbuffer/tbuffer declaration
ActOnFinishBuffer(Decl * Dcl,SourceLocation RBrace)590 void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
591   auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
592   BufDecl->setRBraceLoc(RBrace);
593 
594   validatePackoffset(SemaRef, BufDecl);
595 
596   // create buffer layout struct
597   createHostLayoutStructForBuffer(SemaRef, BufDecl);
598 
599   HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>();
600   if (!RBA || !RBA->hasRegisterSlot()) {
601     SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
602     // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
603     // to codegen. If it does not exist, create an implicit attribute.
604     uint32_t OrderID = getNextImplicitBindingOrderID();
605     if (RBA)
606       RBA->setImplicitBindingOrderID(OrderID);
607     else
608       addImplicitBindingAttrToBuffer(SemaRef, BufDecl, OrderID);
609   }
610 
611   SemaRef.PopDeclContext();
612 }
613 
mergeNumThreadsAttr(Decl * D,const AttributeCommonInfo & AL,int X,int Y,int Z)614 HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
615                                                   const AttributeCommonInfo &AL,
616                                                   int X, int Y, int Z) {
617   if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
618     if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
619       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
620       Diag(AL.getLoc(), diag::note_conflicting_attribute);
621     }
622     return nullptr;
623   }
624   return ::new (getASTContext())
625       HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
626 }
627 
mergeWaveSizeAttr(Decl * D,const AttributeCommonInfo & AL,int Min,int Max,int Preferred,int SpelledArgsCount)628 HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
629                                               const AttributeCommonInfo &AL,
630                                               int Min, int Max, int Preferred,
631                                               int SpelledArgsCount) {
632   if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
633     if (WS->getMin() != Min || WS->getMax() != Max ||
634         WS->getPreferred() != Preferred ||
635         WS->getSpelledArgsCount() != SpelledArgsCount) {
636       Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
637       Diag(AL.getLoc(), diag::note_conflicting_attribute);
638     }
639     return nullptr;
640   }
641   HLSLWaveSizeAttr *Result = ::new (getASTContext())
642       HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
643   Result->setSpelledArgsCount(SpelledArgsCount);
644   return Result;
645 }
646 
647 HLSLVkConstantIdAttr *
mergeVkConstantIdAttr(Decl * D,const AttributeCommonInfo & AL,int Id)648 SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
649                                 int Id) {
650 
651   auto &TargetInfo = getASTContext().getTargetInfo();
652   if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
653     Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
654     return nullptr;
655   }
656 
657   auto *VD = cast<VarDecl>(D);
658 
659   if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
660       Builtin::NotBuiltin) {
661     Diag(VD->getLocation(), diag::err_specialization_const);
662     return nullptr;
663   }
664 
665   if (!VD->getType().isConstQualified()) {
666     Diag(VD->getLocation(), diag::err_specialization_const);
667     return nullptr;
668   }
669 
670   if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
671     if (CI->getId() != Id) {
672       Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
673       Diag(AL.getLoc(), diag::note_conflicting_attribute);
674     }
675     return nullptr;
676   }
677 
678   HLSLVkConstantIdAttr *Result =
679       ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
680   return Result;
681 }
682 
683 HLSLShaderAttr *
mergeShaderAttr(Decl * D,const AttributeCommonInfo & AL,llvm::Triple::EnvironmentType ShaderType)684 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
685                           llvm::Triple::EnvironmentType ShaderType) {
686   if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
687     if (NT->getType() != ShaderType) {
688       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
689       Diag(AL.getLoc(), diag::note_conflicting_attribute);
690     }
691     return nullptr;
692   }
693   return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
694 }
695 
696 HLSLParamModifierAttr *
mergeParamModifierAttr(Decl * D,const AttributeCommonInfo & AL,HLSLParamModifierAttr::Spelling Spelling)697 SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
698                                  HLSLParamModifierAttr::Spelling Spelling) {
699   // We can only merge an `in` attribute with an `out` attribute. All other
700   // combinations of duplicated attributes are ill-formed.
701   if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
702     if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
703         (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
704       D->dropAttr<HLSLParamModifierAttr>();
705       SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
706       return HLSLParamModifierAttr::Create(
707           getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
708           HLSLParamModifierAttr::Keyword_inout);
709     }
710     Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
711     Diag(PA->getLocation(), diag::note_conflicting_attribute);
712     return nullptr;
713   }
714   return HLSLParamModifierAttr::Create(getASTContext(), AL);
715 }
716 
ActOnTopLevelFunction(FunctionDecl * FD)717 void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
718   auto &TargetInfo = getASTContext().getTargetInfo();
719 
720   if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
721     return;
722 
723   llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
724   if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
725     if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
726       // The entry point is already annotated - check that it matches the
727       // triple.
728       if (Shader->getType() != Env) {
729         Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
730             << Shader;
731         FD->setInvalidDecl();
732       }
733     } else {
734       // Implicitly add the shader attribute if the entry function isn't
735       // explicitly annotated.
736       FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
737                                                  FD->getBeginLoc()));
738     }
739   } else {
740     switch (Env) {
741     case llvm::Triple::UnknownEnvironment:
742     case llvm::Triple::Library:
743       break;
744     default:
745       llvm_unreachable("Unhandled environment in triple");
746     }
747   }
748 }
749 
CheckEntryPoint(FunctionDecl * FD)750 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
751   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
752   assert(ShaderAttr && "Entry point has no shader attribute");
753   llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
754   auto &TargetInfo = getASTContext().getTargetInfo();
755   VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
756   switch (ST) {
757   case llvm::Triple::Pixel:
758   case llvm::Triple::Vertex:
759   case llvm::Triple::Geometry:
760   case llvm::Triple::Hull:
761   case llvm::Triple::Domain:
762   case llvm::Triple::RayGeneration:
763   case llvm::Triple::Intersection:
764   case llvm::Triple::AnyHit:
765   case llvm::Triple::ClosestHit:
766   case llvm::Triple::Miss:
767   case llvm::Triple::Callable:
768     if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
769       DiagnoseAttrStageMismatch(NT, ST,
770                                 {llvm::Triple::Compute,
771                                  llvm::Triple::Amplification,
772                                  llvm::Triple::Mesh});
773       FD->setInvalidDecl();
774     }
775     if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
776       DiagnoseAttrStageMismatch(WS, ST,
777                                 {llvm::Triple::Compute,
778                                  llvm::Triple::Amplification,
779                                  llvm::Triple::Mesh});
780       FD->setInvalidDecl();
781     }
782     break;
783 
784   case llvm::Triple::Compute:
785   case llvm::Triple::Amplification:
786   case llvm::Triple::Mesh:
787     if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
788       Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
789           << llvm::Triple::getEnvironmentTypeName(ST);
790       FD->setInvalidDecl();
791     }
792     if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
793       if (Ver < VersionTuple(6, 6)) {
794         Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
795             << WS << "6.6";
796         FD->setInvalidDecl();
797       } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
798         Diag(
799             WS->getLocation(),
800             diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
801             << WS << WS->getSpelledArgsCount() << "6.8";
802         FD->setInvalidDecl();
803       }
804     }
805     break;
806   default:
807     llvm_unreachable("Unhandled environment in triple");
808   }
809 
810   for (ParmVarDecl *Param : FD->parameters()) {
811     if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
812       CheckSemanticAnnotation(FD, Param, AnnotationAttr);
813     } else {
814       // FIXME: Handle struct parameters where annotations are on struct fields.
815       // See: https://github.com/llvm/llvm-project/issues/57875
816       Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
817       Diag(Param->getLocation(), diag::note_previous_decl) << Param;
818       FD->setInvalidDecl();
819     }
820   }
821   // FIXME: Verify return type semantic annotation.
822 }
823 
CheckSemanticAnnotation(FunctionDecl * EntryPoint,const Decl * Param,const HLSLAnnotationAttr * AnnotationAttr)824 void SemaHLSL::CheckSemanticAnnotation(
825     FunctionDecl *EntryPoint, const Decl *Param,
826     const HLSLAnnotationAttr *AnnotationAttr) {
827   auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
828   assert(ShaderAttr && "Entry point has no shader attribute");
829   llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
830 
831   switch (AnnotationAttr->getKind()) {
832   case attr::HLSLSV_DispatchThreadID:
833   case attr::HLSLSV_GroupIndex:
834   case attr::HLSLSV_GroupThreadID:
835   case attr::HLSLSV_GroupID:
836     if (ST == llvm::Triple::Compute)
837       return;
838     DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
839     break;
840   case attr::HLSLSV_Position:
841     // TODO(#143523): allow use on other shader types & output once the overall
842     // semantic logic is implemented.
843     if (ST == llvm::Triple::Pixel)
844       return;
845     DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Pixel});
846     break;
847   default:
848     llvm_unreachable("Unknown HLSLAnnotationAttr");
849   }
850 }
851 
DiagnoseAttrStageMismatch(const Attr * A,llvm::Triple::EnvironmentType Stage,std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages)852 void SemaHLSL::DiagnoseAttrStageMismatch(
853     const Attr *A, llvm::Triple::EnvironmentType Stage,
854     std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
855   SmallVector<StringRef, 8> StageStrings;
856   llvm::transform(AllowedStages, std::back_inserter(StageStrings),
857                   [](llvm::Triple::EnvironmentType ST) {
858                     return StringRef(
859                         HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
860                   });
861   Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
862       << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
863       << (AllowedStages.size() != 1) << join(StageStrings, ", ");
864 }
865 
866 template <CastKind Kind>
castVector(Sema & S,ExprResult & E,QualType & Ty,unsigned Sz)867 static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
868   if (const auto *VTy = Ty->getAs<VectorType>())
869     Ty = VTy->getElementType();
870   Ty = S.getASTContext().getExtVectorType(Ty, Sz);
871   E = S.ImpCastExprToType(E.get(), Ty, Kind);
872 }
873 
874 template <CastKind Kind>
castElement(Sema & S,ExprResult & E,QualType Ty)875 static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
876   E = S.ImpCastExprToType(E.get(), Ty, Kind);
877   return Ty;
878 }
879 
handleFloatVectorBinOpConversion(Sema & SemaRef,ExprResult & LHS,ExprResult & RHS,QualType LHSType,QualType RHSType,QualType LElTy,QualType RElTy,bool IsCompAssign)880 static QualType handleFloatVectorBinOpConversion(
881     Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
882     QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
883   bool LHSFloat = LElTy->isRealFloatingType();
884   bool RHSFloat = RElTy->isRealFloatingType();
885 
886   if (LHSFloat && RHSFloat) {
887     if (IsCompAssign ||
888         SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
889       return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
890 
891     return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
892   }
893 
894   if (LHSFloat)
895     return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
896 
897   assert(RHSFloat);
898   if (IsCompAssign)
899     return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
900 
901   return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
902 }
903 
handleIntegerVectorBinOpConversion(Sema & SemaRef,ExprResult & LHS,ExprResult & RHS,QualType LHSType,QualType RHSType,QualType LElTy,QualType RElTy,bool IsCompAssign)904 static QualType handleIntegerVectorBinOpConversion(
905     Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
906     QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
907 
908   int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
909   bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
910   bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
911   auto &Ctx = SemaRef.getASTContext();
912 
913   // If both types have the same signedness, use the higher ranked type.
914   if (LHSSigned == RHSSigned) {
915     if (IsCompAssign || IntOrder >= 0)
916       return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
917 
918     return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
919   }
920 
921   // If the unsigned type has greater than or equal rank of the signed type, use
922   // the unsigned type.
923   if (IntOrder != (LHSSigned ? 1 : -1)) {
924     if (IsCompAssign || RHSSigned)
925       return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
926     return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
927   }
928 
929   // At this point the signed type has higher rank than the unsigned type, which
930   // means it will be the same size or bigger. If the signed type is bigger, it
931   // can represent all the values of the unsigned type, so select it.
932   if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
933     if (IsCompAssign || LHSSigned)
934       return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
935     return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
936   }
937 
938   // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
939   // to C/C++ leaking through. The place this happens today is long vs long
940   // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
941   // the long long has higher rank than long even though they are the same size.
942 
943   // If this is a compound assignment cast the right hand side to the left hand
944   // side's type.
945   if (IsCompAssign)
946     return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
947 
948   // If this isn't a compound assignment we convert to unsigned long long.
949   QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
950   QualType NewTy = Ctx.getExtVectorType(
951       ElTy, RHSType->castAs<VectorType>()->getNumElements());
952   (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
953 
954   return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
955 }
956 
getScalarCastKind(ASTContext & Ctx,QualType DestTy,QualType SrcTy)957 static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
958                                   QualType SrcTy) {
959   if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
960     return CK_FloatingCast;
961   if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
962     return CK_IntegralCast;
963   if (DestTy->isRealFloatingType())
964     return CK_IntegralToFloating;
965   assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
966   return CK_FloatingToIntegral;
967 }
968 
handleVectorBinOpConversion(ExprResult & LHS,ExprResult & RHS,QualType LHSType,QualType RHSType,bool IsCompAssign)969 QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
970                                                QualType LHSType,
971                                                QualType RHSType,
972                                                bool IsCompAssign) {
973   const auto *LVecTy = LHSType->getAs<VectorType>();
974   const auto *RVecTy = RHSType->getAs<VectorType>();
975   auto &Ctx = getASTContext();
976 
977   // If the LHS is not a vector and this is a compound assignment, we truncate
978   // the argument to a scalar then convert it to the LHS's type.
979   if (!LVecTy && IsCompAssign) {
980     QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
981     RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
982     RHSType = RHS.get()->getType();
983     if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
984       return LHSType;
985     RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
986                                     getScalarCastKind(Ctx, LHSType, RHSType));
987     return LHSType;
988   }
989 
990   unsigned EndSz = std::numeric_limits<unsigned>::max();
991   unsigned LSz = 0;
992   if (LVecTy)
993     LSz = EndSz = LVecTy->getNumElements();
994   if (RVecTy)
995     EndSz = std::min(RVecTy->getNumElements(), EndSz);
996   assert(EndSz != std::numeric_limits<unsigned>::max() &&
997          "one of the above should have had a value");
998 
999   // In a compound assignment, the left operand does not change type, the right
1000   // operand is converted to the type of the left operand.
1001   if (IsCompAssign && LSz != EndSz) {
1002     Diag(LHS.get()->getBeginLoc(),
1003          diag::err_hlsl_vector_compound_assignment_truncation)
1004         << LHSType << RHSType;
1005     return QualType();
1006   }
1007 
1008   if (RVecTy && RVecTy->getNumElements() > EndSz)
1009     castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
1010   if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1011     castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
1012 
1013   if (!RVecTy)
1014     castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
1015   if (!IsCompAssign && !LVecTy)
1016     castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
1017 
1018   // If we're at the same type after resizing we can stop here.
1019   if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1020     return Ctx.getCommonSugaredType(LHSType, RHSType);
1021 
1022   QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1023   QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1024 
1025   // Handle conversion for floating point vectors.
1026   if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1027     return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1028                                             LElTy, RElTy, IsCompAssign);
1029 
1030   assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1031          "HLSL Vectors can only contain integer or floating point types");
1032   return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1033                                             LElTy, RElTy, IsCompAssign);
1034 }
1035 
emitLogicalOperatorFixIt(Expr * LHS,Expr * RHS,BinaryOperatorKind Opc)1036 void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1037                                         BinaryOperatorKind Opc) {
1038   assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1039          "Called with non-logical operator");
1040   llvm::SmallVector<char, 256> Buff;
1041   llvm::raw_svector_ostream OS(Buff);
1042   PrintingPolicy PP(SemaRef.getLangOpts());
1043   StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1044   OS << NewFnName << "(";
1045   LHS->printPretty(OS, nullptr, PP);
1046   OS << ", ";
1047   RHS->printPretty(OS, nullptr, PP);
1048   OS << ")";
1049   SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1050   SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
1051       << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
1052 }
1053 
1054 std::pair<IdentifierInfo *, bool>
ActOnStartRootSignatureDecl(StringRef Signature)1055 SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1056   llvm::hash_code Hash = llvm::hash_value(Signature);
1057   std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash);
1058   IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr));
1059 
1060   // Check if we have already found a decl of the same name.
1061   LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1062                  Sema::LookupOrdinaryName);
1063   bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext);
1064   return {DeclIdent, Found};
1065 }
1066 
ActOnFinishRootSignatureDecl(SourceLocation Loc,IdentifierInfo * DeclIdent,ArrayRef<hlsl::RootSignatureElement> RootElements)1067 void SemaHLSL::ActOnFinishRootSignatureDecl(
1068     SourceLocation Loc, IdentifierInfo *DeclIdent,
1069     ArrayRef<hlsl::RootSignatureElement> RootElements) {
1070 
1071   if (handleRootSignatureElements(RootElements))
1072     return;
1073 
1074   SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1075   for (auto &RootSigElement : RootElements)
1076     Elements.push_back(RootSigElement.getElement());
1077 
1078   auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1079       SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
1080       DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
1081 
1082   SignatureDecl->setImplicit();
1083   SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
1084 }
1085 
handleRootSignatureElements(ArrayRef<hlsl::RootSignatureElement> Elements)1086 bool SemaHLSL::handleRootSignatureElements(
1087     ArrayRef<hlsl::RootSignatureElement> Elements) {
1088   // Define some common error handling functions
1089   bool HadError = false;
1090   auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1091                                        uint32_t UpperBound) {
1092     HadError = true;
1093     this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1094         << LowerBound << UpperBound;
1095   };
1096 
1097   auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1098                                             float LowerBound,
1099                                             float UpperBound) {
1100     HadError = true;
1101     this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1102         << llvm::formatv("{0:f}", LowerBound).sstr<6>()
1103         << llvm::formatv("{0:f}", UpperBound).sstr<6>();
1104   };
1105 
1106   auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1107     if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
1108       ReportError(Loc, 0, 0xfffffffe);
1109   };
1110 
1111   auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1112     if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
1113       ReportError(Loc, 0, 0xffffffef);
1114   };
1115 
1116   const uint32_t Version =
1117       llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
1118   const uint32_t VersionEnum = Version - 1;
1119   auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1120     HadError = true;
1121     this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
1122         << /*version minor*/ VersionEnum;
1123   };
1124 
1125   // Iterate through the elements and do basic validations
1126   for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1127     SourceLocation Loc = RootSigElem.getLocation();
1128     const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1129     if (const auto *Descriptor =
1130             std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1131       VerifyRegister(Loc, Descriptor->Reg.Number);
1132       VerifySpace(Loc, Descriptor->Space);
1133 
1134       if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
1135               Version, llvm::to_underlying(Descriptor->Flags)))
1136         ReportFlagError(Loc);
1137     } else if (const auto *Constants =
1138                    std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1139       VerifyRegister(Loc, Constants->Reg.Number);
1140       VerifySpace(Loc, Constants->Space);
1141     } else if (const auto *Sampler =
1142                    std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1143       VerifyRegister(Loc, Sampler->Reg.Number);
1144       VerifySpace(Loc, Sampler->Space);
1145 
1146       assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1147              "By construction, parseFloatParam can't produce a NaN from a "
1148              "float_literal token");
1149 
1150       if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
1151         ReportError(Loc, 0, 16);
1152       if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
1153         ReportFloatError(Loc, -16.f, 15.99);
1154     } else if (const auto *Clause =
1155                    std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1156                        &Elem)) {
1157       VerifyRegister(Loc, Clause->Reg.Number);
1158       VerifySpace(Loc, Clause->Space);
1159 
1160       if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
1161         // NumDescriptor could techincally be ~0u but that is reserved for
1162         // unbounded, so the diagnostic will not report that as a valid int
1163         // value
1164         ReportError(Loc, 1, 0xfffffffe);
1165       }
1166 
1167       if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
1168               Version, llvm::to_underlying(Clause->Type),
1169               llvm::to_underlying(Clause->Flags)))
1170         ReportFlagError(Loc);
1171     }
1172   }
1173 
1174   using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
1175   using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
1176   using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
1177 
1178   // 1. Collect RangeInfos
1179   llvm::SmallVector<InfoPairT> InfoPairs;
1180   for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1181     const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1182     if (const auto *Descriptor =
1183             std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1184       RangeInfo Info;
1185       Info.LowerBound = Descriptor->Reg.Number;
1186       Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1187 
1188       Info.Class =
1189           llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
1190       Info.Space = Descriptor->Space;
1191       Info.Visibility = Descriptor->Visibility;
1192 
1193       InfoPairs.push_back({Info, &RootSigElem});
1194     } else if (const auto *Constants =
1195                    std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1196       RangeInfo Info;
1197       Info.LowerBound = Constants->Reg.Number;
1198       Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1199 
1200       Info.Class = llvm::dxil::ResourceClass::CBuffer;
1201       Info.Space = Constants->Space;
1202       Info.Visibility = Constants->Visibility;
1203 
1204       InfoPairs.push_back({Info, &RootSigElem});
1205     } else if (const auto *Sampler =
1206                    std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1207       RangeInfo Info;
1208       Info.LowerBound = Sampler->Reg.Number;
1209       Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1210 
1211       Info.Class = llvm::dxil::ResourceClass::Sampler;
1212       Info.Space = Sampler->Space;
1213       Info.Visibility = Sampler->Visibility;
1214 
1215       InfoPairs.push_back({Info, &RootSigElem});
1216     } else if (const auto *Clause =
1217                    std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1218                        &Elem)) {
1219       RangeInfo Info;
1220       Info.LowerBound = Clause->Reg.Number;
1221       // Relevant error will have already been reported above and needs to be
1222       // fixed before we can conduct range analysis, so shortcut error return
1223       if (Clause->NumDescriptors == 0)
1224         return true;
1225       Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
1226                             ? RangeInfo::Unbounded
1227                             : Info.LowerBound + Clause->NumDescriptors -
1228                                   1; // use inclusive ranges []
1229 
1230       Info.Class = Clause->Type;
1231       Info.Space = Clause->Space;
1232 
1233       // Note: Clause does not hold the visibility this will need to
1234       InfoPairs.push_back({Info, &RootSigElem});
1235     } else if (const auto *Table =
1236                    std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1237       // Table holds the Visibility of all owned Clauses in Table, so iterate
1238       // owned Clauses and update their corresponding RangeInfo
1239       assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
1240       // The last Table->NumClauses elements of Infos are the owned Clauses
1241       // generated RangeInfo
1242       auto TableInfos =
1243           MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses);
1244       for (InfoPairT &Pair : TableInfos)
1245         Pair.first.Visibility = Table->Visibility;
1246     }
1247   }
1248 
1249   // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping
1250   llvm::sort(InfoPairs,
1251              [](InfoPairT A, InfoPairT B) { return A.first < B.first; });
1252 
1253   llvm::SmallVector<RangeInfo> Infos;
1254   for (const InfoPairT &Pair : InfoPairs)
1255     Infos.push_back(Pair.first);
1256 
1257   // Helpers to report diagnostics
1258   uint32_t DuplicateCounter = 0;
1259   using ElemPair = std::pair<const hlsl::RootSignatureElement *,
1260                              const hlsl::RootSignatureElement *>;
1261   auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter](
1262                          OverlappingRanges Overlap) -> ElemPair {
1263     // Given we sorted the InfoPairs (and by implication) Infos, and,
1264     // that Overlap.B is the item retrieved from the ResourceRange. Then it is
1265     // guarenteed that Overlap.B <= Overlap.A.
1266     //
1267     // So we will find Overlap.B first and then continue to find Overlap.A
1268     // after
1269     auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B);
1270     auto DistB = std::distance(Infos.begin(), InfoB);
1271     auto PairB = InfoPairs.begin();
1272     std::advance(PairB, DistB);
1273 
1274     auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A);
1275     // Similarily, from the property that we have sorted the RangeInfos,
1276     // all duplicates will be processed one after the other. So
1277     // DuplicateCounter can be re-used for each set of duplicates we
1278     // encounter as we handle incoming errors
1279     DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0;
1280     auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter;
1281     auto PairA = PairB;
1282     std::advance(PairA, DistA);
1283 
1284     return {PairA->second, PairB->second};
1285   };
1286 
1287   auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
1288     auto Pair = GetElemPair(Overlap);
1289     const RangeInfo *Info = Overlap.A;
1290     const hlsl::RootSignatureElement *Elem = Pair.first;
1291     const RangeInfo *OInfo = Overlap.B;
1292 
1293     auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
1294                          ? OInfo->Visibility
1295                          : Info->Visibility;
1296     this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1297         << llvm::to_underlying(Info->Class) << Info->LowerBound
1298         << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
1299         << Info->UpperBound << llvm::to_underlying(OInfo->Class)
1300         << OInfo->LowerBound
1301         << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
1302         << OInfo->UpperBound << Info->Space << CommonVis;
1303 
1304     const hlsl::RootSignatureElement *OElem = Pair.second;
1305     this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here);
1306   };
1307 
1308   // 3. Invoke find overlapping ranges
1309   llvm::SmallVector<OverlappingRanges> Overlaps =
1310       llvm::hlsl::rootsig::findOverlappingRanges(Infos);
1311   for (OverlappingRanges Overlap : Overlaps)
1312     ReportOverlap(Overlap);
1313 
1314   return Overlaps.size() != 0;
1315 }
1316 
handleRootSignatureAttr(Decl * D,const ParsedAttr & AL)1317 void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1318   if (AL.getNumArgs() != 1) {
1319     Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1320     return;
1321   }
1322 
1323   IdentifierInfo *Ident = AL.getArgAsIdent(0)->getIdentifierInfo();
1324   if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1325     if (RS->getSignatureIdent() != Ident) {
1326       Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
1327       return;
1328     }
1329 
1330     Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS;
1331     return;
1332   }
1333 
1334   LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1335   if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
1336     if (auto *SignatureDecl =
1337             dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
1338       D->addAttr(::new (getASTContext()) RootSignatureAttr(
1339           getASTContext(), AL, Ident, SignatureDecl));
1340     }
1341 }
1342 
handleNumThreadsAttr(Decl * D,const ParsedAttr & AL)1343 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1344   llvm::VersionTuple SMVersion =
1345       getASTContext().getTargetInfo().getTriple().getOSVersion();
1346   bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1347                 llvm::Triple::dxil;
1348 
1349   uint32_t ZMax = 1024;
1350   uint32_t ThreadMax = 1024;
1351   if (IsDXIL && SMVersion.getMajor() <= 4) {
1352     ZMax = 1;
1353     ThreadMax = 768;
1354   } else if (IsDXIL && SMVersion.getMajor() == 5) {
1355     ZMax = 64;
1356     ThreadMax = 1024;
1357   }
1358 
1359   uint32_t X;
1360   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
1361     return;
1362   if (X > 1024) {
1363     Diag(AL.getArgAsExpr(0)->getExprLoc(),
1364          diag::err_hlsl_numthreads_argument_oor)
1365         << 0 << 1024;
1366     return;
1367   }
1368   uint32_t Y;
1369   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
1370     return;
1371   if (Y > 1024) {
1372     Diag(AL.getArgAsExpr(1)->getExprLoc(),
1373          diag::err_hlsl_numthreads_argument_oor)
1374         << 1 << 1024;
1375     return;
1376   }
1377   uint32_t Z;
1378   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
1379     return;
1380   if (Z > ZMax) {
1381     SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
1382                  diag::err_hlsl_numthreads_argument_oor)
1383         << 2 << ZMax;
1384     return;
1385   }
1386 
1387   if (X * Y * Z > ThreadMax) {
1388     Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
1389     return;
1390   }
1391 
1392   HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1393   if (NewAttr)
1394     D->addAttr(NewAttr);
1395 }
1396 
isValidWaveSizeValue(unsigned Value)1397 static bool isValidWaveSizeValue(unsigned Value) {
1398   return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1399 }
1400 
handleWaveSizeAttr(Decl * D,const ParsedAttr & AL)1401 void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1402   // validate that the wavesize argument is a power of 2 between 4 and 128
1403   // inclusive
1404   unsigned SpelledArgsCount = AL.getNumArgs();
1405   if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1406     return;
1407 
1408   uint32_t Min;
1409   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
1410     return;
1411 
1412   uint32_t Max = 0;
1413   if (SpelledArgsCount > 1 &&
1414       !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
1415     return;
1416 
1417   uint32_t Preferred = 0;
1418   if (SpelledArgsCount > 2 &&
1419       !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
1420     return;
1421 
1422   if (SpelledArgsCount > 2) {
1423     if (!isValidWaveSizeValue(Preferred)) {
1424       Diag(AL.getArgAsExpr(2)->getExprLoc(),
1425            diag::err_attribute_power_of_two_in_range)
1426           << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1427           << Preferred;
1428       return;
1429     }
1430     // Preferred not in range.
1431     if (Preferred < Min || Preferred > Max) {
1432       Diag(AL.getArgAsExpr(2)->getExprLoc(),
1433            diag::err_attribute_power_of_two_in_range)
1434           << AL << Min << Max << Preferred;
1435       return;
1436     }
1437   } else if (SpelledArgsCount > 1) {
1438     if (!isValidWaveSizeValue(Max)) {
1439       Diag(AL.getArgAsExpr(1)->getExprLoc(),
1440            diag::err_attribute_power_of_two_in_range)
1441           << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1442       return;
1443     }
1444     if (Max < Min) {
1445       Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1446       return;
1447     } else if (Max == Min) {
1448       Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1449     }
1450   } else {
1451     if (!isValidWaveSizeValue(Min)) {
1452       Diag(AL.getArgAsExpr(0)->getExprLoc(),
1453            diag::err_attribute_power_of_two_in_range)
1454           << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1455       return;
1456     }
1457   }
1458 
1459   HLSLWaveSizeAttr *NewAttr =
1460       mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1461   if (NewAttr)
1462     D->addAttr(NewAttr);
1463 }
1464 
handleVkExtBuiltinInputAttr(Decl * D,const ParsedAttr & AL)1465 void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1466   uint32_t ID;
1467   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1468     return;
1469   D->addAttr(::new (getASTContext())
1470                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1471 }
1472 
handleVkConstantIdAttr(Decl * D,const ParsedAttr & AL)1473 void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1474   uint32_t Id;
1475   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1476     return;
1477   HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1478   if (NewAttr)
1479     D->addAttr(NewAttr);
1480 }
1481 
diagnoseInputIDType(QualType T,const ParsedAttr & AL)1482 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1483   const auto *VT = T->getAs<VectorType>();
1484 
1485   if (!T->hasUnsignedIntegerRepresentation() ||
1486       (VT && VT->getNumElements() > 3)) {
1487     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1488         << AL << "uint/uint2/uint3";
1489     return false;
1490   }
1491 
1492   return true;
1493 }
1494 
handleSV_DispatchThreadIDAttr(Decl * D,const ParsedAttr & AL)1495 void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1496   auto *VD = cast<ValueDecl>(D);
1497   if (!diagnoseInputIDType(VD->getType(), AL))
1498     return;
1499 
1500   D->addAttr(::new (getASTContext())
1501                  HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
1502 }
1503 
diagnosePositionType(QualType T,const ParsedAttr & AL)1504 bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1505   const auto *VT = T->getAs<VectorType>();
1506 
1507   if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1508     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1509         << AL << "float/float1/float2/float3/float4";
1510     return false;
1511   }
1512 
1513   return true;
1514 }
1515 
handleSV_PositionAttr(Decl * D,const ParsedAttr & AL)1516 void SemaHLSL::handleSV_PositionAttr(Decl *D, const ParsedAttr &AL) {
1517   auto *VD = cast<ValueDecl>(D);
1518   if (!diagnosePositionType(VD->getType(), AL))
1519     return;
1520 
1521   D->addAttr(::new (getASTContext()) HLSLSV_PositionAttr(getASTContext(), AL));
1522 }
1523 
handleSV_GroupThreadIDAttr(Decl * D,const ParsedAttr & AL)1524 void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1525   auto *VD = cast<ValueDecl>(D);
1526   if (!diagnoseInputIDType(VD->getType(), AL))
1527     return;
1528 
1529   D->addAttr(::new (getASTContext())
1530                  HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
1531 }
1532 
handleSV_GroupIDAttr(Decl * D,const ParsedAttr & AL)1533 void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
1534   auto *VD = cast<ValueDecl>(D);
1535   if (!diagnoseInputIDType(VD->getType(), AL))
1536     return;
1537 
1538   D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
1539 }
1540 
handlePackOffsetAttr(Decl * D,const ParsedAttr & AL)1541 void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1542   if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
1543     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
1544         << AL << "shader constant in a constant buffer";
1545     return;
1546   }
1547 
1548   uint32_t SubComponent;
1549   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
1550     return;
1551   uint32_t Component;
1552   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
1553     return;
1554 
1555   QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
1556   // Check if T is an array or struct type.
1557   // TODO: mark matrix type as aggregate type.
1558   bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1559 
1560   // Check Component is valid for T.
1561   if (Component) {
1562     unsigned Size = getASTContext().getTypeSize(T);
1563     if (IsAggregateTy || Size > 128) {
1564       Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1565       return;
1566     } else {
1567       // Make sure Component + sizeof(T) <= 4.
1568       if ((Component * 32 + Size) > 128) {
1569         Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1570         return;
1571       }
1572       QualType EltTy = T;
1573       if (const auto *VT = T->getAs<VectorType>())
1574         EltTy = VT->getElementType();
1575       unsigned Align = getASTContext().getTypeAlign(EltTy);
1576       if (Align > 32 && Component == 1) {
1577         // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1578         // So we only need to check Component 1 here.
1579         Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
1580             << Align << EltTy;
1581         return;
1582       }
1583     }
1584   }
1585 
1586   D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
1587       getASTContext(), AL, SubComponent, Component));
1588 }
1589 
handleShaderAttr(Decl * D,const ParsedAttr & AL)1590 void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
1591   StringRef Str;
1592   SourceLocation ArgLoc;
1593   if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
1594     return;
1595 
1596   llvm::Triple::EnvironmentType ShaderType;
1597   if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
1598     Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
1599         << AL << Str << ArgLoc;
1600     return;
1601   }
1602 
1603   // FIXME: check function match the shader stage.
1604 
1605   HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1606   if (NewAttr)
1607     D->addAttr(NewAttr);
1608 }
1609 
CreateHLSLAttributedResourceType(Sema & S,QualType Wrapped,ArrayRef<const Attr * > AttrList,QualType & ResType,HLSLAttributedResourceLocInfo * LocInfo)1610 bool clang::CreateHLSLAttributedResourceType(
1611     Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1612     QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1613   assert(AttrList.size() && "expected list of resource attributes");
1614 
1615   QualType ContainedTy = QualType();
1616   TypeSourceInfo *ContainedTyInfo = nullptr;
1617   SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
1618   SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
1619 
1620   HLSLAttributedResourceType::Attributes ResAttrs;
1621 
1622   bool HasResourceClass = false;
1623   for (const Attr *A : AttrList) {
1624     if (!A)
1625       continue;
1626     LocEnd = A->getRange().getEnd();
1627     switch (A->getKind()) {
1628     case attr::HLSLResourceClass: {
1629       ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
1630       if (HasResourceClass) {
1631         S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
1632                                      ? diag::warn_duplicate_attribute_exact
1633                                      : diag::warn_duplicate_attribute)
1634             << A;
1635         return false;
1636       }
1637       ResAttrs.ResourceClass = RC;
1638       HasResourceClass = true;
1639       break;
1640     }
1641     case attr::HLSLROV:
1642       if (ResAttrs.IsROV) {
1643         S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1644         return false;
1645       }
1646       ResAttrs.IsROV = true;
1647       break;
1648     case attr::HLSLRawBuffer:
1649       if (ResAttrs.RawBuffer) {
1650         S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1651         return false;
1652       }
1653       ResAttrs.RawBuffer = true;
1654       break;
1655     case attr::HLSLContainedType: {
1656       const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
1657       QualType Ty = CTAttr->getType();
1658       if (!ContainedTy.isNull()) {
1659         S.Diag(A->getLocation(), ContainedTy == Ty
1660                                      ? diag::warn_duplicate_attribute_exact
1661                                      : diag::warn_duplicate_attribute)
1662             << A;
1663         return false;
1664       }
1665       ContainedTy = Ty;
1666       ContainedTyInfo = CTAttr->getTypeLoc();
1667       break;
1668     }
1669     default:
1670       llvm_unreachable("unhandled resource attribute type");
1671     }
1672   }
1673 
1674   if (!HasResourceClass) {
1675     S.Diag(AttrList.back()->getRange().getEnd(),
1676            diag::err_hlsl_missing_resource_class);
1677     return false;
1678   }
1679 
1680   ResType = S.getASTContext().getHLSLAttributedResourceType(
1681       Wrapped, ContainedTy, ResAttrs);
1682 
1683   if (LocInfo && ContainedTyInfo) {
1684     LocInfo->Range = SourceRange(LocBegin, LocEnd);
1685     LocInfo->ContainedTyInfo = ContainedTyInfo;
1686   }
1687   return true;
1688 }
1689 
1690 // Validates and creates an HLSL attribute that is applied as type attribute on
1691 // HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
1692 // the end of the declaration they are applied to the declaration type by
1693 // wrapping it in HLSLAttributedResourceType.
handleResourceTypeAttr(QualType T,const ParsedAttr & AL)1694 bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
1695   // only allow resource type attributes on intangible types
1696   if (!T->isHLSLResourceType()) {
1697     Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
1698         << AL << getASTContext().HLSLResourceTy;
1699     return false;
1700   }
1701 
1702   // validate number of arguments
1703   if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
1704     return false;
1705 
1706   Attr *A = nullptr;
1707 
1708   AttributeCommonInfo ACI(
1709       AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
1710       AttributeCommonInfo::NoSemaHandlerAttribute,
1711       {
1712           AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
1713           false /*IsRegularKeywordAttribute*/
1714       });
1715 
1716   switch (AL.getKind()) {
1717   case ParsedAttr::AT_HLSLResourceClass: {
1718     if (!AL.isArgIdent(0)) {
1719       Diag(AL.getLoc(), diag::err_attribute_argument_type)
1720           << AL << AANT_ArgumentIdentifier;
1721       return false;
1722     }
1723 
1724     IdentifierLoc *Loc = AL.getArgAsIdent(0);
1725     StringRef Identifier = Loc->getIdentifierInfo()->getName();
1726     SourceLocation ArgLoc = Loc->getLoc();
1727 
1728     // Validate resource class value
1729     ResourceClass RC;
1730     if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
1731       Diag(ArgLoc, diag::warn_attribute_type_not_supported)
1732           << "ResourceClass" << Identifier;
1733       return false;
1734     }
1735     A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI);
1736     break;
1737   }
1738 
1739   case ParsedAttr::AT_HLSLROV:
1740     A = HLSLROVAttr::Create(getASTContext(), ACI);
1741     break;
1742 
1743   case ParsedAttr::AT_HLSLRawBuffer:
1744     A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
1745     break;
1746 
1747   case ParsedAttr::AT_HLSLContainedType: {
1748     if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
1749       Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1750       return false;
1751     }
1752 
1753     TypeSourceInfo *TSI = nullptr;
1754     QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
1755     assert(TSI && "no type source info for attribute argument");
1756     if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
1757                                     diag::err_incomplete_type))
1758       return false;
1759     A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI);
1760     break;
1761   }
1762 
1763   default:
1764     llvm_unreachable("unhandled HLSL attribute");
1765   }
1766 
1767   HLSLResourcesTypeAttrs.emplace_back(A);
1768   return true;
1769 }
1770 
1771 // Combines all resource type attributes and creates HLSLAttributedResourceType.
ProcessResourceTypeAttributes(QualType CurrentType)1772 QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
1773   if (!HLSLResourcesTypeAttrs.size())
1774     return CurrentType;
1775 
1776   QualType QT = CurrentType;
1777   HLSLAttributedResourceLocInfo LocInfo;
1778   if (CreateHLSLAttributedResourceType(SemaRef, CurrentType,
1779                                        HLSLResourcesTypeAttrs, QT, &LocInfo)) {
1780     const HLSLAttributedResourceType *RT =
1781         cast<HLSLAttributedResourceType>(QT.getTypePtr());
1782 
1783     // Temporarily store TypeLoc information for the new type.
1784     // It will be transferred to HLSLAttributesResourceTypeLoc
1785     // shortly after the type is created by TypeSpecLocFiller which
1786     // will call the TakeLocForHLSLAttribute method below.
1787     LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
1788   }
1789   HLSLResourcesTypeAttrs.clear();
1790   return QT;
1791 }
1792 
1793 // Returns source location for the HLSLAttributedResourceType
1794 HLSLAttributedResourceLocInfo
TakeLocForHLSLAttribute(const HLSLAttributedResourceType * RT)1795 SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
1796   HLSLAttributedResourceLocInfo LocInfo = {};
1797   auto I = LocsForHLSLAttributedResources.find(RT);
1798   if (I != LocsForHLSLAttributedResources.end()) {
1799     LocInfo = I->second;
1800     LocsForHLSLAttributedResources.erase(I);
1801     return LocInfo;
1802   }
1803   LocInfo.Range = SourceRange();
1804   return LocInfo;
1805 }
1806 
1807 // Walks though the global variable declaration, collects all resource binding
1808 // requirements and adds them to Bindings
collectResourceBindingsOnUserRecordDecl(const VarDecl * VD,const RecordType * RT)1809 void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
1810                                                        const RecordType *RT) {
1811   const RecordDecl *RD = RT->getDecl();
1812   for (FieldDecl *FD : RD->fields()) {
1813     const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
1814 
1815     // Unwrap arrays
1816     // FIXME: Calculate array size while unwrapping
1817     assert(!Ty->isIncompleteArrayType() &&
1818            "incomplete arrays inside user defined types are not supported");
1819     while (Ty->isConstantArrayType()) {
1820       const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
1821       Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
1822     }
1823 
1824     if (!Ty->isRecordType())
1825       continue;
1826 
1827     if (const HLSLAttributedResourceType *AttrResType =
1828             HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
1829       // Add a new DeclBindingInfo to Bindings if it does not already exist
1830       ResourceClass RC = AttrResType->getAttrs().ResourceClass;
1831       DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
1832       if (!DBI)
1833         Bindings.addDeclBindingInfo(VD, RC);
1834     } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
1835       // Recursively scan embedded struct or class; it would be nice to do this
1836       // without recursion, but tricky to correctly calculate the size of the
1837       // binding, which is something we are probably going to need to do later
1838       // on. Hopefully nesting of structs in structs too many levels is
1839       // unlikely.
1840       collectResourceBindingsOnUserRecordDecl(VD, RT);
1841     }
1842   }
1843 }
1844 
1845 // Diagnose localized register binding errors for a single binding; does not
1846 // diagnose resource binding on user record types, that will be done later
1847 // in processResourceBindingOnDecl based on the information collected in
1848 // collectResourceBindingsOnVarDecl.
1849 // Returns false if the register binding is not valid.
DiagnoseLocalRegisterBinding(Sema & S,SourceLocation & ArgLoc,Decl * D,RegisterType RegType,bool SpecifiedSpace)1850 static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
1851                                          Decl *D, RegisterType RegType,
1852                                          bool SpecifiedSpace) {
1853   int RegTypeNum = static_cast<int>(RegType);
1854 
1855   // check if the decl type is groupshared
1856   if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
1857     S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1858     return false;
1859   }
1860 
1861   // Cbuffers and Tbuffers are HLSLBufferDecl types
1862   if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
1863     ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
1864                                                      : ResourceClass::SRV;
1865     if (RegType == getRegisterType(RC))
1866       return true;
1867 
1868     S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
1869         << RegTypeNum;
1870     return false;
1871   }
1872 
1873   // Samplers, UAVs, and SRVs are VarDecl types
1874   assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
1875   VarDecl *VD = cast<VarDecl>(D);
1876 
1877   // Resource
1878   if (const HLSLAttributedResourceType *AttrResType =
1879           HLSLAttributedResourceType::findHandleTypeOnResource(
1880               VD->getType().getTypePtr())) {
1881     if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
1882       return true;
1883 
1884     S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
1885         << RegTypeNum;
1886     return false;
1887   }
1888 
1889   const clang::Type *Ty = VD->getType().getTypePtr();
1890   while (Ty->isArrayType())
1891     Ty = Ty->getArrayElementTypeNoTypeQual();
1892 
1893   // Basic types
1894   if (Ty->isArithmeticType() || Ty->isVectorType()) {
1895     bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
1896     if (SpecifiedSpace && !DeclaredInCOrTBuffer)
1897       S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
1898 
1899     if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) ||
1900                                   Ty->isFloatingType() || Ty->isVectorType())) {
1901       // Register annotation on default constant buffer declaration ($Globals)
1902       if (RegType == RegisterType::CBuffer)
1903         S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
1904       else if (RegType != RegisterType::C)
1905         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1906       else
1907         return true;
1908     } else {
1909       if (RegType == RegisterType::C)
1910         S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
1911       else
1912         S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1913     }
1914     return false;
1915   }
1916   if (Ty->isRecordType())
1917     // RecordTypes will be diagnosed in processResourceBindingOnDecl
1918     // that is called from ActOnVariableDeclarator
1919     return true;
1920 
1921   // Anything else is an error
1922   S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1923   return false;
1924 }
1925 
ValidateMultipleRegisterAnnotations(Sema & S,Decl * TheDecl,RegisterType regType)1926 static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
1927                                                 RegisterType regType) {
1928   // make sure that there are no two register annotations
1929   // applied to the decl with the same register type
1930   bool RegisterTypesDetected[5] = {false};
1931   RegisterTypesDetected[static_cast<int>(regType)] = true;
1932 
1933   for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
1934     if (HLSLResourceBindingAttr *attr =
1935             dyn_cast<HLSLResourceBindingAttr>(*it)) {
1936 
1937       RegisterType otherRegType = attr->getRegisterType();
1938       if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
1939         int otherRegTypeNum = static_cast<int>(otherRegType);
1940         S.Diag(TheDecl->getLocation(),
1941                diag::err_hlsl_duplicate_register_annotation)
1942             << otherRegTypeNum;
1943         return false;
1944       }
1945       RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
1946     }
1947   }
1948   return true;
1949 }
1950 
DiagnoseHLSLRegisterAttribute(Sema & S,SourceLocation & ArgLoc,Decl * D,RegisterType RegType,bool SpecifiedSpace)1951 static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
1952                                           Decl *D, RegisterType RegType,
1953                                           bool SpecifiedSpace) {
1954 
1955   // exactly one of these two types should be set
1956   assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
1957           (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
1958          "expecting VarDecl or HLSLBufferDecl");
1959 
1960   // check if the declaration contains resource matching the register type
1961   if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
1962     return false;
1963 
1964   // next, if multiple register annotations exist, check that none conflict.
1965   return ValidateMultipleRegisterAnnotations(S, D, RegType);
1966 }
1967 
handleResourceBindingAttr(Decl * TheDecl,const ParsedAttr & AL)1968 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
1969   if (isa<VarDecl>(TheDecl)) {
1970     if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
1971                                     cast<ValueDecl>(TheDecl)->getType(),
1972                                     diag::err_incomplete_type))
1973       return;
1974   }
1975 
1976   StringRef Slot = "";
1977   StringRef Space = "";
1978   SourceLocation SlotLoc, SpaceLoc;
1979 
1980   if (!AL.isArgIdent(0)) {
1981     Diag(AL.getLoc(), diag::err_attribute_argument_type)
1982         << AL << AANT_ArgumentIdentifier;
1983     return;
1984   }
1985   IdentifierLoc *Loc = AL.getArgAsIdent(0);
1986 
1987   if (AL.getNumArgs() == 2) {
1988     Slot = Loc->getIdentifierInfo()->getName();
1989     SlotLoc = Loc->getLoc();
1990     if (!AL.isArgIdent(1)) {
1991       Diag(AL.getLoc(), diag::err_attribute_argument_type)
1992           << AL << AANT_ArgumentIdentifier;
1993       return;
1994     }
1995     Loc = AL.getArgAsIdent(1);
1996     Space = Loc->getIdentifierInfo()->getName();
1997     SpaceLoc = Loc->getLoc();
1998   } else {
1999     StringRef Str = Loc->getIdentifierInfo()->getName();
2000     if (Str.starts_with("space")) {
2001       Space = Str;
2002       SpaceLoc = Loc->getLoc();
2003     } else {
2004       Slot = Str;
2005       SlotLoc = Loc->getLoc();
2006       Space = "space0";
2007     }
2008   }
2009 
2010   RegisterType RegType = RegisterType::SRV;
2011   std::optional<unsigned> SlotNum;
2012   unsigned SpaceNum = 0;
2013 
2014   // Validate slot
2015   if (!Slot.empty()) {
2016     if (!convertToRegisterType(Slot, &RegType)) {
2017       Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
2018       return;
2019     }
2020     if (RegType == RegisterType::I) {
2021       Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i);
2022       return;
2023     }
2024     StringRef SlotNumStr = Slot.substr(1);
2025     unsigned N;
2026     if (SlotNumStr.getAsInteger(10, N)) {
2027       Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
2028       return;
2029     }
2030     SlotNum = N;
2031   }
2032 
2033   // Validate space
2034   if (!Space.starts_with("space")) {
2035     Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2036     return;
2037   }
2038   StringRef SpaceNumStr = Space.substr(5);
2039   if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
2040     Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2041     return;
2042   }
2043 
2044   // If we have slot, diagnose it is the right register type for the decl
2045   if (SlotNum.has_value())
2046     if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType,
2047                                        !SpaceLoc.isInvalid()))
2048       return;
2049 
2050   HLSLResourceBindingAttr *NewAttr =
2051       HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
2052   if (NewAttr) {
2053     NewAttr->setBinding(RegType, SlotNum, SpaceNum);
2054     TheDecl->addAttr(NewAttr);
2055   }
2056 }
2057 
handleParamModifierAttr(Decl * D,const ParsedAttr & AL)2058 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2059   HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2060       D, AL,
2061       static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2062   if (NewAttr)
2063     D->addAttr(NewAttr);
2064 }
2065 
2066 namespace {
2067 
2068 /// This class implements HLSL availability diagnostics for default
2069 /// and relaxed mode
2070 ///
2071 /// The goal of this diagnostic is to emit an error or warning when an
2072 /// unavailable API is found in code that is reachable from the shader
2073 /// entry function or from an exported function (when compiling a shader
2074 /// library).
2075 ///
2076 /// This is done by traversing the AST of all shader entry point functions
2077 /// and of all exported functions, and any functions that are referenced
2078 /// from this AST. In other words, any functions that are reachable from
2079 /// the entry points.
2080 class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2081   Sema &SemaRef;
2082 
2083   // Stack of functions to be scaned
2084   llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2085 
2086   // Tracks which environments functions have been scanned in.
2087   //
2088   // Maps FunctionDecl to an unsigned number that represents the set of shader
2089   // environments the function has been scanned for.
2090   // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2091   // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2092   // (verified by static_asserts in Triple.cpp), we can use it to index
2093   // individual bits in the set, as long as we shift the values to start with 0
2094   // by subtracting the value of llvm::Triple::Pixel first.
2095   //
2096   // The N'th bit in the set will be set if the function has been scanned
2097   // in shader environment whose llvm::Triple::EnvironmentType integer value
2098   // equals (llvm::Triple::Pixel + N).
2099   //
2100   // For example, if a function has been scanned in compute and pixel stage
2101   // environment, the value will be 0x21 (100001 binary) because:
2102   //
2103   //   (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2104   //   (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2105   //
2106   // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2107   // been scanned in any environment.
2108   llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2109 
2110   // Do not access these directly, use the get/set methods below to make
2111   // sure the values are in sync
2112   llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2113   unsigned CurrentShaderStageBit;
2114 
2115   // True if scanning a function that was already scanned in a different
2116   // shader stage context, and therefore we should not report issues that
2117   // depend only on shader model version because they would be duplicate.
2118   bool ReportOnlyShaderStageIssues;
2119 
2120   // Helper methods for dealing with current stage context / environment
SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType)2121   void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2122     static_assert(sizeof(unsigned) >= 4);
2123     assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2124     assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2125            "ShaderType is too big for this bitmap"); // 31 is reserved for
2126                                                      // "unknown"
2127 
2128     unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2129     CurrentShaderEnvironment = ShaderType;
2130     CurrentShaderStageBit = (1 << bitmapIndex);
2131   }
2132 
SetUnknownShaderStageContext()2133   void SetUnknownShaderStageContext() {
2134     CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2135     CurrentShaderStageBit = (1 << 31);
2136   }
2137 
GetCurrentShaderEnvironment() const2138   llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2139     return CurrentShaderEnvironment;
2140   }
2141 
InUnknownShaderStageContext() const2142   bool InUnknownShaderStageContext() const {
2143     return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2144   }
2145 
2146   // Helper methods for dealing with shader stage bitmap
AddToScannedFunctions(const FunctionDecl * FD)2147   void AddToScannedFunctions(const FunctionDecl *FD) {
2148     unsigned &ScannedStages = ScannedDecls[FD];
2149     ScannedStages |= CurrentShaderStageBit;
2150   }
2151 
GetScannedStages(const FunctionDecl * FD)2152   unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2153 
WasAlreadyScannedInCurrentStage(const FunctionDecl * FD)2154   bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2155     return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
2156   }
2157 
WasAlreadyScannedInCurrentStage(unsigned ScannerStages)2158   bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2159     return ScannerStages & CurrentShaderStageBit;
2160   }
2161 
NeverBeenScanned(unsigned ScannedStages)2162   static bool NeverBeenScanned(unsigned ScannedStages) {
2163     return ScannedStages == 0;
2164   }
2165 
2166   // Scanning methods
2167   void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2168   void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2169                              SourceRange Range);
2170   const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2171   bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2172 
2173 public:
DiagnoseHLSLAvailability(Sema & SemaRef)2174   DiagnoseHLSLAvailability(Sema &SemaRef)
2175       : SemaRef(SemaRef),
2176         CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2177         CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2178 
2179   // AST traversal methods
2180   void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2181   void RunOnFunction(const FunctionDecl *FD);
2182 
VisitDeclRefExpr(DeclRefExpr * DRE)2183   bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2184     FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
2185     if (FD)
2186       HandleFunctionOrMethodRef(FD, DRE);
2187     return true;
2188   }
2189 
VisitMemberExpr(MemberExpr * ME)2190   bool VisitMemberExpr(MemberExpr *ME) override {
2191     FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
2192     if (FD)
2193       HandleFunctionOrMethodRef(FD, ME);
2194     return true;
2195   }
2196 };
2197 
HandleFunctionOrMethodRef(FunctionDecl * FD,Expr * RefExpr)2198 void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2199                                                          Expr *RefExpr) {
2200   assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2201          "expected DeclRefExpr or MemberExpr");
2202 
2203   // has a definition -> add to stack to be scanned
2204   const FunctionDecl *FDWithBody = nullptr;
2205   if (FD->hasBody(FDWithBody)) {
2206     if (!WasAlreadyScannedInCurrentStage(FDWithBody))
2207       DeclsToScan.push_back(FDWithBody);
2208     return;
2209   }
2210 
2211   // no body -> diagnose availability
2212   const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
2213   if (AA)
2214     CheckDeclAvailability(
2215         FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2216 }
2217 
RunOnTranslationUnit(const TranslationUnitDecl * TU)2218 void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2219     const TranslationUnitDecl *TU) {
2220 
2221   // Iterate over all shader entry functions and library exports, and for those
2222   // that have a body (definiton), run diag scan on each, setting appropriate
2223   // shader environment context based on whether it is a shader entry function
2224   // or an exported function. Exported functions can be in namespaces and in
2225   // export declarations so we need to scan those declaration contexts as well.
2226   llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2227   DeclContextsToScan.push_back(TU);
2228 
2229   while (!DeclContextsToScan.empty()) {
2230     const DeclContext *DC = DeclContextsToScan.pop_back_val();
2231     for (auto &D : DC->decls()) {
2232       // do not scan implicit declaration generated by the implementation
2233       if (D->isImplicit())
2234         continue;
2235 
2236       // for namespace or export declaration add the context to the list to be
2237       // scanned later
2238       if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
2239         DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
2240         continue;
2241       }
2242 
2243       // skip over other decls or function decls without body
2244       const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
2245       if (!FD || !FD->isThisDeclarationADefinition())
2246         continue;
2247 
2248       // shader entry point
2249       if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2250         SetShaderStageContext(ShaderAttr->getType());
2251         RunOnFunction(FD);
2252         continue;
2253       }
2254       // exported library function
2255       // FIXME: replace this loop with external linkage check once issue #92071
2256       // is resolved
2257       bool isExport = FD->isInExportDeclContext();
2258       if (!isExport) {
2259         for (const auto *Redecl : FD->redecls()) {
2260           if (Redecl->isInExportDeclContext()) {
2261             isExport = true;
2262             break;
2263           }
2264         }
2265       }
2266       if (isExport) {
2267         SetUnknownShaderStageContext();
2268         RunOnFunction(FD);
2269         continue;
2270       }
2271     }
2272   }
2273 }
2274 
RunOnFunction(const FunctionDecl * FD)2275 void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2276   assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2277   DeclsToScan.push_back(FD);
2278 
2279   while (!DeclsToScan.empty()) {
2280     // Take one decl from the stack and check it by traversing its AST.
2281     // For any CallExpr found during the traversal add it's callee to the top of
2282     // the stack to be processed next. Functions already processed are stored in
2283     // ScannedDecls.
2284     const FunctionDecl *FD = DeclsToScan.pop_back_val();
2285 
2286     // Decl was already scanned
2287     const unsigned ScannedStages = GetScannedStages(FD);
2288     if (WasAlreadyScannedInCurrentStage(ScannedStages))
2289       continue;
2290 
2291     ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2292 
2293     AddToScannedFunctions(FD);
2294     TraverseStmt(FD->getBody());
2295   }
2296 }
2297 
HasMatchingEnvironmentOrNone(const AvailabilityAttr * AA)2298 bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2299     const AvailabilityAttr *AA) {
2300   IdentifierInfo *IIEnvironment = AA->getEnvironment();
2301   if (!IIEnvironment)
2302     return true;
2303 
2304   llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2305   if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2306     return false;
2307 
2308   llvm::Triple::EnvironmentType AttrEnv =
2309       AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
2310 
2311   return CurrentEnv == AttrEnv;
2312 }
2313 
2314 const AvailabilityAttr *
FindAvailabilityAttr(const Decl * D)2315 DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2316   AvailabilityAttr const *PartialMatch = nullptr;
2317   // Check each AvailabilityAttr to find the one for this platform.
2318   // For multiple attributes with the same platform try to find one for this
2319   // environment.
2320   for (const auto *A : D->attrs()) {
2321     if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
2322       StringRef AttrPlatform = Avail->getPlatform()->getName();
2323       StringRef TargetPlatform =
2324           SemaRef.getASTContext().getTargetInfo().getPlatformName();
2325 
2326       // Match the platform name.
2327       if (AttrPlatform == TargetPlatform) {
2328         // Find the best matching attribute for this environment
2329         if (HasMatchingEnvironmentOrNone(Avail))
2330           return Avail;
2331         PartialMatch = Avail;
2332       }
2333     }
2334   }
2335   return PartialMatch;
2336 }
2337 
2338 // Check availability against target shader model version and current shader
2339 // stage and emit diagnostic
CheckDeclAvailability(NamedDecl * D,const AvailabilityAttr * AA,SourceRange Range)2340 void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2341                                                      const AvailabilityAttr *AA,
2342                                                      SourceRange Range) {
2343 
2344   IdentifierInfo *IIEnv = AA->getEnvironment();
2345 
2346   if (!IIEnv) {
2347     // The availability attribute does not have environment -> it depends only
2348     // on shader model version and not on specific the shader stage.
2349 
2350     // Skip emitting the diagnostics if the diagnostic mode is set to
2351     // strict (-fhlsl-strict-availability) because all relevant diagnostics
2352     // were already emitted in the DiagnoseUnguardedAvailability scan
2353     // (SemaAvailability.cpp).
2354     if (SemaRef.getLangOpts().HLSLStrictAvailability)
2355       return;
2356 
2357     // Do not report shader-stage-independent issues if scanning a function
2358     // that was already scanned in a different shader stage context (they would
2359     // be duplicate)
2360     if (ReportOnlyShaderStageIssues)
2361       return;
2362 
2363   } else {
2364     // The availability attribute has environment -> we need to know
2365     // the current stage context to property diagnose it.
2366     if (InUnknownShaderStageContext())
2367       return;
2368   }
2369 
2370   // Check introduced version and if environment matches
2371   bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2372   VersionTuple Introduced = AA->getIntroduced();
2373   VersionTuple TargetVersion =
2374       SemaRef.Context.getTargetInfo().getPlatformMinVersion();
2375 
2376   if (TargetVersion >= Introduced && EnvironmentMatches)
2377     return;
2378 
2379   // Emit diagnostic message
2380   const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2381   llvm::StringRef PlatformName(
2382       AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
2383 
2384   llvm::StringRef CurrentEnvStr =
2385       llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
2386 
2387   llvm::StringRef AttrEnvStr =
2388       AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2389   bool UseEnvironment = !AttrEnvStr.empty();
2390 
2391   if (EnvironmentMatches) {
2392     SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
2393         << Range << D << PlatformName << Introduced.getAsString()
2394         << UseEnvironment << CurrentEnvStr;
2395   } else {
2396     SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
2397         << Range << D;
2398   }
2399 
2400   SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
2401       << D << PlatformName << Introduced.getAsString()
2402       << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2403       << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2404 }
2405 
2406 } // namespace
2407 
ActOnEndOfTranslationUnit(TranslationUnitDecl * TU)2408 void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
2409   // process default CBuffer - create buffer layout struct and invoke codegenCGH
2410   if (!DefaultCBufferDecls.empty()) {
2411     HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
2412         SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
2413         DefaultCBufferDecls);
2414     addImplicitBindingAttrToBuffer(SemaRef, DefaultCBuffer,
2415                                    getNextImplicitBindingOrderID());
2416     SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
2417     createHostLayoutStructForBuffer(SemaRef, DefaultCBuffer);
2418 
2419     // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2420     for (const Decl *VD : DefaultCBufferDecls) {
2421       const HLSLResourceBindingAttr *RBA =
2422           VD->getAttr<HLSLResourceBindingAttr>();
2423       if (RBA && RBA->hasRegisterSlot() &&
2424           RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2425         DefaultCBuffer->setHasValidPackoffset(true);
2426         break;
2427       }
2428     }
2429 
2430     DeclGroupRef DG(DefaultCBuffer);
2431     SemaRef.Consumer.HandleTopLevelDecl(DG);
2432   }
2433   diagnoseAvailabilityViolations(TU);
2434 }
2435 
diagnoseAvailabilityViolations(TranslationUnitDecl * TU)2436 void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2437   // Skip running the diagnostics scan if the diagnostic mode is
2438   // strict (-fhlsl-strict-availability) and the target shader stage is known
2439   // because all relevant diagnostics were already emitted in the
2440   // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2441   const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2442   if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2443       TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2444     return;
2445 
2446   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2447 }
2448 
CheckAllArgsHaveSameType(Sema * S,CallExpr * TheCall)2449 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2450   assert(TheCall->getNumArgs() > 1);
2451   QualType ArgTy0 = TheCall->getArg(0)->getType();
2452 
2453   for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2454     if (!S->getASTContext().hasSameUnqualifiedType(
2455             ArgTy0, TheCall->getArg(I)->getType())) {
2456       S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
2457           << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2458           << SourceRange(TheCall->getArg(0)->getBeginLoc(),
2459                          TheCall->getArg(N - 1)->getEndLoc());
2460       return true;
2461     }
2462   }
2463   return false;
2464 }
2465 
CheckArgTypeMatches(Sema * S,Expr * Arg,QualType ExpectedType)2466 static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
2467   QualType ArgType = Arg->getType();
2468   if (!S->getASTContext().hasSameUnqualifiedType(ArgType, ExpectedType)) {
2469     S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
2470         << ArgType << ExpectedType << 1 << 0 << 0;
2471     return true;
2472   }
2473   return false;
2474 }
2475 
CheckAllArgTypesAreCorrect(Sema * S,CallExpr * TheCall,llvm::function_ref<bool (Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)> Check)2476 static bool CheckAllArgTypesAreCorrect(
2477     Sema *S, CallExpr *TheCall,
2478     llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
2479                             clang::QualType PassedType)>
2480         Check) {
2481   for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
2482     Expr *Arg = TheCall->getArg(I);
2483     if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2484       return true;
2485   }
2486   return false;
2487 }
2488 
CheckFloatOrHalfRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2489 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
2490                                            int ArgOrdinal,
2491                                            clang::QualType PassedType) {
2492   clang::QualType BaseType =
2493       PassedType->isVectorType()
2494           ? PassedType->castAs<clang::VectorType>()->getElementType()
2495           : PassedType;
2496   if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
2497     return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2498            << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
2499            << /* half or float */ 2 << PassedType;
2500   return false;
2501 }
2502 
CheckModifiableLValue(Sema * S,CallExpr * TheCall,unsigned ArgIndex)2503 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
2504                                   unsigned ArgIndex) {
2505   auto *Arg = TheCall->getArg(ArgIndex);
2506   SourceLocation OrigLoc = Arg->getExprLoc();
2507   if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
2508       Expr::MLV_Valid)
2509     return false;
2510   S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
2511   return true;
2512 }
2513 
CheckNoDoubleVectors(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2514 static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
2515                                  clang::QualType PassedType) {
2516   const auto *VecTy = PassedType->getAs<VectorType>();
2517   if (!VecTy)
2518     return false;
2519 
2520   if (VecTy->getElementType()->isDoubleType())
2521     return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2522            << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
2523            << PassedType;
2524   return false;
2525 }
2526 
CheckFloatingOrIntRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2527 static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
2528                                              int ArgOrdinal,
2529                                              clang::QualType PassedType) {
2530   if (!PassedType->hasIntegerRepresentation() &&
2531       !PassedType->hasFloatingRepresentation())
2532     return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2533            << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
2534            << /* fp */ 1 << PassedType;
2535   return false;
2536 }
2537 
CheckUnsignedIntVecRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2538 static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
2539                                               int ArgOrdinal,
2540                                               clang::QualType PassedType) {
2541   if (auto *VecTy = PassedType->getAs<VectorType>())
2542     if (VecTy->getElementType()->isUnsignedIntegerType())
2543       return false;
2544 
2545   return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2546          << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
2547          << PassedType;
2548 }
2549 
2550 // checks for unsigned ints of all sizes
CheckUnsignedIntRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2551 static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
2552                                            int ArgOrdinal,
2553                                            clang::QualType PassedType) {
2554   if (!PassedType->hasUnsignedIntegerRepresentation())
2555     return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2556            << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
2557            << /* no fp */ 0 << PassedType;
2558   return false;
2559 }
2560 
SetElementTypeAsReturnType(Sema * S,CallExpr * TheCall,QualType ReturnType)2561 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
2562                                        QualType ReturnType) {
2563   auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
2564   if (VecTyA)
2565     ReturnType =
2566         S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements());
2567 
2568   TheCall->setType(ReturnType);
2569 }
2570 
CheckScalarOrVector(Sema * S,CallExpr * TheCall,QualType Scalar,unsigned ArgIndex)2571 static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
2572                                 unsigned ArgIndex) {
2573   assert(TheCall->getNumArgs() >= ArgIndex);
2574   QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2575   auto *VTy = ArgType->getAs<VectorType>();
2576   // not the scalar or vector<scalar>
2577   if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
2578         (VTy &&
2579          S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
2580     S->Diag(TheCall->getArg(0)->getBeginLoc(),
2581             diag::err_typecheck_expect_scalar_or_vector)
2582         << ArgType << Scalar;
2583     return true;
2584   }
2585   return false;
2586 }
2587 
CheckAnyScalarOrVector(Sema * S,CallExpr * TheCall,unsigned ArgIndex)2588 static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
2589                                    unsigned ArgIndex) {
2590   assert(TheCall->getNumArgs() >= ArgIndex);
2591   QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2592   auto *VTy = ArgType->getAs<VectorType>();
2593   // not the scalar or vector<scalar>
2594   if (!(ArgType->isScalarType() ||
2595         (VTy && VTy->getElementType()->isScalarType()))) {
2596     S->Diag(TheCall->getArg(0)->getBeginLoc(),
2597             diag::err_typecheck_expect_any_scalar_or_vector)
2598         << ArgType << 1;
2599     return true;
2600   }
2601   return false;
2602 }
2603 
CheckWaveActive(Sema * S,CallExpr * TheCall)2604 static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
2605   QualType BoolType = S->getASTContext().BoolTy;
2606   assert(TheCall->getNumArgs() >= 1);
2607   QualType ArgType = TheCall->getArg(0)->getType();
2608   auto *VTy = ArgType->getAs<VectorType>();
2609   // is the bool or vector<bool>
2610   if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
2611       (VTy &&
2612        S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
2613     S->Diag(TheCall->getArg(0)->getBeginLoc(),
2614             diag::err_typecheck_expect_any_scalar_or_vector)
2615         << ArgType << 0;
2616     return true;
2617   }
2618   return false;
2619 }
2620 
CheckBoolSelect(Sema * S,CallExpr * TheCall)2621 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
2622   assert(TheCall->getNumArgs() == 3);
2623   Expr *Arg1 = TheCall->getArg(1);
2624   Expr *Arg2 = TheCall->getArg(2);
2625   if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
2626     S->Diag(TheCall->getBeginLoc(),
2627             diag::err_typecheck_call_different_arg_types)
2628         << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2629         << Arg2->getSourceRange();
2630     return true;
2631   }
2632 
2633   TheCall->setType(Arg1->getType());
2634   return false;
2635 }
2636 
CheckVectorSelect(Sema * S,CallExpr * TheCall)2637 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
2638   assert(TheCall->getNumArgs() == 3);
2639   Expr *Arg1 = TheCall->getArg(1);
2640   QualType Arg1Ty = Arg1->getType();
2641   Expr *Arg2 = TheCall->getArg(2);
2642   QualType Arg2Ty = Arg2->getType();
2643 
2644   QualType Arg1ScalarTy = Arg1Ty;
2645   if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
2646     Arg1ScalarTy = VTy->getElementType();
2647 
2648   QualType Arg2ScalarTy = Arg2Ty;
2649   if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
2650     Arg2ScalarTy = VTy->getElementType();
2651 
2652   if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
2653     S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
2654         << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
2655 
2656   QualType Arg0Ty = TheCall->getArg(0)->getType();
2657   unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
2658   unsigned Arg1Length = Arg1Ty->isVectorType()
2659                             ? Arg1Ty->getAs<VectorType>()->getNumElements()
2660                             : 0;
2661   unsigned Arg2Length = Arg2Ty->isVectorType()
2662                             ? Arg2Ty->getAs<VectorType>()->getNumElements()
2663                             : 0;
2664   if (Arg1Length > 0 && Arg0Length != Arg1Length) {
2665     S->Diag(TheCall->getBeginLoc(),
2666             diag::err_typecheck_vector_lengths_not_equal)
2667         << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
2668         << Arg1->getSourceRange();
2669     return true;
2670   }
2671 
2672   if (Arg2Length > 0 && Arg0Length != Arg2Length) {
2673     S->Diag(TheCall->getBeginLoc(),
2674             diag::err_typecheck_vector_lengths_not_equal)
2675         << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
2676         << Arg2->getSourceRange();
2677     return true;
2678   }
2679 
2680   TheCall->setType(
2681       S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
2682   return false;
2683 }
2684 
CheckResourceHandle(Sema * S,CallExpr * TheCall,unsigned ArgIndex,llvm::function_ref<bool (const HLSLAttributedResourceType * ResType)> Check=nullptr)2685 static bool CheckResourceHandle(
2686     Sema *S, CallExpr *TheCall, unsigned ArgIndex,
2687     llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
2688         nullptr) {
2689   assert(TheCall->getNumArgs() >= ArgIndex);
2690   QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2691   const HLSLAttributedResourceType *ResTy =
2692       ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
2693   if (!ResTy) {
2694     S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
2695             diag::err_typecheck_expect_hlsl_resource)
2696         << ArgType;
2697     return true;
2698   }
2699   if (Check && Check(ResTy)) {
2700     S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
2701             diag::err_invalid_hlsl_resource_type)
2702         << ArgType;
2703     return true;
2704   }
2705   return false;
2706 }
2707 
2708 // Note: returning true in this case results in CheckBuiltinFunctionCall
2709 // returning an ExprError
CheckBuiltinFunctionCall(unsigned BuiltinID,CallExpr * TheCall)2710 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
2711   switch (BuiltinID) {
2712   case Builtin::BI__builtin_hlsl_adduint64: {
2713     if (SemaRef.checkArgCount(TheCall, 2))
2714       return true;
2715 
2716     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2717                                    CheckUnsignedIntVecRepresentation))
2718       return true;
2719 
2720     auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
2721     // ensure arg integers are 32-bits
2722     uint64_t ElementBitCount = getASTContext()
2723                                    .getTypeSizeInChars(VTy->getElementType())
2724                                    .getQuantity() *
2725                                8;
2726     if (ElementBitCount != 32) {
2727       SemaRef.Diag(TheCall->getBeginLoc(),
2728                    diag::err_integer_incorrect_bit_count)
2729           << 32 << ElementBitCount;
2730       return true;
2731     }
2732 
2733     // ensure both args are vectors of total bit size of a multiple of 64
2734     int NumElementsArg = VTy->getNumElements();
2735     if (NumElementsArg != 2 && NumElementsArg != 4) {
2736       SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count)
2737           << 1 /*a multiple of*/ << 64 << NumElementsArg * ElementBitCount;
2738       return true;
2739     }
2740 
2741     // ensure first arg and second arg have the same type
2742     if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2743       return true;
2744 
2745     ExprResult A = TheCall->getArg(0);
2746     QualType ArgTyA = A.get()->getType();
2747     // return type is the same as the input type
2748     TheCall->setType(ArgTyA);
2749     break;
2750   }
2751   case Builtin::BI__builtin_hlsl_resource_getpointer: {
2752     if (SemaRef.checkArgCount(TheCall, 2) ||
2753         CheckResourceHandle(&SemaRef, TheCall, 0) ||
2754         CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
2755                             SemaRef.getASTContext().UnsignedIntTy))
2756       return true;
2757 
2758     auto *ResourceTy =
2759         TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
2760     QualType ContainedTy = ResourceTy->getContainedType();
2761     auto ReturnType =
2762         SemaRef.Context.getAddrSpaceQualType(ContainedTy, LangAS::hlsl_device);
2763     ReturnType = SemaRef.Context.getPointerType(ReturnType);
2764     TheCall->setType(ReturnType);
2765     TheCall->setValueKind(VK_LValue);
2766 
2767     break;
2768   }
2769   case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
2770     if (SemaRef.checkArgCount(TheCall, 1) ||
2771         CheckResourceHandle(&SemaRef, TheCall, 0))
2772       return true;
2773     // use the type of the handle (arg0) as a return type
2774     QualType ResourceTy = TheCall->getArg(0)->getType();
2775     TheCall->setType(ResourceTy);
2776     break;
2777   }
2778   case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
2779     ASTContext &AST = SemaRef.getASTContext();
2780     if (SemaRef.checkArgCount(TheCall, 6) ||
2781         CheckResourceHandle(&SemaRef, TheCall, 0) ||
2782         CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) ||
2783         CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.UnsignedIntTy) ||
2784         CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.IntTy) ||
2785         CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) ||
2786         CheckArgTypeMatches(&SemaRef, TheCall->getArg(5),
2787                             AST.getPointerType(AST.CharTy.withConst())))
2788       return true;
2789     // use the type of the handle (arg0) as a return type
2790     QualType ResourceTy = TheCall->getArg(0)->getType();
2791     TheCall->setType(ResourceTy);
2792     break;
2793   }
2794   case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
2795     ASTContext &AST = SemaRef.getASTContext();
2796     if (SemaRef.checkArgCount(TheCall, 6) ||
2797         CheckResourceHandle(&SemaRef, TheCall, 0) ||
2798         CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) ||
2799         CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.IntTy) ||
2800         CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.UnsignedIntTy) ||
2801         CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) ||
2802         CheckArgTypeMatches(&SemaRef, TheCall->getArg(5),
2803                             AST.getPointerType(AST.CharTy.withConst())))
2804       return true;
2805     // use the type of the handle (arg0) as a return type
2806     QualType ResourceTy = TheCall->getArg(0)->getType();
2807     TheCall->setType(ResourceTy);
2808     break;
2809   }
2810   case Builtin::BI__builtin_hlsl_and:
2811   case Builtin::BI__builtin_hlsl_or: {
2812     if (SemaRef.checkArgCount(TheCall, 2))
2813       return true;
2814     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
2815       return true;
2816     if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2817       return true;
2818 
2819     ExprResult A = TheCall->getArg(0);
2820     QualType ArgTyA = A.get()->getType();
2821     // return type is the same as the input type
2822     TheCall->setType(ArgTyA);
2823     break;
2824   }
2825   case Builtin::BI__builtin_hlsl_all:
2826   case Builtin::BI__builtin_hlsl_any: {
2827     if (SemaRef.checkArgCount(TheCall, 1))
2828       return true;
2829     if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
2830       return true;
2831     break;
2832   }
2833   case Builtin::BI__builtin_hlsl_asdouble: {
2834     if (SemaRef.checkArgCount(TheCall, 2))
2835       return true;
2836     if (CheckScalarOrVector(
2837             &SemaRef, TheCall,
2838             /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
2839             /* arg index */ 0))
2840       return true;
2841     if (CheckScalarOrVector(
2842             &SemaRef, TheCall,
2843             /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
2844             /* arg index */ 1))
2845       return true;
2846     if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2847       return true;
2848 
2849     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
2850     break;
2851   }
2852   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
2853     if (SemaRef.BuiltinElementwiseTernaryMath(
2854             TheCall, /*ArgTyRestr=*/
2855             Sema::EltwiseBuiltinArgTyRestriction::None))
2856       return true;
2857     break;
2858   }
2859   case Builtin::BI__builtin_hlsl_dot: {
2860     // arg count is checked by BuiltinVectorToScalarMath
2861     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
2862       return true;
2863     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors))
2864       return true;
2865     break;
2866   }
2867   case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
2868   case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
2869     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2870       return true;
2871 
2872     const Expr *Arg = TheCall->getArg(0);
2873     QualType ArgTy = Arg->getType();
2874     QualType EltTy = ArgTy;
2875 
2876     QualType ResTy = SemaRef.Context.UnsignedIntTy;
2877 
2878     if (auto *VecTy = EltTy->getAs<VectorType>()) {
2879       EltTy = VecTy->getElementType();
2880       ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements());
2881     }
2882 
2883     if (!EltTy->isIntegerType()) {
2884       Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2885           << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
2886           << /* no fp */ 0 << ArgTy;
2887       return true;
2888     }
2889 
2890     TheCall->setType(ResTy);
2891     break;
2892   }
2893   case Builtin::BI__builtin_hlsl_select: {
2894     if (SemaRef.checkArgCount(TheCall, 3))
2895       return true;
2896     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
2897       return true;
2898     QualType ArgTy = TheCall->getArg(0)->getType();
2899     if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
2900       return true;
2901     auto *VTy = ArgTy->getAs<VectorType>();
2902     if (VTy && VTy->getElementType()->isBooleanType() &&
2903         CheckVectorSelect(&SemaRef, TheCall))
2904       return true;
2905     break;
2906   }
2907   case Builtin::BI__builtin_hlsl_elementwise_saturate:
2908   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
2909     if (SemaRef.checkArgCount(TheCall, 1))
2910       return true;
2911     if (!TheCall->getArg(0)
2912              ->getType()
2913              ->hasFloatingRepresentation()) // half or float or double
2914       return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
2915                           diag::err_builtin_invalid_arg_type)
2916              << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
2917              << /* fp */ 1 << TheCall->getArg(0)->getType();
2918     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2919       return true;
2920     break;
2921   }
2922   case Builtin::BI__builtin_hlsl_elementwise_degrees:
2923   case Builtin::BI__builtin_hlsl_elementwise_radians:
2924   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
2925   case Builtin::BI__builtin_hlsl_elementwise_frac: {
2926     if (SemaRef.checkArgCount(TheCall, 1))
2927       return true;
2928     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2929                                    CheckFloatOrHalfRepresentation))
2930       return true;
2931     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2932       return true;
2933     break;
2934   }
2935   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
2936     if (SemaRef.checkArgCount(TheCall, 1))
2937       return true;
2938     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2939                                    CheckFloatOrHalfRepresentation))
2940       return true;
2941     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2942       return true;
2943     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().BoolTy);
2944     break;
2945   }
2946   case Builtin::BI__builtin_hlsl_lerp: {
2947     if (SemaRef.checkArgCount(TheCall, 3))
2948       return true;
2949     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2950                                    CheckFloatOrHalfRepresentation))
2951       return true;
2952     if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2953       return true;
2954     if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
2955       return true;
2956     break;
2957   }
2958   case Builtin::BI__builtin_hlsl_mad: {
2959     if (SemaRef.BuiltinElementwiseTernaryMath(
2960             TheCall, /*ArgTyRestr=*/
2961             Sema::EltwiseBuiltinArgTyRestriction::None))
2962       return true;
2963     break;
2964   }
2965   case Builtin::BI__builtin_hlsl_normalize: {
2966     if (SemaRef.checkArgCount(TheCall, 1))
2967       return true;
2968     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2969                                    CheckFloatOrHalfRepresentation))
2970       return true;
2971     ExprResult A = TheCall->getArg(0);
2972     QualType ArgTyA = A.get()->getType();
2973     // return type is the same as the input type
2974     TheCall->setType(ArgTyA);
2975     break;
2976   }
2977   case Builtin::BI__builtin_hlsl_elementwise_sign: {
2978     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2979       return true;
2980     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2981                                    CheckFloatingOrIntRepresentation))
2982       return true;
2983     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
2984     break;
2985   }
2986   case Builtin::BI__builtin_hlsl_step: {
2987     if (SemaRef.checkArgCount(TheCall, 2))
2988       return true;
2989     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2990                                    CheckFloatOrHalfRepresentation))
2991       return true;
2992 
2993     ExprResult A = TheCall->getArg(0);
2994     QualType ArgTyA = A.get()->getType();
2995     // return type is the same as the input type
2996     TheCall->setType(ArgTyA);
2997     break;
2998   }
2999   case Builtin::BI__builtin_hlsl_wave_active_max:
3000   case Builtin::BI__builtin_hlsl_wave_active_sum: {
3001     if (SemaRef.checkArgCount(TheCall, 1))
3002       return true;
3003 
3004     // Ensure input expr type is a scalar/vector and the same as the return type
3005     if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3006       return true;
3007     if (CheckWaveActive(&SemaRef, TheCall))
3008       return true;
3009     ExprResult Expr = TheCall->getArg(0);
3010     QualType ArgTyExpr = Expr.get()->getType();
3011     TheCall->setType(ArgTyExpr);
3012     break;
3013   }
3014   // Note these are llvm builtins that we want to catch invalid intrinsic
3015   // generation. Normal handling of these builitns will occur elsewhere.
3016   case Builtin::BI__builtin_elementwise_bitreverse: {
3017     // does not include a check for number of arguments
3018     // because that is done previously
3019     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3020                                    CheckUnsignedIntRepresentation))
3021       return true;
3022     break;
3023   }
3024   case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
3025     if (SemaRef.checkArgCount(TheCall, 2))
3026       return true;
3027 
3028     // Ensure index parameter type can be interpreted as a uint
3029     ExprResult Index = TheCall->getArg(1);
3030     QualType ArgTyIndex = Index.get()->getType();
3031     if (!ArgTyIndex->isIntegerType()) {
3032       SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3033                    diag::err_typecheck_convert_incompatible)
3034           << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3035       return true;
3036     }
3037 
3038     // Ensure input expr type is a scalar/vector and the same as the return type
3039     if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3040       return true;
3041 
3042     ExprResult Expr = TheCall->getArg(0);
3043     QualType ArgTyExpr = Expr.get()->getType();
3044     TheCall->setType(ArgTyExpr);
3045     break;
3046   }
3047   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
3048     if (SemaRef.checkArgCount(TheCall, 0))
3049       return true;
3050     break;
3051   }
3052   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
3053     if (SemaRef.checkArgCount(TheCall, 3))
3054       return true;
3055 
3056     if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
3057         CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3058                             1) ||
3059         CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3060                             2))
3061       return true;
3062 
3063     if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
3064         CheckModifiableLValue(&SemaRef, TheCall, 2))
3065       return true;
3066     break;
3067   }
3068   case Builtin::BI__builtin_hlsl_elementwise_clip: {
3069     if (SemaRef.checkArgCount(TheCall, 1))
3070       return true;
3071 
3072     if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
3073       return true;
3074     break;
3075   }
3076   case Builtin::BI__builtin_elementwise_acos:
3077   case Builtin::BI__builtin_elementwise_asin:
3078   case Builtin::BI__builtin_elementwise_atan:
3079   case Builtin::BI__builtin_elementwise_atan2:
3080   case Builtin::BI__builtin_elementwise_ceil:
3081   case Builtin::BI__builtin_elementwise_cos:
3082   case Builtin::BI__builtin_elementwise_cosh:
3083   case Builtin::BI__builtin_elementwise_exp:
3084   case Builtin::BI__builtin_elementwise_exp2:
3085   case Builtin::BI__builtin_elementwise_exp10:
3086   case Builtin::BI__builtin_elementwise_floor:
3087   case Builtin::BI__builtin_elementwise_fmod:
3088   case Builtin::BI__builtin_elementwise_log:
3089   case Builtin::BI__builtin_elementwise_log2:
3090   case Builtin::BI__builtin_elementwise_log10:
3091   case Builtin::BI__builtin_elementwise_pow:
3092   case Builtin::BI__builtin_elementwise_roundeven:
3093   case Builtin::BI__builtin_elementwise_sin:
3094   case Builtin::BI__builtin_elementwise_sinh:
3095   case Builtin::BI__builtin_elementwise_sqrt:
3096   case Builtin::BI__builtin_elementwise_tan:
3097   case Builtin::BI__builtin_elementwise_tanh:
3098   case Builtin::BI__builtin_elementwise_trunc: {
3099     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3100                                    CheckFloatOrHalfRepresentation))
3101       return true;
3102     break;
3103   }
3104   case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3105     auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3106       return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3107                ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3108     };
3109     if (SemaRef.checkArgCount(TheCall, 2) ||
3110         CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy) ||
3111         CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3112                             SemaRef.getASTContext().IntTy))
3113       return true;
3114     Expr *OffsetExpr = TheCall->getArg(1);
3115     std::optional<llvm::APSInt> Offset =
3116         OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
3117     if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
3118       SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3119                    diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3120           << 1;
3121       return true;
3122     }
3123     break;
3124   }
3125   }
3126   return false;
3127 }
3128 
BuildFlattenedTypeList(QualType BaseTy,llvm::SmallVectorImpl<QualType> & List)3129 static void BuildFlattenedTypeList(QualType BaseTy,
3130                                    llvm::SmallVectorImpl<QualType> &List) {
3131   llvm::SmallVector<QualType, 16> WorkList;
3132   WorkList.push_back(BaseTy);
3133   while (!WorkList.empty()) {
3134     QualType T = WorkList.pop_back_val();
3135     T = T.getCanonicalType().getUnqualifiedType();
3136     assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
3137     if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
3138       llvm::SmallVector<QualType, 16> ElementFields;
3139       // Generally I've avoided recursion in this algorithm, but arrays of
3140       // structs could be time-consuming to flatten and churn through on the
3141       // work list. Hopefully nesting arrays of structs containing arrays
3142       // of structs too many levels deep is unlikely.
3143       BuildFlattenedTypeList(AT->getElementType(), ElementFields);
3144       // Repeat the element's field list n times.
3145       for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
3146         llvm::append_range(List, ElementFields);
3147       continue;
3148     }
3149     // Vectors can only have element types that are builtin types, so this can
3150     // add directly to the list instead of to the WorkList.
3151     if (const auto *VT = dyn_cast<VectorType>(T)) {
3152       List.insert(List.end(), VT->getNumElements(), VT->getElementType());
3153       continue;
3154     }
3155     if (const auto *RT = dyn_cast<RecordType>(T)) {
3156       const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
3157       assert(RD && "HLSL record types should all be CXXRecordDecls!");
3158 
3159       if (RD->isStandardLayout())
3160         RD = RD->getStandardLayoutBaseWithFields();
3161 
3162       // For types that we shouldn't decompose (unions and non-aggregates), just
3163       // add the type itself to the list.
3164       if (RD->isUnion() || !RD->isAggregate()) {
3165         List.push_back(T);
3166         continue;
3167       }
3168 
3169       llvm::SmallVector<QualType, 16> FieldTypes;
3170       for (const auto *FD : RD->fields())
3171         FieldTypes.push_back(FD->getType());
3172       // Reverse the newly added sub-range.
3173       std::reverse(FieldTypes.begin(), FieldTypes.end());
3174       llvm::append_range(WorkList, FieldTypes);
3175 
3176       // If this wasn't a standard layout type we may also have some base
3177       // classes to deal with.
3178       if (!RD->isStandardLayout()) {
3179         FieldTypes.clear();
3180         for (const auto &Base : RD->bases())
3181           FieldTypes.push_back(Base.getType());
3182         std::reverse(FieldTypes.begin(), FieldTypes.end());
3183         llvm::append_range(WorkList, FieldTypes);
3184       }
3185       continue;
3186     }
3187     List.push_back(T);
3188   }
3189 }
3190 
IsTypedResourceElementCompatible(clang::QualType QT)3191 bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
3192   // null and array types are not allowed.
3193   if (QT.isNull() || QT->isArrayType())
3194     return false;
3195 
3196   // UDT types are not allowed
3197   if (QT->isRecordType())
3198     return false;
3199 
3200   if (QT->isBooleanType() || QT->isEnumeralType())
3201     return false;
3202 
3203   // the only other valid builtin types are scalars or vectors
3204   if (QT->isArithmeticType()) {
3205     if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
3206       return false;
3207     return true;
3208   }
3209 
3210   if (const VectorType *VT = QT->getAs<VectorType>()) {
3211     int ArraySize = VT->getNumElements();
3212 
3213     if (ArraySize > 4)
3214       return false;
3215 
3216     QualType ElTy = VT->getElementType();
3217     if (ElTy->isBooleanType())
3218       return false;
3219 
3220     if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
3221       return false;
3222     return true;
3223   }
3224 
3225   return false;
3226 }
3227 
IsScalarizedLayoutCompatible(QualType T1,QualType T2) const3228 bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
3229   if (T1.isNull() || T2.isNull())
3230     return false;
3231 
3232   T1 = T1.getCanonicalType().getUnqualifiedType();
3233   T2 = T2.getCanonicalType().getUnqualifiedType();
3234 
3235   // If both types are the same canonical type, they're obviously compatible.
3236   if (SemaRef.getASTContext().hasSameType(T1, T2))
3237     return true;
3238 
3239   llvm::SmallVector<QualType, 16> T1Types;
3240   BuildFlattenedTypeList(T1, T1Types);
3241   llvm::SmallVector<QualType, 16> T2Types;
3242   BuildFlattenedTypeList(T2, T2Types);
3243 
3244   // Check the flattened type list
3245   return llvm::equal(T1Types, T2Types,
3246                      [this](QualType LHS, QualType RHS) -> bool {
3247                        return SemaRef.IsLayoutCompatible(LHS, RHS);
3248                      });
3249 }
3250 
CheckCompatibleParameterABI(FunctionDecl * New,FunctionDecl * Old)3251 bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
3252                                            FunctionDecl *Old) {
3253   if (New->getNumParams() != Old->getNumParams())
3254     return true;
3255 
3256   bool HadError = false;
3257 
3258   for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
3259     ParmVarDecl *NewParam = New->getParamDecl(i);
3260     ParmVarDecl *OldParam = Old->getParamDecl(i);
3261 
3262     // HLSL parameter declarations for inout and out must match between
3263     // declarations. In HLSL inout and out are ambiguous at the call site,
3264     // but have different calling behavior, so you cannot overload a
3265     // method based on a difference between inout and out annotations.
3266     const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
3267     unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
3268     const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
3269     unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
3270 
3271     if (NSpellingIdx != OSpellingIdx) {
3272       SemaRef.Diag(NewParam->getLocation(),
3273                    diag::err_hlsl_param_qualifier_mismatch)
3274           << NDAttr << NewParam;
3275       SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
3276           << ODAttr;
3277       HadError = true;
3278     }
3279   }
3280   return HadError;
3281 }
3282 
3283 // Generally follows PerformScalarCast, with cases reordered for
3284 // clarity of what types are supported
CanPerformScalarCast(QualType SrcTy,QualType DestTy)3285 bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
3286 
3287   if (!SrcTy->isScalarType() || !DestTy->isScalarType())
3288     return false;
3289 
3290   if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
3291     return true;
3292 
3293   switch (SrcTy->getScalarTypeKind()) {
3294   case Type::STK_Bool: // casting from bool is like casting from an integer
3295   case Type::STK_Integral:
3296     switch (DestTy->getScalarTypeKind()) {
3297     case Type::STK_Bool:
3298     case Type::STK_Integral:
3299     case Type::STK_Floating:
3300       return true;
3301     case Type::STK_CPointer:
3302     case Type::STK_ObjCObjectPointer:
3303     case Type::STK_BlockPointer:
3304     case Type::STK_MemberPointer:
3305       llvm_unreachable("HLSL doesn't support pointers.");
3306     case Type::STK_IntegralComplex:
3307     case Type::STK_FloatingComplex:
3308       llvm_unreachable("HLSL doesn't support complex types.");
3309     case Type::STK_FixedPoint:
3310       llvm_unreachable("HLSL doesn't support fixed point types.");
3311     }
3312     llvm_unreachable("Should have returned before this");
3313 
3314   case Type::STK_Floating:
3315     switch (DestTy->getScalarTypeKind()) {
3316     case Type::STK_Floating:
3317     case Type::STK_Bool:
3318     case Type::STK_Integral:
3319       return true;
3320     case Type::STK_FloatingComplex:
3321     case Type::STK_IntegralComplex:
3322       llvm_unreachable("HLSL doesn't support complex types.");
3323     case Type::STK_FixedPoint:
3324       llvm_unreachable("HLSL doesn't support fixed point types.");
3325     case Type::STK_CPointer:
3326     case Type::STK_ObjCObjectPointer:
3327     case Type::STK_BlockPointer:
3328     case Type::STK_MemberPointer:
3329       llvm_unreachable("HLSL doesn't support pointers.");
3330     }
3331     llvm_unreachable("Should have returned before this");
3332 
3333   case Type::STK_MemberPointer:
3334   case Type::STK_CPointer:
3335   case Type::STK_BlockPointer:
3336   case Type::STK_ObjCObjectPointer:
3337     llvm_unreachable("HLSL doesn't support pointers.");
3338 
3339   case Type::STK_FixedPoint:
3340     llvm_unreachable("HLSL doesn't support fixed point types.");
3341 
3342   case Type::STK_FloatingComplex:
3343   case Type::STK_IntegralComplex:
3344     llvm_unreachable("HLSL doesn't support complex types.");
3345   }
3346 
3347   llvm_unreachable("Unhandled scalar cast");
3348 }
3349 
3350 // Detect if a type contains a bitfield. Will be removed when
3351 // bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
ContainsBitField(QualType BaseTy)3352 bool SemaHLSL::ContainsBitField(QualType BaseTy) {
3353   llvm::SmallVector<QualType, 16> WorkList;
3354   WorkList.push_back(BaseTy);
3355   while (!WorkList.empty()) {
3356     QualType T = WorkList.pop_back_val();
3357     T = T.getCanonicalType().getUnqualifiedType();
3358     // only check aggregate types
3359     if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
3360       WorkList.push_back(AT->getElementType());
3361       continue;
3362     }
3363     if (const auto *RT = dyn_cast<RecordType>(T)) {
3364       const RecordDecl *RD = RT->getDecl();
3365       if (RD->isUnion())
3366         continue;
3367 
3368       const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(RD);
3369 
3370       if (CXXD && CXXD->isStandardLayout())
3371         RD = CXXD->getStandardLayoutBaseWithFields();
3372 
3373       for (const auto *FD : RD->fields()) {
3374         if (FD->isBitField())
3375           return true;
3376         WorkList.push_back(FD->getType());
3377       }
3378       continue;
3379     }
3380   }
3381   return false;
3382 }
3383 
3384 // Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
3385 // Src is a scalar or a vector of length 1
3386 // Or if Dest is a vector and Src is a vector of length 1
CanPerformAggregateSplatCast(Expr * Src,QualType DestTy)3387 bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
3388 
3389   QualType SrcTy = Src->getType();
3390   // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
3391   // going to be a vector splat from a scalar.
3392   if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
3393       DestTy->isScalarType())
3394     return false;
3395 
3396   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
3397 
3398   // Src isn't a scalar or a vector of length 1
3399   if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
3400     return false;
3401 
3402   if (SrcVecTy)
3403     SrcTy = SrcVecTy->getElementType();
3404 
3405   if (ContainsBitField(DestTy))
3406     return false;
3407 
3408   llvm::SmallVector<QualType> DestTypes;
3409   BuildFlattenedTypeList(DestTy, DestTypes);
3410 
3411   for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
3412     if (DestTypes[I]->isUnionType())
3413       return false;
3414     if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
3415       return false;
3416   }
3417   return true;
3418 }
3419 
3420 // Can we perform an HLSL Elementwise cast?
3421 // TODO: update this code when matrices are added; see issue #88060
CanPerformElementwiseCast(Expr * Src,QualType DestTy)3422 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
3423 
3424   // Don't handle casts where LHS and RHS are any combination of scalar/vector
3425   // There must be an aggregate somewhere
3426   QualType SrcTy = Src->getType();
3427   if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
3428     return false;
3429 
3430   if (SrcTy->isVectorType() &&
3431       (DestTy->isScalarType() || DestTy->isVectorType()))
3432     return false;
3433 
3434   if (ContainsBitField(DestTy) || ContainsBitField(SrcTy))
3435     return false;
3436 
3437   llvm::SmallVector<QualType> DestTypes;
3438   BuildFlattenedTypeList(DestTy, DestTypes);
3439   llvm::SmallVector<QualType> SrcTypes;
3440   BuildFlattenedTypeList(SrcTy, SrcTypes);
3441 
3442   // Usually the size of SrcTypes must be greater than or equal to the size of
3443   // DestTypes.
3444   if (SrcTypes.size() < DestTypes.size())
3445     return false;
3446 
3447   unsigned SrcSize = SrcTypes.size();
3448   unsigned DstSize = DestTypes.size();
3449   unsigned I;
3450   for (I = 0; I < DstSize && I < SrcSize; I++) {
3451     if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
3452       return false;
3453     if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
3454       return false;
3455     }
3456   }
3457 
3458   // check the rest of the source type for unions.
3459   for (; I < SrcSize; I++) {
3460     if (SrcTypes[I]->isUnionType())
3461       return false;
3462   }
3463   return true;
3464 }
3465 
ActOnOutParamExpr(ParmVarDecl * Param,Expr * Arg)3466 ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
3467   assert(Param->hasAttr<HLSLParamModifierAttr>() &&
3468          "We should not get here without a parameter modifier expression");
3469   const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
3470   if (Attr->getABI() == ParameterABI::Ordinary)
3471     return ExprResult(Arg);
3472 
3473   bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
3474   if (!Arg->isLValue()) {
3475     SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
3476         << Arg << (IsInOut ? 1 : 0);
3477     return ExprError();
3478   }
3479 
3480   ASTContext &Ctx = SemaRef.getASTContext();
3481 
3482   QualType Ty = Param->getType().getNonLValueExprType(Ctx);
3483 
3484   // HLSL allows implicit conversions from scalars to vectors, but not the
3485   // inverse, so we need to disallow `inout` with scalar->vector or
3486   // scalar->matrix conversions.
3487   if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
3488     SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
3489         << Arg << (IsInOut ? 1 : 0);
3490     return ExprError();
3491   }
3492 
3493   auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
3494                                            VK_LValue, OK_Ordinary, Arg);
3495 
3496   // Parameters are initialized via copy initialization. This allows for
3497   // overload resolution of argument constructors.
3498   InitializedEntity Entity =
3499       InitializedEntity::InitializeParameter(Ctx, Ty, false);
3500   ExprResult Res =
3501       SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
3502   if (Res.isInvalid())
3503     return ExprError();
3504   Expr *Base = Res.get();
3505   // After the cast, drop the reference type when creating the exprs.
3506   Ty = Ty.getNonLValueExprType(Ctx);
3507   auto *OpV = new (Ctx)
3508       OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
3509 
3510   // Writebacks are performed with `=` binary operator, which allows for
3511   // overload resolution on writeback result expressions.
3512   Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
3513                            tok::equal, ArgOpV, OpV);
3514 
3515   if (Res.isInvalid())
3516     return ExprError();
3517   Expr *Writeback = Res.get();
3518   auto *OutExpr =
3519       HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
3520 
3521   return ExprResult(OutExpr);
3522 }
3523 
getInoutParameterType(QualType Ty)3524 QualType SemaHLSL::getInoutParameterType(QualType Ty) {
3525   // If HLSL gains support for references, all the cites that use this will need
3526   // to be updated with semantic checking to produce errors for
3527   // pointers/references.
3528   assert(!Ty->isReferenceType() &&
3529          "Pointer and reference types cannot be inout or out parameters");
3530   Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
3531   Ty.addRestrict();
3532   return Ty;
3533 }
3534 
IsDefaultBufferConstantDecl(VarDecl * VD)3535 static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
3536   QualType QT = VD->getType();
3537   return VD->getDeclContext()->isTranslationUnit() &&
3538          QT.getAddressSpace() == LangAS::Default &&
3539          VD->getStorageClass() != SC_Static &&
3540          !VD->hasAttr<HLSLVkConstantIdAttr>() &&
3541          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
3542 }
3543 
deduceAddressSpace(VarDecl * Decl)3544 void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
3545   // The variable already has an address space (groupshared for ex).
3546   if (Decl->getType().hasAddressSpace())
3547     return;
3548 
3549   if (Decl->getType()->isDependentType())
3550     return;
3551 
3552   QualType Type = Decl->getType();
3553 
3554   if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
3555     LangAS ImplAS = LangAS::hlsl_input;
3556     Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
3557     Decl->setType(Type);
3558     return;
3559   }
3560 
3561   if (Type->isSamplerT() || Type->isVoidType())
3562     return;
3563 
3564   // Resource handles.
3565   if (isResourceRecordTypeOrArrayOf(Type->getUnqualifiedDesugaredType()))
3566     return;
3567 
3568   // Only static globals belong to the Private address space.
3569   // Non-static globals belongs to the cbuffer.
3570   if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
3571     return;
3572 
3573   LangAS ImplAS = LangAS::hlsl_private;
3574   Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
3575   Decl->setType(Type);
3576 }
3577 
ActOnVariableDeclarator(VarDecl * VD)3578 void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
3579   if (VD->hasGlobalStorage()) {
3580     // make sure the declaration has a complete type
3581     if (SemaRef.RequireCompleteType(
3582             VD->getLocation(),
3583             SemaRef.getASTContext().getBaseElementType(VD->getType()),
3584             diag::err_typecheck_decl_incomplete_type)) {
3585       VD->setInvalidDecl();
3586       deduceAddressSpace(VD);
3587       return;
3588     }
3589 
3590     // Global variables outside a cbuffer block that are not a resource, static,
3591     // groupshared, or an empty array or struct belong to the default constant
3592     // buffer $Globals (to be created at the end of the translation unit).
3593     if (IsDefaultBufferConstantDecl(VD)) {
3594       // update address space to hlsl_constant
3595       QualType NewTy = getASTContext().getAddrSpaceQualType(
3596           VD->getType(), LangAS::hlsl_constant);
3597       VD->setType(NewTy);
3598       DefaultCBufferDecls.push_back(VD);
3599     }
3600 
3601     // find all resources bindings on decl
3602     if (VD->getType()->isHLSLIntangibleType())
3603       collectResourceBindingsOnVarDecl(VD);
3604 
3605     const Type *VarType = VD->getType().getTypePtr();
3606     while (VarType->isArrayType())
3607       VarType = VarType->getArrayElementTypeNoTypeQual();
3608     if (VarType->isHLSLResourceRecord() ||
3609         VD->hasAttr<HLSLVkConstantIdAttr>()) {
3610       // Make the variable for resources static. The global externally visible
3611       // storage is accessed through the handle, which is a member. The variable
3612       // itself is not externally visible.
3613       VD->setStorageClass(StorageClass::SC_Static);
3614     }
3615 
3616     // process explicit bindings
3617     processExplicitBindingsOnDecl(VD);
3618   }
3619 
3620   deduceAddressSpace(VD);
3621 }
3622 
initVarDeclWithCtor(Sema & S,VarDecl * VD,MutableArrayRef<Expr * > Args)3623 static bool initVarDeclWithCtor(Sema &S, VarDecl *VD,
3624                                 MutableArrayRef<Expr *> Args) {
3625   InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
3626   InitializationKind Kind = InitializationKind::CreateDirect(
3627       VD->getLocation(), SourceLocation(), SourceLocation());
3628 
3629   InitializationSequence InitSeq(S, Entity, Kind, Args);
3630   if (InitSeq.Failed())
3631     return false;
3632 
3633   ExprResult Init = InitSeq.Perform(S, Entity, Kind, Args);
3634   if (!Init.get())
3635     return false;
3636 
3637   VD->setInit(S.MaybeCreateExprWithCleanups(Init.get()));
3638   VD->setInitStyle(VarDecl::CallInit);
3639   S.CheckCompleteVariableDeclaration(VD);
3640   return true;
3641 }
3642 
initGlobalResourceDecl(VarDecl * VD)3643 bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
3644   std::optional<uint32_t> RegisterSlot;
3645   uint32_t SpaceNo = 0;
3646   HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
3647   if (RBA) {
3648     if (RBA->hasRegisterSlot())
3649       RegisterSlot = RBA->getSlotNumber();
3650     SpaceNo = RBA->getSpaceNumber();
3651   }
3652 
3653   ASTContext &AST = SemaRef.getASTContext();
3654   uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy);
3655   uint64_t IntTySize = AST.getTypeSize(AST.IntTy);
3656   IntegerLiteral *RangeSize = IntegerLiteral::Create(
3657       AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation());
3658   IntegerLiteral *Index = IntegerLiteral::Create(
3659       AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation());
3660   IntegerLiteral *Space =
3661       IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, SpaceNo),
3662                              AST.UnsignedIntTy, SourceLocation());
3663   StringRef VarName = VD->getName();
3664   StringLiteral *Name = StringLiteral::Create(
3665       AST, VarName, StringLiteralKind::Ordinary, false,
3666       AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()),
3667       SourceLocation());
3668 
3669   // resource with explicit binding
3670   if (RegisterSlot.has_value()) {
3671     IntegerLiteral *RegSlot = IntegerLiteral::Create(
3672         AST, llvm::APInt(UIntTySize, RegisterSlot.value()), AST.UnsignedIntTy,
3673         SourceLocation());
3674     Expr *Args[] = {RegSlot, Space, RangeSize, Index, Name};
3675     return initVarDeclWithCtor(SemaRef, VD, Args);
3676   }
3677 
3678   // resource with implicit binding
3679   IntegerLiteral *OrderId = IntegerLiteral::Create(
3680       AST, llvm::APInt(UIntTySize, getNextImplicitBindingOrderID()),
3681       AST.UnsignedIntTy, SourceLocation());
3682   Expr *Args[] = {Space, RangeSize, Index, OrderId, Name};
3683   return initVarDeclWithCtor(SemaRef, VD, Args);
3684 }
3685 
3686 // Returns true if the initialization has been handled.
3687 // Returns false to use default initialization.
ActOnUninitializedVarDecl(VarDecl * VD)3688 bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
3689   // Objects in the hlsl_constant address space are initialized
3690   // externally, so don't synthesize an implicit initializer.
3691   if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
3692     return true;
3693 
3694   // Initialize resources
3695   if (!isResourceRecordTypeOrArrayOf(VD))
3696     return false;
3697 
3698   // FIXME: We currectly support only simple resources - no arrays of resources
3699   // or resources in user defined structs.
3700   // (llvm/llvm-project#133835, llvm/llvm-project#133837)
3701   // Initialize resources at the global scope
3702   if (VD->hasGlobalStorage() && VD->getType()->isHLSLResourceRecord())
3703     return initGlobalResourceDecl(VD);
3704 
3705   return false;
3706 }
3707 
3708 // Walks though the global variable declaration, collects all resource binding
3709 // requirements and adds them to Bindings
collectResourceBindingsOnVarDecl(VarDecl * VD)3710 void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
3711   assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
3712          "expected global variable that contains HLSL resource");
3713 
3714   // Cbuffers and Tbuffers are HLSLBufferDecl types
3715   if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
3716     Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
3717                                         ? ResourceClass::CBuffer
3718                                         : ResourceClass::SRV);
3719     return;
3720   }
3721 
3722   // Unwrap arrays
3723   // FIXME: Calculate array size while unwrapping
3724   const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
3725   while (Ty->isConstantArrayType()) {
3726     const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
3727     Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
3728   }
3729 
3730   // Resource (or array of resources)
3731   if (const HLSLAttributedResourceType *AttrResType =
3732           HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
3733     Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
3734     return;
3735   }
3736 
3737   // User defined record type
3738   if (const RecordType *RT = dyn_cast<RecordType>(Ty))
3739     collectResourceBindingsOnUserRecordDecl(VD, RT);
3740 }
3741 
3742 // Walks though the explicit resource binding attributes on the declaration,
3743 // and makes sure there is a resource that matched the binding and updates
3744 // DeclBindingInfoLists
processExplicitBindingsOnDecl(VarDecl * VD)3745 void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
3746   assert(VD->hasGlobalStorage() && "expected global variable");
3747 
3748   bool HasBinding = false;
3749   for (Attr *A : VD->attrs()) {
3750     HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
3751     if (!RBA || !RBA->hasRegisterSlot())
3752       continue;
3753     HasBinding = true;
3754 
3755     RegisterType RT = RBA->getRegisterType();
3756     assert(RT != RegisterType::I && "invalid or obsolete register type should "
3757                                     "never have an attribute created");
3758 
3759     if (RT == RegisterType::C) {
3760       if (Bindings.hasBindingInfoForDecl(VD))
3761         SemaRef.Diag(VD->getLocation(),
3762                      diag::warn_hlsl_user_defined_type_missing_member)
3763             << static_cast<int>(RT);
3764       continue;
3765     }
3766 
3767     // Find DeclBindingInfo for this binding and update it, or report error
3768     // if it does not exist (user type does to contain resources with the
3769     // expected resource class).
3770     ResourceClass RC = getResourceClass(RT);
3771     if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
3772       // update binding info
3773       BI->setBindingAttribute(RBA, BindingType::Explicit);
3774     } else {
3775       SemaRef.Diag(VD->getLocation(),
3776                    diag::warn_hlsl_user_defined_type_missing_member)
3777           << static_cast<int>(RT);
3778     }
3779   }
3780 
3781   if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
3782     SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding);
3783 }
3784 namespace {
3785 class InitListTransformer {
3786   Sema &S;
3787   ASTContext &Ctx;
3788   QualType InitTy;
3789   QualType *DstIt = nullptr;
3790   Expr **ArgIt = nullptr;
3791   // Is wrapping the destination type iterator required? This is only used for
3792   // incomplete array types where we loop over the destination type since we
3793   // don't know the full number of elements from the declaration.
3794   bool Wrap;
3795 
castInitializer(Expr * E)3796   bool castInitializer(Expr *E) {
3797     assert(DstIt && "This should always be something!");
3798     if (DstIt == DestTypes.end()) {
3799       if (!Wrap) {
3800         ArgExprs.push_back(E);
3801         // This is odd, but it isn't technically a failure due to conversion, we
3802         // handle mismatched counts of arguments differently.
3803         return true;
3804       }
3805       DstIt = DestTypes.begin();
3806     }
3807     InitializedEntity Entity = InitializedEntity::InitializeParameter(
3808         Ctx, *DstIt, /* Consumed (ObjC) */ false);
3809     ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
3810     if (Res.isInvalid())
3811       return false;
3812     Expr *Init = Res.get();
3813     ArgExprs.push_back(Init);
3814     DstIt++;
3815     return true;
3816   }
3817 
buildInitializerListImpl(Expr * E)3818   bool buildInitializerListImpl(Expr *E) {
3819     // If this is an initialization list, traverse the sub initializers.
3820     if (auto *Init = dyn_cast<InitListExpr>(E)) {
3821       for (auto *SubInit : Init->inits())
3822         if (!buildInitializerListImpl(SubInit))
3823           return false;
3824       return true;
3825     }
3826 
3827     // If this is a scalar type, just enqueue the expression.
3828     QualType Ty = E->getType();
3829 
3830     if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3831       return castInitializer(E);
3832 
3833     if (auto *VecTy = Ty->getAs<VectorType>()) {
3834       uint64_t Size = VecTy->getNumElements();
3835 
3836       QualType SizeTy = Ctx.getSizeType();
3837       uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
3838       for (uint64_t I = 0; I < Size; ++I) {
3839         auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
3840                                            SizeTy, SourceLocation());
3841 
3842         ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3843             E, E->getBeginLoc(), Idx, E->getEndLoc());
3844         if (ElExpr.isInvalid())
3845           return false;
3846         if (!castInitializer(ElExpr.get()))
3847           return false;
3848       }
3849       return true;
3850     }
3851 
3852     if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
3853       uint64_t Size = ArrTy->getZExtSize();
3854       QualType SizeTy = Ctx.getSizeType();
3855       uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
3856       for (uint64_t I = 0; I < Size; ++I) {
3857         auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
3858                                            SizeTy, SourceLocation());
3859         ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3860             E, E->getBeginLoc(), Idx, E->getEndLoc());
3861         if (ElExpr.isInvalid())
3862           return false;
3863         if (!buildInitializerListImpl(ElExpr.get()))
3864           return false;
3865       }
3866       return true;
3867     }
3868 
3869     if (auto *RTy = Ty->getAs<RecordType>()) {
3870       llvm::SmallVector<const RecordType *> RecordTypes;
3871       RecordTypes.push_back(RTy);
3872       while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3873         CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3874         assert(D->getNumBases() == 1 &&
3875                "HLSL doesn't support multiple inheritance");
3876         RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
3877       }
3878       while (!RecordTypes.empty()) {
3879         const RecordType *RT = RecordTypes.pop_back_val();
3880         for (auto *FD : RT->getDecl()->fields()) {
3881           DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
3882           DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
3883           ExprResult Res = S.BuildFieldReferenceExpr(
3884               E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
3885           if (Res.isInvalid())
3886             return false;
3887           if (!buildInitializerListImpl(Res.get()))
3888             return false;
3889         }
3890       }
3891     }
3892     return true;
3893   }
3894 
generateInitListsImpl(QualType Ty)3895   Expr *generateInitListsImpl(QualType Ty) {
3896     assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
3897     if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3898       return *(ArgIt++);
3899 
3900     llvm::SmallVector<Expr *> Inits;
3901     assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
3902     Ty = Ty.getDesugaredType(Ctx);
3903     if (Ty->isVectorType() || Ty->isConstantArrayType()) {
3904       QualType ElTy;
3905       uint64_t Size = 0;
3906       if (auto *ATy = Ty->getAs<VectorType>()) {
3907         ElTy = ATy->getElementType();
3908         Size = ATy->getNumElements();
3909       } else {
3910         auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
3911         ElTy = VTy->getElementType();
3912         Size = VTy->getZExtSize();
3913       }
3914       for (uint64_t I = 0; I < Size; ++I)
3915         Inits.push_back(generateInitListsImpl(ElTy));
3916     }
3917     if (auto *RTy = Ty->getAs<RecordType>()) {
3918       llvm::SmallVector<const RecordType *> RecordTypes;
3919       RecordTypes.push_back(RTy);
3920       while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3921         CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3922         assert(D->getNumBases() == 1 &&
3923                "HLSL doesn't support multiple inheritance");
3924         RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
3925       }
3926       while (!RecordTypes.empty()) {
3927         const RecordType *RT = RecordTypes.pop_back_val();
3928         for (auto *FD : RT->getDecl()->fields()) {
3929           Inits.push_back(generateInitListsImpl(FD->getType()));
3930         }
3931       }
3932     }
3933     auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3934                                            Inits, Inits.back()->getEndLoc());
3935     NewInit->setType(Ty);
3936     return NewInit;
3937   }
3938 
3939 public:
3940   llvm::SmallVector<QualType, 16> DestTypes;
3941   llvm::SmallVector<Expr *, 16> ArgExprs;
InitListTransformer(Sema & SemaRef,const InitializedEntity & Entity)3942   InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
3943       : S(SemaRef), Ctx(SemaRef.getASTContext()),
3944         Wrap(Entity.getType()->isIncompleteArrayType()) {
3945     InitTy = Entity.getType().getNonReferenceType();
3946     // When we're generating initializer lists for incomplete array types we
3947     // need to wrap around both when building the initializers and when
3948     // generating the final initializer lists.
3949     if (Wrap) {
3950       assert(InitTy->isIncompleteArrayType());
3951       const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy);
3952       InitTy = IAT->getElementType();
3953     }
3954     BuildFlattenedTypeList(InitTy, DestTypes);
3955     DstIt = DestTypes.begin();
3956   }
3957 
buildInitializerList(Expr * E)3958   bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
3959 
generateInitLists()3960   Expr *generateInitLists() {
3961     assert(!ArgExprs.empty() &&
3962            "Call buildInitializerList to generate argument expressions.");
3963     ArgIt = ArgExprs.begin();
3964     if (!Wrap)
3965       return generateInitListsImpl(InitTy);
3966     llvm::SmallVector<Expr *> Inits;
3967     while (ArgIt != ArgExprs.end())
3968       Inits.push_back(generateInitListsImpl(InitTy));
3969 
3970     auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3971                                            Inits, Inits.back()->getEndLoc());
3972     llvm::APInt ArySize(64, Inits.size());
3973     NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr,
3974                                               ArraySizeModifier::Normal, 0));
3975     return NewInit;
3976   }
3977 };
3978 } // namespace
3979 
transformInitList(const InitializedEntity & Entity,InitListExpr * Init)3980 bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
3981                                  InitListExpr *Init) {
3982   // If the initializer is a scalar, just return it.
3983   if (Init->getType()->isScalarType())
3984     return true;
3985   ASTContext &Ctx = SemaRef.getASTContext();
3986   InitListTransformer ILT(SemaRef, Entity);
3987 
3988   for (unsigned I = 0; I < Init->getNumInits(); ++I) {
3989     Expr *E = Init->getInit(I);
3990     if (E->HasSideEffects(Ctx)) {
3991       QualType Ty = E->getType();
3992       if (Ty->isRecordType())
3993         E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
3994       E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
3995                                     E->getObjectKind(), E);
3996       Init->setInit(I, E);
3997     }
3998     if (!ILT.buildInitializerList(E))
3999       return false;
4000   }
4001   size_t ExpectedSize = ILT.DestTypes.size();
4002   size_t ActualSize = ILT.ArgExprs.size();
4003   // For incomplete arrays it is completely arbitrary to choose whether we think
4004   // the user intended fewer or more elements. This implementation assumes that
4005   // the user intended more, and errors that there are too few initializers to
4006   // complete the final element.
4007   if (Entity.getType()->isIncompleteArrayType())
4008     ExpectedSize =
4009         ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
4010 
4011   // An initializer list might be attempting to initialize a reference or
4012   // rvalue-reference. When checking the initializer we should look through
4013   // the reference.
4014   QualType InitTy = Entity.getType().getNonReferenceType();
4015   if (InitTy.hasAddressSpace())
4016     InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
4017   if (ExpectedSize != ActualSize) {
4018     int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
4019     SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
4020         << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
4021     return false;
4022   }
4023 
4024   // generateInitListsImpl will always return an InitListExpr here, because the
4025   // scalar case is handled above.
4026   auto *NewInit = cast<InitListExpr>(ILT.generateInitLists());
4027   Init->resizeInits(Ctx, NewInit->getNumInits());
4028   for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
4029     Init->updateInit(Ctx, I, NewInit->getInit(I));
4030   return true;
4031 }
4032 
handleInitialization(VarDecl * VDecl,Expr * & Init)4033 bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
4034   const HLSLVkConstantIdAttr *ConstIdAttr =
4035       VDecl->getAttr<HLSLVkConstantIdAttr>();
4036   if (!ConstIdAttr)
4037     return true;
4038 
4039   ASTContext &Context = SemaRef.getASTContext();
4040 
4041   APValue InitValue;
4042   if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
4043     Diag(VDecl->getLocation(), diag::err_specialization_const);
4044     VDecl->setInvalidDecl();
4045     return false;
4046   }
4047 
4048   Builtin::ID BID =
4049       getSpecConstBuiltinId(VDecl->getType()->getUnqualifiedDesugaredType());
4050 
4051   // Argument 1: The ID from the attribute
4052   int ConstantID = ConstIdAttr->getId();
4053   llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
4054   Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
4055                                         ConstIdAttr->getLocation());
4056 
4057   SmallVector<Expr *, 2> Args = {IdExpr, Init};
4058   Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
4059   if (C->getType()->getCanonicalTypeUnqualified() !=
4060       VDecl->getType()->getCanonicalTypeUnqualified()) {
4061     C = SemaRef
4062             .BuildCStyleCastExpr(SourceLocation(),
4063                                  Context.getTrivialTypeSourceInfo(
4064                                      Init->getType(), Init->getExprLoc()),
4065                                  SourceLocation(), C)
4066             .get();
4067   }
4068   Init = C;
4069   return true;
4070 }
4071