1 //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This implements Semantic Analysis for HLSL constructs.
9 //===----------------------------------------------------------------------===//
10
11 #include "clang/Sema/SemaHLSL.h"
12 #include "clang/AST/ASTConsumer.h"
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Attr.h"
15 #include "clang/AST/Attrs.inc"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclBase.h"
18 #include "clang/AST/DeclCXX.h"
19 #include "clang/AST/DeclarationName.h"
20 #include "clang/AST/DynamicRecursiveASTVisitor.h"
21 #include "clang/AST/Expr.h"
22 #include "clang/AST/Type.h"
23 #include "clang/AST/TypeLoc.h"
24 #include "clang/Basic/Builtins.h"
25 #include "clang/Basic/DiagnosticSema.h"
26 #include "clang/Basic/IdentifierTable.h"
27 #include "clang/Basic/LLVM.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "clang/Basic/Specifiers.h"
30 #include "clang/Basic/TargetInfo.h"
31 #include "clang/Sema/Initialization.h"
32 #include "clang/Sema/Lookup.h"
33 #include "clang/Sema/ParsedAttr.h"
34 #include "clang/Sema/Sema.h"
35 #include "clang/Sema/Template.h"
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/ADT/Twine.h"
42 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/DXILABI.h"
45 #include "llvm/Support/ErrorHandling.h"
46 #include "llvm/Support/FormatVariadic.h"
47 #include "llvm/TargetParser/Triple.h"
48 #include <cmath>
49 #include <cstddef>
50 #include <iterator>
51 #include <utility>
52
53 using namespace clang;
54 using RegisterType = HLSLResourceBindingAttr::RegisterType;
55
56 static CXXRecordDecl *createHostLayoutStruct(Sema &S,
57 CXXRecordDecl *StructDecl);
58
getRegisterType(ResourceClass RC)59 static RegisterType getRegisterType(ResourceClass RC) {
60 switch (RC) {
61 case ResourceClass::SRV:
62 return RegisterType::SRV;
63 case ResourceClass::UAV:
64 return RegisterType::UAV;
65 case ResourceClass::CBuffer:
66 return RegisterType::CBuffer;
67 case ResourceClass::Sampler:
68 return RegisterType::Sampler;
69 }
70 llvm_unreachable("unexpected ResourceClass value");
71 }
72
73 // Converts the first letter of string Slot to RegisterType.
74 // Returns false if the letter does not correspond to a valid register type.
convertToRegisterType(StringRef Slot,RegisterType * RT)75 static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
76 assert(RT != nullptr);
77 switch (Slot[0]) {
78 case 't':
79 case 'T':
80 *RT = RegisterType::SRV;
81 return true;
82 case 'u':
83 case 'U':
84 *RT = RegisterType::UAV;
85 return true;
86 case 'b':
87 case 'B':
88 *RT = RegisterType::CBuffer;
89 return true;
90 case 's':
91 case 'S':
92 *RT = RegisterType::Sampler;
93 return true;
94 case 'c':
95 case 'C':
96 *RT = RegisterType::C;
97 return true;
98 case 'i':
99 case 'I':
100 *RT = RegisterType::I;
101 return true;
102 default:
103 return false;
104 }
105 }
106
getResourceClass(RegisterType RT)107 static ResourceClass getResourceClass(RegisterType RT) {
108 switch (RT) {
109 case RegisterType::SRV:
110 return ResourceClass::SRV;
111 case RegisterType::UAV:
112 return ResourceClass::UAV;
113 case RegisterType::CBuffer:
114 return ResourceClass::CBuffer;
115 case RegisterType::Sampler:
116 return ResourceClass::Sampler;
117 case RegisterType::C:
118 case RegisterType::I:
119 // Deliberately falling through to the unreachable below.
120 break;
121 }
122 llvm_unreachable("unexpected RegisterType value");
123 }
124
getSpecConstBuiltinId(const Type * Type)125 static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
126 const auto *BT = dyn_cast<BuiltinType>(Type);
127 if (!BT) {
128 if (!Type->isEnumeralType())
129 return Builtin::NotBuiltin;
130 return Builtin::BI__builtin_get_spirv_spec_constant_int;
131 }
132
133 switch (BT->getKind()) {
134 case BuiltinType::Bool:
135 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
136 case BuiltinType::Short:
137 return Builtin::BI__builtin_get_spirv_spec_constant_short;
138 case BuiltinType::Int:
139 return Builtin::BI__builtin_get_spirv_spec_constant_int;
140 case BuiltinType::LongLong:
141 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
142 case BuiltinType::UShort:
143 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
144 case BuiltinType::UInt:
145 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
146 case BuiltinType::ULongLong:
147 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
148 case BuiltinType::Half:
149 return Builtin::BI__builtin_get_spirv_spec_constant_half;
150 case BuiltinType::Float:
151 return Builtin::BI__builtin_get_spirv_spec_constant_float;
152 case BuiltinType::Double:
153 return Builtin::BI__builtin_get_spirv_spec_constant_double;
154 default:
155 return Builtin::NotBuiltin;
156 }
157 }
158
addDeclBindingInfo(const VarDecl * VD,ResourceClass ResClass)159 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
160 ResourceClass ResClass) {
161 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
162 "DeclBindingInfo already added");
163 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
164 // VarDecl may have multiple entries for different resource classes.
165 // DeclToBindingListIndex stores the index of the first binding we saw
166 // for this decl. If there are any additional ones then that index
167 // shouldn't be updated.
168 DeclToBindingListIndex.try_emplace(VD, BindingsList.size());
169 return &BindingsList.emplace_back(VD, ResClass);
170 }
171
getDeclBindingInfo(const VarDecl * VD,ResourceClass ResClass)172 DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
173 ResourceClass ResClass) {
174 auto Entry = DeclToBindingListIndex.find(VD);
175 if (Entry != DeclToBindingListIndex.end()) {
176 for (unsigned Index = Entry->getSecond();
177 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
178 ++Index) {
179 if (BindingsList[Index].ResClass == ResClass)
180 return &BindingsList[Index];
181 }
182 }
183 return nullptr;
184 }
185
hasBindingInfoForDecl(const VarDecl * VD) const186 bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
187 return DeclToBindingListIndex.contains(VD);
188 }
189
SemaHLSL(Sema & S)190 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
191
ActOnStartBuffer(Scope * BufferScope,bool CBuffer,SourceLocation KwLoc,IdentifierInfo * Ident,SourceLocation IdentLoc,SourceLocation LBrace)192 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
193 SourceLocation KwLoc, IdentifierInfo *Ident,
194 SourceLocation IdentLoc,
195 SourceLocation LBrace) {
196 // For anonymous namespace, take the location of the left brace.
197 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
198 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
199 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
200
201 // if CBuffer is false, then it's a TBuffer
202 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
203 : llvm::hlsl::ResourceClass::SRV;
204 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC));
205
206 SemaRef.PushOnScopeChains(Result, BufferScope);
207 SemaRef.PushDeclContext(BufferScope, Result);
208
209 return Result;
210 }
211
calculateLegacyCbufferFieldAlign(const ASTContext & Context,QualType T)212 static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
213 QualType T) {
214 // Arrays and Structs are always aligned to new buffer rows
215 if (T->isArrayType() || T->isStructureType())
216 return 16;
217
218 // Vectors are aligned to the type they contain
219 if (const VectorType *VT = T->getAs<VectorType>())
220 return calculateLegacyCbufferFieldAlign(Context, VT->getElementType());
221
222 assert(Context.getTypeSize(T) <= 64 &&
223 "Scalar bit widths larger than 64 not supported");
224
225 // Scalar types are aligned to their byte width
226 return Context.getTypeSize(T) / 8;
227 }
228
229 // Calculate the size of a legacy cbuffer type in bytes based on
230 // https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
calculateLegacyCbufferSize(const ASTContext & Context,QualType T)231 static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
232 QualType T) {
233 constexpr unsigned CBufferAlign = 16;
234 if (const RecordType *RT = T->getAs<RecordType>()) {
235 unsigned Size = 0;
236 const RecordDecl *RD = RT->getDecl();
237 for (const FieldDecl *Field : RD->fields()) {
238 QualType Ty = Field->getType();
239 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
240 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty);
241
242 // If the field crosses the row boundary after alignment it drops to the
243 // next row
244 unsigned AlignSize = llvm::alignTo(Size, FieldAlign);
245 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
246 FieldAlign = CBufferAlign;
247 }
248
249 Size = llvm::alignTo(Size, FieldAlign);
250 Size += FieldSize;
251 }
252 return Size;
253 }
254
255 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
256 unsigned ElementCount = AT->getSize().getZExtValue();
257 if (ElementCount == 0)
258 return 0;
259
260 unsigned ElementSize =
261 calculateLegacyCbufferSize(Context, AT->getElementType());
262 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
263 return AlignedElementSize * (ElementCount - 1) + ElementSize;
264 }
265
266 if (const VectorType *VT = T->getAs<VectorType>()) {
267 unsigned ElementCount = VT->getNumElements();
268 unsigned ElementSize =
269 calculateLegacyCbufferSize(Context, VT->getElementType());
270 return ElementSize * ElementCount;
271 }
272
273 return Context.getTypeSize(T) / 8;
274 }
275
276 // Validate packoffset:
277 // - if packoffset it used it must be set on all declarations inside the buffer
278 // - packoffset ranges must not overlap
validatePackoffset(Sema & S,HLSLBufferDecl * BufDecl)279 static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
280 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
281
282 // Make sure the packoffset annotations are either on all declarations
283 // or on none.
284 bool HasPackOffset = false;
285 bool HasNonPackOffset = false;
286 for (auto *Field : BufDecl->buffer_decls()) {
287 VarDecl *Var = dyn_cast<VarDecl>(Field);
288 if (!Var)
289 continue;
290 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
291 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
292 HasPackOffset = true;
293 } else {
294 HasNonPackOffset = true;
295 }
296 }
297
298 if (!HasPackOffset)
299 return;
300
301 if (HasNonPackOffset)
302 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
303
304 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
305 // and compare adjacent values.
306 bool IsValid = true;
307 ASTContext &Context = S.getASTContext();
308 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
309 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
310 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
311 return LHS.second->getOffsetInBytes() <
312 RHS.second->getOffsetInBytes();
313 });
314 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
315 VarDecl *Var = PackOffsetVec[i].first;
316 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
317 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
318 unsigned Begin = Attr->getOffsetInBytes();
319 unsigned End = Begin + Size;
320 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
321 if (End > NextBegin) {
322 VarDecl *NextVar = PackOffsetVec[i + 1].first;
323 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
324 << NextVar << Var;
325 IsValid = false;
326 }
327 }
328 BufDecl->setHasValidPackoffset(IsValid);
329 }
330
331 // Returns true if the array has a zero size = if any of the dimensions is 0
isZeroSizedArray(const ConstantArrayType * CAT)332 static bool isZeroSizedArray(const ConstantArrayType *CAT) {
333 while (CAT && !CAT->isZeroSize())
334 CAT = dyn_cast<ConstantArrayType>(
335 CAT->getElementType()->getUnqualifiedDesugaredType());
336 return CAT != nullptr;
337 }
338
339 // Returns true if the record type is an HLSL resource class or an array of
340 // resource classes
isResourceRecordTypeOrArrayOf(const Type * Ty)341 static bool isResourceRecordTypeOrArrayOf(const Type *Ty) {
342 while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Ty))
343 Ty = CAT->getArrayElementTypeNoTypeQual();
344 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty) != nullptr;
345 }
346
isResourceRecordTypeOrArrayOf(VarDecl * VD)347 static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
348 return isResourceRecordTypeOrArrayOf(VD->getType().getTypePtr());
349 }
350
351 // Returns true if the type is a leaf element type that is not valid to be
352 // included in HLSL Buffer, such as a resource class, empty struct, zero-sized
353 // array, or a builtin intangible type. Returns false it is a valid leaf element
354 // type or if it is a record type that needs to be inspected further.
isInvalidConstantBufferLeafElementType(const Type * Ty)355 static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
356 Ty = Ty->getUnqualifiedDesugaredType();
357 if (isResourceRecordTypeOrArrayOf(Ty))
358 return true;
359 if (Ty->isRecordType())
360 return Ty->getAsCXXRecordDecl()->isEmpty();
361 if (Ty->isConstantArrayType() &&
362 isZeroSizedArray(cast<ConstantArrayType>(Ty)))
363 return true;
364 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
365 return true;
366 return false;
367 }
368
369 // Returns true if the struct contains at least one element that prevents it
370 // from being included inside HLSL Buffer as is, such as an intangible type,
371 // empty struct, or zero-sized array. If it does, a new implicit layout struct
372 // needs to be created for HLSL Buffer use that will exclude these unwanted
373 // declarations (see createHostLayoutStruct function).
requiresImplicitBufferLayoutStructure(const CXXRecordDecl * RD)374 static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
375 if (RD->getTypeForDecl()->isHLSLIntangibleType() || RD->isEmpty())
376 return true;
377 // check fields
378 for (const FieldDecl *Field : RD->fields()) {
379 QualType Ty = Field->getType();
380 if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr()))
381 return true;
382 if (Ty->isRecordType() &&
383 requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl()))
384 return true;
385 }
386 // check bases
387 for (const CXXBaseSpecifier &Base : RD->bases())
388 if (requiresImplicitBufferLayoutStructure(
389 Base.getType()->getAsCXXRecordDecl()))
390 return true;
391 return false;
392 }
393
findRecordDeclInContext(IdentifierInfo * II,DeclContext * DC)394 static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
395 DeclContext *DC) {
396 CXXRecordDecl *RD = nullptr;
397 for (NamedDecl *Decl :
398 DC->getNonTransparentContext()->lookup(DeclarationName(II))) {
399 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) {
400 assert(RD == nullptr &&
401 "there should be at most 1 record by a given name in a scope");
402 RD = FoundRD;
403 }
404 }
405 return RD;
406 }
407
408 // Creates a name for buffer layout struct using the provide name base.
409 // If the name must be unique (not previously defined), a suffix is added
410 // until a unique name is found.
getHostLayoutStructName(Sema & S,NamedDecl * BaseDecl,bool MustBeUnique)411 static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
412 bool MustBeUnique) {
413 ASTContext &AST = S.getASTContext();
414
415 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
416 llvm::SmallString<64> Name("__cblayout_");
417 if (NameBaseII) {
418 Name.append(NameBaseII->getName());
419 } else {
420 // anonymous struct
421 Name.append("anon");
422 MustBeUnique = true;
423 }
424
425 size_t NameLength = Name.size();
426 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier);
427 if (!MustBeUnique)
428 return II;
429
430 unsigned suffix = 0;
431 while (true) {
432 if (suffix != 0) {
433 Name.append("_");
434 Name.append(llvm::Twine(suffix).str());
435 II = &AST.Idents.get(Name, tok::TokenKind::identifier);
436 }
437 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext()))
438 return II;
439 // declaration with that name already exists - increment suffix and try
440 // again until unique name is found
441 suffix++;
442 Name.truncate(NameLength);
443 };
444 }
445
446 // Creates a field declaration of given name and type for HLSL buffer layout
447 // struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
createFieldForHostLayoutStruct(Sema & S,const Type * Ty,IdentifierInfo * II,CXXRecordDecl * LayoutStruct)448 static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
449 IdentifierInfo *II,
450 CXXRecordDecl *LayoutStruct) {
451 if (isInvalidConstantBufferLeafElementType(Ty))
452 return nullptr;
453
454 if (Ty->isRecordType()) {
455 CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
456 if (requiresImplicitBufferLayoutStructure(RD)) {
457 RD = createHostLayoutStruct(S, RD);
458 if (!RD)
459 return nullptr;
460 Ty = RD->getTypeForDecl();
461 }
462 }
463
464 QualType QT = QualType(Ty, 0);
465 ASTContext &AST = S.getASTContext();
466 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(QT, SourceLocation());
467 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(),
468 SourceLocation(), II, QT, TSI, nullptr, false,
469 InClassInitStyle::ICIS_NoInit);
470 Field->setAccess(AccessSpecifier::AS_public);
471 return Field;
472 }
473
474 // Creates host layout struct for a struct included in HLSL Buffer.
475 // The layout struct will include only fields that are allowed in HLSL buffer.
476 // These fields will be filtered out:
477 // - resource classes
478 // - empty structs
479 // - zero-sized arrays
480 // Returns nullptr if the resulting layout struct would be empty.
createHostLayoutStruct(Sema & S,CXXRecordDecl * StructDecl)481 static CXXRecordDecl *createHostLayoutStruct(Sema &S,
482 CXXRecordDecl *StructDecl) {
483 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
484 "struct is already HLSL buffer compatible");
485
486 ASTContext &AST = S.getASTContext();
487 DeclContext *DC = StructDecl->getDeclContext();
488 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false);
489
490 // reuse existing if the layout struct if it already exists
491 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
492 return RD;
493
494 CXXRecordDecl *LS =
495 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(),
496 SourceLocation(), II);
497 LS->setImplicit(true);
498 LS->addAttr(PackedAttr::CreateImplicit(AST));
499 LS->startDefinition();
500
501 // copy base struct, create HLSL Buffer compatible version if needed
502 if (unsigned NumBases = StructDecl->getNumBases()) {
503 assert(NumBases == 1 && "HLSL supports only one base type");
504 (void)NumBases;
505 CXXBaseSpecifier Base = *StructDecl->bases_begin();
506 CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
507 if (requiresImplicitBufferLayoutStructure(BaseDecl)) {
508 BaseDecl = createHostLayoutStruct(S, BaseDecl);
509 if (BaseDecl) {
510 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(
511 QualType(BaseDecl->getTypeForDecl(), 0));
512 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
513 AS_none, TSI, SourceLocation());
514 }
515 }
516 if (BaseDecl) {
517 const CXXBaseSpecifier *BasesArray[1] = {&Base};
518 LS->setBases(BasesArray, 1);
519 }
520 }
521
522 // filter struct fields
523 for (const FieldDecl *FD : StructDecl->fields()) {
524 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
525 if (FieldDecl *NewFD =
526 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS))
527 LS->addDecl(NewFD);
528 }
529 LS->completeDefinition();
530
531 if (LS->field_empty() && LS->getNumBases() == 0)
532 return nullptr;
533
534 DC->addDecl(LS);
535 return LS;
536 }
537
538 // Creates host layout struct for HLSL Buffer. The struct will include only
539 // fields of types that are allowed in HLSL buffer and it will filter out:
540 // - static or groupshared variable declarations
541 // - resource classes
542 // - empty structs
543 // - zero-sized arrays
544 // - non-variable declarations
545 // The layout struct will be added to the HLSLBufferDecl declarations.
createHostLayoutStructForBuffer(Sema & S,HLSLBufferDecl * BufDecl)546 void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
547 ASTContext &AST = S.getASTContext();
548 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true);
549
550 CXXRecordDecl *LS =
551 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl,
552 SourceLocation(), SourceLocation(), II);
553 LS->addAttr(PackedAttr::CreateImplicit(AST));
554 LS->setImplicit(true);
555 LS->startDefinition();
556
557 for (Decl *D : BufDecl->buffer_decls()) {
558 VarDecl *VD = dyn_cast<VarDecl>(D);
559 if (!VD || VD->getStorageClass() == SC_Static ||
560 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
561 continue;
562 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
563 if (FieldDecl *FD =
564 createFieldForHostLayoutStruct(S, Ty, VD->getIdentifier(), LS)) {
565 // add the field decl to the layout struct
566 LS->addDecl(FD);
567 // update address space of the original decl to hlsl_constant
568 QualType NewTy =
569 AST.getAddrSpaceQualType(VD->getType(), LangAS::hlsl_constant);
570 VD->setType(NewTy);
571 }
572 }
573 LS->completeDefinition();
574 BufDecl->addLayoutStruct(LS);
575 }
576
addImplicitBindingAttrToBuffer(Sema & S,HLSLBufferDecl * BufDecl,uint32_t ImplicitBindingOrderID)577 static void addImplicitBindingAttrToBuffer(Sema &S, HLSLBufferDecl *BufDecl,
578 uint32_t ImplicitBindingOrderID) {
579 RegisterType RT =
580 BufDecl->isCBuffer() ? RegisterType::CBuffer : RegisterType::SRV;
581 auto *Attr =
582 HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {});
583 std::optional<unsigned> RegSlot;
584 Attr->setBinding(RT, RegSlot, 0);
585 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
586 BufDecl->addAttr(Attr);
587 }
588
589 // Handle end of cbuffer/tbuffer declaration
ActOnFinishBuffer(Decl * Dcl,SourceLocation RBrace)590 void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
591 auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
592 BufDecl->setRBraceLoc(RBrace);
593
594 validatePackoffset(SemaRef, BufDecl);
595
596 // create buffer layout struct
597 createHostLayoutStructForBuffer(SemaRef, BufDecl);
598
599 HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>();
600 if (!RBA || !RBA->hasRegisterSlot()) {
601 SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding);
602 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
603 // to codegen. If it does not exist, create an implicit attribute.
604 uint32_t OrderID = getNextImplicitBindingOrderID();
605 if (RBA)
606 RBA->setImplicitBindingOrderID(OrderID);
607 else
608 addImplicitBindingAttrToBuffer(SemaRef, BufDecl, OrderID);
609 }
610
611 SemaRef.PopDeclContext();
612 }
613
mergeNumThreadsAttr(Decl * D,const AttributeCommonInfo & AL,int X,int Y,int Z)614 HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
615 const AttributeCommonInfo &AL,
616 int X, int Y, int Z) {
617 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
618 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
619 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
620 Diag(AL.getLoc(), diag::note_conflicting_attribute);
621 }
622 return nullptr;
623 }
624 return ::new (getASTContext())
625 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
626 }
627
mergeWaveSizeAttr(Decl * D,const AttributeCommonInfo & AL,int Min,int Max,int Preferred,int SpelledArgsCount)628 HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
629 const AttributeCommonInfo &AL,
630 int Min, int Max, int Preferred,
631 int SpelledArgsCount) {
632 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
633 if (WS->getMin() != Min || WS->getMax() != Max ||
634 WS->getPreferred() != Preferred ||
635 WS->getSpelledArgsCount() != SpelledArgsCount) {
636 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
637 Diag(AL.getLoc(), diag::note_conflicting_attribute);
638 }
639 return nullptr;
640 }
641 HLSLWaveSizeAttr *Result = ::new (getASTContext())
642 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
643 Result->setSpelledArgsCount(SpelledArgsCount);
644 return Result;
645 }
646
647 HLSLVkConstantIdAttr *
mergeVkConstantIdAttr(Decl * D,const AttributeCommonInfo & AL,int Id)648 SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
649 int Id) {
650
651 auto &TargetInfo = getASTContext().getTargetInfo();
652 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
653 Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
654 return nullptr;
655 }
656
657 auto *VD = cast<VarDecl>(D);
658
659 if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
660 Builtin::NotBuiltin) {
661 Diag(VD->getLocation(), diag::err_specialization_const);
662 return nullptr;
663 }
664
665 if (!VD->getType().isConstQualified()) {
666 Diag(VD->getLocation(), diag::err_specialization_const);
667 return nullptr;
668 }
669
670 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
671 if (CI->getId() != Id) {
672 Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
673 Diag(AL.getLoc(), diag::note_conflicting_attribute);
674 }
675 return nullptr;
676 }
677
678 HLSLVkConstantIdAttr *Result =
679 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
680 return Result;
681 }
682
683 HLSLShaderAttr *
mergeShaderAttr(Decl * D,const AttributeCommonInfo & AL,llvm::Triple::EnvironmentType ShaderType)684 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
685 llvm::Triple::EnvironmentType ShaderType) {
686 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
687 if (NT->getType() != ShaderType) {
688 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
689 Diag(AL.getLoc(), diag::note_conflicting_attribute);
690 }
691 return nullptr;
692 }
693 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
694 }
695
696 HLSLParamModifierAttr *
mergeParamModifierAttr(Decl * D,const AttributeCommonInfo & AL,HLSLParamModifierAttr::Spelling Spelling)697 SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
698 HLSLParamModifierAttr::Spelling Spelling) {
699 // We can only merge an `in` attribute with an `out` attribute. All other
700 // combinations of duplicated attributes are ill-formed.
701 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
702 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
703 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
704 D->dropAttr<HLSLParamModifierAttr>();
705 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
706 return HLSLParamModifierAttr::Create(
707 getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
708 HLSLParamModifierAttr::Keyword_inout);
709 }
710 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
711 Diag(PA->getLocation(), diag::note_conflicting_attribute);
712 return nullptr;
713 }
714 return HLSLParamModifierAttr::Create(getASTContext(), AL);
715 }
716
ActOnTopLevelFunction(FunctionDecl * FD)717 void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
718 auto &TargetInfo = getASTContext().getTargetInfo();
719
720 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
721 return;
722
723 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
724 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
725 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
726 // The entry point is already annotated - check that it matches the
727 // triple.
728 if (Shader->getType() != Env) {
729 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
730 << Shader;
731 FD->setInvalidDecl();
732 }
733 } else {
734 // Implicitly add the shader attribute if the entry function isn't
735 // explicitly annotated.
736 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
737 FD->getBeginLoc()));
738 }
739 } else {
740 switch (Env) {
741 case llvm::Triple::UnknownEnvironment:
742 case llvm::Triple::Library:
743 break;
744 default:
745 llvm_unreachable("Unhandled environment in triple");
746 }
747 }
748 }
749
CheckEntryPoint(FunctionDecl * FD)750 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
751 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
752 assert(ShaderAttr && "Entry point has no shader attribute");
753 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
754 auto &TargetInfo = getASTContext().getTargetInfo();
755 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
756 switch (ST) {
757 case llvm::Triple::Pixel:
758 case llvm::Triple::Vertex:
759 case llvm::Triple::Geometry:
760 case llvm::Triple::Hull:
761 case llvm::Triple::Domain:
762 case llvm::Triple::RayGeneration:
763 case llvm::Triple::Intersection:
764 case llvm::Triple::AnyHit:
765 case llvm::Triple::ClosestHit:
766 case llvm::Triple::Miss:
767 case llvm::Triple::Callable:
768 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
769 DiagnoseAttrStageMismatch(NT, ST,
770 {llvm::Triple::Compute,
771 llvm::Triple::Amplification,
772 llvm::Triple::Mesh});
773 FD->setInvalidDecl();
774 }
775 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
776 DiagnoseAttrStageMismatch(WS, ST,
777 {llvm::Triple::Compute,
778 llvm::Triple::Amplification,
779 llvm::Triple::Mesh});
780 FD->setInvalidDecl();
781 }
782 break;
783
784 case llvm::Triple::Compute:
785 case llvm::Triple::Amplification:
786 case llvm::Triple::Mesh:
787 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
788 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
789 << llvm::Triple::getEnvironmentTypeName(ST);
790 FD->setInvalidDecl();
791 }
792 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
793 if (Ver < VersionTuple(6, 6)) {
794 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
795 << WS << "6.6";
796 FD->setInvalidDecl();
797 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
798 Diag(
799 WS->getLocation(),
800 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
801 << WS << WS->getSpelledArgsCount() << "6.8";
802 FD->setInvalidDecl();
803 }
804 }
805 break;
806 default:
807 llvm_unreachable("Unhandled environment in triple");
808 }
809
810 for (ParmVarDecl *Param : FD->parameters()) {
811 if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
812 CheckSemanticAnnotation(FD, Param, AnnotationAttr);
813 } else {
814 // FIXME: Handle struct parameters where annotations are on struct fields.
815 // See: https://github.com/llvm/llvm-project/issues/57875
816 Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
817 Diag(Param->getLocation(), diag::note_previous_decl) << Param;
818 FD->setInvalidDecl();
819 }
820 }
821 // FIXME: Verify return type semantic annotation.
822 }
823
CheckSemanticAnnotation(FunctionDecl * EntryPoint,const Decl * Param,const HLSLAnnotationAttr * AnnotationAttr)824 void SemaHLSL::CheckSemanticAnnotation(
825 FunctionDecl *EntryPoint, const Decl *Param,
826 const HLSLAnnotationAttr *AnnotationAttr) {
827 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
828 assert(ShaderAttr && "Entry point has no shader attribute");
829 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
830
831 switch (AnnotationAttr->getKind()) {
832 case attr::HLSLSV_DispatchThreadID:
833 case attr::HLSLSV_GroupIndex:
834 case attr::HLSLSV_GroupThreadID:
835 case attr::HLSLSV_GroupID:
836 if (ST == llvm::Triple::Compute)
837 return;
838 DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
839 break;
840 case attr::HLSLSV_Position:
841 // TODO(#143523): allow use on other shader types & output once the overall
842 // semantic logic is implemented.
843 if (ST == llvm::Triple::Pixel)
844 return;
845 DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Pixel});
846 break;
847 default:
848 llvm_unreachable("Unknown HLSLAnnotationAttr");
849 }
850 }
851
DiagnoseAttrStageMismatch(const Attr * A,llvm::Triple::EnvironmentType Stage,std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages)852 void SemaHLSL::DiagnoseAttrStageMismatch(
853 const Attr *A, llvm::Triple::EnvironmentType Stage,
854 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
855 SmallVector<StringRef, 8> StageStrings;
856 llvm::transform(AllowedStages, std::back_inserter(StageStrings),
857 [](llvm::Triple::EnvironmentType ST) {
858 return StringRef(
859 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
860 });
861 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
862 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage)
863 << (AllowedStages.size() != 1) << join(StageStrings, ", ");
864 }
865
866 template <CastKind Kind>
castVector(Sema & S,ExprResult & E,QualType & Ty,unsigned Sz)867 static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
868 if (const auto *VTy = Ty->getAs<VectorType>())
869 Ty = VTy->getElementType();
870 Ty = S.getASTContext().getExtVectorType(Ty, Sz);
871 E = S.ImpCastExprToType(E.get(), Ty, Kind);
872 }
873
874 template <CastKind Kind>
castElement(Sema & S,ExprResult & E,QualType Ty)875 static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
876 E = S.ImpCastExprToType(E.get(), Ty, Kind);
877 return Ty;
878 }
879
handleFloatVectorBinOpConversion(Sema & SemaRef,ExprResult & LHS,ExprResult & RHS,QualType LHSType,QualType RHSType,QualType LElTy,QualType RElTy,bool IsCompAssign)880 static QualType handleFloatVectorBinOpConversion(
881 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
882 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
883 bool LHSFloat = LElTy->isRealFloatingType();
884 bool RHSFloat = RElTy->isRealFloatingType();
885
886 if (LHSFloat && RHSFloat) {
887 if (IsCompAssign ||
888 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
889 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
890
891 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
892 }
893
894 if (LHSFloat)
895 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
896
897 assert(RHSFloat);
898 if (IsCompAssign)
899 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
900
901 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
902 }
903
handleIntegerVectorBinOpConversion(Sema & SemaRef,ExprResult & LHS,ExprResult & RHS,QualType LHSType,QualType RHSType,QualType LElTy,QualType RElTy,bool IsCompAssign)904 static QualType handleIntegerVectorBinOpConversion(
905 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
906 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
907
908 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
909 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
910 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
911 auto &Ctx = SemaRef.getASTContext();
912
913 // If both types have the same signedness, use the higher ranked type.
914 if (LHSSigned == RHSSigned) {
915 if (IsCompAssign || IntOrder >= 0)
916 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
917
918 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
919 }
920
921 // If the unsigned type has greater than or equal rank of the signed type, use
922 // the unsigned type.
923 if (IntOrder != (LHSSigned ? 1 : -1)) {
924 if (IsCompAssign || RHSSigned)
925 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
926 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
927 }
928
929 // At this point the signed type has higher rank than the unsigned type, which
930 // means it will be the same size or bigger. If the signed type is bigger, it
931 // can represent all the values of the unsigned type, so select it.
932 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
933 if (IsCompAssign || LHSSigned)
934 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
935 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
936 }
937
938 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
939 // to C/C++ leaking through. The place this happens today is long vs long
940 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
941 // the long long has higher rank than long even though they are the same size.
942
943 // If this is a compound assignment cast the right hand side to the left hand
944 // side's type.
945 if (IsCompAssign)
946 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
947
948 // If this isn't a compound assignment we convert to unsigned long long.
949 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
950 QualType NewTy = Ctx.getExtVectorType(
951 ElTy, RHSType->castAs<VectorType>()->getNumElements());
952 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
953
954 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
955 }
956
getScalarCastKind(ASTContext & Ctx,QualType DestTy,QualType SrcTy)957 static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
958 QualType SrcTy) {
959 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
960 return CK_FloatingCast;
961 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
962 return CK_IntegralCast;
963 if (DestTy->isRealFloatingType())
964 return CK_IntegralToFloating;
965 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
966 return CK_FloatingToIntegral;
967 }
968
handleVectorBinOpConversion(ExprResult & LHS,ExprResult & RHS,QualType LHSType,QualType RHSType,bool IsCompAssign)969 QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
970 QualType LHSType,
971 QualType RHSType,
972 bool IsCompAssign) {
973 const auto *LVecTy = LHSType->getAs<VectorType>();
974 const auto *RVecTy = RHSType->getAs<VectorType>();
975 auto &Ctx = getASTContext();
976
977 // If the LHS is not a vector and this is a compound assignment, we truncate
978 // the argument to a scalar then convert it to the LHS's type.
979 if (!LVecTy && IsCompAssign) {
980 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
981 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
982 RHSType = RHS.get()->getType();
983 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
984 return LHSType;
985 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
986 getScalarCastKind(Ctx, LHSType, RHSType));
987 return LHSType;
988 }
989
990 unsigned EndSz = std::numeric_limits<unsigned>::max();
991 unsigned LSz = 0;
992 if (LVecTy)
993 LSz = EndSz = LVecTy->getNumElements();
994 if (RVecTy)
995 EndSz = std::min(RVecTy->getNumElements(), EndSz);
996 assert(EndSz != std::numeric_limits<unsigned>::max() &&
997 "one of the above should have had a value");
998
999 // In a compound assignment, the left operand does not change type, the right
1000 // operand is converted to the type of the left operand.
1001 if (IsCompAssign && LSz != EndSz) {
1002 Diag(LHS.get()->getBeginLoc(),
1003 diag::err_hlsl_vector_compound_assignment_truncation)
1004 << LHSType << RHSType;
1005 return QualType();
1006 }
1007
1008 if (RVecTy && RVecTy->getNumElements() > EndSz)
1009 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
1010 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1011 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
1012
1013 if (!RVecTy)
1014 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
1015 if (!IsCompAssign && !LVecTy)
1016 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
1017
1018 // If we're at the same type after resizing we can stop here.
1019 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
1020 return Ctx.getCommonSugaredType(LHSType, RHSType);
1021
1022 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1023 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1024
1025 // Handle conversion for floating point vectors.
1026 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1027 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1028 LElTy, RElTy, IsCompAssign);
1029
1030 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1031 "HLSL Vectors can only contain integer or floating point types");
1032 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1033 LElTy, RElTy, IsCompAssign);
1034 }
1035
emitLogicalOperatorFixIt(Expr * LHS,Expr * RHS,BinaryOperatorKind Opc)1036 void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1037 BinaryOperatorKind Opc) {
1038 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1039 "Called with non-logical operator");
1040 llvm::SmallVector<char, 256> Buff;
1041 llvm::raw_svector_ostream OS(Buff);
1042 PrintingPolicy PP(SemaRef.getLangOpts());
1043 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1044 OS << NewFnName << "(";
1045 LHS->printPretty(OS, nullptr, PP);
1046 OS << ", ";
1047 RHS->printPretty(OS, nullptr, PP);
1048 OS << ")";
1049 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1050 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
1051 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
1052 }
1053
1054 std::pair<IdentifierInfo *, bool>
ActOnStartRootSignatureDecl(StringRef Signature)1055 SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1056 llvm::hash_code Hash = llvm::hash_value(Signature);
1057 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash);
1058 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr));
1059
1060 // Check if we have already found a decl of the same name.
1061 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1062 Sema::LookupOrdinaryName);
1063 bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext);
1064 return {DeclIdent, Found};
1065 }
1066
ActOnFinishRootSignatureDecl(SourceLocation Loc,IdentifierInfo * DeclIdent,ArrayRef<hlsl::RootSignatureElement> RootElements)1067 void SemaHLSL::ActOnFinishRootSignatureDecl(
1068 SourceLocation Loc, IdentifierInfo *DeclIdent,
1069 ArrayRef<hlsl::RootSignatureElement> RootElements) {
1070
1071 if (handleRootSignatureElements(RootElements))
1072 return;
1073
1074 SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1075 for (auto &RootSigElement : RootElements)
1076 Elements.push_back(RootSigElement.getElement());
1077
1078 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1079 SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
1080 DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
1081
1082 SignatureDecl->setImplicit();
1083 SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
1084 }
1085
handleRootSignatureElements(ArrayRef<hlsl::RootSignatureElement> Elements)1086 bool SemaHLSL::handleRootSignatureElements(
1087 ArrayRef<hlsl::RootSignatureElement> Elements) {
1088 // Define some common error handling functions
1089 bool HadError = false;
1090 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1091 uint32_t UpperBound) {
1092 HadError = true;
1093 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1094 << LowerBound << UpperBound;
1095 };
1096
1097 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1098 float LowerBound,
1099 float UpperBound) {
1100 HadError = true;
1101 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
1102 << llvm::formatv("{0:f}", LowerBound).sstr<6>()
1103 << llvm::formatv("{0:f}", UpperBound).sstr<6>();
1104 };
1105
1106 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1107 if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
1108 ReportError(Loc, 0, 0xfffffffe);
1109 };
1110
1111 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1112 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
1113 ReportError(Loc, 0, 0xffffffef);
1114 };
1115
1116 const uint32_t Version =
1117 llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
1118 const uint32_t VersionEnum = Version - 1;
1119 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1120 HadError = true;
1121 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
1122 << /*version minor*/ VersionEnum;
1123 };
1124
1125 // Iterate through the elements and do basic validations
1126 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1127 SourceLocation Loc = RootSigElem.getLocation();
1128 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1129 if (const auto *Descriptor =
1130 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1131 VerifyRegister(Loc, Descriptor->Reg.Number);
1132 VerifySpace(Loc, Descriptor->Space);
1133
1134 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
1135 Version, llvm::to_underlying(Descriptor->Flags)))
1136 ReportFlagError(Loc);
1137 } else if (const auto *Constants =
1138 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1139 VerifyRegister(Loc, Constants->Reg.Number);
1140 VerifySpace(Loc, Constants->Space);
1141 } else if (const auto *Sampler =
1142 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1143 VerifyRegister(Loc, Sampler->Reg.Number);
1144 VerifySpace(Loc, Sampler->Space);
1145
1146 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1147 "By construction, parseFloatParam can't produce a NaN from a "
1148 "float_literal token");
1149
1150 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
1151 ReportError(Loc, 0, 16);
1152 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
1153 ReportFloatError(Loc, -16.f, 15.99);
1154 } else if (const auto *Clause =
1155 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1156 &Elem)) {
1157 VerifyRegister(Loc, Clause->Reg.Number);
1158 VerifySpace(Loc, Clause->Space);
1159
1160 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
1161 // NumDescriptor could techincally be ~0u but that is reserved for
1162 // unbounded, so the diagnostic will not report that as a valid int
1163 // value
1164 ReportError(Loc, 1, 0xfffffffe);
1165 }
1166
1167 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
1168 Version, llvm::to_underlying(Clause->Type),
1169 llvm::to_underlying(Clause->Flags)))
1170 ReportFlagError(Loc);
1171 }
1172 }
1173
1174 using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
1175 using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
1176 using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
1177
1178 // 1. Collect RangeInfos
1179 llvm::SmallVector<InfoPairT> InfoPairs;
1180 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1181 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1182 if (const auto *Descriptor =
1183 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1184 RangeInfo Info;
1185 Info.LowerBound = Descriptor->Reg.Number;
1186 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1187
1188 Info.Class =
1189 llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
1190 Info.Space = Descriptor->Space;
1191 Info.Visibility = Descriptor->Visibility;
1192
1193 InfoPairs.push_back({Info, &RootSigElem});
1194 } else if (const auto *Constants =
1195 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1196 RangeInfo Info;
1197 Info.LowerBound = Constants->Reg.Number;
1198 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1199
1200 Info.Class = llvm::dxil::ResourceClass::CBuffer;
1201 Info.Space = Constants->Space;
1202 Info.Visibility = Constants->Visibility;
1203
1204 InfoPairs.push_back({Info, &RootSigElem});
1205 } else if (const auto *Sampler =
1206 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1207 RangeInfo Info;
1208 Info.LowerBound = Sampler->Reg.Number;
1209 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1210
1211 Info.Class = llvm::dxil::ResourceClass::Sampler;
1212 Info.Space = Sampler->Space;
1213 Info.Visibility = Sampler->Visibility;
1214
1215 InfoPairs.push_back({Info, &RootSigElem});
1216 } else if (const auto *Clause =
1217 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1218 &Elem)) {
1219 RangeInfo Info;
1220 Info.LowerBound = Clause->Reg.Number;
1221 // Relevant error will have already been reported above and needs to be
1222 // fixed before we can conduct range analysis, so shortcut error return
1223 if (Clause->NumDescriptors == 0)
1224 return true;
1225 Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
1226 ? RangeInfo::Unbounded
1227 : Info.LowerBound + Clause->NumDescriptors -
1228 1; // use inclusive ranges []
1229
1230 Info.Class = Clause->Type;
1231 Info.Space = Clause->Space;
1232
1233 // Note: Clause does not hold the visibility this will need to
1234 InfoPairs.push_back({Info, &RootSigElem});
1235 } else if (const auto *Table =
1236 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1237 // Table holds the Visibility of all owned Clauses in Table, so iterate
1238 // owned Clauses and update their corresponding RangeInfo
1239 assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
1240 // The last Table->NumClauses elements of Infos are the owned Clauses
1241 // generated RangeInfo
1242 auto TableInfos =
1243 MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses);
1244 for (InfoPairT &Pair : TableInfos)
1245 Pair.first.Visibility = Table->Visibility;
1246 }
1247 }
1248
1249 // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping
1250 llvm::sort(InfoPairs,
1251 [](InfoPairT A, InfoPairT B) { return A.first < B.first; });
1252
1253 llvm::SmallVector<RangeInfo> Infos;
1254 for (const InfoPairT &Pair : InfoPairs)
1255 Infos.push_back(Pair.first);
1256
1257 // Helpers to report diagnostics
1258 uint32_t DuplicateCounter = 0;
1259 using ElemPair = std::pair<const hlsl::RootSignatureElement *,
1260 const hlsl::RootSignatureElement *>;
1261 auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter](
1262 OverlappingRanges Overlap) -> ElemPair {
1263 // Given we sorted the InfoPairs (and by implication) Infos, and,
1264 // that Overlap.B is the item retrieved from the ResourceRange. Then it is
1265 // guarenteed that Overlap.B <= Overlap.A.
1266 //
1267 // So we will find Overlap.B first and then continue to find Overlap.A
1268 // after
1269 auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B);
1270 auto DistB = std::distance(Infos.begin(), InfoB);
1271 auto PairB = InfoPairs.begin();
1272 std::advance(PairB, DistB);
1273
1274 auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A);
1275 // Similarily, from the property that we have sorted the RangeInfos,
1276 // all duplicates will be processed one after the other. So
1277 // DuplicateCounter can be re-used for each set of duplicates we
1278 // encounter as we handle incoming errors
1279 DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0;
1280 auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter;
1281 auto PairA = PairB;
1282 std::advance(PairA, DistA);
1283
1284 return {PairA->second, PairB->second};
1285 };
1286
1287 auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
1288 auto Pair = GetElemPair(Overlap);
1289 const RangeInfo *Info = Overlap.A;
1290 const hlsl::RootSignatureElement *Elem = Pair.first;
1291 const RangeInfo *OInfo = Overlap.B;
1292
1293 auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
1294 ? OInfo->Visibility
1295 : Info->Visibility;
1296 this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1297 << llvm::to_underlying(Info->Class) << Info->LowerBound
1298 << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
1299 << Info->UpperBound << llvm::to_underlying(OInfo->Class)
1300 << OInfo->LowerBound
1301 << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
1302 << OInfo->UpperBound << Info->Space << CommonVis;
1303
1304 const hlsl::RootSignatureElement *OElem = Pair.second;
1305 this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here);
1306 };
1307
1308 // 3. Invoke find overlapping ranges
1309 llvm::SmallVector<OverlappingRanges> Overlaps =
1310 llvm::hlsl::rootsig::findOverlappingRanges(Infos);
1311 for (OverlappingRanges Overlap : Overlaps)
1312 ReportOverlap(Overlap);
1313
1314 return Overlaps.size() != 0;
1315 }
1316
handleRootSignatureAttr(Decl * D,const ParsedAttr & AL)1317 void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1318 if (AL.getNumArgs() != 1) {
1319 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1320 return;
1321 }
1322
1323 IdentifierInfo *Ident = AL.getArgAsIdent(0)->getIdentifierInfo();
1324 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1325 if (RS->getSignatureIdent() != Ident) {
1326 Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
1327 return;
1328 }
1329
1330 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS;
1331 return;
1332 }
1333
1334 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1335 if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
1336 if (auto *SignatureDecl =
1337 dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
1338 D->addAttr(::new (getASTContext()) RootSignatureAttr(
1339 getASTContext(), AL, Ident, SignatureDecl));
1340 }
1341 }
1342
handleNumThreadsAttr(Decl * D,const ParsedAttr & AL)1343 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1344 llvm::VersionTuple SMVersion =
1345 getASTContext().getTargetInfo().getTriple().getOSVersion();
1346 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1347 llvm::Triple::dxil;
1348
1349 uint32_t ZMax = 1024;
1350 uint32_t ThreadMax = 1024;
1351 if (IsDXIL && SMVersion.getMajor() <= 4) {
1352 ZMax = 1;
1353 ThreadMax = 768;
1354 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1355 ZMax = 64;
1356 ThreadMax = 1024;
1357 }
1358
1359 uint32_t X;
1360 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
1361 return;
1362 if (X > 1024) {
1363 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1364 diag::err_hlsl_numthreads_argument_oor)
1365 << 0 << 1024;
1366 return;
1367 }
1368 uint32_t Y;
1369 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
1370 return;
1371 if (Y > 1024) {
1372 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1373 diag::err_hlsl_numthreads_argument_oor)
1374 << 1 << 1024;
1375 return;
1376 }
1377 uint32_t Z;
1378 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
1379 return;
1380 if (Z > ZMax) {
1381 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
1382 diag::err_hlsl_numthreads_argument_oor)
1383 << 2 << ZMax;
1384 return;
1385 }
1386
1387 if (X * Y * Z > ThreadMax) {
1388 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
1389 return;
1390 }
1391
1392 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1393 if (NewAttr)
1394 D->addAttr(NewAttr);
1395 }
1396
isValidWaveSizeValue(unsigned Value)1397 static bool isValidWaveSizeValue(unsigned Value) {
1398 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1399 }
1400
handleWaveSizeAttr(Decl * D,const ParsedAttr & AL)1401 void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1402 // validate that the wavesize argument is a power of 2 between 4 and 128
1403 // inclusive
1404 unsigned SpelledArgsCount = AL.getNumArgs();
1405 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1406 return;
1407
1408 uint32_t Min;
1409 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
1410 return;
1411
1412 uint32_t Max = 0;
1413 if (SpelledArgsCount > 1 &&
1414 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
1415 return;
1416
1417 uint32_t Preferred = 0;
1418 if (SpelledArgsCount > 2 &&
1419 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
1420 return;
1421
1422 if (SpelledArgsCount > 2) {
1423 if (!isValidWaveSizeValue(Preferred)) {
1424 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1425 diag::err_attribute_power_of_two_in_range)
1426 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1427 << Preferred;
1428 return;
1429 }
1430 // Preferred not in range.
1431 if (Preferred < Min || Preferred > Max) {
1432 Diag(AL.getArgAsExpr(2)->getExprLoc(),
1433 diag::err_attribute_power_of_two_in_range)
1434 << AL << Min << Max << Preferred;
1435 return;
1436 }
1437 } else if (SpelledArgsCount > 1) {
1438 if (!isValidWaveSizeValue(Max)) {
1439 Diag(AL.getArgAsExpr(1)->getExprLoc(),
1440 diag::err_attribute_power_of_two_in_range)
1441 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1442 return;
1443 }
1444 if (Max < Min) {
1445 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
1446 return;
1447 } else if (Max == Min) {
1448 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
1449 }
1450 } else {
1451 if (!isValidWaveSizeValue(Min)) {
1452 Diag(AL.getArgAsExpr(0)->getExprLoc(),
1453 diag::err_attribute_power_of_two_in_range)
1454 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1455 return;
1456 }
1457 }
1458
1459 HLSLWaveSizeAttr *NewAttr =
1460 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1461 if (NewAttr)
1462 D->addAttr(NewAttr);
1463 }
1464
handleVkExtBuiltinInputAttr(Decl * D,const ParsedAttr & AL)1465 void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1466 uint32_t ID;
1467 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID))
1468 return;
1469 D->addAttr(::new (getASTContext())
1470 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1471 }
1472
handleVkConstantIdAttr(Decl * D,const ParsedAttr & AL)1473 void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1474 uint32_t Id;
1475 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1476 return;
1477 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1478 if (NewAttr)
1479 D->addAttr(NewAttr);
1480 }
1481
diagnoseInputIDType(QualType T,const ParsedAttr & AL)1482 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1483 const auto *VT = T->getAs<VectorType>();
1484
1485 if (!T->hasUnsignedIntegerRepresentation() ||
1486 (VT && VT->getNumElements() > 3)) {
1487 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1488 << AL << "uint/uint2/uint3";
1489 return false;
1490 }
1491
1492 return true;
1493 }
1494
handleSV_DispatchThreadIDAttr(Decl * D,const ParsedAttr & AL)1495 void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1496 auto *VD = cast<ValueDecl>(D);
1497 if (!diagnoseInputIDType(VD->getType(), AL))
1498 return;
1499
1500 D->addAttr(::new (getASTContext())
1501 HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
1502 }
1503
diagnosePositionType(QualType T,const ParsedAttr & AL)1504 bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1505 const auto *VT = T->getAs<VectorType>();
1506
1507 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1508 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
1509 << AL << "float/float1/float2/float3/float4";
1510 return false;
1511 }
1512
1513 return true;
1514 }
1515
handleSV_PositionAttr(Decl * D,const ParsedAttr & AL)1516 void SemaHLSL::handleSV_PositionAttr(Decl *D, const ParsedAttr &AL) {
1517 auto *VD = cast<ValueDecl>(D);
1518 if (!diagnosePositionType(VD->getType(), AL))
1519 return;
1520
1521 D->addAttr(::new (getASTContext()) HLSLSV_PositionAttr(getASTContext(), AL));
1522 }
1523
handleSV_GroupThreadIDAttr(Decl * D,const ParsedAttr & AL)1524 void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1525 auto *VD = cast<ValueDecl>(D);
1526 if (!diagnoseInputIDType(VD->getType(), AL))
1527 return;
1528
1529 D->addAttr(::new (getASTContext())
1530 HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
1531 }
1532
handleSV_GroupIDAttr(Decl * D,const ParsedAttr & AL)1533 void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
1534 auto *VD = cast<ValueDecl>(D);
1535 if (!diagnoseInputIDType(VD->getType(), AL))
1536 return;
1537
1538 D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
1539 }
1540
handlePackOffsetAttr(Decl * D,const ParsedAttr & AL)1541 void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1542 if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
1543 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
1544 << AL << "shader constant in a constant buffer";
1545 return;
1546 }
1547
1548 uint32_t SubComponent;
1549 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
1550 return;
1551 uint32_t Component;
1552 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
1553 return;
1554
1555 QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
1556 // Check if T is an array or struct type.
1557 // TODO: mark matrix type as aggregate type.
1558 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1559
1560 // Check Component is valid for T.
1561 if (Component) {
1562 unsigned Size = getASTContext().getTypeSize(T);
1563 if (IsAggregateTy || Size > 128) {
1564 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1565 return;
1566 } else {
1567 // Make sure Component + sizeof(T) <= 4.
1568 if ((Component * 32 + Size) > 128) {
1569 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
1570 return;
1571 }
1572 QualType EltTy = T;
1573 if (const auto *VT = T->getAs<VectorType>())
1574 EltTy = VT->getElementType();
1575 unsigned Align = getASTContext().getTypeAlign(EltTy);
1576 if (Align > 32 && Component == 1) {
1577 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1578 // So we only need to check Component 1 here.
1579 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
1580 << Align << EltTy;
1581 return;
1582 }
1583 }
1584 }
1585
1586 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
1587 getASTContext(), AL, SubComponent, Component));
1588 }
1589
handleShaderAttr(Decl * D,const ParsedAttr & AL)1590 void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
1591 StringRef Str;
1592 SourceLocation ArgLoc;
1593 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
1594 return;
1595
1596 llvm::Triple::EnvironmentType ShaderType;
1597 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
1598 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
1599 << AL << Str << ArgLoc;
1600 return;
1601 }
1602
1603 // FIXME: check function match the shader stage.
1604
1605 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1606 if (NewAttr)
1607 D->addAttr(NewAttr);
1608 }
1609
CreateHLSLAttributedResourceType(Sema & S,QualType Wrapped,ArrayRef<const Attr * > AttrList,QualType & ResType,HLSLAttributedResourceLocInfo * LocInfo)1610 bool clang::CreateHLSLAttributedResourceType(
1611 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1612 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1613 assert(AttrList.size() && "expected list of resource attributes");
1614
1615 QualType ContainedTy = QualType();
1616 TypeSourceInfo *ContainedTyInfo = nullptr;
1617 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
1618 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
1619
1620 HLSLAttributedResourceType::Attributes ResAttrs;
1621
1622 bool HasResourceClass = false;
1623 for (const Attr *A : AttrList) {
1624 if (!A)
1625 continue;
1626 LocEnd = A->getRange().getEnd();
1627 switch (A->getKind()) {
1628 case attr::HLSLResourceClass: {
1629 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
1630 if (HasResourceClass) {
1631 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
1632 ? diag::warn_duplicate_attribute_exact
1633 : diag::warn_duplicate_attribute)
1634 << A;
1635 return false;
1636 }
1637 ResAttrs.ResourceClass = RC;
1638 HasResourceClass = true;
1639 break;
1640 }
1641 case attr::HLSLROV:
1642 if (ResAttrs.IsROV) {
1643 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1644 return false;
1645 }
1646 ResAttrs.IsROV = true;
1647 break;
1648 case attr::HLSLRawBuffer:
1649 if (ResAttrs.RawBuffer) {
1650 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1651 return false;
1652 }
1653 ResAttrs.RawBuffer = true;
1654 break;
1655 case attr::HLSLContainedType: {
1656 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
1657 QualType Ty = CTAttr->getType();
1658 if (!ContainedTy.isNull()) {
1659 S.Diag(A->getLocation(), ContainedTy == Ty
1660 ? diag::warn_duplicate_attribute_exact
1661 : diag::warn_duplicate_attribute)
1662 << A;
1663 return false;
1664 }
1665 ContainedTy = Ty;
1666 ContainedTyInfo = CTAttr->getTypeLoc();
1667 break;
1668 }
1669 default:
1670 llvm_unreachable("unhandled resource attribute type");
1671 }
1672 }
1673
1674 if (!HasResourceClass) {
1675 S.Diag(AttrList.back()->getRange().getEnd(),
1676 diag::err_hlsl_missing_resource_class);
1677 return false;
1678 }
1679
1680 ResType = S.getASTContext().getHLSLAttributedResourceType(
1681 Wrapped, ContainedTy, ResAttrs);
1682
1683 if (LocInfo && ContainedTyInfo) {
1684 LocInfo->Range = SourceRange(LocBegin, LocEnd);
1685 LocInfo->ContainedTyInfo = ContainedTyInfo;
1686 }
1687 return true;
1688 }
1689
1690 // Validates and creates an HLSL attribute that is applied as type attribute on
1691 // HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
1692 // the end of the declaration they are applied to the declaration type by
1693 // wrapping it in HLSLAttributedResourceType.
handleResourceTypeAttr(QualType T,const ParsedAttr & AL)1694 bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
1695 // only allow resource type attributes on intangible types
1696 if (!T->isHLSLResourceType()) {
1697 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type)
1698 << AL << getASTContext().HLSLResourceTy;
1699 return false;
1700 }
1701
1702 // validate number of arguments
1703 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs()))
1704 return false;
1705
1706 Attr *A = nullptr;
1707
1708 AttributeCommonInfo ACI(
1709 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
1710 AttributeCommonInfo::NoSemaHandlerAttribute,
1711 {
1712 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
1713 false /*IsRegularKeywordAttribute*/
1714 });
1715
1716 switch (AL.getKind()) {
1717 case ParsedAttr::AT_HLSLResourceClass: {
1718 if (!AL.isArgIdent(0)) {
1719 Diag(AL.getLoc(), diag::err_attribute_argument_type)
1720 << AL << AANT_ArgumentIdentifier;
1721 return false;
1722 }
1723
1724 IdentifierLoc *Loc = AL.getArgAsIdent(0);
1725 StringRef Identifier = Loc->getIdentifierInfo()->getName();
1726 SourceLocation ArgLoc = Loc->getLoc();
1727
1728 // Validate resource class value
1729 ResourceClass RC;
1730 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
1731 Diag(ArgLoc, diag::warn_attribute_type_not_supported)
1732 << "ResourceClass" << Identifier;
1733 return false;
1734 }
1735 A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI);
1736 break;
1737 }
1738
1739 case ParsedAttr::AT_HLSLROV:
1740 A = HLSLROVAttr::Create(getASTContext(), ACI);
1741 break;
1742
1743 case ParsedAttr::AT_HLSLRawBuffer:
1744 A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
1745 break;
1746
1747 case ParsedAttr::AT_HLSLContainedType: {
1748 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
1749 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
1750 return false;
1751 }
1752
1753 TypeSourceInfo *TSI = nullptr;
1754 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI);
1755 assert(TSI && "no type source info for attribute argument");
1756 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT,
1757 diag::err_incomplete_type))
1758 return false;
1759 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI);
1760 break;
1761 }
1762
1763 default:
1764 llvm_unreachable("unhandled HLSL attribute");
1765 }
1766
1767 HLSLResourcesTypeAttrs.emplace_back(A);
1768 return true;
1769 }
1770
1771 // Combines all resource type attributes and creates HLSLAttributedResourceType.
ProcessResourceTypeAttributes(QualType CurrentType)1772 QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
1773 if (!HLSLResourcesTypeAttrs.size())
1774 return CurrentType;
1775
1776 QualType QT = CurrentType;
1777 HLSLAttributedResourceLocInfo LocInfo;
1778 if (CreateHLSLAttributedResourceType(SemaRef, CurrentType,
1779 HLSLResourcesTypeAttrs, QT, &LocInfo)) {
1780 const HLSLAttributedResourceType *RT =
1781 cast<HLSLAttributedResourceType>(QT.getTypePtr());
1782
1783 // Temporarily store TypeLoc information for the new type.
1784 // It will be transferred to HLSLAttributesResourceTypeLoc
1785 // shortly after the type is created by TypeSpecLocFiller which
1786 // will call the TakeLocForHLSLAttribute method below.
1787 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo));
1788 }
1789 HLSLResourcesTypeAttrs.clear();
1790 return QT;
1791 }
1792
1793 // Returns source location for the HLSLAttributedResourceType
1794 HLSLAttributedResourceLocInfo
TakeLocForHLSLAttribute(const HLSLAttributedResourceType * RT)1795 SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
1796 HLSLAttributedResourceLocInfo LocInfo = {};
1797 auto I = LocsForHLSLAttributedResources.find(RT);
1798 if (I != LocsForHLSLAttributedResources.end()) {
1799 LocInfo = I->second;
1800 LocsForHLSLAttributedResources.erase(I);
1801 return LocInfo;
1802 }
1803 LocInfo.Range = SourceRange();
1804 return LocInfo;
1805 }
1806
1807 // Walks though the global variable declaration, collects all resource binding
1808 // requirements and adds them to Bindings
collectResourceBindingsOnUserRecordDecl(const VarDecl * VD,const RecordType * RT)1809 void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
1810 const RecordType *RT) {
1811 const RecordDecl *RD = RT->getDecl();
1812 for (FieldDecl *FD : RD->fields()) {
1813 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
1814
1815 // Unwrap arrays
1816 // FIXME: Calculate array size while unwrapping
1817 assert(!Ty->isIncompleteArrayType() &&
1818 "incomplete arrays inside user defined types are not supported");
1819 while (Ty->isConstantArrayType()) {
1820 const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
1821 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
1822 }
1823
1824 if (!Ty->isRecordType())
1825 continue;
1826
1827 if (const HLSLAttributedResourceType *AttrResType =
1828 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
1829 // Add a new DeclBindingInfo to Bindings if it does not already exist
1830 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
1831 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC);
1832 if (!DBI)
1833 Bindings.addDeclBindingInfo(VD, RC);
1834 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) {
1835 // Recursively scan embedded struct or class; it would be nice to do this
1836 // without recursion, but tricky to correctly calculate the size of the
1837 // binding, which is something we are probably going to need to do later
1838 // on. Hopefully nesting of structs in structs too many levels is
1839 // unlikely.
1840 collectResourceBindingsOnUserRecordDecl(VD, RT);
1841 }
1842 }
1843 }
1844
1845 // Diagnose localized register binding errors for a single binding; does not
1846 // diagnose resource binding on user record types, that will be done later
1847 // in processResourceBindingOnDecl based on the information collected in
1848 // collectResourceBindingsOnVarDecl.
1849 // Returns false if the register binding is not valid.
DiagnoseLocalRegisterBinding(Sema & S,SourceLocation & ArgLoc,Decl * D,RegisterType RegType,bool SpecifiedSpace)1850 static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
1851 Decl *D, RegisterType RegType,
1852 bool SpecifiedSpace) {
1853 int RegTypeNum = static_cast<int>(RegType);
1854
1855 // check if the decl type is groupshared
1856 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
1857 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1858 return false;
1859 }
1860
1861 // Cbuffers and Tbuffers are HLSLBufferDecl types
1862 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
1863 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
1864 : ResourceClass::SRV;
1865 if (RegType == getRegisterType(RC))
1866 return true;
1867
1868 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
1869 << RegTypeNum;
1870 return false;
1871 }
1872
1873 // Samplers, UAVs, and SRVs are VarDecl types
1874 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
1875 VarDecl *VD = cast<VarDecl>(D);
1876
1877 // Resource
1878 if (const HLSLAttributedResourceType *AttrResType =
1879 HLSLAttributedResourceType::findHandleTypeOnResource(
1880 VD->getType().getTypePtr())) {
1881 if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass))
1882 return true;
1883
1884 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
1885 << RegTypeNum;
1886 return false;
1887 }
1888
1889 const clang::Type *Ty = VD->getType().getTypePtr();
1890 while (Ty->isArrayType())
1891 Ty = Ty->getArrayElementTypeNoTypeQual();
1892
1893 // Basic types
1894 if (Ty->isArithmeticType() || Ty->isVectorType()) {
1895 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
1896 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
1897 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
1898
1899 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) ||
1900 Ty->isFloatingType() || Ty->isVectorType())) {
1901 // Register annotation on default constant buffer declaration ($Globals)
1902 if (RegType == RegisterType::CBuffer)
1903 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
1904 else if (RegType != RegisterType::C)
1905 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1906 else
1907 return true;
1908 } else {
1909 if (RegType == RegisterType::C)
1910 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
1911 else
1912 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1913 }
1914 return false;
1915 }
1916 if (Ty->isRecordType())
1917 // RecordTypes will be diagnosed in processResourceBindingOnDecl
1918 // that is called from ActOnVariableDeclarator
1919 return true;
1920
1921 // Anything else is an error
1922 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1923 return false;
1924 }
1925
ValidateMultipleRegisterAnnotations(Sema & S,Decl * TheDecl,RegisterType regType)1926 static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
1927 RegisterType regType) {
1928 // make sure that there are no two register annotations
1929 // applied to the decl with the same register type
1930 bool RegisterTypesDetected[5] = {false};
1931 RegisterTypesDetected[static_cast<int>(regType)] = true;
1932
1933 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
1934 if (HLSLResourceBindingAttr *attr =
1935 dyn_cast<HLSLResourceBindingAttr>(*it)) {
1936
1937 RegisterType otherRegType = attr->getRegisterType();
1938 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
1939 int otherRegTypeNum = static_cast<int>(otherRegType);
1940 S.Diag(TheDecl->getLocation(),
1941 diag::err_hlsl_duplicate_register_annotation)
1942 << otherRegTypeNum;
1943 return false;
1944 }
1945 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
1946 }
1947 }
1948 return true;
1949 }
1950
DiagnoseHLSLRegisterAttribute(Sema & S,SourceLocation & ArgLoc,Decl * D,RegisterType RegType,bool SpecifiedSpace)1951 static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
1952 Decl *D, RegisterType RegType,
1953 bool SpecifiedSpace) {
1954
1955 // exactly one of these two types should be set
1956 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
1957 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
1958 "expecting VarDecl or HLSLBufferDecl");
1959
1960 // check if the declaration contains resource matching the register type
1961 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
1962 return false;
1963
1964 // next, if multiple register annotations exist, check that none conflict.
1965 return ValidateMultipleRegisterAnnotations(S, D, RegType);
1966 }
1967
handleResourceBindingAttr(Decl * TheDecl,const ParsedAttr & AL)1968 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
1969 if (isa<VarDecl>(TheDecl)) {
1970 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
1971 cast<ValueDecl>(TheDecl)->getType(),
1972 diag::err_incomplete_type))
1973 return;
1974 }
1975
1976 StringRef Slot = "";
1977 StringRef Space = "";
1978 SourceLocation SlotLoc, SpaceLoc;
1979
1980 if (!AL.isArgIdent(0)) {
1981 Diag(AL.getLoc(), diag::err_attribute_argument_type)
1982 << AL << AANT_ArgumentIdentifier;
1983 return;
1984 }
1985 IdentifierLoc *Loc = AL.getArgAsIdent(0);
1986
1987 if (AL.getNumArgs() == 2) {
1988 Slot = Loc->getIdentifierInfo()->getName();
1989 SlotLoc = Loc->getLoc();
1990 if (!AL.isArgIdent(1)) {
1991 Diag(AL.getLoc(), diag::err_attribute_argument_type)
1992 << AL << AANT_ArgumentIdentifier;
1993 return;
1994 }
1995 Loc = AL.getArgAsIdent(1);
1996 Space = Loc->getIdentifierInfo()->getName();
1997 SpaceLoc = Loc->getLoc();
1998 } else {
1999 StringRef Str = Loc->getIdentifierInfo()->getName();
2000 if (Str.starts_with("space")) {
2001 Space = Str;
2002 SpaceLoc = Loc->getLoc();
2003 } else {
2004 Slot = Str;
2005 SlotLoc = Loc->getLoc();
2006 Space = "space0";
2007 }
2008 }
2009
2010 RegisterType RegType = RegisterType::SRV;
2011 std::optional<unsigned> SlotNum;
2012 unsigned SpaceNum = 0;
2013
2014 // Validate slot
2015 if (!Slot.empty()) {
2016 if (!convertToRegisterType(Slot, &RegType)) {
2017 Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1);
2018 return;
2019 }
2020 if (RegType == RegisterType::I) {
2021 Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i);
2022 return;
2023 }
2024 StringRef SlotNumStr = Slot.substr(1);
2025 unsigned N;
2026 if (SlotNumStr.getAsInteger(10, N)) {
2027 Diag(SlotLoc, diag::err_hlsl_unsupported_register_number);
2028 return;
2029 }
2030 SlotNum = N;
2031 }
2032
2033 // Validate space
2034 if (!Space.starts_with("space")) {
2035 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2036 return;
2037 }
2038 StringRef SpaceNumStr = Space.substr(5);
2039 if (SpaceNumStr.getAsInteger(10, SpaceNum)) {
2040 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space;
2041 return;
2042 }
2043
2044 // If we have slot, diagnose it is the right register type for the decl
2045 if (SlotNum.has_value())
2046 if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType,
2047 !SpaceLoc.isInvalid()))
2048 return;
2049
2050 HLSLResourceBindingAttr *NewAttr =
2051 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
2052 if (NewAttr) {
2053 NewAttr->setBinding(RegType, SlotNum, SpaceNum);
2054 TheDecl->addAttr(NewAttr);
2055 }
2056 }
2057
handleParamModifierAttr(Decl * D,const ParsedAttr & AL)2058 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2059 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2060 D, AL,
2061 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2062 if (NewAttr)
2063 D->addAttr(NewAttr);
2064 }
2065
2066 namespace {
2067
2068 /// This class implements HLSL availability diagnostics for default
2069 /// and relaxed mode
2070 ///
2071 /// The goal of this diagnostic is to emit an error or warning when an
2072 /// unavailable API is found in code that is reachable from the shader
2073 /// entry function or from an exported function (when compiling a shader
2074 /// library).
2075 ///
2076 /// This is done by traversing the AST of all shader entry point functions
2077 /// and of all exported functions, and any functions that are referenced
2078 /// from this AST. In other words, any functions that are reachable from
2079 /// the entry points.
2080 class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2081 Sema &SemaRef;
2082
2083 // Stack of functions to be scaned
2084 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2085
2086 // Tracks which environments functions have been scanned in.
2087 //
2088 // Maps FunctionDecl to an unsigned number that represents the set of shader
2089 // environments the function has been scanned for.
2090 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2091 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2092 // (verified by static_asserts in Triple.cpp), we can use it to index
2093 // individual bits in the set, as long as we shift the values to start with 0
2094 // by subtracting the value of llvm::Triple::Pixel first.
2095 //
2096 // The N'th bit in the set will be set if the function has been scanned
2097 // in shader environment whose llvm::Triple::EnvironmentType integer value
2098 // equals (llvm::Triple::Pixel + N).
2099 //
2100 // For example, if a function has been scanned in compute and pixel stage
2101 // environment, the value will be 0x21 (100001 binary) because:
2102 //
2103 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2104 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2105 //
2106 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2107 // been scanned in any environment.
2108 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2109
2110 // Do not access these directly, use the get/set methods below to make
2111 // sure the values are in sync
2112 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2113 unsigned CurrentShaderStageBit;
2114
2115 // True if scanning a function that was already scanned in a different
2116 // shader stage context, and therefore we should not report issues that
2117 // depend only on shader model version because they would be duplicate.
2118 bool ReportOnlyShaderStageIssues;
2119
2120 // Helper methods for dealing with current stage context / environment
SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType)2121 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2122 static_assert(sizeof(unsigned) >= 4);
2123 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2124 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2125 "ShaderType is too big for this bitmap"); // 31 is reserved for
2126 // "unknown"
2127
2128 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2129 CurrentShaderEnvironment = ShaderType;
2130 CurrentShaderStageBit = (1 << bitmapIndex);
2131 }
2132
SetUnknownShaderStageContext()2133 void SetUnknownShaderStageContext() {
2134 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2135 CurrentShaderStageBit = (1 << 31);
2136 }
2137
GetCurrentShaderEnvironment() const2138 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2139 return CurrentShaderEnvironment;
2140 }
2141
InUnknownShaderStageContext() const2142 bool InUnknownShaderStageContext() const {
2143 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2144 }
2145
2146 // Helper methods for dealing with shader stage bitmap
AddToScannedFunctions(const FunctionDecl * FD)2147 void AddToScannedFunctions(const FunctionDecl *FD) {
2148 unsigned &ScannedStages = ScannedDecls[FD];
2149 ScannedStages |= CurrentShaderStageBit;
2150 }
2151
GetScannedStages(const FunctionDecl * FD)2152 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2153
WasAlreadyScannedInCurrentStage(const FunctionDecl * FD)2154 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2155 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
2156 }
2157
WasAlreadyScannedInCurrentStage(unsigned ScannerStages)2158 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2159 return ScannerStages & CurrentShaderStageBit;
2160 }
2161
NeverBeenScanned(unsigned ScannedStages)2162 static bool NeverBeenScanned(unsigned ScannedStages) {
2163 return ScannedStages == 0;
2164 }
2165
2166 // Scanning methods
2167 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2168 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2169 SourceRange Range);
2170 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2171 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2172
2173 public:
DiagnoseHLSLAvailability(Sema & SemaRef)2174 DiagnoseHLSLAvailability(Sema &SemaRef)
2175 : SemaRef(SemaRef),
2176 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2177 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2178
2179 // AST traversal methods
2180 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2181 void RunOnFunction(const FunctionDecl *FD);
2182
VisitDeclRefExpr(DeclRefExpr * DRE)2183 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2184 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
2185 if (FD)
2186 HandleFunctionOrMethodRef(FD, DRE);
2187 return true;
2188 }
2189
VisitMemberExpr(MemberExpr * ME)2190 bool VisitMemberExpr(MemberExpr *ME) override {
2191 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
2192 if (FD)
2193 HandleFunctionOrMethodRef(FD, ME);
2194 return true;
2195 }
2196 };
2197
HandleFunctionOrMethodRef(FunctionDecl * FD,Expr * RefExpr)2198 void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2199 Expr *RefExpr) {
2200 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2201 "expected DeclRefExpr or MemberExpr");
2202
2203 // has a definition -> add to stack to be scanned
2204 const FunctionDecl *FDWithBody = nullptr;
2205 if (FD->hasBody(FDWithBody)) {
2206 if (!WasAlreadyScannedInCurrentStage(FDWithBody))
2207 DeclsToScan.push_back(FDWithBody);
2208 return;
2209 }
2210
2211 // no body -> diagnose availability
2212 const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
2213 if (AA)
2214 CheckDeclAvailability(
2215 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2216 }
2217
RunOnTranslationUnit(const TranslationUnitDecl * TU)2218 void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2219 const TranslationUnitDecl *TU) {
2220
2221 // Iterate over all shader entry functions and library exports, and for those
2222 // that have a body (definiton), run diag scan on each, setting appropriate
2223 // shader environment context based on whether it is a shader entry function
2224 // or an exported function. Exported functions can be in namespaces and in
2225 // export declarations so we need to scan those declaration contexts as well.
2226 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2227 DeclContextsToScan.push_back(TU);
2228
2229 while (!DeclContextsToScan.empty()) {
2230 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2231 for (auto &D : DC->decls()) {
2232 // do not scan implicit declaration generated by the implementation
2233 if (D->isImplicit())
2234 continue;
2235
2236 // for namespace or export declaration add the context to the list to be
2237 // scanned later
2238 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
2239 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
2240 continue;
2241 }
2242
2243 // skip over other decls or function decls without body
2244 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
2245 if (!FD || !FD->isThisDeclarationADefinition())
2246 continue;
2247
2248 // shader entry point
2249 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2250 SetShaderStageContext(ShaderAttr->getType());
2251 RunOnFunction(FD);
2252 continue;
2253 }
2254 // exported library function
2255 // FIXME: replace this loop with external linkage check once issue #92071
2256 // is resolved
2257 bool isExport = FD->isInExportDeclContext();
2258 if (!isExport) {
2259 for (const auto *Redecl : FD->redecls()) {
2260 if (Redecl->isInExportDeclContext()) {
2261 isExport = true;
2262 break;
2263 }
2264 }
2265 }
2266 if (isExport) {
2267 SetUnknownShaderStageContext();
2268 RunOnFunction(FD);
2269 continue;
2270 }
2271 }
2272 }
2273 }
2274
RunOnFunction(const FunctionDecl * FD)2275 void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2276 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2277 DeclsToScan.push_back(FD);
2278
2279 while (!DeclsToScan.empty()) {
2280 // Take one decl from the stack and check it by traversing its AST.
2281 // For any CallExpr found during the traversal add it's callee to the top of
2282 // the stack to be processed next. Functions already processed are stored in
2283 // ScannedDecls.
2284 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2285
2286 // Decl was already scanned
2287 const unsigned ScannedStages = GetScannedStages(FD);
2288 if (WasAlreadyScannedInCurrentStage(ScannedStages))
2289 continue;
2290
2291 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2292
2293 AddToScannedFunctions(FD);
2294 TraverseStmt(FD->getBody());
2295 }
2296 }
2297
HasMatchingEnvironmentOrNone(const AvailabilityAttr * AA)2298 bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2299 const AvailabilityAttr *AA) {
2300 IdentifierInfo *IIEnvironment = AA->getEnvironment();
2301 if (!IIEnvironment)
2302 return true;
2303
2304 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2305 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2306 return false;
2307
2308 llvm::Triple::EnvironmentType AttrEnv =
2309 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
2310
2311 return CurrentEnv == AttrEnv;
2312 }
2313
2314 const AvailabilityAttr *
FindAvailabilityAttr(const Decl * D)2315 DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2316 AvailabilityAttr const *PartialMatch = nullptr;
2317 // Check each AvailabilityAttr to find the one for this platform.
2318 // For multiple attributes with the same platform try to find one for this
2319 // environment.
2320 for (const auto *A : D->attrs()) {
2321 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
2322 StringRef AttrPlatform = Avail->getPlatform()->getName();
2323 StringRef TargetPlatform =
2324 SemaRef.getASTContext().getTargetInfo().getPlatformName();
2325
2326 // Match the platform name.
2327 if (AttrPlatform == TargetPlatform) {
2328 // Find the best matching attribute for this environment
2329 if (HasMatchingEnvironmentOrNone(Avail))
2330 return Avail;
2331 PartialMatch = Avail;
2332 }
2333 }
2334 }
2335 return PartialMatch;
2336 }
2337
2338 // Check availability against target shader model version and current shader
2339 // stage and emit diagnostic
CheckDeclAvailability(NamedDecl * D,const AvailabilityAttr * AA,SourceRange Range)2340 void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2341 const AvailabilityAttr *AA,
2342 SourceRange Range) {
2343
2344 IdentifierInfo *IIEnv = AA->getEnvironment();
2345
2346 if (!IIEnv) {
2347 // The availability attribute does not have environment -> it depends only
2348 // on shader model version and not on specific the shader stage.
2349
2350 // Skip emitting the diagnostics if the diagnostic mode is set to
2351 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2352 // were already emitted in the DiagnoseUnguardedAvailability scan
2353 // (SemaAvailability.cpp).
2354 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2355 return;
2356
2357 // Do not report shader-stage-independent issues if scanning a function
2358 // that was already scanned in a different shader stage context (they would
2359 // be duplicate)
2360 if (ReportOnlyShaderStageIssues)
2361 return;
2362
2363 } else {
2364 // The availability attribute has environment -> we need to know
2365 // the current stage context to property diagnose it.
2366 if (InUnknownShaderStageContext())
2367 return;
2368 }
2369
2370 // Check introduced version and if environment matches
2371 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2372 VersionTuple Introduced = AA->getIntroduced();
2373 VersionTuple TargetVersion =
2374 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
2375
2376 if (TargetVersion >= Introduced && EnvironmentMatches)
2377 return;
2378
2379 // Emit diagnostic message
2380 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2381 llvm::StringRef PlatformName(
2382 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
2383
2384 llvm::StringRef CurrentEnvStr =
2385 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
2386
2387 llvm::StringRef AttrEnvStr =
2388 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2389 bool UseEnvironment = !AttrEnvStr.empty();
2390
2391 if (EnvironmentMatches) {
2392 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
2393 << Range << D << PlatformName << Introduced.getAsString()
2394 << UseEnvironment << CurrentEnvStr;
2395 } else {
2396 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
2397 << Range << D;
2398 }
2399
2400 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
2401 << D << PlatformName << Introduced.getAsString()
2402 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2403 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2404 }
2405
2406 } // namespace
2407
ActOnEndOfTranslationUnit(TranslationUnitDecl * TU)2408 void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
2409 // process default CBuffer - create buffer layout struct and invoke codegenCGH
2410 if (!DefaultCBufferDecls.empty()) {
2411 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
2412 SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
2413 DefaultCBufferDecls);
2414 addImplicitBindingAttrToBuffer(SemaRef, DefaultCBuffer,
2415 getNextImplicitBindingOrderID());
2416 SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
2417 createHostLayoutStructForBuffer(SemaRef, DefaultCBuffer);
2418
2419 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2420 for (const Decl *VD : DefaultCBufferDecls) {
2421 const HLSLResourceBindingAttr *RBA =
2422 VD->getAttr<HLSLResourceBindingAttr>();
2423 if (RBA && RBA->hasRegisterSlot() &&
2424 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2425 DefaultCBuffer->setHasValidPackoffset(true);
2426 break;
2427 }
2428 }
2429
2430 DeclGroupRef DG(DefaultCBuffer);
2431 SemaRef.Consumer.HandleTopLevelDecl(DG);
2432 }
2433 diagnoseAvailabilityViolations(TU);
2434 }
2435
diagnoseAvailabilityViolations(TranslationUnitDecl * TU)2436 void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2437 // Skip running the diagnostics scan if the diagnostic mode is
2438 // strict (-fhlsl-strict-availability) and the target shader stage is known
2439 // because all relevant diagnostics were already emitted in the
2440 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2441 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2442 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2443 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2444 return;
2445
2446 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2447 }
2448
CheckAllArgsHaveSameType(Sema * S,CallExpr * TheCall)2449 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2450 assert(TheCall->getNumArgs() > 1);
2451 QualType ArgTy0 = TheCall->getArg(0)->getType();
2452
2453 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2454 if (!S->getASTContext().hasSameUnqualifiedType(
2455 ArgTy0, TheCall->getArg(I)->getType())) {
2456 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
2457 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2458 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
2459 TheCall->getArg(N - 1)->getEndLoc());
2460 return true;
2461 }
2462 }
2463 return false;
2464 }
2465
CheckArgTypeMatches(Sema * S,Expr * Arg,QualType ExpectedType)2466 static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
2467 QualType ArgType = Arg->getType();
2468 if (!S->getASTContext().hasSameUnqualifiedType(ArgType, ExpectedType)) {
2469 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
2470 << ArgType << ExpectedType << 1 << 0 << 0;
2471 return true;
2472 }
2473 return false;
2474 }
2475
CheckAllArgTypesAreCorrect(Sema * S,CallExpr * TheCall,llvm::function_ref<bool (Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)> Check)2476 static bool CheckAllArgTypesAreCorrect(
2477 Sema *S, CallExpr *TheCall,
2478 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
2479 clang::QualType PassedType)>
2480 Check) {
2481 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
2482 Expr *Arg = TheCall->getArg(I);
2483 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2484 return true;
2485 }
2486 return false;
2487 }
2488
CheckFloatOrHalfRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2489 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
2490 int ArgOrdinal,
2491 clang::QualType PassedType) {
2492 clang::QualType BaseType =
2493 PassedType->isVectorType()
2494 ? PassedType->castAs<clang::VectorType>()->getElementType()
2495 : PassedType;
2496 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
2497 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2498 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
2499 << /* half or float */ 2 << PassedType;
2500 return false;
2501 }
2502
CheckModifiableLValue(Sema * S,CallExpr * TheCall,unsigned ArgIndex)2503 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
2504 unsigned ArgIndex) {
2505 auto *Arg = TheCall->getArg(ArgIndex);
2506 SourceLocation OrigLoc = Arg->getExprLoc();
2507 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
2508 Expr::MLV_Valid)
2509 return false;
2510 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
2511 return true;
2512 }
2513
CheckNoDoubleVectors(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2514 static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
2515 clang::QualType PassedType) {
2516 const auto *VecTy = PassedType->getAs<VectorType>();
2517 if (!VecTy)
2518 return false;
2519
2520 if (VecTy->getElementType()->isDoubleType())
2521 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2522 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
2523 << PassedType;
2524 return false;
2525 }
2526
CheckFloatingOrIntRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2527 static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
2528 int ArgOrdinal,
2529 clang::QualType PassedType) {
2530 if (!PassedType->hasIntegerRepresentation() &&
2531 !PassedType->hasFloatingRepresentation())
2532 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2533 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
2534 << /* fp */ 1 << PassedType;
2535 return false;
2536 }
2537
CheckUnsignedIntVecRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2538 static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
2539 int ArgOrdinal,
2540 clang::QualType PassedType) {
2541 if (auto *VecTy = PassedType->getAs<VectorType>())
2542 if (VecTy->getElementType()->isUnsignedIntegerType())
2543 return false;
2544
2545 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2546 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
2547 << PassedType;
2548 }
2549
2550 // checks for unsigned ints of all sizes
CheckUnsignedIntRepresentation(Sema * S,SourceLocation Loc,int ArgOrdinal,clang::QualType PassedType)2551 static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
2552 int ArgOrdinal,
2553 clang::QualType PassedType) {
2554 if (!PassedType->hasUnsignedIntegerRepresentation())
2555 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2556 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
2557 << /* no fp */ 0 << PassedType;
2558 return false;
2559 }
2560
SetElementTypeAsReturnType(Sema * S,CallExpr * TheCall,QualType ReturnType)2561 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
2562 QualType ReturnType) {
2563 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
2564 if (VecTyA)
2565 ReturnType =
2566 S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements());
2567
2568 TheCall->setType(ReturnType);
2569 }
2570
CheckScalarOrVector(Sema * S,CallExpr * TheCall,QualType Scalar,unsigned ArgIndex)2571 static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
2572 unsigned ArgIndex) {
2573 assert(TheCall->getNumArgs() >= ArgIndex);
2574 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2575 auto *VTy = ArgType->getAs<VectorType>();
2576 // not the scalar or vector<scalar>
2577 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
2578 (VTy &&
2579 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
2580 S->Diag(TheCall->getArg(0)->getBeginLoc(),
2581 diag::err_typecheck_expect_scalar_or_vector)
2582 << ArgType << Scalar;
2583 return true;
2584 }
2585 return false;
2586 }
2587
CheckAnyScalarOrVector(Sema * S,CallExpr * TheCall,unsigned ArgIndex)2588 static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
2589 unsigned ArgIndex) {
2590 assert(TheCall->getNumArgs() >= ArgIndex);
2591 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2592 auto *VTy = ArgType->getAs<VectorType>();
2593 // not the scalar or vector<scalar>
2594 if (!(ArgType->isScalarType() ||
2595 (VTy && VTy->getElementType()->isScalarType()))) {
2596 S->Diag(TheCall->getArg(0)->getBeginLoc(),
2597 diag::err_typecheck_expect_any_scalar_or_vector)
2598 << ArgType << 1;
2599 return true;
2600 }
2601 return false;
2602 }
2603
CheckWaveActive(Sema * S,CallExpr * TheCall)2604 static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
2605 QualType BoolType = S->getASTContext().BoolTy;
2606 assert(TheCall->getNumArgs() >= 1);
2607 QualType ArgType = TheCall->getArg(0)->getType();
2608 auto *VTy = ArgType->getAs<VectorType>();
2609 // is the bool or vector<bool>
2610 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) ||
2611 (VTy &&
2612 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) {
2613 S->Diag(TheCall->getArg(0)->getBeginLoc(),
2614 diag::err_typecheck_expect_any_scalar_or_vector)
2615 << ArgType << 0;
2616 return true;
2617 }
2618 return false;
2619 }
2620
CheckBoolSelect(Sema * S,CallExpr * TheCall)2621 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
2622 assert(TheCall->getNumArgs() == 3);
2623 Expr *Arg1 = TheCall->getArg(1);
2624 Expr *Arg2 = TheCall->getArg(2);
2625 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
2626 S->Diag(TheCall->getBeginLoc(),
2627 diag::err_typecheck_call_different_arg_types)
2628 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2629 << Arg2->getSourceRange();
2630 return true;
2631 }
2632
2633 TheCall->setType(Arg1->getType());
2634 return false;
2635 }
2636
CheckVectorSelect(Sema * S,CallExpr * TheCall)2637 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
2638 assert(TheCall->getNumArgs() == 3);
2639 Expr *Arg1 = TheCall->getArg(1);
2640 QualType Arg1Ty = Arg1->getType();
2641 Expr *Arg2 = TheCall->getArg(2);
2642 QualType Arg2Ty = Arg2->getType();
2643
2644 QualType Arg1ScalarTy = Arg1Ty;
2645 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
2646 Arg1ScalarTy = VTy->getElementType();
2647
2648 QualType Arg2ScalarTy = Arg2Ty;
2649 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
2650 Arg2ScalarTy = VTy->getElementType();
2651
2652 if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
2653 S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
2654 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
2655
2656 QualType Arg0Ty = TheCall->getArg(0)->getType();
2657 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
2658 unsigned Arg1Length = Arg1Ty->isVectorType()
2659 ? Arg1Ty->getAs<VectorType>()->getNumElements()
2660 : 0;
2661 unsigned Arg2Length = Arg2Ty->isVectorType()
2662 ? Arg2Ty->getAs<VectorType>()->getNumElements()
2663 : 0;
2664 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
2665 S->Diag(TheCall->getBeginLoc(),
2666 diag::err_typecheck_vector_lengths_not_equal)
2667 << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
2668 << Arg1->getSourceRange();
2669 return true;
2670 }
2671
2672 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
2673 S->Diag(TheCall->getBeginLoc(),
2674 diag::err_typecheck_vector_lengths_not_equal)
2675 << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
2676 << Arg2->getSourceRange();
2677 return true;
2678 }
2679
2680 TheCall->setType(
2681 S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
2682 return false;
2683 }
2684
CheckResourceHandle(Sema * S,CallExpr * TheCall,unsigned ArgIndex,llvm::function_ref<bool (const HLSLAttributedResourceType * ResType)> Check=nullptr)2685 static bool CheckResourceHandle(
2686 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
2687 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
2688 nullptr) {
2689 assert(TheCall->getNumArgs() >= ArgIndex);
2690 QualType ArgType = TheCall->getArg(ArgIndex)->getType();
2691 const HLSLAttributedResourceType *ResTy =
2692 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
2693 if (!ResTy) {
2694 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
2695 diag::err_typecheck_expect_hlsl_resource)
2696 << ArgType;
2697 return true;
2698 }
2699 if (Check && Check(ResTy)) {
2700 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(),
2701 diag::err_invalid_hlsl_resource_type)
2702 << ArgType;
2703 return true;
2704 }
2705 return false;
2706 }
2707
2708 // Note: returning true in this case results in CheckBuiltinFunctionCall
2709 // returning an ExprError
CheckBuiltinFunctionCall(unsigned BuiltinID,CallExpr * TheCall)2710 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
2711 switch (BuiltinID) {
2712 case Builtin::BI__builtin_hlsl_adduint64: {
2713 if (SemaRef.checkArgCount(TheCall, 2))
2714 return true;
2715
2716 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2717 CheckUnsignedIntVecRepresentation))
2718 return true;
2719
2720 auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
2721 // ensure arg integers are 32-bits
2722 uint64_t ElementBitCount = getASTContext()
2723 .getTypeSizeInChars(VTy->getElementType())
2724 .getQuantity() *
2725 8;
2726 if (ElementBitCount != 32) {
2727 SemaRef.Diag(TheCall->getBeginLoc(),
2728 diag::err_integer_incorrect_bit_count)
2729 << 32 << ElementBitCount;
2730 return true;
2731 }
2732
2733 // ensure both args are vectors of total bit size of a multiple of 64
2734 int NumElementsArg = VTy->getNumElements();
2735 if (NumElementsArg != 2 && NumElementsArg != 4) {
2736 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count)
2737 << 1 /*a multiple of*/ << 64 << NumElementsArg * ElementBitCount;
2738 return true;
2739 }
2740
2741 // ensure first arg and second arg have the same type
2742 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2743 return true;
2744
2745 ExprResult A = TheCall->getArg(0);
2746 QualType ArgTyA = A.get()->getType();
2747 // return type is the same as the input type
2748 TheCall->setType(ArgTyA);
2749 break;
2750 }
2751 case Builtin::BI__builtin_hlsl_resource_getpointer: {
2752 if (SemaRef.checkArgCount(TheCall, 2) ||
2753 CheckResourceHandle(&SemaRef, TheCall, 0) ||
2754 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
2755 SemaRef.getASTContext().UnsignedIntTy))
2756 return true;
2757
2758 auto *ResourceTy =
2759 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
2760 QualType ContainedTy = ResourceTy->getContainedType();
2761 auto ReturnType =
2762 SemaRef.Context.getAddrSpaceQualType(ContainedTy, LangAS::hlsl_device);
2763 ReturnType = SemaRef.Context.getPointerType(ReturnType);
2764 TheCall->setType(ReturnType);
2765 TheCall->setValueKind(VK_LValue);
2766
2767 break;
2768 }
2769 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
2770 if (SemaRef.checkArgCount(TheCall, 1) ||
2771 CheckResourceHandle(&SemaRef, TheCall, 0))
2772 return true;
2773 // use the type of the handle (arg0) as a return type
2774 QualType ResourceTy = TheCall->getArg(0)->getType();
2775 TheCall->setType(ResourceTy);
2776 break;
2777 }
2778 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
2779 ASTContext &AST = SemaRef.getASTContext();
2780 if (SemaRef.checkArgCount(TheCall, 6) ||
2781 CheckResourceHandle(&SemaRef, TheCall, 0) ||
2782 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) ||
2783 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.UnsignedIntTy) ||
2784 CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.IntTy) ||
2785 CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) ||
2786 CheckArgTypeMatches(&SemaRef, TheCall->getArg(5),
2787 AST.getPointerType(AST.CharTy.withConst())))
2788 return true;
2789 // use the type of the handle (arg0) as a return type
2790 QualType ResourceTy = TheCall->getArg(0)->getType();
2791 TheCall->setType(ResourceTy);
2792 break;
2793 }
2794 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
2795 ASTContext &AST = SemaRef.getASTContext();
2796 if (SemaRef.checkArgCount(TheCall, 6) ||
2797 CheckResourceHandle(&SemaRef, TheCall, 0) ||
2798 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) ||
2799 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.IntTy) ||
2800 CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.UnsignedIntTy) ||
2801 CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) ||
2802 CheckArgTypeMatches(&SemaRef, TheCall->getArg(5),
2803 AST.getPointerType(AST.CharTy.withConst())))
2804 return true;
2805 // use the type of the handle (arg0) as a return type
2806 QualType ResourceTy = TheCall->getArg(0)->getType();
2807 TheCall->setType(ResourceTy);
2808 break;
2809 }
2810 case Builtin::BI__builtin_hlsl_and:
2811 case Builtin::BI__builtin_hlsl_or: {
2812 if (SemaRef.checkArgCount(TheCall, 2))
2813 return true;
2814 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
2815 return true;
2816 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2817 return true;
2818
2819 ExprResult A = TheCall->getArg(0);
2820 QualType ArgTyA = A.get()->getType();
2821 // return type is the same as the input type
2822 TheCall->setType(ArgTyA);
2823 break;
2824 }
2825 case Builtin::BI__builtin_hlsl_all:
2826 case Builtin::BI__builtin_hlsl_any: {
2827 if (SemaRef.checkArgCount(TheCall, 1))
2828 return true;
2829 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
2830 return true;
2831 break;
2832 }
2833 case Builtin::BI__builtin_hlsl_asdouble: {
2834 if (SemaRef.checkArgCount(TheCall, 2))
2835 return true;
2836 if (CheckScalarOrVector(
2837 &SemaRef, TheCall,
2838 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
2839 /* arg index */ 0))
2840 return true;
2841 if (CheckScalarOrVector(
2842 &SemaRef, TheCall,
2843 /*only check for uint*/ SemaRef.Context.UnsignedIntTy,
2844 /* arg index */ 1))
2845 return true;
2846 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2847 return true;
2848
2849 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
2850 break;
2851 }
2852 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
2853 if (SemaRef.BuiltinElementwiseTernaryMath(
2854 TheCall, /*ArgTyRestr=*/
2855 Sema::EltwiseBuiltinArgTyRestriction::None))
2856 return true;
2857 break;
2858 }
2859 case Builtin::BI__builtin_hlsl_dot: {
2860 // arg count is checked by BuiltinVectorToScalarMath
2861 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
2862 return true;
2863 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors))
2864 return true;
2865 break;
2866 }
2867 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
2868 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
2869 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2870 return true;
2871
2872 const Expr *Arg = TheCall->getArg(0);
2873 QualType ArgTy = Arg->getType();
2874 QualType EltTy = ArgTy;
2875
2876 QualType ResTy = SemaRef.Context.UnsignedIntTy;
2877
2878 if (auto *VecTy = EltTy->getAs<VectorType>()) {
2879 EltTy = VecTy->getElementType();
2880 ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements());
2881 }
2882
2883 if (!EltTy->isIntegerType()) {
2884 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2885 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
2886 << /* no fp */ 0 << ArgTy;
2887 return true;
2888 }
2889
2890 TheCall->setType(ResTy);
2891 break;
2892 }
2893 case Builtin::BI__builtin_hlsl_select: {
2894 if (SemaRef.checkArgCount(TheCall, 3))
2895 return true;
2896 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
2897 return true;
2898 QualType ArgTy = TheCall->getArg(0)->getType();
2899 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
2900 return true;
2901 auto *VTy = ArgTy->getAs<VectorType>();
2902 if (VTy && VTy->getElementType()->isBooleanType() &&
2903 CheckVectorSelect(&SemaRef, TheCall))
2904 return true;
2905 break;
2906 }
2907 case Builtin::BI__builtin_hlsl_elementwise_saturate:
2908 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
2909 if (SemaRef.checkArgCount(TheCall, 1))
2910 return true;
2911 if (!TheCall->getArg(0)
2912 ->getType()
2913 ->hasFloatingRepresentation()) // half or float or double
2914 return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
2915 diag::err_builtin_invalid_arg_type)
2916 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
2917 << /* fp */ 1 << TheCall->getArg(0)->getType();
2918 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2919 return true;
2920 break;
2921 }
2922 case Builtin::BI__builtin_hlsl_elementwise_degrees:
2923 case Builtin::BI__builtin_hlsl_elementwise_radians:
2924 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
2925 case Builtin::BI__builtin_hlsl_elementwise_frac: {
2926 if (SemaRef.checkArgCount(TheCall, 1))
2927 return true;
2928 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2929 CheckFloatOrHalfRepresentation))
2930 return true;
2931 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2932 return true;
2933 break;
2934 }
2935 case Builtin::BI__builtin_hlsl_elementwise_isinf: {
2936 if (SemaRef.checkArgCount(TheCall, 1))
2937 return true;
2938 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2939 CheckFloatOrHalfRepresentation))
2940 return true;
2941 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2942 return true;
2943 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().BoolTy);
2944 break;
2945 }
2946 case Builtin::BI__builtin_hlsl_lerp: {
2947 if (SemaRef.checkArgCount(TheCall, 3))
2948 return true;
2949 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2950 CheckFloatOrHalfRepresentation))
2951 return true;
2952 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
2953 return true;
2954 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
2955 return true;
2956 break;
2957 }
2958 case Builtin::BI__builtin_hlsl_mad: {
2959 if (SemaRef.BuiltinElementwiseTernaryMath(
2960 TheCall, /*ArgTyRestr=*/
2961 Sema::EltwiseBuiltinArgTyRestriction::None))
2962 return true;
2963 break;
2964 }
2965 case Builtin::BI__builtin_hlsl_normalize: {
2966 if (SemaRef.checkArgCount(TheCall, 1))
2967 return true;
2968 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2969 CheckFloatOrHalfRepresentation))
2970 return true;
2971 ExprResult A = TheCall->getArg(0);
2972 QualType ArgTyA = A.get()->getType();
2973 // return type is the same as the input type
2974 TheCall->setType(ArgTyA);
2975 break;
2976 }
2977 case Builtin::BI__builtin_hlsl_elementwise_sign: {
2978 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2979 return true;
2980 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2981 CheckFloatingOrIntRepresentation))
2982 return true;
2983 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
2984 break;
2985 }
2986 case Builtin::BI__builtin_hlsl_step: {
2987 if (SemaRef.checkArgCount(TheCall, 2))
2988 return true;
2989 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
2990 CheckFloatOrHalfRepresentation))
2991 return true;
2992
2993 ExprResult A = TheCall->getArg(0);
2994 QualType ArgTyA = A.get()->getType();
2995 // return type is the same as the input type
2996 TheCall->setType(ArgTyA);
2997 break;
2998 }
2999 case Builtin::BI__builtin_hlsl_wave_active_max:
3000 case Builtin::BI__builtin_hlsl_wave_active_sum: {
3001 if (SemaRef.checkArgCount(TheCall, 1))
3002 return true;
3003
3004 // Ensure input expr type is a scalar/vector and the same as the return type
3005 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3006 return true;
3007 if (CheckWaveActive(&SemaRef, TheCall))
3008 return true;
3009 ExprResult Expr = TheCall->getArg(0);
3010 QualType ArgTyExpr = Expr.get()->getType();
3011 TheCall->setType(ArgTyExpr);
3012 break;
3013 }
3014 // Note these are llvm builtins that we want to catch invalid intrinsic
3015 // generation. Normal handling of these builitns will occur elsewhere.
3016 case Builtin::BI__builtin_elementwise_bitreverse: {
3017 // does not include a check for number of arguments
3018 // because that is done previously
3019 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3020 CheckUnsignedIntRepresentation))
3021 return true;
3022 break;
3023 }
3024 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
3025 if (SemaRef.checkArgCount(TheCall, 2))
3026 return true;
3027
3028 // Ensure index parameter type can be interpreted as a uint
3029 ExprResult Index = TheCall->getArg(1);
3030 QualType ArgTyIndex = Index.get()->getType();
3031 if (!ArgTyIndex->isIntegerType()) {
3032 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3033 diag::err_typecheck_convert_incompatible)
3034 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3035 return true;
3036 }
3037
3038 // Ensure input expr type is a scalar/vector and the same as the return type
3039 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
3040 return true;
3041
3042 ExprResult Expr = TheCall->getArg(0);
3043 QualType ArgTyExpr = Expr.get()->getType();
3044 TheCall->setType(ArgTyExpr);
3045 break;
3046 }
3047 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
3048 if (SemaRef.checkArgCount(TheCall, 0))
3049 return true;
3050 break;
3051 }
3052 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
3053 if (SemaRef.checkArgCount(TheCall, 3))
3054 return true;
3055
3056 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
3057 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3058 1) ||
3059 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
3060 2))
3061 return true;
3062
3063 if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
3064 CheckModifiableLValue(&SemaRef, TheCall, 2))
3065 return true;
3066 break;
3067 }
3068 case Builtin::BI__builtin_hlsl_elementwise_clip: {
3069 if (SemaRef.checkArgCount(TheCall, 1))
3070 return true;
3071
3072 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0))
3073 return true;
3074 break;
3075 }
3076 case Builtin::BI__builtin_elementwise_acos:
3077 case Builtin::BI__builtin_elementwise_asin:
3078 case Builtin::BI__builtin_elementwise_atan:
3079 case Builtin::BI__builtin_elementwise_atan2:
3080 case Builtin::BI__builtin_elementwise_ceil:
3081 case Builtin::BI__builtin_elementwise_cos:
3082 case Builtin::BI__builtin_elementwise_cosh:
3083 case Builtin::BI__builtin_elementwise_exp:
3084 case Builtin::BI__builtin_elementwise_exp2:
3085 case Builtin::BI__builtin_elementwise_exp10:
3086 case Builtin::BI__builtin_elementwise_floor:
3087 case Builtin::BI__builtin_elementwise_fmod:
3088 case Builtin::BI__builtin_elementwise_log:
3089 case Builtin::BI__builtin_elementwise_log2:
3090 case Builtin::BI__builtin_elementwise_log10:
3091 case Builtin::BI__builtin_elementwise_pow:
3092 case Builtin::BI__builtin_elementwise_roundeven:
3093 case Builtin::BI__builtin_elementwise_sin:
3094 case Builtin::BI__builtin_elementwise_sinh:
3095 case Builtin::BI__builtin_elementwise_sqrt:
3096 case Builtin::BI__builtin_elementwise_tan:
3097 case Builtin::BI__builtin_elementwise_tanh:
3098 case Builtin::BI__builtin_elementwise_trunc: {
3099 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
3100 CheckFloatOrHalfRepresentation))
3101 return true;
3102 break;
3103 }
3104 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3105 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3106 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3107 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3108 };
3109 if (SemaRef.checkArgCount(TheCall, 2) ||
3110 CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy) ||
3111 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
3112 SemaRef.getASTContext().IntTy))
3113 return true;
3114 Expr *OffsetExpr = TheCall->getArg(1);
3115 std::optional<llvm::APSInt> Offset =
3116 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
3117 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) {
3118 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
3119 diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3120 << 1;
3121 return true;
3122 }
3123 break;
3124 }
3125 }
3126 return false;
3127 }
3128
BuildFlattenedTypeList(QualType BaseTy,llvm::SmallVectorImpl<QualType> & List)3129 static void BuildFlattenedTypeList(QualType BaseTy,
3130 llvm::SmallVectorImpl<QualType> &List) {
3131 llvm::SmallVector<QualType, 16> WorkList;
3132 WorkList.push_back(BaseTy);
3133 while (!WorkList.empty()) {
3134 QualType T = WorkList.pop_back_val();
3135 T = T.getCanonicalType().getUnqualifiedType();
3136 assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
3137 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
3138 llvm::SmallVector<QualType, 16> ElementFields;
3139 // Generally I've avoided recursion in this algorithm, but arrays of
3140 // structs could be time-consuming to flatten and churn through on the
3141 // work list. Hopefully nesting arrays of structs containing arrays
3142 // of structs too many levels deep is unlikely.
3143 BuildFlattenedTypeList(AT->getElementType(), ElementFields);
3144 // Repeat the element's field list n times.
3145 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
3146 llvm::append_range(List, ElementFields);
3147 continue;
3148 }
3149 // Vectors can only have element types that are builtin types, so this can
3150 // add directly to the list instead of to the WorkList.
3151 if (const auto *VT = dyn_cast<VectorType>(T)) {
3152 List.insert(List.end(), VT->getNumElements(), VT->getElementType());
3153 continue;
3154 }
3155 if (const auto *RT = dyn_cast<RecordType>(T)) {
3156 const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
3157 assert(RD && "HLSL record types should all be CXXRecordDecls!");
3158
3159 if (RD->isStandardLayout())
3160 RD = RD->getStandardLayoutBaseWithFields();
3161
3162 // For types that we shouldn't decompose (unions and non-aggregates), just
3163 // add the type itself to the list.
3164 if (RD->isUnion() || !RD->isAggregate()) {
3165 List.push_back(T);
3166 continue;
3167 }
3168
3169 llvm::SmallVector<QualType, 16> FieldTypes;
3170 for (const auto *FD : RD->fields())
3171 FieldTypes.push_back(FD->getType());
3172 // Reverse the newly added sub-range.
3173 std::reverse(FieldTypes.begin(), FieldTypes.end());
3174 llvm::append_range(WorkList, FieldTypes);
3175
3176 // If this wasn't a standard layout type we may also have some base
3177 // classes to deal with.
3178 if (!RD->isStandardLayout()) {
3179 FieldTypes.clear();
3180 for (const auto &Base : RD->bases())
3181 FieldTypes.push_back(Base.getType());
3182 std::reverse(FieldTypes.begin(), FieldTypes.end());
3183 llvm::append_range(WorkList, FieldTypes);
3184 }
3185 continue;
3186 }
3187 List.push_back(T);
3188 }
3189 }
3190
IsTypedResourceElementCompatible(clang::QualType QT)3191 bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
3192 // null and array types are not allowed.
3193 if (QT.isNull() || QT->isArrayType())
3194 return false;
3195
3196 // UDT types are not allowed
3197 if (QT->isRecordType())
3198 return false;
3199
3200 if (QT->isBooleanType() || QT->isEnumeralType())
3201 return false;
3202
3203 // the only other valid builtin types are scalars or vectors
3204 if (QT->isArithmeticType()) {
3205 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
3206 return false;
3207 return true;
3208 }
3209
3210 if (const VectorType *VT = QT->getAs<VectorType>()) {
3211 int ArraySize = VT->getNumElements();
3212
3213 if (ArraySize > 4)
3214 return false;
3215
3216 QualType ElTy = VT->getElementType();
3217 if (ElTy->isBooleanType())
3218 return false;
3219
3220 if (SemaRef.Context.getTypeSize(QT) / 8 > 16)
3221 return false;
3222 return true;
3223 }
3224
3225 return false;
3226 }
3227
IsScalarizedLayoutCompatible(QualType T1,QualType T2) const3228 bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
3229 if (T1.isNull() || T2.isNull())
3230 return false;
3231
3232 T1 = T1.getCanonicalType().getUnqualifiedType();
3233 T2 = T2.getCanonicalType().getUnqualifiedType();
3234
3235 // If both types are the same canonical type, they're obviously compatible.
3236 if (SemaRef.getASTContext().hasSameType(T1, T2))
3237 return true;
3238
3239 llvm::SmallVector<QualType, 16> T1Types;
3240 BuildFlattenedTypeList(T1, T1Types);
3241 llvm::SmallVector<QualType, 16> T2Types;
3242 BuildFlattenedTypeList(T2, T2Types);
3243
3244 // Check the flattened type list
3245 return llvm::equal(T1Types, T2Types,
3246 [this](QualType LHS, QualType RHS) -> bool {
3247 return SemaRef.IsLayoutCompatible(LHS, RHS);
3248 });
3249 }
3250
CheckCompatibleParameterABI(FunctionDecl * New,FunctionDecl * Old)3251 bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
3252 FunctionDecl *Old) {
3253 if (New->getNumParams() != Old->getNumParams())
3254 return true;
3255
3256 bool HadError = false;
3257
3258 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
3259 ParmVarDecl *NewParam = New->getParamDecl(i);
3260 ParmVarDecl *OldParam = Old->getParamDecl(i);
3261
3262 // HLSL parameter declarations for inout and out must match between
3263 // declarations. In HLSL inout and out are ambiguous at the call site,
3264 // but have different calling behavior, so you cannot overload a
3265 // method based on a difference between inout and out annotations.
3266 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
3267 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
3268 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
3269 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
3270
3271 if (NSpellingIdx != OSpellingIdx) {
3272 SemaRef.Diag(NewParam->getLocation(),
3273 diag::err_hlsl_param_qualifier_mismatch)
3274 << NDAttr << NewParam;
3275 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as)
3276 << ODAttr;
3277 HadError = true;
3278 }
3279 }
3280 return HadError;
3281 }
3282
3283 // Generally follows PerformScalarCast, with cases reordered for
3284 // clarity of what types are supported
CanPerformScalarCast(QualType SrcTy,QualType DestTy)3285 bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
3286
3287 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
3288 return false;
3289
3290 if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
3291 return true;
3292
3293 switch (SrcTy->getScalarTypeKind()) {
3294 case Type::STK_Bool: // casting from bool is like casting from an integer
3295 case Type::STK_Integral:
3296 switch (DestTy->getScalarTypeKind()) {
3297 case Type::STK_Bool:
3298 case Type::STK_Integral:
3299 case Type::STK_Floating:
3300 return true;
3301 case Type::STK_CPointer:
3302 case Type::STK_ObjCObjectPointer:
3303 case Type::STK_BlockPointer:
3304 case Type::STK_MemberPointer:
3305 llvm_unreachable("HLSL doesn't support pointers.");
3306 case Type::STK_IntegralComplex:
3307 case Type::STK_FloatingComplex:
3308 llvm_unreachable("HLSL doesn't support complex types.");
3309 case Type::STK_FixedPoint:
3310 llvm_unreachable("HLSL doesn't support fixed point types.");
3311 }
3312 llvm_unreachable("Should have returned before this");
3313
3314 case Type::STK_Floating:
3315 switch (DestTy->getScalarTypeKind()) {
3316 case Type::STK_Floating:
3317 case Type::STK_Bool:
3318 case Type::STK_Integral:
3319 return true;
3320 case Type::STK_FloatingComplex:
3321 case Type::STK_IntegralComplex:
3322 llvm_unreachable("HLSL doesn't support complex types.");
3323 case Type::STK_FixedPoint:
3324 llvm_unreachable("HLSL doesn't support fixed point types.");
3325 case Type::STK_CPointer:
3326 case Type::STK_ObjCObjectPointer:
3327 case Type::STK_BlockPointer:
3328 case Type::STK_MemberPointer:
3329 llvm_unreachable("HLSL doesn't support pointers.");
3330 }
3331 llvm_unreachable("Should have returned before this");
3332
3333 case Type::STK_MemberPointer:
3334 case Type::STK_CPointer:
3335 case Type::STK_BlockPointer:
3336 case Type::STK_ObjCObjectPointer:
3337 llvm_unreachable("HLSL doesn't support pointers.");
3338
3339 case Type::STK_FixedPoint:
3340 llvm_unreachable("HLSL doesn't support fixed point types.");
3341
3342 case Type::STK_FloatingComplex:
3343 case Type::STK_IntegralComplex:
3344 llvm_unreachable("HLSL doesn't support complex types.");
3345 }
3346
3347 llvm_unreachable("Unhandled scalar cast");
3348 }
3349
3350 // Detect if a type contains a bitfield. Will be removed when
3351 // bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
ContainsBitField(QualType BaseTy)3352 bool SemaHLSL::ContainsBitField(QualType BaseTy) {
3353 llvm::SmallVector<QualType, 16> WorkList;
3354 WorkList.push_back(BaseTy);
3355 while (!WorkList.empty()) {
3356 QualType T = WorkList.pop_back_val();
3357 T = T.getCanonicalType().getUnqualifiedType();
3358 // only check aggregate types
3359 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
3360 WorkList.push_back(AT->getElementType());
3361 continue;
3362 }
3363 if (const auto *RT = dyn_cast<RecordType>(T)) {
3364 const RecordDecl *RD = RT->getDecl();
3365 if (RD->isUnion())
3366 continue;
3367
3368 const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(RD);
3369
3370 if (CXXD && CXXD->isStandardLayout())
3371 RD = CXXD->getStandardLayoutBaseWithFields();
3372
3373 for (const auto *FD : RD->fields()) {
3374 if (FD->isBitField())
3375 return true;
3376 WorkList.push_back(FD->getType());
3377 }
3378 continue;
3379 }
3380 }
3381 return false;
3382 }
3383
3384 // Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
3385 // Src is a scalar or a vector of length 1
3386 // Or if Dest is a vector and Src is a vector of length 1
CanPerformAggregateSplatCast(Expr * Src,QualType DestTy)3387 bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
3388
3389 QualType SrcTy = Src->getType();
3390 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
3391 // going to be a vector splat from a scalar.
3392 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
3393 DestTy->isScalarType())
3394 return false;
3395
3396 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
3397
3398 // Src isn't a scalar or a vector of length 1
3399 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
3400 return false;
3401
3402 if (SrcVecTy)
3403 SrcTy = SrcVecTy->getElementType();
3404
3405 if (ContainsBitField(DestTy))
3406 return false;
3407
3408 llvm::SmallVector<QualType> DestTypes;
3409 BuildFlattenedTypeList(DestTy, DestTypes);
3410
3411 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
3412 if (DestTypes[I]->isUnionType())
3413 return false;
3414 if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
3415 return false;
3416 }
3417 return true;
3418 }
3419
3420 // Can we perform an HLSL Elementwise cast?
3421 // TODO: update this code when matrices are added; see issue #88060
CanPerformElementwiseCast(Expr * Src,QualType DestTy)3422 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
3423
3424 // Don't handle casts where LHS and RHS are any combination of scalar/vector
3425 // There must be an aggregate somewhere
3426 QualType SrcTy = Src->getType();
3427 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
3428 return false;
3429
3430 if (SrcTy->isVectorType() &&
3431 (DestTy->isScalarType() || DestTy->isVectorType()))
3432 return false;
3433
3434 if (ContainsBitField(DestTy) || ContainsBitField(SrcTy))
3435 return false;
3436
3437 llvm::SmallVector<QualType> DestTypes;
3438 BuildFlattenedTypeList(DestTy, DestTypes);
3439 llvm::SmallVector<QualType> SrcTypes;
3440 BuildFlattenedTypeList(SrcTy, SrcTypes);
3441
3442 // Usually the size of SrcTypes must be greater than or equal to the size of
3443 // DestTypes.
3444 if (SrcTypes.size() < DestTypes.size())
3445 return false;
3446
3447 unsigned SrcSize = SrcTypes.size();
3448 unsigned DstSize = DestTypes.size();
3449 unsigned I;
3450 for (I = 0; I < DstSize && I < SrcSize; I++) {
3451 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
3452 return false;
3453 if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
3454 return false;
3455 }
3456 }
3457
3458 // check the rest of the source type for unions.
3459 for (; I < SrcSize; I++) {
3460 if (SrcTypes[I]->isUnionType())
3461 return false;
3462 }
3463 return true;
3464 }
3465
ActOnOutParamExpr(ParmVarDecl * Param,Expr * Arg)3466 ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
3467 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
3468 "We should not get here without a parameter modifier expression");
3469 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
3470 if (Attr->getABI() == ParameterABI::Ordinary)
3471 return ExprResult(Arg);
3472
3473 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
3474 if (!Arg->isLValue()) {
3475 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue)
3476 << Arg << (IsInOut ? 1 : 0);
3477 return ExprError();
3478 }
3479
3480 ASTContext &Ctx = SemaRef.getASTContext();
3481
3482 QualType Ty = Param->getType().getNonLValueExprType(Ctx);
3483
3484 // HLSL allows implicit conversions from scalars to vectors, but not the
3485 // inverse, so we need to disallow `inout` with scalar->vector or
3486 // scalar->matrix conversions.
3487 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
3488 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension)
3489 << Arg << (IsInOut ? 1 : 0);
3490 return ExprError();
3491 }
3492
3493 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
3494 VK_LValue, OK_Ordinary, Arg);
3495
3496 // Parameters are initialized via copy initialization. This allows for
3497 // overload resolution of argument constructors.
3498 InitializedEntity Entity =
3499 InitializedEntity::InitializeParameter(Ctx, Ty, false);
3500 ExprResult Res =
3501 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV);
3502 if (Res.isInvalid())
3503 return ExprError();
3504 Expr *Base = Res.get();
3505 // After the cast, drop the reference type when creating the exprs.
3506 Ty = Ty.getNonLValueExprType(Ctx);
3507 auto *OpV = new (Ctx)
3508 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
3509
3510 // Writebacks are performed with `=` binary operator, which allows for
3511 // overload resolution on writeback result expressions.
3512 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(),
3513 tok::equal, ArgOpV, OpV);
3514
3515 if (Res.isInvalid())
3516 return ExprError();
3517 Expr *Writeback = Res.get();
3518 auto *OutExpr =
3519 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut);
3520
3521 return ExprResult(OutExpr);
3522 }
3523
getInoutParameterType(QualType Ty)3524 QualType SemaHLSL::getInoutParameterType(QualType Ty) {
3525 // If HLSL gains support for references, all the cites that use this will need
3526 // to be updated with semantic checking to produce errors for
3527 // pointers/references.
3528 assert(!Ty->isReferenceType() &&
3529 "Pointer and reference types cannot be inout or out parameters");
3530 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty);
3531 Ty.addRestrict();
3532 return Ty;
3533 }
3534
IsDefaultBufferConstantDecl(VarDecl * VD)3535 static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
3536 QualType QT = VD->getType();
3537 return VD->getDeclContext()->isTranslationUnit() &&
3538 QT.getAddressSpace() == LangAS::Default &&
3539 VD->getStorageClass() != SC_Static &&
3540 !VD->hasAttr<HLSLVkConstantIdAttr>() &&
3541 !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
3542 }
3543
deduceAddressSpace(VarDecl * Decl)3544 void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
3545 // The variable already has an address space (groupshared for ex).
3546 if (Decl->getType().hasAddressSpace())
3547 return;
3548
3549 if (Decl->getType()->isDependentType())
3550 return;
3551
3552 QualType Type = Decl->getType();
3553
3554 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
3555 LangAS ImplAS = LangAS::hlsl_input;
3556 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
3557 Decl->setType(Type);
3558 return;
3559 }
3560
3561 if (Type->isSamplerT() || Type->isVoidType())
3562 return;
3563
3564 // Resource handles.
3565 if (isResourceRecordTypeOrArrayOf(Type->getUnqualifiedDesugaredType()))
3566 return;
3567
3568 // Only static globals belong to the Private address space.
3569 // Non-static globals belongs to the cbuffer.
3570 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
3571 return;
3572
3573 LangAS ImplAS = LangAS::hlsl_private;
3574 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS);
3575 Decl->setType(Type);
3576 }
3577
ActOnVariableDeclarator(VarDecl * VD)3578 void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
3579 if (VD->hasGlobalStorage()) {
3580 // make sure the declaration has a complete type
3581 if (SemaRef.RequireCompleteType(
3582 VD->getLocation(),
3583 SemaRef.getASTContext().getBaseElementType(VD->getType()),
3584 diag::err_typecheck_decl_incomplete_type)) {
3585 VD->setInvalidDecl();
3586 deduceAddressSpace(VD);
3587 return;
3588 }
3589
3590 // Global variables outside a cbuffer block that are not a resource, static,
3591 // groupshared, or an empty array or struct belong to the default constant
3592 // buffer $Globals (to be created at the end of the translation unit).
3593 if (IsDefaultBufferConstantDecl(VD)) {
3594 // update address space to hlsl_constant
3595 QualType NewTy = getASTContext().getAddrSpaceQualType(
3596 VD->getType(), LangAS::hlsl_constant);
3597 VD->setType(NewTy);
3598 DefaultCBufferDecls.push_back(VD);
3599 }
3600
3601 // find all resources bindings on decl
3602 if (VD->getType()->isHLSLIntangibleType())
3603 collectResourceBindingsOnVarDecl(VD);
3604
3605 const Type *VarType = VD->getType().getTypePtr();
3606 while (VarType->isArrayType())
3607 VarType = VarType->getArrayElementTypeNoTypeQual();
3608 if (VarType->isHLSLResourceRecord() ||
3609 VD->hasAttr<HLSLVkConstantIdAttr>()) {
3610 // Make the variable for resources static. The global externally visible
3611 // storage is accessed through the handle, which is a member. The variable
3612 // itself is not externally visible.
3613 VD->setStorageClass(StorageClass::SC_Static);
3614 }
3615
3616 // process explicit bindings
3617 processExplicitBindingsOnDecl(VD);
3618 }
3619
3620 deduceAddressSpace(VD);
3621 }
3622
initVarDeclWithCtor(Sema & S,VarDecl * VD,MutableArrayRef<Expr * > Args)3623 static bool initVarDeclWithCtor(Sema &S, VarDecl *VD,
3624 MutableArrayRef<Expr *> Args) {
3625 InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
3626 InitializationKind Kind = InitializationKind::CreateDirect(
3627 VD->getLocation(), SourceLocation(), SourceLocation());
3628
3629 InitializationSequence InitSeq(S, Entity, Kind, Args);
3630 if (InitSeq.Failed())
3631 return false;
3632
3633 ExprResult Init = InitSeq.Perform(S, Entity, Kind, Args);
3634 if (!Init.get())
3635 return false;
3636
3637 VD->setInit(S.MaybeCreateExprWithCleanups(Init.get()));
3638 VD->setInitStyle(VarDecl::CallInit);
3639 S.CheckCompleteVariableDeclaration(VD);
3640 return true;
3641 }
3642
initGlobalResourceDecl(VarDecl * VD)3643 bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
3644 std::optional<uint32_t> RegisterSlot;
3645 uint32_t SpaceNo = 0;
3646 HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
3647 if (RBA) {
3648 if (RBA->hasRegisterSlot())
3649 RegisterSlot = RBA->getSlotNumber();
3650 SpaceNo = RBA->getSpaceNumber();
3651 }
3652
3653 ASTContext &AST = SemaRef.getASTContext();
3654 uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy);
3655 uint64_t IntTySize = AST.getTypeSize(AST.IntTy);
3656 IntegerLiteral *RangeSize = IntegerLiteral::Create(
3657 AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation());
3658 IntegerLiteral *Index = IntegerLiteral::Create(
3659 AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation());
3660 IntegerLiteral *Space =
3661 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, SpaceNo),
3662 AST.UnsignedIntTy, SourceLocation());
3663 StringRef VarName = VD->getName();
3664 StringLiteral *Name = StringLiteral::Create(
3665 AST, VarName, StringLiteralKind::Ordinary, false,
3666 AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()),
3667 SourceLocation());
3668
3669 // resource with explicit binding
3670 if (RegisterSlot.has_value()) {
3671 IntegerLiteral *RegSlot = IntegerLiteral::Create(
3672 AST, llvm::APInt(UIntTySize, RegisterSlot.value()), AST.UnsignedIntTy,
3673 SourceLocation());
3674 Expr *Args[] = {RegSlot, Space, RangeSize, Index, Name};
3675 return initVarDeclWithCtor(SemaRef, VD, Args);
3676 }
3677
3678 // resource with implicit binding
3679 IntegerLiteral *OrderId = IntegerLiteral::Create(
3680 AST, llvm::APInt(UIntTySize, getNextImplicitBindingOrderID()),
3681 AST.UnsignedIntTy, SourceLocation());
3682 Expr *Args[] = {Space, RangeSize, Index, OrderId, Name};
3683 return initVarDeclWithCtor(SemaRef, VD, Args);
3684 }
3685
3686 // Returns true if the initialization has been handled.
3687 // Returns false to use default initialization.
ActOnUninitializedVarDecl(VarDecl * VD)3688 bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
3689 // Objects in the hlsl_constant address space are initialized
3690 // externally, so don't synthesize an implicit initializer.
3691 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
3692 return true;
3693
3694 // Initialize resources
3695 if (!isResourceRecordTypeOrArrayOf(VD))
3696 return false;
3697
3698 // FIXME: We currectly support only simple resources - no arrays of resources
3699 // or resources in user defined structs.
3700 // (llvm/llvm-project#133835, llvm/llvm-project#133837)
3701 // Initialize resources at the global scope
3702 if (VD->hasGlobalStorage() && VD->getType()->isHLSLResourceRecord())
3703 return initGlobalResourceDecl(VD);
3704
3705 return false;
3706 }
3707
3708 // Walks though the global variable declaration, collects all resource binding
3709 // requirements and adds them to Bindings
collectResourceBindingsOnVarDecl(VarDecl * VD)3710 void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
3711 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
3712 "expected global variable that contains HLSL resource");
3713
3714 // Cbuffers and Tbuffers are HLSLBufferDecl types
3715 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) {
3716 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer()
3717 ? ResourceClass::CBuffer
3718 : ResourceClass::SRV);
3719 return;
3720 }
3721
3722 // Unwrap arrays
3723 // FIXME: Calculate array size while unwrapping
3724 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
3725 while (Ty->isConstantArrayType()) {
3726 const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
3727 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
3728 }
3729
3730 // Resource (or array of resources)
3731 if (const HLSLAttributedResourceType *AttrResType =
3732 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) {
3733 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass);
3734 return;
3735 }
3736
3737 // User defined record type
3738 if (const RecordType *RT = dyn_cast<RecordType>(Ty))
3739 collectResourceBindingsOnUserRecordDecl(VD, RT);
3740 }
3741
3742 // Walks though the explicit resource binding attributes on the declaration,
3743 // and makes sure there is a resource that matched the binding and updates
3744 // DeclBindingInfoLists
processExplicitBindingsOnDecl(VarDecl * VD)3745 void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
3746 assert(VD->hasGlobalStorage() && "expected global variable");
3747
3748 bool HasBinding = false;
3749 for (Attr *A : VD->attrs()) {
3750 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
3751 if (!RBA || !RBA->hasRegisterSlot())
3752 continue;
3753 HasBinding = true;
3754
3755 RegisterType RT = RBA->getRegisterType();
3756 assert(RT != RegisterType::I && "invalid or obsolete register type should "
3757 "never have an attribute created");
3758
3759 if (RT == RegisterType::C) {
3760 if (Bindings.hasBindingInfoForDecl(VD))
3761 SemaRef.Diag(VD->getLocation(),
3762 diag::warn_hlsl_user_defined_type_missing_member)
3763 << static_cast<int>(RT);
3764 continue;
3765 }
3766
3767 // Find DeclBindingInfo for this binding and update it, or report error
3768 // if it does not exist (user type does to contain resources with the
3769 // expected resource class).
3770 ResourceClass RC = getResourceClass(RT);
3771 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) {
3772 // update binding info
3773 BI->setBindingAttribute(RBA, BindingType::Explicit);
3774 } else {
3775 SemaRef.Diag(VD->getLocation(),
3776 diag::warn_hlsl_user_defined_type_missing_member)
3777 << static_cast<int>(RT);
3778 }
3779 }
3780
3781 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
3782 SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding);
3783 }
3784 namespace {
3785 class InitListTransformer {
3786 Sema &S;
3787 ASTContext &Ctx;
3788 QualType InitTy;
3789 QualType *DstIt = nullptr;
3790 Expr **ArgIt = nullptr;
3791 // Is wrapping the destination type iterator required? This is only used for
3792 // incomplete array types where we loop over the destination type since we
3793 // don't know the full number of elements from the declaration.
3794 bool Wrap;
3795
castInitializer(Expr * E)3796 bool castInitializer(Expr *E) {
3797 assert(DstIt && "This should always be something!");
3798 if (DstIt == DestTypes.end()) {
3799 if (!Wrap) {
3800 ArgExprs.push_back(E);
3801 // This is odd, but it isn't technically a failure due to conversion, we
3802 // handle mismatched counts of arguments differently.
3803 return true;
3804 }
3805 DstIt = DestTypes.begin();
3806 }
3807 InitializedEntity Entity = InitializedEntity::InitializeParameter(
3808 Ctx, *DstIt, /* Consumed (ObjC) */ false);
3809 ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
3810 if (Res.isInvalid())
3811 return false;
3812 Expr *Init = Res.get();
3813 ArgExprs.push_back(Init);
3814 DstIt++;
3815 return true;
3816 }
3817
buildInitializerListImpl(Expr * E)3818 bool buildInitializerListImpl(Expr *E) {
3819 // If this is an initialization list, traverse the sub initializers.
3820 if (auto *Init = dyn_cast<InitListExpr>(E)) {
3821 for (auto *SubInit : Init->inits())
3822 if (!buildInitializerListImpl(SubInit))
3823 return false;
3824 return true;
3825 }
3826
3827 // If this is a scalar type, just enqueue the expression.
3828 QualType Ty = E->getType();
3829
3830 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3831 return castInitializer(E);
3832
3833 if (auto *VecTy = Ty->getAs<VectorType>()) {
3834 uint64_t Size = VecTy->getNumElements();
3835
3836 QualType SizeTy = Ctx.getSizeType();
3837 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
3838 for (uint64_t I = 0; I < Size; ++I) {
3839 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
3840 SizeTy, SourceLocation());
3841
3842 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3843 E, E->getBeginLoc(), Idx, E->getEndLoc());
3844 if (ElExpr.isInvalid())
3845 return false;
3846 if (!castInitializer(ElExpr.get()))
3847 return false;
3848 }
3849 return true;
3850 }
3851
3852 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
3853 uint64_t Size = ArrTy->getZExtSize();
3854 QualType SizeTy = Ctx.getSizeType();
3855 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
3856 for (uint64_t I = 0; I < Size; ++I) {
3857 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
3858 SizeTy, SourceLocation());
3859 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3860 E, E->getBeginLoc(), Idx, E->getEndLoc());
3861 if (ElExpr.isInvalid())
3862 return false;
3863 if (!buildInitializerListImpl(ElExpr.get()))
3864 return false;
3865 }
3866 return true;
3867 }
3868
3869 if (auto *RTy = Ty->getAs<RecordType>()) {
3870 llvm::SmallVector<const RecordType *> RecordTypes;
3871 RecordTypes.push_back(RTy);
3872 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3873 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3874 assert(D->getNumBases() == 1 &&
3875 "HLSL doesn't support multiple inheritance");
3876 RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
3877 }
3878 while (!RecordTypes.empty()) {
3879 const RecordType *RT = RecordTypes.pop_back_val();
3880 for (auto *FD : RT->getDecl()->fields()) {
3881 DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
3882 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
3883 ExprResult Res = S.BuildFieldReferenceExpr(
3884 E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
3885 if (Res.isInvalid())
3886 return false;
3887 if (!buildInitializerListImpl(Res.get()))
3888 return false;
3889 }
3890 }
3891 }
3892 return true;
3893 }
3894
generateInitListsImpl(QualType Ty)3895 Expr *generateInitListsImpl(QualType Ty) {
3896 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
3897 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3898 return *(ArgIt++);
3899
3900 llvm::SmallVector<Expr *> Inits;
3901 assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
3902 Ty = Ty.getDesugaredType(Ctx);
3903 if (Ty->isVectorType() || Ty->isConstantArrayType()) {
3904 QualType ElTy;
3905 uint64_t Size = 0;
3906 if (auto *ATy = Ty->getAs<VectorType>()) {
3907 ElTy = ATy->getElementType();
3908 Size = ATy->getNumElements();
3909 } else {
3910 auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
3911 ElTy = VTy->getElementType();
3912 Size = VTy->getZExtSize();
3913 }
3914 for (uint64_t I = 0; I < Size; ++I)
3915 Inits.push_back(generateInitListsImpl(ElTy));
3916 }
3917 if (auto *RTy = Ty->getAs<RecordType>()) {
3918 llvm::SmallVector<const RecordType *> RecordTypes;
3919 RecordTypes.push_back(RTy);
3920 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3921 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3922 assert(D->getNumBases() == 1 &&
3923 "HLSL doesn't support multiple inheritance");
3924 RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
3925 }
3926 while (!RecordTypes.empty()) {
3927 const RecordType *RT = RecordTypes.pop_back_val();
3928 for (auto *FD : RT->getDecl()->fields()) {
3929 Inits.push_back(generateInitListsImpl(FD->getType()));
3930 }
3931 }
3932 }
3933 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3934 Inits, Inits.back()->getEndLoc());
3935 NewInit->setType(Ty);
3936 return NewInit;
3937 }
3938
3939 public:
3940 llvm::SmallVector<QualType, 16> DestTypes;
3941 llvm::SmallVector<Expr *, 16> ArgExprs;
InitListTransformer(Sema & SemaRef,const InitializedEntity & Entity)3942 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
3943 : S(SemaRef), Ctx(SemaRef.getASTContext()),
3944 Wrap(Entity.getType()->isIncompleteArrayType()) {
3945 InitTy = Entity.getType().getNonReferenceType();
3946 // When we're generating initializer lists for incomplete array types we
3947 // need to wrap around both when building the initializers and when
3948 // generating the final initializer lists.
3949 if (Wrap) {
3950 assert(InitTy->isIncompleteArrayType());
3951 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy);
3952 InitTy = IAT->getElementType();
3953 }
3954 BuildFlattenedTypeList(InitTy, DestTypes);
3955 DstIt = DestTypes.begin();
3956 }
3957
buildInitializerList(Expr * E)3958 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
3959
generateInitLists()3960 Expr *generateInitLists() {
3961 assert(!ArgExprs.empty() &&
3962 "Call buildInitializerList to generate argument expressions.");
3963 ArgIt = ArgExprs.begin();
3964 if (!Wrap)
3965 return generateInitListsImpl(InitTy);
3966 llvm::SmallVector<Expr *> Inits;
3967 while (ArgIt != ArgExprs.end())
3968 Inits.push_back(generateInitListsImpl(InitTy));
3969
3970 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3971 Inits, Inits.back()->getEndLoc());
3972 llvm::APInt ArySize(64, Inits.size());
3973 NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr,
3974 ArraySizeModifier::Normal, 0));
3975 return NewInit;
3976 }
3977 };
3978 } // namespace
3979
transformInitList(const InitializedEntity & Entity,InitListExpr * Init)3980 bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
3981 InitListExpr *Init) {
3982 // If the initializer is a scalar, just return it.
3983 if (Init->getType()->isScalarType())
3984 return true;
3985 ASTContext &Ctx = SemaRef.getASTContext();
3986 InitListTransformer ILT(SemaRef, Entity);
3987
3988 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
3989 Expr *E = Init->getInit(I);
3990 if (E->HasSideEffects(Ctx)) {
3991 QualType Ty = E->getType();
3992 if (Ty->isRecordType())
3993 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
3994 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
3995 E->getObjectKind(), E);
3996 Init->setInit(I, E);
3997 }
3998 if (!ILT.buildInitializerList(E))
3999 return false;
4000 }
4001 size_t ExpectedSize = ILT.DestTypes.size();
4002 size_t ActualSize = ILT.ArgExprs.size();
4003 // For incomplete arrays it is completely arbitrary to choose whether we think
4004 // the user intended fewer or more elements. This implementation assumes that
4005 // the user intended more, and errors that there are too few initializers to
4006 // complete the final element.
4007 if (Entity.getType()->isIncompleteArrayType())
4008 ExpectedSize =
4009 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
4010
4011 // An initializer list might be attempting to initialize a reference or
4012 // rvalue-reference. When checking the initializer we should look through
4013 // the reference.
4014 QualType InitTy = Entity.getType().getNonReferenceType();
4015 if (InitTy.hasAddressSpace())
4016 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy);
4017 if (ExpectedSize != ActualSize) {
4018 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
4019 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
4020 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
4021 return false;
4022 }
4023
4024 // generateInitListsImpl will always return an InitListExpr here, because the
4025 // scalar case is handled above.
4026 auto *NewInit = cast<InitListExpr>(ILT.generateInitLists());
4027 Init->resizeInits(Ctx, NewInit->getNumInits());
4028 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
4029 Init->updateInit(Ctx, I, NewInit->getInit(I));
4030 return true;
4031 }
4032
handleInitialization(VarDecl * VDecl,Expr * & Init)4033 bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
4034 const HLSLVkConstantIdAttr *ConstIdAttr =
4035 VDecl->getAttr<HLSLVkConstantIdAttr>();
4036 if (!ConstIdAttr)
4037 return true;
4038
4039 ASTContext &Context = SemaRef.getASTContext();
4040
4041 APValue InitValue;
4042 if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
4043 Diag(VDecl->getLocation(), diag::err_specialization_const);
4044 VDecl->setInvalidDecl();
4045 return false;
4046 }
4047
4048 Builtin::ID BID =
4049 getSpecConstBuiltinId(VDecl->getType()->getUnqualifiedDesugaredType());
4050
4051 // Argument 1: The ID from the attribute
4052 int ConstantID = ConstIdAttr->getId();
4053 llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
4054 Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
4055 ConstIdAttr->getLocation());
4056
4057 SmallVector<Expr *, 2> Args = {IdExpr, Init};
4058 Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
4059 if (C->getType()->getCanonicalTypeUnqualified() !=
4060 VDecl->getType()->getCanonicalTypeUnqualified()) {
4061 C = SemaRef
4062 .BuildCStyleCastExpr(SourceLocation(),
4063 Context.getTrivialTypeSourceInfo(
4064 Init->getType(), Init->getExprLoc()),
4065 SourceLocation(), C)
4066 .get();
4067 }
4068 Init = C;
4069 return true;
4070 }
4071