1 //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This implements Semantic Analysis for HLSL constructs. 9 //===----------------------------------------------------------------------===// 10 11 #include "clang/Sema/SemaHLSL.h" 12 #include "clang/AST/ASTConsumer.h" 13 #include "clang/AST/ASTContext.h" 14 #include "clang/AST/Attr.h" 15 #include "clang/AST/Attrs.inc" 16 #include "clang/AST/Decl.h" 17 #include "clang/AST/DeclBase.h" 18 #include "clang/AST/DeclCXX.h" 19 #include "clang/AST/DeclarationName.h" 20 #include "clang/AST/DynamicRecursiveASTVisitor.h" 21 #include "clang/AST/Expr.h" 22 #include "clang/AST/Type.h" 23 #include "clang/AST/TypeLoc.h" 24 #include "clang/Basic/Builtins.h" 25 #include "clang/Basic/DiagnosticSema.h" 26 #include "clang/Basic/IdentifierTable.h" 27 #include "clang/Basic/LLVM.h" 28 #include "clang/Basic/SourceLocation.h" 29 #include "clang/Basic/Specifiers.h" 30 #include "clang/Basic/TargetInfo.h" 31 #include "clang/Sema/Initialization.h" 32 #include "clang/Sema/Lookup.h" 33 #include "clang/Sema/ParsedAttr.h" 34 #include "clang/Sema/Sema.h" 35 #include "clang/Sema/Template.h" 36 #include "llvm/ADT/ArrayRef.h" 37 #include "llvm/ADT/STLExtras.h" 38 #include "llvm/ADT/SmallVector.h" 39 #include "llvm/ADT/StringExtras.h" 40 #include "llvm/ADT/StringRef.h" 41 #include "llvm/ADT/Twine.h" 42 #include "llvm/Frontend/HLSL/RootSignatureValidations.h" 43 #include "llvm/Support/Casting.h" 44 #include "llvm/Support/DXILABI.h" 45 #include "llvm/Support/ErrorHandling.h" 46 #include "llvm/Support/FormatVariadic.h" 47 #include "llvm/TargetParser/Triple.h" 48 #include <cmath> 49 #include <cstddef> 50 #include <iterator> 51 #include <utility> 52 53 using namespace clang; 54 using RegisterType = HLSLResourceBindingAttr::RegisterType; 55 56 static CXXRecordDecl *createHostLayoutStruct(Sema &S, 57 CXXRecordDecl *StructDecl); 58 59 static RegisterType getRegisterType(ResourceClass RC) { 60 switch (RC) { 61 case ResourceClass::SRV: 62 return RegisterType::SRV; 63 case ResourceClass::UAV: 64 return RegisterType::UAV; 65 case ResourceClass::CBuffer: 66 return RegisterType::CBuffer; 67 case ResourceClass::Sampler: 68 return RegisterType::Sampler; 69 } 70 llvm_unreachable("unexpected ResourceClass value"); 71 } 72 73 // Converts the first letter of string Slot to RegisterType. 74 // Returns false if the letter does not correspond to a valid register type. 75 static bool convertToRegisterType(StringRef Slot, RegisterType *RT) { 76 assert(RT != nullptr); 77 switch (Slot[0]) { 78 case 't': 79 case 'T': 80 *RT = RegisterType::SRV; 81 return true; 82 case 'u': 83 case 'U': 84 *RT = RegisterType::UAV; 85 return true; 86 case 'b': 87 case 'B': 88 *RT = RegisterType::CBuffer; 89 return true; 90 case 's': 91 case 'S': 92 *RT = RegisterType::Sampler; 93 return true; 94 case 'c': 95 case 'C': 96 *RT = RegisterType::C; 97 return true; 98 case 'i': 99 case 'I': 100 *RT = RegisterType::I; 101 return true; 102 default: 103 return false; 104 } 105 } 106 107 static ResourceClass getResourceClass(RegisterType RT) { 108 switch (RT) { 109 case RegisterType::SRV: 110 return ResourceClass::SRV; 111 case RegisterType::UAV: 112 return ResourceClass::UAV; 113 case RegisterType::CBuffer: 114 return ResourceClass::CBuffer; 115 case RegisterType::Sampler: 116 return ResourceClass::Sampler; 117 case RegisterType::C: 118 case RegisterType::I: 119 // Deliberately falling through to the unreachable below. 120 break; 121 } 122 llvm_unreachable("unexpected RegisterType value"); 123 } 124 125 static Builtin::ID getSpecConstBuiltinId(const Type *Type) { 126 const auto *BT = dyn_cast<BuiltinType>(Type); 127 if (!BT) { 128 if (!Type->isEnumeralType()) 129 return Builtin::NotBuiltin; 130 return Builtin::BI__builtin_get_spirv_spec_constant_int; 131 } 132 133 switch (BT->getKind()) { 134 case BuiltinType::Bool: 135 return Builtin::BI__builtin_get_spirv_spec_constant_bool; 136 case BuiltinType::Short: 137 return Builtin::BI__builtin_get_spirv_spec_constant_short; 138 case BuiltinType::Int: 139 return Builtin::BI__builtin_get_spirv_spec_constant_int; 140 case BuiltinType::LongLong: 141 return Builtin::BI__builtin_get_spirv_spec_constant_longlong; 142 case BuiltinType::UShort: 143 return Builtin::BI__builtin_get_spirv_spec_constant_ushort; 144 case BuiltinType::UInt: 145 return Builtin::BI__builtin_get_spirv_spec_constant_uint; 146 case BuiltinType::ULongLong: 147 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong; 148 case BuiltinType::Half: 149 return Builtin::BI__builtin_get_spirv_spec_constant_half; 150 case BuiltinType::Float: 151 return Builtin::BI__builtin_get_spirv_spec_constant_float; 152 case BuiltinType::Double: 153 return Builtin::BI__builtin_get_spirv_spec_constant_double; 154 default: 155 return Builtin::NotBuiltin; 156 } 157 } 158 159 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD, 160 ResourceClass ResClass) { 161 assert(getDeclBindingInfo(VD, ResClass) == nullptr && 162 "DeclBindingInfo already added"); 163 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD); 164 // VarDecl may have multiple entries for different resource classes. 165 // DeclToBindingListIndex stores the index of the first binding we saw 166 // for this decl. If there are any additional ones then that index 167 // shouldn't be updated. 168 DeclToBindingListIndex.try_emplace(VD, BindingsList.size()); 169 return &BindingsList.emplace_back(VD, ResClass); 170 } 171 172 DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD, 173 ResourceClass ResClass) { 174 auto Entry = DeclToBindingListIndex.find(VD); 175 if (Entry != DeclToBindingListIndex.end()) { 176 for (unsigned Index = Entry->getSecond(); 177 Index < BindingsList.size() && BindingsList[Index].Decl == VD; 178 ++Index) { 179 if (BindingsList[Index].ResClass == ResClass) 180 return &BindingsList[Index]; 181 } 182 } 183 return nullptr; 184 } 185 186 bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const { 187 return DeclToBindingListIndex.contains(VD); 188 } 189 190 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} 191 192 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, 193 SourceLocation KwLoc, IdentifierInfo *Ident, 194 SourceLocation IdentLoc, 195 SourceLocation LBrace) { 196 // For anonymous namespace, take the location of the left brace. 197 DeclContext *LexicalParent = SemaRef.getCurLexicalContext(); 198 HLSLBufferDecl *Result = HLSLBufferDecl::Create( 199 getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace); 200 201 // if CBuffer is false, then it's a TBuffer 202 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer 203 : llvm::hlsl::ResourceClass::SRV; 204 Result->addAttr(HLSLResourceClassAttr::CreateImplicit(getASTContext(), RC)); 205 206 SemaRef.PushOnScopeChains(Result, BufferScope); 207 SemaRef.PushDeclContext(BufferScope, Result); 208 209 return Result; 210 } 211 212 static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context, 213 QualType T) { 214 // Arrays and Structs are always aligned to new buffer rows 215 if (T->isArrayType() || T->isStructureType()) 216 return 16; 217 218 // Vectors are aligned to the type they contain 219 if (const VectorType *VT = T->getAs<VectorType>()) 220 return calculateLegacyCbufferFieldAlign(Context, VT->getElementType()); 221 222 assert(Context.getTypeSize(T) <= 64 && 223 "Scalar bit widths larger than 64 not supported"); 224 225 // Scalar types are aligned to their byte width 226 return Context.getTypeSize(T) / 8; 227 } 228 229 // Calculate the size of a legacy cbuffer type in bytes based on 230 // https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules 231 static unsigned calculateLegacyCbufferSize(const ASTContext &Context, 232 QualType T) { 233 constexpr unsigned CBufferAlign = 16; 234 if (const RecordType *RT = T->getAs<RecordType>()) { 235 unsigned Size = 0; 236 const RecordDecl *RD = RT->getDecl(); 237 for (const FieldDecl *Field : RD->fields()) { 238 QualType Ty = Field->getType(); 239 unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty); 240 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, Ty); 241 242 // If the field crosses the row boundary after alignment it drops to the 243 // next row 244 unsigned AlignSize = llvm::alignTo(Size, FieldAlign); 245 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) { 246 FieldAlign = CBufferAlign; 247 } 248 249 Size = llvm::alignTo(Size, FieldAlign); 250 Size += FieldSize; 251 } 252 return Size; 253 } 254 255 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) { 256 unsigned ElementCount = AT->getSize().getZExtValue(); 257 if (ElementCount == 0) 258 return 0; 259 260 unsigned ElementSize = 261 calculateLegacyCbufferSize(Context, AT->getElementType()); 262 unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign); 263 return AlignedElementSize * (ElementCount - 1) + ElementSize; 264 } 265 266 if (const VectorType *VT = T->getAs<VectorType>()) { 267 unsigned ElementCount = VT->getNumElements(); 268 unsigned ElementSize = 269 calculateLegacyCbufferSize(Context, VT->getElementType()); 270 return ElementSize * ElementCount; 271 } 272 273 return Context.getTypeSize(T) / 8; 274 } 275 276 // Validate packoffset: 277 // - if packoffset it used it must be set on all declarations inside the buffer 278 // - packoffset ranges must not overlap 279 static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) { 280 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec; 281 282 // Make sure the packoffset annotations are either on all declarations 283 // or on none. 284 bool HasPackOffset = false; 285 bool HasNonPackOffset = false; 286 for (auto *Field : BufDecl->buffer_decls()) { 287 VarDecl *Var = dyn_cast<VarDecl>(Field); 288 if (!Var) 289 continue; 290 if (Field->hasAttr<HLSLPackOffsetAttr>()) { 291 PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>()); 292 HasPackOffset = true; 293 } else { 294 HasNonPackOffset = true; 295 } 296 } 297 298 if (!HasPackOffset) 299 return; 300 301 if (HasNonPackOffset) 302 S.Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix); 303 304 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset 305 // and compare adjacent values. 306 bool IsValid = true; 307 ASTContext &Context = S.getASTContext(); 308 std::sort(PackOffsetVec.begin(), PackOffsetVec.end(), 309 [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS, 310 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) { 311 return LHS.second->getOffsetInBytes() < 312 RHS.second->getOffsetInBytes(); 313 }); 314 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) { 315 VarDecl *Var = PackOffsetVec[i].first; 316 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second; 317 unsigned Size = calculateLegacyCbufferSize(Context, Var->getType()); 318 unsigned Begin = Attr->getOffsetInBytes(); 319 unsigned End = Begin + Size; 320 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes(); 321 if (End > NextBegin) { 322 VarDecl *NextVar = PackOffsetVec[i + 1].first; 323 S.Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap) 324 << NextVar << Var; 325 IsValid = false; 326 } 327 } 328 BufDecl->setHasValidPackoffset(IsValid); 329 } 330 331 // Returns true if the array has a zero size = if any of the dimensions is 0 332 static bool isZeroSizedArray(const ConstantArrayType *CAT) { 333 while (CAT && !CAT->isZeroSize()) 334 CAT = dyn_cast<ConstantArrayType>( 335 CAT->getElementType()->getUnqualifiedDesugaredType()); 336 return CAT != nullptr; 337 } 338 339 // Returns true if the record type is an HLSL resource class or an array of 340 // resource classes 341 static bool isResourceRecordTypeOrArrayOf(const Type *Ty) { 342 while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Ty)) 343 Ty = CAT->getArrayElementTypeNoTypeQual(); 344 return HLSLAttributedResourceType::findHandleTypeOnResource(Ty) != nullptr; 345 } 346 347 static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) { 348 return isResourceRecordTypeOrArrayOf(VD->getType().getTypePtr()); 349 } 350 351 // Returns true if the type is a leaf element type that is not valid to be 352 // included in HLSL Buffer, such as a resource class, empty struct, zero-sized 353 // array, or a builtin intangible type. Returns false it is a valid leaf element 354 // type or if it is a record type that needs to be inspected further. 355 static bool isInvalidConstantBufferLeafElementType(const Type *Ty) { 356 Ty = Ty->getUnqualifiedDesugaredType(); 357 if (isResourceRecordTypeOrArrayOf(Ty)) 358 return true; 359 if (Ty->isRecordType()) 360 return Ty->getAsCXXRecordDecl()->isEmpty(); 361 if (Ty->isConstantArrayType() && 362 isZeroSizedArray(cast<ConstantArrayType>(Ty))) 363 return true; 364 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType()) 365 return true; 366 return false; 367 } 368 369 // Returns true if the struct contains at least one element that prevents it 370 // from being included inside HLSL Buffer as is, such as an intangible type, 371 // empty struct, or zero-sized array. If it does, a new implicit layout struct 372 // needs to be created for HLSL Buffer use that will exclude these unwanted 373 // declarations (see createHostLayoutStruct function). 374 static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) { 375 if (RD->getTypeForDecl()->isHLSLIntangibleType() || RD->isEmpty()) 376 return true; 377 // check fields 378 for (const FieldDecl *Field : RD->fields()) { 379 QualType Ty = Field->getType(); 380 if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr())) 381 return true; 382 if (Ty->isRecordType() && 383 requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl())) 384 return true; 385 } 386 // check bases 387 for (const CXXBaseSpecifier &Base : RD->bases()) 388 if (requiresImplicitBufferLayoutStructure( 389 Base.getType()->getAsCXXRecordDecl())) 390 return true; 391 return false; 392 } 393 394 static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II, 395 DeclContext *DC) { 396 CXXRecordDecl *RD = nullptr; 397 for (NamedDecl *Decl : 398 DC->getNonTransparentContext()->lookup(DeclarationName(II))) { 399 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Decl)) { 400 assert(RD == nullptr && 401 "there should be at most 1 record by a given name in a scope"); 402 RD = FoundRD; 403 } 404 } 405 return RD; 406 } 407 408 // Creates a name for buffer layout struct using the provide name base. 409 // If the name must be unique (not previously defined), a suffix is added 410 // until a unique name is found. 411 static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl, 412 bool MustBeUnique) { 413 ASTContext &AST = S.getASTContext(); 414 415 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier(); 416 llvm::SmallString<64> Name("__cblayout_"); 417 if (NameBaseII) { 418 Name.append(NameBaseII->getName()); 419 } else { 420 // anonymous struct 421 Name.append("anon"); 422 MustBeUnique = true; 423 } 424 425 size_t NameLength = Name.size(); 426 IdentifierInfo *II = &AST.Idents.get(Name, tok::TokenKind::identifier); 427 if (!MustBeUnique) 428 return II; 429 430 unsigned suffix = 0; 431 while (true) { 432 if (suffix != 0) { 433 Name.append("_"); 434 Name.append(llvm::Twine(suffix).str()); 435 II = &AST.Idents.get(Name, tok::TokenKind::identifier); 436 } 437 if (!findRecordDeclInContext(II, BaseDecl->getDeclContext())) 438 return II; 439 // declaration with that name already exists - increment suffix and try 440 // again until unique name is found 441 suffix++; 442 Name.truncate(NameLength); 443 }; 444 } 445 446 // Creates a field declaration of given name and type for HLSL buffer layout 447 // struct. Returns nullptr if the type cannot be use in HLSL Buffer layout. 448 static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty, 449 IdentifierInfo *II, 450 CXXRecordDecl *LayoutStruct) { 451 if (isInvalidConstantBufferLeafElementType(Ty)) 452 return nullptr; 453 454 if (Ty->isRecordType()) { 455 CXXRecordDecl *RD = Ty->getAsCXXRecordDecl(); 456 if (requiresImplicitBufferLayoutStructure(RD)) { 457 RD = createHostLayoutStruct(S, RD); 458 if (!RD) 459 return nullptr; 460 Ty = RD->getTypeForDecl(); 461 } 462 } 463 464 QualType QT = QualType(Ty, 0); 465 ASTContext &AST = S.getASTContext(); 466 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(QT, SourceLocation()); 467 auto *Field = FieldDecl::Create(AST, LayoutStruct, SourceLocation(), 468 SourceLocation(), II, QT, TSI, nullptr, false, 469 InClassInitStyle::ICIS_NoInit); 470 Field->setAccess(AccessSpecifier::AS_public); 471 return Field; 472 } 473 474 // Creates host layout struct for a struct included in HLSL Buffer. 475 // The layout struct will include only fields that are allowed in HLSL buffer. 476 // These fields will be filtered out: 477 // - resource classes 478 // - empty structs 479 // - zero-sized arrays 480 // Returns nullptr if the resulting layout struct would be empty. 481 static CXXRecordDecl *createHostLayoutStruct(Sema &S, 482 CXXRecordDecl *StructDecl) { 483 assert(requiresImplicitBufferLayoutStructure(StructDecl) && 484 "struct is already HLSL buffer compatible"); 485 486 ASTContext &AST = S.getASTContext(); 487 DeclContext *DC = StructDecl->getDeclContext(); 488 IdentifierInfo *II = getHostLayoutStructName(S, StructDecl, false); 489 490 // reuse existing if the layout struct if it already exists 491 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC)) 492 return RD; 493 494 CXXRecordDecl *LS = 495 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, DC, SourceLocation(), 496 SourceLocation(), II); 497 LS->setImplicit(true); 498 LS->addAttr(PackedAttr::CreateImplicit(AST)); 499 LS->startDefinition(); 500 501 // copy base struct, create HLSL Buffer compatible version if needed 502 if (unsigned NumBases = StructDecl->getNumBases()) { 503 assert(NumBases == 1 && "HLSL supports only one base type"); 504 (void)NumBases; 505 CXXBaseSpecifier Base = *StructDecl->bases_begin(); 506 CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl(); 507 if (requiresImplicitBufferLayoutStructure(BaseDecl)) { 508 BaseDecl = createHostLayoutStruct(S, BaseDecl); 509 if (BaseDecl) { 510 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo( 511 QualType(BaseDecl->getTypeForDecl(), 0)); 512 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(), 513 AS_none, TSI, SourceLocation()); 514 } 515 } 516 if (BaseDecl) { 517 const CXXBaseSpecifier *BasesArray[1] = {&Base}; 518 LS->setBases(BasesArray, 1); 519 } 520 } 521 522 // filter struct fields 523 for (const FieldDecl *FD : StructDecl->fields()) { 524 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); 525 if (FieldDecl *NewFD = 526 createFieldForHostLayoutStruct(S, Ty, FD->getIdentifier(), LS)) 527 LS->addDecl(NewFD); 528 } 529 LS->completeDefinition(); 530 531 if (LS->field_empty() && LS->getNumBases() == 0) 532 return nullptr; 533 534 DC->addDecl(LS); 535 return LS; 536 } 537 538 // Creates host layout struct for HLSL Buffer. The struct will include only 539 // fields of types that are allowed in HLSL buffer and it will filter out: 540 // - static or groupshared variable declarations 541 // - resource classes 542 // - empty structs 543 // - zero-sized arrays 544 // - non-variable declarations 545 // The layout struct will be added to the HLSLBufferDecl declarations. 546 void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) { 547 ASTContext &AST = S.getASTContext(); 548 IdentifierInfo *II = getHostLayoutStructName(S, BufDecl, true); 549 550 CXXRecordDecl *LS = 551 CXXRecordDecl::Create(AST, TagDecl::TagKind::Struct, BufDecl, 552 SourceLocation(), SourceLocation(), II); 553 LS->addAttr(PackedAttr::CreateImplicit(AST)); 554 LS->setImplicit(true); 555 LS->startDefinition(); 556 557 for (Decl *D : BufDecl->buffer_decls()) { 558 VarDecl *VD = dyn_cast<VarDecl>(D); 559 if (!VD || VD->getStorageClass() == SC_Static || 560 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared) 561 continue; 562 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); 563 if (FieldDecl *FD = 564 createFieldForHostLayoutStruct(S, Ty, VD->getIdentifier(), LS)) { 565 // add the field decl to the layout struct 566 LS->addDecl(FD); 567 // update address space of the original decl to hlsl_constant 568 QualType NewTy = 569 AST.getAddrSpaceQualType(VD->getType(), LangAS::hlsl_constant); 570 VD->setType(NewTy); 571 } 572 } 573 LS->completeDefinition(); 574 BufDecl->addLayoutStruct(LS); 575 } 576 577 static void addImplicitBindingAttrToBuffer(Sema &S, HLSLBufferDecl *BufDecl, 578 uint32_t ImplicitBindingOrderID) { 579 RegisterType RT = 580 BufDecl->isCBuffer() ? RegisterType::CBuffer : RegisterType::SRV; 581 auto *Attr = 582 HLSLResourceBindingAttr::CreateImplicit(S.getASTContext(), "", "0", {}); 583 std::optional<unsigned> RegSlot; 584 Attr->setBinding(RT, RegSlot, 0); 585 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID); 586 BufDecl->addAttr(Attr); 587 } 588 589 // Handle end of cbuffer/tbuffer declaration 590 void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { 591 auto *BufDecl = cast<HLSLBufferDecl>(Dcl); 592 BufDecl->setRBraceLoc(RBrace); 593 594 validatePackoffset(SemaRef, BufDecl); 595 596 // create buffer layout struct 597 createHostLayoutStructForBuffer(SemaRef, BufDecl); 598 599 HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>(); 600 if (!RBA || !RBA->hasRegisterSlot()) { 601 SemaRef.Diag(Dcl->getLocation(), diag::warn_hlsl_implicit_binding); 602 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID 603 // to codegen. If it does not exist, create an implicit attribute. 604 uint32_t OrderID = getNextImplicitBindingOrderID(); 605 if (RBA) 606 RBA->setImplicitBindingOrderID(OrderID); 607 else 608 addImplicitBindingAttrToBuffer(SemaRef, BufDecl, OrderID); 609 } 610 611 SemaRef.PopDeclContext(); 612 } 613 614 HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, 615 const AttributeCommonInfo &AL, 616 int X, int Y, int Z) { 617 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) { 618 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { 619 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; 620 Diag(AL.getLoc(), diag::note_conflicting_attribute); 621 } 622 return nullptr; 623 } 624 return ::new (getASTContext()) 625 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); 626 } 627 628 HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D, 629 const AttributeCommonInfo &AL, 630 int Min, int Max, int Preferred, 631 int SpelledArgsCount) { 632 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) { 633 if (WS->getMin() != Min || WS->getMax() != Max || 634 WS->getPreferred() != Preferred || 635 WS->getSpelledArgsCount() != SpelledArgsCount) { 636 Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; 637 Diag(AL.getLoc(), diag::note_conflicting_attribute); 638 } 639 return nullptr; 640 } 641 HLSLWaveSizeAttr *Result = ::new (getASTContext()) 642 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred); 643 Result->setSpelledArgsCount(SpelledArgsCount); 644 return Result; 645 } 646 647 HLSLVkConstantIdAttr * 648 SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, 649 int Id) { 650 651 auto &TargetInfo = getASTContext().getTargetInfo(); 652 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) { 653 Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL; 654 return nullptr; 655 } 656 657 auto *VD = cast<VarDecl>(D); 658 659 if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) == 660 Builtin::NotBuiltin) { 661 Diag(VD->getLocation(), diag::err_specialization_const); 662 return nullptr; 663 } 664 665 if (!VD->getType().isConstQualified()) { 666 Diag(VD->getLocation(), diag::err_specialization_const); 667 return nullptr; 668 } 669 670 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) { 671 if (CI->getId() != Id) { 672 Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; 673 Diag(AL.getLoc(), diag::note_conflicting_attribute); 674 } 675 return nullptr; 676 } 677 678 HLSLVkConstantIdAttr *Result = 679 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id); 680 return Result; 681 } 682 683 HLSLShaderAttr * 684 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, 685 llvm::Triple::EnvironmentType ShaderType) { 686 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { 687 if (NT->getType() != ShaderType) { 688 Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; 689 Diag(AL.getLoc(), diag::note_conflicting_attribute); 690 } 691 return nullptr; 692 } 693 return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL); 694 } 695 696 HLSLParamModifierAttr * 697 SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, 698 HLSLParamModifierAttr::Spelling Spelling) { 699 // We can only merge an `in` attribute with an `out` attribute. All other 700 // combinations of duplicated attributes are ill-formed. 701 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) { 702 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || 703 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { 704 D->dropAttr<HLSLParamModifierAttr>(); 705 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; 706 return HLSLParamModifierAttr::Create( 707 getASTContext(), /*MergedSpelling=*/true, AdjustedRange, 708 HLSLParamModifierAttr::Keyword_inout); 709 } 710 Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL; 711 Diag(PA->getLocation(), diag::note_conflicting_attribute); 712 return nullptr; 713 } 714 return HLSLParamModifierAttr::Create(getASTContext(), AL); 715 } 716 717 void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { 718 auto &TargetInfo = getASTContext().getTargetInfo(); 719 720 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) 721 return; 722 723 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment(); 724 if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) { 725 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { 726 // The entry point is already annotated - check that it matches the 727 // triple. 728 if (Shader->getType() != Env) { 729 Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) 730 << Shader; 731 FD->setInvalidDecl(); 732 } 733 } else { 734 // Implicitly add the shader attribute if the entry function isn't 735 // explicitly annotated. 736 FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env, 737 FD->getBeginLoc())); 738 } 739 } else { 740 switch (Env) { 741 case llvm::Triple::UnknownEnvironment: 742 case llvm::Triple::Library: 743 break; 744 default: 745 llvm_unreachable("Unhandled environment in triple"); 746 } 747 } 748 } 749 750 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { 751 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); 752 assert(ShaderAttr && "Entry point has no shader attribute"); 753 llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); 754 auto &TargetInfo = getASTContext().getTargetInfo(); 755 VersionTuple Ver = TargetInfo.getTriple().getOSVersion(); 756 switch (ST) { 757 case llvm::Triple::Pixel: 758 case llvm::Triple::Vertex: 759 case llvm::Triple::Geometry: 760 case llvm::Triple::Hull: 761 case llvm::Triple::Domain: 762 case llvm::Triple::RayGeneration: 763 case llvm::Triple::Intersection: 764 case llvm::Triple::AnyHit: 765 case llvm::Triple::ClosestHit: 766 case llvm::Triple::Miss: 767 case llvm::Triple::Callable: 768 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { 769 DiagnoseAttrStageMismatch(NT, ST, 770 {llvm::Triple::Compute, 771 llvm::Triple::Amplification, 772 llvm::Triple::Mesh}); 773 FD->setInvalidDecl(); 774 } 775 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) { 776 DiagnoseAttrStageMismatch(WS, ST, 777 {llvm::Triple::Compute, 778 llvm::Triple::Amplification, 779 llvm::Triple::Mesh}); 780 FD->setInvalidDecl(); 781 } 782 break; 783 784 case llvm::Triple::Compute: 785 case llvm::Triple::Amplification: 786 case llvm::Triple::Mesh: 787 if (!FD->hasAttr<HLSLNumThreadsAttr>()) { 788 Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) 789 << llvm::Triple::getEnvironmentTypeName(ST); 790 FD->setInvalidDecl(); 791 } 792 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) { 793 if (Ver < VersionTuple(6, 6)) { 794 Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model) 795 << WS << "6.6"; 796 FD->setInvalidDecl(); 797 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) { 798 Diag( 799 WS->getLocation(), 800 diag::err_hlsl_attribute_number_arguments_insufficient_shader_model) 801 << WS << WS->getSpelledArgsCount() << "6.8"; 802 FD->setInvalidDecl(); 803 } 804 } 805 break; 806 default: 807 llvm_unreachable("Unhandled environment in triple"); 808 } 809 810 for (ParmVarDecl *Param : FD->parameters()) { 811 if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) { 812 CheckSemanticAnnotation(FD, Param, AnnotationAttr); 813 } else { 814 // FIXME: Handle struct parameters where annotations are on struct fields. 815 // See: https://github.com/llvm/llvm-project/issues/57875 816 Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation); 817 Diag(Param->getLocation(), diag::note_previous_decl) << Param; 818 FD->setInvalidDecl(); 819 } 820 } 821 // FIXME: Verify return type semantic annotation. 822 } 823 824 void SemaHLSL::CheckSemanticAnnotation( 825 FunctionDecl *EntryPoint, const Decl *Param, 826 const HLSLAnnotationAttr *AnnotationAttr) { 827 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); 828 assert(ShaderAttr && "Entry point has no shader attribute"); 829 llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); 830 831 switch (AnnotationAttr->getKind()) { 832 case attr::HLSLSV_DispatchThreadID: 833 case attr::HLSLSV_GroupIndex: 834 case attr::HLSLSV_GroupThreadID: 835 case attr::HLSLSV_GroupID: 836 if (ST == llvm::Triple::Compute) 837 return; 838 DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute}); 839 break; 840 case attr::HLSLSV_Position: 841 // TODO(#143523): allow use on other shader types & output once the overall 842 // semantic logic is implemented. 843 if (ST == llvm::Triple::Pixel) 844 return; 845 DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Pixel}); 846 break; 847 default: 848 llvm_unreachable("Unknown HLSLAnnotationAttr"); 849 } 850 } 851 852 void SemaHLSL::DiagnoseAttrStageMismatch( 853 const Attr *A, llvm::Triple::EnvironmentType Stage, 854 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) { 855 SmallVector<StringRef, 8> StageStrings; 856 llvm::transform(AllowedStages, std::back_inserter(StageStrings), 857 [](llvm::Triple::EnvironmentType ST) { 858 return StringRef( 859 HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST)); 860 }); 861 Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) 862 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Stage) 863 << (AllowedStages.size() != 1) << join(StageStrings, ", "); 864 } 865 866 template <CastKind Kind> 867 static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) { 868 if (const auto *VTy = Ty->getAs<VectorType>()) 869 Ty = VTy->getElementType(); 870 Ty = S.getASTContext().getExtVectorType(Ty, Sz); 871 E = S.ImpCastExprToType(E.get(), Ty, Kind); 872 } 873 874 template <CastKind Kind> 875 static QualType castElement(Sema &S, ExprResult &E, QualType Ty) { 876 E = S.ImpCastExprToType(E.get(), Ty, Kind); 877 return Ty; 878 } 879 880 static QualType handleFloatVectorBinOpConversion( 881 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, 882 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) { 883 bool LHSFloat = LElTy->isRealFloatingType(); 884 bool RHSFloat = RElTy->isRealFloatingType(); 885 886 if (LHSFloat && RHSFloat) { 887 if (IsCompAssign || 888 SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0) 889 return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType); 890 891 return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType); 892 } 893 894 if (LHSFloat) 895 return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType); 896 897 assert(RHSFloat); 898 if (IsCompAssign) 899 return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType); 900 901 return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType); 902 } 903 904 static QualType handleIntegerVectorBinOpConversion( 905 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType, 906 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) { 907 908 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy); 909 bool LHSSigned = LElTy->hasSignedIntegerRepresentation(); 910 bool RHSSigned = RElTy->hasSignedIntegerRepresentation(); 911 auto &Ctx = SemaRef.getASTContext(); 912 913 // If both types have the same signedness, use the higher ranked type. 914 if (LHSSigned == RHSSigned) { 915 if (IsCompAssign || IntOrder >= 0) 916 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 917 918 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType); 919 } 920 921 // If the unsigned type has greater than or equal rank of the signed type, use 922 // the unsigned type. 923 if (IntOrder != (LHSSigned ? 1 : -1)) { 924 if (IsCompAssign || RHSSigned) 925 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 926 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType); 927 } 928 929 // At this point the signed type has higher rank than the unsigned type, which 930 // means it will be the same size or bigger. If the signed type is bigger, it 931 // can represent all the values of the unsigned type, so select it. 932 if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) { 933 if (IsCompAssign || LHSSigned) 934 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 935 return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType); 936 } 937 938 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due 939 // to C/C++ leaking through. The place this happens today is long vs long 940 // long. When arguments are vector<unsigned long, N> and vector<long long, N>, 941 // the long long has higher rank than long even though they are the same size. 942 943 // If this is a compound assignment cast the right hand side to the left hand 944 // side's type. 945 if (IsCompAssign) 946 return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType); 947 948 // If this isn't a compound assignment we convert to unsigned long long. 949 QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy); 950 QualType NewTy = Ctx.getExtVectorType( 951 ElTy, RHSType->castAs<VectorType>()->getNumElements()); 952 (void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy); 953 954 return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy); 955 } 956 957 static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy, 958 QualType SrcTy) { 959 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType()) 960 return CK_FloatingCast; 961 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx)) 962 return CK_IntegralCast; 963 if (DestTy->isRealFloatingType()) 964 return CK_IntegralToFloating; 965 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx)); 966 return CK_FloatingToIntegral; 967 } 968 969 QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, 970 QualType LHSType, 971 QualType RHSType, 972 bool IsCompAssign) { 973 const auto *LVecTy = LHSType->getAs<VectorType>(); 974 const auto *RVecTy = RHSType->getAs<VectorType>(); 975 auto &Ctx = getASTContext(); 976 977 // If the LHS is not a vector and this is a compound assignment, we truncate 978 // the argument to a scalar then convert it to the LHS's type. 979 if (!LVecTy && IsCompAssign) { 980 QualType RElTy = RHSType->castAs<VectorType>()->getElementType(); 981 RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation); 982 RHSType = RHS.get()->getType(); 983 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType)) 984 return LHSType; 985 RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType, 986 getScalarCastKind(Ctx, LHSType, RHSType)); 987 return LHSType; 988 } 989 990 unsigned EndSz = std::numeric_limits<unsigned>::max(); 991 unsigned LSz = 0; 992 if (LVecTy) 993 LSz = EndSz = LVecTy->getNumElements(); 994 if (RVecTy) 995 EndSz = std::min(RVecTy->getNumElements(), EndSz); 996 assert(EndSz != std::numeric_limits<unsigned>::max() && 997 "one of the above should have had a value"); 998 999 // In a compound assignment, the left operand does not change type, the right 1000 // operand is converted to the type of the left operand. 1001 if (IsCompAssign && LSz != EndSz) { 1002 Diag(LHS.get()->getBeginLoc(), 1003 diag::err_hlsl_vector_compound_assignment_truncation) 1004 << LHSType << RHSType; 1005 return QualType(); 1006 } 1007 1008 if (RVecTy && RVecTy->getNumElements() > EndSz) 1009 castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz); 1010 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz) 1011 castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz); 1012 1013 if (!RVecTy) 1014 castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz); 1015 if (!IsCompAssign && !LVecTy) 1016 castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz); 1017 1018 // If we're at the same type after resizing we can stop here. 1019 if (Ctx.hasSameUnqualifiedType(LHSType, RHSType)) 1020 return Ctx.getCommonSugaredType(LHSType, RHSType); 1021 1022 QualType LElTy = LHSType->castAs<VectorType>()->getElementType(); 1023 QualType RElTy = RHSType->castAs<VectorType>()->getElementType(); 1024 1025 // Handle conversion for floating point vectors. 1026 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType()) 1027 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType, 1028 LElTy, RElTy, IsCompAssign); 1029 1030 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) && 1031 "HLSL Vectors can only contain integer or floating point types"); 1032 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType, 1033 LElTy, RElTy, IsCompAssign); 1034 } 1035 1036 void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, 1037 BinaryOperatorKind Opc) { 1038 assert((Opc == BO_LOr || Opc == BO_LAnd) && 1039 "Called with non-logical operator"); 1040 llvm::SmallVector<char, 256> Buff; 1041 llvm::raw_svector_ostream OS(Buff); 1042 PrintingPolicy PP(SemaRef.getLangOpts()); 1043 StringRef NewFnName = Opc == BO_LOr ? "or" : "and"; 1044 OS << NewFnName << "("; 1045 LHS->printPretty(OS, nullptr, PP); 1046 OS << ", "; 1047 RHS->printPretty(OS, nullptr, PP); 1048 OS << ")"; 1049 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc()); 1050 SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion) 1051 << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str()); 1052 } 1053 1054 std::pair<IdentifierInfo *, bool> 1055 SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) { 1056 llvm::hash_code Hash = llvm::hash_value(Signature); 1057 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(Hash); 1058 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(IdStr)); 1059 1060 // Check if we have already found a decl of the same name. 1061 LookupResult R(SemaRef, DeclIdent, SourceLocation(), 1062 Sema::LookupOrdinaryName); 1063 bool Found = SemaRef.LookupQualifiedName(R, SemaRef.CurContext); 1064 return {DeclIdent, Found}; 1065 } 1066 1067 void SemaHLSL::ActOnFinishRootSignatureDecl( 1068 SourceLocation Loc, IdentifierInfo *DeclIdent, 1069 ArrayRef<hlsl::RootSignatureElement> RootElements) { 1070 1071 if (handleRootSignatureElements(RootElements)) 1072 return; 1073 1074 SmallVector<llvm::hlsl::rootsig::RootElement> Elements; 1075 for (auto &RootSigElement : RootElements) 1076 Elements.push_back(RootSigElement.getElement()); 1077 1078 auto *SignatureDecl = HLSLRootSignatureDecl::Create( 1079 SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc, 1080 DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements); 1081 1082 SignatureDecl->setImplicit(); 1083 SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope()); 1084 } 1085 1086 bool SemaHLSL::handleRootSignatureElements( 1087 ArrayRef<hlsl::RootSignatureElement> Elements) { 1088 // Define some common error handling functions 1089 bool HadError = false; 1090 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound, 1091 uint32_t UpperBound) { 1092 HadError = true; 1093 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value) 1094 << LowerBound << UpperBound; 1095 }; 1096 1097 auto ReportFloatError = [this, &HadError](SourceLocation Loc, 1098 float LowerBound, 1099 float UpperBound) { 1100 HadError = true; 1101 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value) 1102 << llvm::formatv("{0:f}", LowerBound).sstr<6>() 1103 << llvm::formatv("{0:f}", UpperBound).sstr<6>(); 1104 }; 1105 1106 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) { 1107 if (!llvm::hlsl::rootsig::verifyRegisterValue(Register)) 1108 ReportError(Loc, 0, 0xfffffffe); 1109 }; 1110 1111 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) { 1112 if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space)) 1113 ReportError(Loc, 0, 0xffffffef); 1114 }; 1115 1116 const uint32_t Version = 1117 llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer); 1118 const uint32_t VersionEnum = Version - 1; 1119 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) { 1120 HadError = true; 1121 this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag) 1122 << /*version minor*/ VersionEnum; 1123 }; 1124 1125 // Iterate through the elements and do basic validations 1126 for (const hlsl::RootSignatureElement &RootSigElem : Elements) { 1127 SourceLocation Loc = RootSigElem.getLocation(); 1128 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement(); 1129 if (const auto *Descriptor = 1130 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) { 1131 VerifyRegister(Loc, Descriptor->Reg.Number); 1132 VerifySpace(Loc, Descriptor->Space); 1133 1134 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag( 1135 Version, llvm::to_underlying(Descriptor->Flags))) 1136 ReportFlagError(Loc); 1137 } else if (const auto *Constants = 1138 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) { 1139 VerifyRegister(Loc, Constants->Reg.Number); 1140 VerifySpace(Loc, Constants->Space); 1141 } else if (const auto *Sampler = 1142 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) { 1143 VerifyRegister(Loc, Sampler->Reg.Number); 1144 VerifySpace(Loc, Sampler->Space); 1145 1146 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) && 1147 "By construction, parseFloatParam can't produce a NaN from a " 1148 "float_literal token"); 1149 1150 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy)) 1151 ReportError(Loc, 0, 16); 1152 if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias)) 1153 ReportFloatError(Loc, -16.f, 15.99); 1154 } else if (const auto *Clause = 1155 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>( 1156 &Elem)) { 1157 VerifyRegister(Loc, Clause->Reg.Number); 1158 VerifySpace(Loc, Clause->Space); 1159 1160 if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) { 1161 // NumDescriptor could techincally be ~0u but that is reserved for 1162 // unbounded, so the diagnostic will not report that as a valid int 1163 // value 1164 ReportError(Loc, 1, 0xfffffffe); 1165 } 1166 1167 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( 1168 Version, llvm::to_underlying(Clause->Type), 1169 llvm::to_underlying(Clause->Flags))) 1170 ReportFlagError(Loc); 1171 } 1172 } 1173 1174 using RangeInfo = llvm::hlsl::rootsig::RangeInfo; 1175 using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges; 1176 using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>; 1177 1178 // 1. Collect RangeInfos 1179 llvm::SmallVector<InfoPairT> InfoPairs; 1180 for (const hlsl::RootSignatureElement &RootSigElem : Elements) { 1181 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement(); 1182 if (const auto *Descriptor = 1183 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) { 1184 RangeInfo Info; 1185 Info.LowerBound = Descriptor->Reg.Number; 1186 Info.UpperBound = Info.LowerBound; // use inclusive ranges [] 1187 1188 Info.Class = 1189 llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type)); 1190 Info.Space = Descriptor->Space; 1191 Info.Visibility = Descriptor->Visibility; 1192 1193 InfoPairs.push_back({Info, &RootSigElem}); 1194 } else if (const auto *Constants = 1195 std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) { 1196 RangeInfo Info; 1197 Info.LowerBound = Constants->Reg.Number; 1198 Info.UpperBound = Info.LowerBound; // use inclusive ranges [] 1199 1200 Info.Class = llvm::dxil::ResourceClass::CBuffer; 1201 Info.Space = Constants->Space; 1202 Info.Visibility = Constants->Visibility; 1203 1204 InfoPairs.push_back({Info, &RootSigElem}); 1205 } else if (const auto *Sampler = 1206 std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) { 1207 RangeInfo Info; 1208 Info.LowerBound = Sampler->Reg.Number; 1209 Info.UpperBound = Info.LowerBound; // use inclusive ranges [] 1210 1211 Info.Class = llvm::dxil::ResourceClass::Sampler; 1212 Info.Space = Sampler->Space; 1213 Info.Visibility = Sampler->Visibility; 1214 1215 InfoPairs.push_back({Info, &RootSigElem}); 1216 } else if (const auto *Clause = 1217 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>( 1218 &Elem)) { 1219 RangeInfo Info; 1220 Info.LowerBound = Clause->Reg.Number; 1221 // Relevant error will have already been reported above and needs to be 1222 // fixed before we can conduct range analysis, so shortcut error return 1223 if (Clause->NumDescriptors == 0) 1224 return true; 1225 Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded 1226 ? RangeInfo::Unbounded 1227 : Info.LowerBound + Clause->NumDescriptors - 1228 1; // use inclusive ranges [] 1229 1230 Info.Class = Clause->Type; 1231 Info.Space = Clause->Space; 1232 1233 // Note: Clause does not hold the visibility this will need to 1234 InfoPairs.push_back({Info, &RootSigElem}); 1235 } else if (const auto *Table = 1236 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) { 1237 // Table holds the Visibility of all owned Clauses in Table, so iterate 1238 // owned Clauses and update their corresponding RangeInfo 1239 assert(Table->NumClauses <= InfoPairs.size() && "RootElement"); 1240 // The last Table->NumClauses elements of Infos are the owned Clauses 1241 // generated RangeInfo 1242 auto TableInfos = 1243 MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses); 1244 for (InfoPairT &Pair : TableInfos) 1245 Pair.first.Visibility = Table->Visibility; 1246 } 1247 } 1248 1249 // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping 1250 llvm::sort(InfoPairs, 1251 [](InfoPairT A, InfoPairT B) { return A.first < B.first; }); 1252 1253 llvm::SmallVector<RangeInfo> Infos; 1254 for (const InfoPairT &Pair : InfoPairs) 1255 Infos.push_back(Pair.first); 1256 1257 // Helpers to report diagnostics 1258 uint32_t DuplicateCounter = 0; 1259 using ElemPair = std::pair<const hlsl::RootSignatureElement *, 1260 const hlsl::RootSignatureElement *>; 1261 auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter]( 1262 OverlappingRanges Overlap) -> ElemPair { 1263 // Given we sorted the InfoPairs (and by implication) Infos, and, 1264 // that Overlap.B is the item retrieved from the ResourceRange. Then it is 1265 // guarenteed that Overlap.B <= Overlap.A. 1266 // 1267 // So we will find Overlap.B first and then continue to find Overlap.A 1268 // after 1269 auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B); 1270 auto DistB = std::distance(Infos.begin(), InfoB); 1271 auto PairB = InfoPairs.begin(); 1272 std::advance(PairB, DistB); 1273 1274 auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A); 1275 // Similarily, from the property that we have sorted the RangeInfos, 1276 // all duplicates will be processed one after the other. So 1277 // DuplicateCounter can be re-used for each set of duplicates we 1278 // encounter as we handle incoming errors 1279 DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0; 1280 auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter; 1281 auto PairA = PairB; 1282 std::advance(PairA, DistA); 1283 1284 return {PairA->second, PairB->second}; 1285 }; 1286 1287 auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) { 1288 auto Pair = GetElemPair(Overlap); 1289 const RangeInfo *Info = Overlap.A; 1290 const hlsl::RootSignatureElement *Elem = Pair.first; 1291 const RangeInfo *OInfo = Overlap.B; 1292 1293 auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All 1294 ? OInfo->Visibility 1295 : Info->Visibility; 1296 this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap) 1297 << llvm::to_underlying(Info->Class) << Info->LowerBound 1298 << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded) 1299 << Info->UpperBound << llvm::to_underlying(OInfo->Class) 1300 << OInfo->LowerBound 1301 << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded) 1302 << OInfo->UpperBound << Info->Space << CommonVis; 1303 1304 const hlsl::RootSignatureElement *OElem = Pair.second; 1305 this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here); 1306 }; 1307 1308 // 3. Invoke find overlapping ranges 1309 llvm::SmallVector<OverlappingRanges> Overlaps = 1310 llvm::hlsl::rootsig::findOverlappingRanges(Infos); 1311 for (OverlappingRanges Overlap : Overlaps) 1312 ReportOverlap(Overlap); 1313 1314 return Overlaps.size() != 0; 1315 } 1316 1317 void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) { 1318 if (AL.getNumArgs() != 1) { 1319 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1; 1320 return; 1321 } 1322 1323 IdentifierInfo *Ident = AL.getArgAsIdent(0)->getIdentifierInfo(); 1324 if (auto *RS = D->getAttr<RootSignatureAttr>()) { 1325 if (RS->getSignatureIdent() != Ident) { 1326 Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS; 1327 return; 1328 } 1329 1330 Diag(AL.getLoc(), diag::warn_duplicate_attribute_exact) << RS; 1331 return; 1332 } 1333 1334 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName); 1335 if (SemaRef.LookupQualifiedName(R, D->getDeclContext())) 1336 if (auto *SignatureDecl = 1337 dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) { 1338 D->addAttr(::new (getASTContext()) RootSignatureAttr( 1339 getASTContext(), AL, Ident, SignatureDecl)); 1340 } 1341 } 1342 1343 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) { 1344 llvm::VersionTuple SMVersion = 1345 getASTContext().getTargetInfo().getTriple().getOSVersion(); 1346 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() == 1347 llvm::Triple::dxil; 1348 1349 uint32_t ZMax = 1024; 1350 uint32_t ThreadMax = 1024; 1351 if (IsDXIL && SMVersion.getMajor() <= 4) { 1352 ZMax = 1; 1353 ThreadMax = 768; 1354 } else if (IsDXIL && SMVersion.getMajor() == 5) { 1355 ZMax = 64; 1356 ThreadMax = 1024; 1357 } 1358 1359 uint32_t X; 1360 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X)) 1361 return; 1362 if (X > 1024) { 1363 Diag(AL.getArgAsExpr(0)->getExprLoc(), 1364 diag::err_hlsl_numthreads_argument_oor) 1365 << 0 << 1024; 1366 return; 1367 } 1368 uint32_t Y; 1369 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y)) 1370 return; 1371 if (Y > 1024) { 1372 Diag(AL.getArgAsExpr(1)->getExprLoc(), 1373 diag::err_hlsl_numthreads_argument_oor) 1374 << 1 << 1024; 1375 return; 1376 } 1377 uint32_t Z; 1378 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z)) 1379 return; 1380 if (Z > ZMax) { 1381 SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(), 1382 diag::err_hlsl_numthreads_argument_oor) 1383 << 2 << ZMax; 1384 return; 1385 } 1386 1387 if (X * Y * Z > ThreadMax) { 1388 Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax; 1389 return; 1390 } 1391 1392 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z); 1393 if (NewAttr) 1394 D->addAttr(NewAttr); 1395 } 1396 1397 static bool isValidWaveSizeValue(unsigned Value) { 1398 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128; 1399 } 1400 1401 void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) { 1402 // validate that the wavesize argument is a power of 2 between 4 and 128 1403 // inclusive 1404 unsigned SpelledArgsCount = AL.getNumArgs(); 1405 if (SpelledArgsCount == 0 || SpelledArgsCount > 3) 1406 return; 1407 1408 uint32_t Min; 1409 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min)) 1410 return; 1411 1412 uint32_t Max = 0; 1413 if (SpelledArgsCount > 1 && 1414 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max)) 1415 return; 1416 1417 uint32_t Preferred = 0; 1418 if (SpelledArgsCount > 2 && 1419 !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred)) 1420 return; 1421 1422 if (SpelledArgsCount > 2) { 1423 if (!isValidWaveSizeValue(Preferred)) { 1424 Diag(AL.getArgAsExpr(2)->getExprLoc(), 1425 diag::err_attribute_power_of_two_in_range) 1426 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize 1427 << Preferred; 1428 return; 1429 } 1430 // Preferred not in range. 1431 if (Preferred < Min || Preferred > Max) { 1432 Diag(AL.getArgAsExpr(2)->getExprLoc(), 1433 diag::err_attribute_power_of_two_in_range) 1434 << AL << Min << Max << Preferred; 1435 return; 1436 } 1437 } else if (SpelledArgsCount > 1) { 1438 if (!isValidWaveSizeValue(Max)) { 1439 Diag(AL.getArgAsExpr(1)->getExprLoc(), 1440 diag::err_attribute_power_of_two_in_range) 1441 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max; 1442 return; 1443 } 1444 if (Max < Min) { 1445 Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1; 1446 return; 1447 } else if (Max == Min) { 1448 Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL; 1449 } 1450 } else { 1451 if (!isValidWaveSizeValue(Min)) { 1452 Diag(AL.getArgAsExpr(0)->getExprLoc(), 1453 diag::err_attribute_power_of_two_in_range) 1454 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min; 1455 return; 1456 } 1457 } 1458 1459 HLSLWaveSizeAttr *NewAttr = 1460 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount); 1461 if (NewAttr) 1462 D->addAttr(NewAttr); 1463 } 1464 1465 void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) { 1466 uint32_t ID; 1467 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID)) 1468 return; 1469 D->addAttr(::new (getASTContext()) 1470 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID)); 1471 } 1472 1473 void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) { 1474 uint32_t Id; 1475 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id)) 1476 return; 1477 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id); 1478 if (NewAttr) 1479 D->addAttr(NewAttr); 1480 } 1481 1482 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) { 1483 const auto *VT = T->getAs<VectorType>(); 1484 1485 if (!T->hasUnsignedIntegerRepresentation() || 1486 (VT && VT->getNumElements() > 3)) { 1487 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) 1488 << AL << "uint/uint2/uint3"; 1489 return false; 1490 } 1491 1492 return true; 1493 } 1494 1495 void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { 1496 auto *VD = cast<ValueDecl>(D); 1497 if (!diagnoseInputIDType(VD->getType(), AL)) 1498 return; 1499 1500 D->addAttr(::new (getASTContext()) 1501 HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); 1502 } 1503 1504 bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) { 1505 const auto *VT = T->getAs<VectorType>(); 1506 1507 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) { 1508 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) 1509 << AL << "float/float1/float2/float3/float4"; 1510 return false; 1511 } 1512 1513 return true; 1514 } 1515 1516 void SemaHLSL::handleSV_PositionAttr(Decl *D, const ParsedAttr &AL) { 1517 auto *VD = cast<ValueDecl>(D); 1518 if (!diagnosePositionType(VD->getType(), AL)) 1519 return; 1520 1521 D->addAttr(::new (getASTContext()) HLSLSV_PositionAttr(getASTContext(), AL)); 1522 } 1523 1524 void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) { 1525 auto *VD = cast<ValueDecl>(D); 1526 if (!diagnoseInputIDType(VD->getType(), AL)) 1527 return; 1528 1529 D->addAttr(::new (getASTContext()) 1530 HLSLSV_GroupThreadIDAttr(getASTContext(), AL)); 1531 } 1532 1533 void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) { 1534 auto *VD = cast<ValueDecl>(D); 1535 if (!diagnoseInputIDType(VD->getType(), AL)) 1536 return; 1537 1538 D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL)); 1539 } 1540 1541 void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) { 1542 if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) { 1543 Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node) 1544 << AL << "shader constant in a constant buffer"; 1545 return; 1546 } 1547 1548 uint32_t SubComponent; 1549 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent)) 1550 return; 1551 uint32_t Component; 1552 if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component)) 1553 return; 1554 1555 QualType T = cast<VarDecl>(D)->getType().getCanonicalType(); 1556 // Check if T is an array or struct type. 1557 // TODO: mark matrix type as aggregate type. 1558 bool IsAggregateTy = (T->isArrayType() || T->isStructureType()); 1559 1560 // Check Component is valid for T. 1561 if (Component) { 1562 unsigned Size = getASTContext().getTypeSize(T); 1563 if (IsAggregateTy || Size > 128) { 1564 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary); 1565 return; 1566 } else { 1567 // Make sure Component + sizeof(T) <= 4. 1568 if ((Component * 32 + Size) > 128) { 1569 Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary); 1570 return; 1571 } 1572 QualType EltTy = T; 1573 if (const auto *VT = T->getAs<VectorType>()) 1574 EltTy = VT->getElementType(); 1575 unsigned Align = getASTContext().getTypeAlign(EltTy); 1576 if (Align > 32 && Component == 1) { 1577 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary. 1578 // So we only need to check Component 1 here. 1579 Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch) 1580 << Align << EltTy; 1581 return; 1582 } 1583 } 1584 } 1585 1586 D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr( 1587 getASTContext(), AL, SubComponent, Component)); 1588 } 1589 1590 void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) { 1591 StringRef Str; 1592 SourceLocation ArgLoc; 1593 if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc)) 1594 return; 1595 1596 llvm::Triple::EnvironmentType ShaderType; 1597 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) { 1598 Diag(AL.getLoc(), diag::warn_attribute_type_not_supported) 1599 << AL << Str << ArgLoc; 1600 return; 1601 } 1602 1603 // FIXME: check function match the shader stage. 1604 1605 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType); 1606 if (NewAttr) 1607 D->addAttr(NewAttr); 1608 } 1609 1610 bool clang::CreateHLSLAttributedResourceType( 1611 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList, 1612 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) { 1613 assert(AttrList.size() && "expected list of resource attributes"); 1614 1615 QualType ContainedTy = QualType(); 1616 TypeSourceInfo *ContainedTyInfo = nullptr; 1617 SourceLocation LocBegin = AttrList[0]->getRange().getBegin(); 1618 SourceLocation LocEnd = AttrList[0]->getRange().getEnd(); 1619 1620 HLSLAttributedResourceType::Attributes ResAttrs; 1621 1622 bool HasResourceClass = false; 1623 for (const Attr *A : AttrList) { 1624 if (!A) 1625 continue; 1626 LocEnd = A->getRange().getEnd(); 1627 switch (A->getKind()) { 1628 case attr::HLSLResourceClass: { 1629 ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass(); 1630 if (HasResourceClass) { 1631 S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC 1632 ? diag::warn_duplicate_attribute_exact 1633 : diag::warn_duplicate_attribute) 1634 << A; 1635 return false; 1636 } 1637 ResAttrs.ResourceClass = RC; 1638 HasResourceClass = true; 1639 break; 1640 } 1641 case attr::HLSLROV: 1642 if (ResAttrs.IsROV) { 1643 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A; 1644 return false; 1645 } 1646 ResAttrs.IsROV = true; 1647 break; 1648 case attr::HLSLRawBuffer: 1649 if (ResAttrs.RawBuffer) { 1650 S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A; 1651 return false; 1652 } 1653 ResAttrs.RawBuffer = true; 1654 break; 1655 case attr::HLSLContainedType: { 1656 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A); 1657 QualType Ty = CTAttr->getType(); 1658 if (!ContainedTy.isNull()) { 1659 S.Diag(A->getLocation(), ContainedTy == Ty 1660 ? diag::warn_duplicate_attribute_exact 1661 : diag::warn_duplicate_attribute) 1662 << A; 1663 return false; 1664 } 1665 ContainedTy = Ty; 1666 ContainedTyInfo = CTAttr->getTypeLoc(); 1667 break; 1668 } 1669 default: 1670 llvm_unreachable("unhandled resource attribute type"); 1671 } 1672 } 1673 1674 if (!HasResourceClass) { 1675 S.Diag(AttrList.back()->getRange().getEnd(), 1676 diag::err_hlsl_missing_resource_class); 1677 return false; 1678 } 1679 1680 ResType = S.getASTContext().getHLSLAttributedResourceType( 1681 Wrapped, ContainedTy, ResAttrs); 1682 1683 if (LocInfo && ContainedTyInfo) { 1684 LocInfo->Range = SourceRange(LocBegin, LocEnd); 1685 LocInfo->ContainedTyInfo = ContainedTyInfo; 1686 } 1687 return true; 1688 } 1689 1690 // Validates and creates an HLSL attribute that is applied as type attribute on 1691 // HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at 1692 // the end of the declaration they are applied to the declaration type by 1693 // wrapping it in HLSLAttributedResourceType. 1694 bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) { 1695 // only allow resource type attributes on intangible types 1696 if (!T->isHLSLResourceType()) { 1697 Diag(AL.getLoc(), diag::err_hlsl_attribute_needs_intangible_type) 1698 << AL << getASTContext().HLSLResourceTy; 1699 return false; 1700 } 1701 1702 // validate number of arguments 1703 if (!AL.checkExactlyNumArgs(SemaRef, AL.getMinArgs())) 1704 return false; 1705 1706 Attr *A = nullptr; 1707 1708 AttributeCommonInfo ACI( 1709 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()), 1710 AttributeCommonInfo::NoSemaHandlerAttribute, 1711 { 1712 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/, 1713 false /*IsRegularKeywordAttribute*/ 1714 }); 1715 1716 switch (AL.getKind()) { 1717 case ParsedAttr::AT_HLSLResourceClass: { 1718 if (!AL.isArgIdent(0)) { 1719 Diag(AL.getLoc(), diag::err_attribute_argument_type) 1720 << AL << AANT_ArgumentIdentifier; 1721 return false; 1722 } 1723 1724 IdentifierLoc *Loc = AL.getArgAsIdent(0); 1725 StringRef Identifier = Loc->getIdentifierInfo()->getName(); 1726 SourceLocation ArgLoc = Loc->getLoc(); 1727 1728 // Validate resource class value 1729 ResourceClass RC; 1730 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) { 1731 Diag(ArgLoc, diag::warn_attribute_type_not_supported) 1732 << "ResourceClass" << Identifier; 1733 return false; 1734 } 1735 A = HLSLResourceClassAttr::Create(getASTContext(), RC, ACI); 1736 break; 1737 } 1738 1739 case ParsedAttr::AT_HLSLROV: 1740 A = HLSLROVAttr::Create(getASTContext(), ACI); 1741 break; 1742 1743 case ParsedAttr::AT_HLSLRawBuffer: 1744 A = HLSLRawBufferAttr::Create(getASTContext(), ACI); 1745 break; 1746 1747 case ParsedAttr::AT_HLSLContainedType: { 1748 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) { 1749 Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1; 1750 return false; 1751 } 1752 1753 TypeSourceInfo *TSI = nullptr; 1754 QualType QT = SemaRef.GetTypeFromParser(AL.getTypeArg(), &TSI); 1755 assert(TSI && "no type source info for attribute argument"); 1756 if (SemaRef.RequireCompleteType(TSI->getTypeLoc().getBeginLoc(), QT, 1757 diag::err_incomplete_type)) 1758 return false; 1759 A = HLSLContainedTypeAttr::Create(getASTContext(), TSI, ACI); 1760 break; 1761 } 1762 1763 default: 1764 llvm_unreachable("unhandled HLSL attribute"); 1765 } 1766 1767 HLSLResourcesTypeAttrs.emplace_back(A); 1768 return true; 1769 } 1770 1771 // Combines all resource type attributes and creates HLSLAttributedResourceType. 1772 QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) { 1773 if (!HLSLResourcesTypeAttrs.size()) 1774 return CurrentType; 1775 1776 QualType QT = CurrentType; 1777 HLSLAttributedResourceLocInfo LocInfo; 1778 if (CreateHLSLAttributedResourceType(SemaRef, CurrentType, 1779 HLSLResourcesTypeAttrs, QT, &LocInfo)) { 1780 const HLSLAttributedResourceType *RT = 1781 cast<HLSLAttributedResourceType>(QT.getTypePtr()); 1782 1783 // Temporarily store TypeLoc information for the new type. 1784 // It will be transferred to HLSLAttributesResourceTypeLoc 1785 // shortly after the type is created by TypeSpecLocFiller which 1786 // will call the TakeLocForHLSLAttribute method below. 1787 LocsForHLSLAttributedResources.insert(std::pair(RT, LocInfo)); 1788 } 1789 HLSLResourcesTypeAttrs.clear(); 1790 return QT; 1791 } 1792 1793 // Returns source location for the HLSLAttributedResourceType 1794 HLSLAttributedResourceLocInfo 1795 SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { 1796 HLSLAttributedResourceLocInfo LocInfo = {}; 1797 auto I = LocsForHLSLAttributedResources.find(RT); 1798 if (I != LocsForHLSLAttributedResources.end()) { 1799 LocInfo = I->second; 1800 LocsForHLSLAttributedResources.erase(I); 1801 return LocInfo; 1802 } 1803 LocInfo.Range = SourceRange(); 1804 return LocInfo; 1805 } 1806 1807 // Walks though the global variable declaration, collects all resource binding 1808 // requirements and adds them to Bindings 1809 void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD, 1810 const RecordType *RT) { 1811 const RecordDecl *RD = RT->getDecl(); 1812 for (FieldDecl *FD : RD->fields()) { 1813 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); 1814 1815 // Unwrap arrays 1816 // FIXME: Calculate array size while unwrapping 1817 assert(!Ty->isIncompleteArrayType() && 1818 "incomplete arrays inside user defined types are not supported"); 1819 while (Ty->isConstantArrayType()) { 1820 const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); 1821 Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); 1822 } 1823 1824 if (!Ty->isRecordType()) 1825 continue; 1826 1827 if (const HLSLAttributedResourceType *AttrResType = 1828 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) { 1829 // Add a new DeclBindingInfo to Bindings if it does not already exist 1830 ResourceClass RC = AttrResType->getAttrs().ResourceClass; 1831 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC); 1832 if (!DBI) 1833 Bindings.addDeclBindingInfo(VD, RC); 1834 } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) { 1835 // Recursively scan embedded struct or class; it would be nice to do this 1836 // without recursion, but tricky to correctly calculate the size of the 1837 // binding, which is something we are probably going to need to do later 1838 // on. Hopefully nesting of structs in structs too many levels is 1839 // unlikely. 1840 collectResourceBindingsOnUserRecordDecl(VD, RT); 1841 } 1842 } 1843 } 1844 1845 // Diagnose localized register binding errors for a single binding; does not 1846 // diagnose resource binding on user record types, that will be done later 1847 // in processResourceBindingOnDecl based on the information collected in 1848 // collectResourceBindingsOnVarDecl. 1849 // Returns false if the register binding is not valid. 1850 static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, 1851 Decl *D, RegisterType RegType, 1852 bool SpecifiedSpace) { 1853 int RegTypeNum = static_cast<int>(RegType); 1854 1855 // check if the decl type is groupshared 1856 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { 1857 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1858 return false; 1859 } 1860 1861 // Cbuffers and Tbuffers are HLSLBufferDecl types 1862 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { 1863 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer 1864 : ResourceClass::SRV; 1865 if (RegType == getRegisterType(RC)) 1866 return true; 1867 1868 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) 1869 << RegTypeNum; 1870 return false; 1871 } 1872 1873 // Samplers, UAVs, and SRVs are VarDecl types 1874 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); 1875 VarDecl *VD = cast<VarDecl>(D); 1876 1877 // Resource 1878 if (const HLSLAttributedResourceType *AttrResType = 1879 HLSLAttributedResourceType::findHandleTypeOnResource( 1880 VD->getType().getTypePtr())) { 1881 if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass)) 1882 return true; 1883 1884 S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) 1885 << RegTypeNum; 1886 return false; 1887 } 1888 1889 const clang::Type *Ty = VD->getType().getTypePtr(); 1890 while (Ty->isArrayType()) 1891 Ty = Ty->getArrayElementTypeNoTypeQual(); 1892 1893 // Basic types 1894 if (Ty->isArithmeticType() || Ty->isVectorType()) { 1895 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); 1896 if (SpecifiedSpace && !DeclaredInCOrTBuffer) 1897 S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); 1898 1899 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) || 1900 Ty->isFloatingType() || Ty->isVectorType())) { 1901 // Register annotation on default constant buffer declaration ($Globals) 1902 if (RegType == RegisterType::CBuffer) 1903 S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); 1904 else if (RegType != RegisterType::C) 1905 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1906 else 1907 return true; 1908 } else { 1909 if (RegType == RegisterType::C) 1910 S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); 1911 else 1912 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1913 } 1914 return false; 1915 } 1916 if (Ty->isRecordType()) 1917 // RecordTypes will be diagnosed in processResourceBindingOnDecl 1918 // that is called from ActOnVariableDeclarator 1919 return true; 1920 1921 // Anything else is an error 1922 S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; 1923 return false; 1924 } 1925 1926 static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, 1927 RegisterType regType) { 1928 // make sure that there are no two register annotations 1929 // applied to the decl with the same register type 1930 bool RegisterTypesDetected[5] = {false}; 1931 RegisterTypesDetected[static_cast<int>(regType)] = true; 1932 1933 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) { 1934 if (HLSLResourceBindingAttr *attr = 1935 dyn_cast<HLSLResourceBindingAttr>(*it)) { 1936 1937 RegisterType otherRegType = attr->getRegisterType(); 1938 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) { 1939 int otherRegTypeNum = static_cast<int>(otherRegType); 1940 S.Diag(TheDecl->getLocation(), 1941 diag::err_hlsl_duplicate_register_annotation) 1942 << otherRegTypeNum; 1943 return false; 1944 } 1945 RegisterTypesDetected[static_cast<int>(otherRegType)] = true; 1946 } 1947 } 1948 return true; 1949 } 1950 1951 static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, 1952 Decl *D, RegisterType RegType, 1953 bool SpecifiedSpace) { 1954 1955 // exactly one of these two types should be set 1956 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) || 1957 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) && 1958 "expecting VarDecl or HLSLBufferDecl"); 1959 1960 // check if the declaration contains resource matching the register type 1961 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace)) 1962 return false; 1963 1964 // next, if multiple register annotations exist, check that none conflict. 1965 return ValidateMultipleRegisterAnnotations(S, D, RegType); 1966 } 1967 1968 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { 1969 if (isa<VarDecl>(TheDecl)) { 1970 if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), 1971 cast<ValueDecl>(TheDecl)->getType(), 1972 diag::err_incomplete_type)) 1973 return; 1974 } 1975 1976 StringRef Slot = ""; 1977 StringRef Space = ""; 1978 SourceLocation SlotLoc, SpaceLoc; 1979 1980 if (!AL.isArgIdent(0)) { 1981 Diag(AL.getLoc(), diag::err_attribute_argument_type) 1982 << AL << AANT_ArgumentIdentifier; 1983 return; 1984 } 1985 IdentifierLoc *Loc = AL.getArgAsIdent(0); 1986 1987 if (AL.getNumArgs() == 2) { 1988 Slot = Loc->getIdentifierInfo()->getName(); 1989 SlotLoc = Loc->getLoc(); 1990 if (!AL.isArgIdent(1)) { 1991 Diag(AL.getLoc(), diag::err_attribute_argument_type) 1992 << AL << AANT_ArgumentIdentifier; 1993 return; 1994 } 1995 Loc = AL.getArgAsIdent(1); 1996 Space = Loc->getIdentifierInfo()->getName(); 1997 SpaceLoc = Loc->getLoc(); 1998 } else { 1999 StringRef Str = Loc->getIdentifierInfo()->getName(); 2000 if (Str.starts_with("space")) { 2001 Space = Str; 2002 SpaceLoc = Loc->getLoc(); 2003 } else { 2004 Slot = Str; 2005 SlotLoc = Loc->getLoc(); 2006 Space = "space0"; 2007 } 2008 } 2009 2010 RegisterType RegType = RegisterType::SRV; 2011 std::optional<unsigned> SlotNum; 2012 unsigned SpaceNum = 0; 2013 2014 // Validate slot 2015 if (!Slot.empty()) { 2016 if (!convertToRegisterType(Slot, &RegType)) { 2017 Diag(SlotLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1); 2018 return; 2019 } 2020 if (RegType == RegisterType::I) { 2021 Diag(SlotLoc, diag::warn_hlsl_deprecated_register_type_i); 2022 return; 2023 } 2024 StringRef SlotNumStr = Slot.substr(1); 2025 unsigned N; 2026 if (SlotNumStr.getAsInteger(10, N)) { 2027 Diag(SlotLoc, diag::err_hlsl_unsupported_register_number); 2028 return; 2029 } 2030 SlotNum = N; 2031 } 2032 2033 // Validate space 2034 if (!Space.starts_with("space")) { 2035 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space; 2036 return; 2037 } 2038 StringRef SpaceNumStr = Space.substr(5); 2039 if (SpaceNumStr.getAsInteger(10, SpaceNum)) { 2040 Diag(SpaceLoc, diag::err_hlsl_expected_space) << Space; 2041 return; 2042 } 2043 2044 // If we have slot, diagnose it is the right register type for the decl 2045 if (SlotNum.has_value()) 2046 if (!DiagnoseHLSLRegisterAttribute(SemaRef, SlotLoc, TheDecl, RegType, 2047 !SpaceLoc.isInvalid())) 2048 return; 2049 2050 HLSLResourceBindingAttr *NewAttr = 2051 HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL); 2052 if (NewAttr) { 2053 NewAttr->setBinding(RegType, SlotNum, SpaceNum); 2054 TheDecl->addAttr(NewAttr); 2055 } 2056 } 2057 2058 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) { 2059 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr( 2060 D, AL, 2061 static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling())); 2062 if (NewAttr) 2063 D->addAttr(NewAttr); 2064 } 2065 2066 namespace { 2067 2068 /// This class implements HLSL availability diagnostics for default 2069 /// and relaxed mode 2070 /// 2071 /// The goal of this diagnostic is to emit an error or warning when an 2072 /// unavailable API is found in code that is reachable from the shader 2073 /// entry function or from an exported function (when compiling a shader 2074 /// library). 2075 /// 2076 /// This is done by traversing the AST of all shader entry point functions 2077 /// and of all exported functions, and any functions that are referenced 2078 /// from this AST. In other words, any functions that are reachable from 2079 /// the entry points. 2080 class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor { 2081 Sema &SemaRef; 2082 2083 // Stack of functions to be scaned 2084 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan; 2085 2086 // Tracks which environments functions have been scanned in. 2087 // 2088 // Maps FunctionDecl to an unsigned number that represents the set of shader 2089 // environments the function has been scanned for. 2090 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed 2091 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification 2092 // (verified by static_asserts in Triple.cpp), we can use it to index 2093 // individual bits in the set, as long as we shift the values to start with 0 2094 // by subtracting the value of llvm::Triple::Pixel first. 2095 // 2096 // The N'th bit in the set will be set if the function has been scanned 2097 // in shader environment whose llvm::Triple::EnvironmentType integer value 2098 // equals (llvm::Triple::Pixel + N). 2099 // 2100 // For example, if a function has been scanned in compute and pixel stage 2101 // environment, the value will be 0x21 (100001 binary) because: 2102 // 2103 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0 2104 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5 2105 // 2106 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not 2107 // been scanned in any environment. 2108 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls; 2109 2110 // Do not access these directly, use the get/set methods below to make 2111 // sure the values are in sync 2112 llvm::Triple::EnvironmentType CurrentShaderEnvironment; 2113 unsigned CurrentShaderStageBit; 2114 2115 // True if scanning a function that was already scanned in a different 2116 // shader stage context, and therefore we should not report issues that 2117 // depend only on shader model version because they would be duplicate. 2118 bool ReportOnlyShaderStageIssues; 2119 2120 // Helper methods for dealing with current stage context / environment 2121 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) { 2122 static_assert(sizeof(unsigned) >= 4); 2123 assert(HLSLShaderAttr::isValidShaderType(ShaderType)); 2124 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 && 2125 "ShaderType is too big for this bitmap"); // 31 is reserved for 2126 // "unknown" 2127 2128 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel; 2129 CurrentShaderEnvironment = ShaderType; 2130 CurrentShaderStageBit = (1 << bitmapIndex); 2131 } 2132 2133 void SetUnknownShaderStageContext() { 2134 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment; 2135 CurrentShaderStageBit = (1 << 31); 2136 } 2137 2138 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const { 2139 return CurrentShaderEnvironment; 2140 } 2141 2142 bool InUnknownShaderStageContext() const { 2143 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment; 2144 } 2145 2146 // Helper methods for dealing with shader stage bitmap 2147 void AddToScannedFunctions(const FunctionDecl *FD) { 2148 unsigned &ScannedStages = ScannedDecls[FD]; 2149 ScannedStages |= CurrentShaderStageBit; 2150 } 2151 2152 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; } 2153 2154 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) { 2155 return WasAlreadyScannedInCurrentStage(GetScannedStages(FD)); 2156 } 2157 2158 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) { 2159 return ScannerStages & CurrentShaderStageBit; 2160 } 2161 2162 static bool NeverBeenScanned(unsigned ScannedStages) { 2163 return ScannedStages == 0; 2164 } 2165 2166 // Scanning methods 2167 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr); 2168 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA, 2169 SourceRange Range); 2170 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D); 2171 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA); 2172 2173 public: 2174 DiagnoseHLSLAvailability(Sema &SemaRef) 2175 : SemaRef(SemaRef), 2176 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment), 2177 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {} 2178 2179 // AST traversal methods 2180 void RunOnTranslationUnit(const TranslationUnitDecl *TU); 2181 void RunOnFunction(const FunctionDecl *FD); 2182 2183 bool VisitDeclRefExpr(DeclRefExpr *DRE) override { 2184 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl()); 2185 if (FD) 2186 HandleFunctionOrMethodRef(FD, DRE); 2187 return true; 2188 } 2189 2190 bool VisitMemberExpr(MemberExpr *ME) override { 2191 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl()); 2192 if (FD) 2193 HandleFunctionOrMethodRef(FD, ME); 2194 return true; 2195 } 2196 }; 2197 2198 void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD, 2199 Expr *RefExpr) { 2200 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) && 2201 "expected DeclRefExpr or MemberExpr"); 2202 2203 // has a definition -> add to stack to be scanned 2204 const FunctionDecl *FDWithBody = nullptr; 2205 if (FD->hasBody(FDWithBody)) { 2206 if (!WasAlreadyScannedInCurrentStage(FDWithBody)) 2207 DeclsToScan.push_back(FDWithBody); 2208 return; 2209 } 2210 2211 // no body -> diagnose availability 2212 const AvailabilityAttr *AA = FindAvailabilityAttr(FD); 2213 if (AA) 2214 CheckDeclAvailability( 2215 FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc())); 2216 } 2217 2218 void DiagnoseHLSLAvailability::RunOnTranslationUnit( 2219 const TranslationUnitDecl *TU) { 2220 2221 // Iterate over all shader entry functions and library exports, and for those 2222 // that have a body (definiton), run diag scan on each, setting appropriate 2223 // shader environment context based on whether it is a shader entry function 2224 // or an exported function. Exported functions can be in namespaces and in 2225 // export declarations so we need to scan those declaration contexts as well. 2226 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan; 2227 DeclContextsToScan.push_back(TU); 2228 2229 while (!DeclContextsToScan.empty()) { 2230 const DeclContext *DC = DeclContextsToScan.pop_back_val(); 2231 for (auto &D : DC->decls()) { 2232 // do not scan implicit declaration generated by the implementation 2233 if (D->isImplicit()) 2234 continue; 2235 2236 // for namespace or export declaration add the context to the list to be 2237 // scanned later 2238 if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) { 2239 DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D)); 2240 continue; 2241 } 2242 2243 // skip over other decls or function decls without body 2244 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D); 2245 if (!FD || !FD->isThisDeclarationADefinition()) 2246 continue; 2247 2248 // shader entry point 2249 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) { 2250 SetShaderStageContext(ShaderAttr->getType()); 2251 RunOnFunction(FD); 2252 continue; 2253 } 2254 // exported library function 2255 // FIXME: replace this loop with external linkage check once issue #92071 2256 // is resolved 2257 bool isExport = FD->isInExportDeclContext(); 2258 if (!isExport) { 2259 for (const auto *Redecl : FD->redecls()) { 2260 if (Redecl->isInExportDeclContext()) { 2261 isExport = true; 2262 break; 2263 } 2264 } 2265 } 2266 if (isExport) { 2267 SetUnknownShaderStageContext(); 2268 RunOnFunction(FD); 2269 continue; 2270 } 2271 } 2272 } 2273 } 2274 2275 void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) { 2276 assert(DeclsToScan.empty() && "DeclsToScan should be empty"); 2277 DeclsToScan.push_back(FD); 2278 2279 while (!DeclsToScan.empty()) { 2280 // Take one decl from the stack and check it by traversing its AST. 2281 // For any CallExpr found during the traversal add it's callee to the top of 2282 // the stack to be processed next. Functions already processed are stored in 2283 // ScannedDecls. 2284 const FunctionDecl *FD = DeclsToScan.pop_back_val(); 2285 2286 // Decl was already scanned 2287 const unsigned ScannedStages = GetScannedStages(FD); 2288 if (WasAlreadyScannedInCurrentStage(ScannedStages)) 2289 continue; 2290 2291 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages); 2292 2293 AddToScannedFunctions(FD); 2294 TraverseStmt(FD->getBody()); 2295 } 2296 } 2297 2298 bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone( 2299 const AvailabilityAttr *AA) { 2300 IdentifierInfo *IIEnvironment = AA->getEnvironment(); 2301 if (!IIEnvironment) 2302 return true; 2303 2304 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment(); 2305 if (CurrentEnv == llvm::Triple::UnknownEnvironment) 2306 return false; 2307 2308 llvm::Triple::EnvironmentType AttrEnv = 2309 AvailabilityAttr::getEnvironmentType(IIEnvironment->getName()); 2310 2311 return CurrentEnv == AttrEnv; 2312 } 2313 2314 const AvailabilityAttr * 2315 DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) { 2316 AvailabilityAttr const *PartialMatch = nullptr; 2317 // Check each AvailabilityAttr to find the one for this platform. 2318 // For multiple attributes with the same platform try to find one for this 2319 // environment. 2320 for (const auto *A : D->attrs()) { 2321 if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) { 2322 StringRef AttrPlatform = Avail->getPlatform()->getName(); 2323 StringRef TargetPlatform = 2324 SemaRef.getASTContext().getTargetInfo().getPlatformName(); 2325 2326 // Match the platform name. 2327 if (AttrPlatform == TargetPlatform) { 2328 // Find the best matching attribute for this environment 2329 if (HasMatchingEnvironmentOrNone(Avail)) 2330 return Avail; 2331 PartialMatch = Avail; 2332 } 2333 } 2334 } 2335 return PartialMatch; 2336 } 2337 2338 // Check availability against target shader model version and current shader 2339 // stage and emit diagnostic 2340 void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D, 2341 const AvailabilityAttr *AA, 2342 SourceRange Range) { 2343 2344 IdentifierInfo *IIEnv = AA->getEnvironment(); 2345 2346 if (!IIEnv) { 2347 // The availability attribute does not have environment -> it depends only 2348 // on shader model version and not on specific the shader stage. 2349 2350 // Skip emitting the diagnostics if the diagnostic mode is set to 2351 // strict (-fhlsl-strict-availability) because all relevant diagnostics 2352 // were already emitted in the DiagnoseUnguardedAvailability scan 2353 // (SemaAvailability.cpp). 2354 if (SemaRef.getLangOpts().HLSLStrictAvailability) 2355 return; 2356 2357 // Do not report shader-stage-independent issues if scanning a function 2358 // that was already scanned in a different shader stage context (they would 2359 // be duplicate) 2360 if (ReportOnlyShaderStageIssues) 2361 return; 2362 2363 } else { 2364 // The availability attribute has environment -> we need to know 2365 // the current stage context to property diagnose it. 2366 if (InUnknownShaderStageContext()) 2367 return; 2368 } 2369 2370 // Check introduced version and if environment matches 2371 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA); 2372 VersionTuple Introduced = AA->getIntroduced(); 2373 VersionTuple TargetVersion = 2374 SemaRef.Context.getTargetInfo().getPlatformMinVersion(); 2375 2376 if (TargetVersion >= Introduced && EnvironmentMatches) 2377 return; 2378 2379 // Emit diagnostic message 2380 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo(); 2381 llvm::StringRef PlatformName( 2382 AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName())); 2383 2384 llvm::StringRef CurrentEnvStr = 2385 llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment()); 2386 2387 llvm::StringRef AttrEnvStr = 2388 AA->getEnvironment() ? AA->getEnvironment()->getName() : ""; 2389 bool UseEnvironment = !AttrEnvStr.empty(); 2390 2391 if (EnvironmentMatches) { 2392 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability) 2393 << Range << D << PlatformName << Introduced.getAsString() 2394 << UseEnvironment << CurrentEnvStr; 2395 } else { 2396 SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable) 2397 << Range << D; 2398 } 2399 2400 SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here) 2401 << D << PlatformName << Introduced.getAsString() 2402 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString() 2403 << UseEnvironment << AttrEnvStr << CurrentEnvStr; 2404 } 2405 2406 } // namespace 2407 2408 void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) { 2409 // process default CBuffer - create buffer layout struct and invoke codegenCGH 2410 if (!DefaultCBufferDecls.empty()) { 2411 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer( 2412 SemaRef.getASTContext(), SemaRef.getCurLexicalContext(), 2413 DefaultCBufferDecls); 2414 addImplicitBindingAttrToBuffer(SemaRef, DefaultCBuffer, 2415 getNextImplicitBindingOrderID()); 2416 SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer); 2417 createHostLayoutStructForBuffer(SemaRef, DefaultCBuffer); 2418 2419 // Set HasValidPackoffset if any of the decls has a register(c#) annotation; 2420 for (const Decl *VD : DefaultCBufferDecls) { 2421 const HLSLResourceBindingAttr *RBA = 2422 VD->getAttr<HLSLResourceBindingAttr>(); 2423 if (RBA && RBA->hasRegisterSlot() && 2424 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) { 2425 DefaultCBuffer->setHasValidPackoffset(true); 2426 break; 2427 } 2428 } 2429 2430 DeclGroupRef DG(DefaultCBuffer); 2431 SemaRef.Consumer.HandleTopLevelDecl(DG); 2432 } 2433 diagnoseAvailabilityViolations(TU); 2434 } 2435 2436 void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) { 2437 // Skip running the diagnostics scan if the diagnostic mode is 2438 // strict (-fhlsl-strict-availability) and the target shader stage is known 2439 // because all relevant diagnostics were already emitted in the 2440 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp). 2441 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo(); 2442 if (SemaRef.getLangOpts().HLSLStrictAvailability && 2443 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library) 2444 return; 2445 2446 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU); 2447 } 2448 2449 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) { 2450 assert(TheCall->getNumArgs() > 1); 2451 QualType ArgTy0 = TheCall->getArg(0)->getType(); 2452 2453 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) { 2454 if (!S->getASTContext().hasSameUnqualifiedType( 2455 ArgTy0, TheCall->getArg(I)->getType())) { 2456 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector) 2457 << TheCall->getDirectCallee() << /*useAllTerminology*/ true 2458 << SourceRange(TheCall->getArg(0)->getBeginLoc(), 2459 TheCall->getArg(N - 1)->getEndLoc()); 2460 return true; 2461 } 2462 } 2463 return false; 2464 } 2465 2466 static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) { 2467 QualType ArgType = Arg->getType(); 2468 if (!S->getASTContext().hasSameUnqualifiedType(ArgType, ExpectedType)) { 2469 S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) 2470 << ArgType << ExpectedType << 1 << 0 << 0; 2471 return true; 2472 } 2473 return false; 2474 } 2475 2476 static bool CheckAllArgTypesAreCorrect( 2477 Sema *S, CallExpr *TheCall, 2478 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal, 2479 clang::QualType PassedType)> 2480 Check) { 2481 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) { 2482 Expr *Arg = TheCall->getArg(I); 2483 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType())) 2484 return true; 2485 } 2486 return false; 2487 } 2488 2489 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, 2490 int ArgOrdinal, 2491 clang::QualType PassedType) { 2492 clang::QualType BaseType = 2493 PassedType->isVectorType() 2494 ? PassedType->castAs<clang::VectorType>()->getElementType() 2495 : PassedType; 2496 if (!BaseType->isHalfType() && !BaseType->isFloat32Type()) 2497 return S->Diag(Loc, diag::err_builtin_invalid_arg_type) 2498 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0 2499 << /* half or float */ 2 << PassedType; 2500 return false; 2501 } 2502 2503 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, 2504 unsigned ArgIndex) { 2505 auto *Arg = TheCall->getArg(ArgIndex); 2506 SourceLocation OrigLoc = Arg->getExprLoc(); 2507 if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) == 2508 Expr::MLV_Valid) 2509 return false; 2510 S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0; 2511 return true; 2512 } 2513 2514 static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal, 2515 clang::QualType PassedType) { 2516 const auto *VecTy = PassedType->getAs<VectorType>(); 2517 if (!VecTy) 2518 return false; 2519 2520 if (VecTy->getElementType()->isDoubleType()) 2521 return S->Diag(Loc, diag::err_builtin_invalid_arg_type) 2522 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1 2523 << PassedType; 2524 return false; 2525 } 2526 2527 static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc, 2528 int ArgOrdinal, 2529 clang::QualType PassedType) { 2530 if (!PassedType->hasIntegerRepresentation() && 2531 !PassedType->hasFloatingRepresentation()) 2532 return S->Diag(Loc, diag::err_builtin_invalid_arg_type) 2533 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1 2534 << /* fp */ 1 << PassedType; 2535 return false; 2536 } 2537 2538 static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc, 2539 int ArgOrdinal, 2540 clang::QualType PassedType) { 2541 if (auto *VecTy = PassedType->getAs<VectorType>()) 2542 if (VecTy->getElementType()->isUnsignedIntegerType()) 2543 return false; 2544 2545 return S->Diag(Loc, diag::err_builtin_invalid_arg_type) 2546 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0 2547 << PassedType; 2548 } 2549 2550 // checks for unsigned ints of all sizes 2551 static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc, 2552 int ArgOrdinal, 2553 clang::QualType PassedType) { 2554 if (!PassedType->hasUnsignedIntegerRepresentation()) 2555 return S->Diag(Loc, diag::err_builtin_invalid_arg_type) 2556 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3 2557 << /* no fp */ 0 << PassedType; 2558 return false; 2559 } 2560 2561 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, 2562 QualType ReturnType) { 2563 auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>(); 2564 if (VecTyA) 2565 ReturnType = 2566 S->Context.getExtVectorType(ReturnType, VecTyA->getNumElements()); 2567 2568 TheCall->setType(ReturnType); 2569 } 2570 2571 static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, 2572 unsigned ArgIndex) { 2573 assert(TheCall->getNumArgs() >= ArgIndex); 2574 QualType ArgType = TheCall->getArg(ArgIndex)->getType(); 2575 auto *VTy = ArgType->getAs<VectorType>(); 2576 // not the scalar or vector<scalar> 2577 if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) || 2578 (VTy && 2579 S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) { 2580 S->Diag(TheCall->getArg(0)->getBeginLoc(), 2581 diag::err_typecheck_expect_scalar_or_vector) 2582 << ArgType << Scalar; 2583 return true; 2584 } 2585 return false; 2586 } 2587 2588 static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall, 2589 unsigned ArgIndex) { 2590 assert(TheCall->getNumArgs() >= ArgIndex); 2591 QualType ArgType = TheCall->getArg(ArgIndex)->getType(); 2592 auto *VTy = ArgType->getAs<VectorType>(); 2593 // not the scalar or vector<scalar> 2594 if (!(ArgType->isScalarType() || 2595 (VTy && VTy->getElementType()->isScalarType()))) { 2596 S->Diag(TheCall->getArg(0)->getBeginLoc(), 2597 diag::err_typecheck_expect_any_scalar_or_vector) 2598 << ArgType << 1; 2599 return true; 2600 } 2601 return false; 2602 } 2603 2604 static bool CheckWaveActive(Sema *S, CallExpr *TheCall) { 2605 QualType BoolType = S->getASTContext().BoolTy; 2606 assert(TheCall->getNumArgs() >= 1); 2607 QualType ArgType = TheCall->getArg(0)->getType(); 2608 auto *VTy = ArgType->getAs<VectorType>(); 2609 // is the bool or vector<bool> 2610 if (S->Context.hasSameUnqualifiedType(ArgType, BoolType) || 2611 (VTy && 2612 S->Context.hasSameUnqualifiedType(VTy->getElementType(), BoolType))) { 2613 S->Diag(TheCall->getArg(0)->getBeginLoc(), 2614 diag::err_typecheck_expect_any_scalar_or_vector) 2615 << ArgType << 0; 2616 return true; 2617 } 2618 return false; 2619 } 2620 2621 static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { 2622 assert(TheCall->getNumArgs() == 3); 2623 Expr *Arg1 = TheCall->getArg(1); 2624 Expr *Arg2 = TheCall->getArg(2); 2625 if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) { 2626 S->Diag(TheCall->getBeginLoc(), 2627 diag::err_typecheck_call_different_arg_types) 2628 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange() 2629 << Arg2->getSourceRange(); 2630 return true; 2631 } 2632 2633 TheCall->setType(Arg1->getType()); 2634 return false; 2635 } 2636 2637 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) { 2638 assert(TheCall->getNumArgs() == 3); 2639 Expr *Arg1 = TheCall->getArg(1); 2640 QualType Arg1Ty = Arg1->getType(); 2641 Expr *Arg2 = TheCall->getArg(2); 2642 QualType Arg2Ty = Arg2->getType(); 2643 2644 QualType Arg1ScalarTy = Arg1Ty; 2645 if (auto VTy = Arg1ScalarTy->getAs<VectorType>()) 2646 Arg1ScalarTy = VTy->getElementType(); 2647 2648 QualType Arg2ScalarTy = Arg2Ty; 2649 if (auto VTy = Arg2ScalarTy->getAs<VectorType>()) 2650 Arg2ScalarTy = VTy->getElementType(); 2651 2652 if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy)) 2653 S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch) 2654 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty; 2655 2656 QualType Arg0Ty = TheCall->getArg(0)->getType(); 2657 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements(); 2658 unsigned Arg1Length = Arg1Ty->isVectorType() 2659 ? Arg1Ty->getAs<VectorType>()->getNumElements() 2660 : 0; 2661 unsigned Arg2Length = Arg2Ty->isVectorType() 2662 ? Arg2Ty->getAs<VectorType>()->getNumElements() 2663 : 0; 2664 if (Arg1Length > 0 && Arg0Length != Arg1Length) { 2665 S->Diag(TheCall->getBeginLoc(), 2666 diag::err_typecheck_vector_lengths_not_equal) 2667 << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange() 2668 << Arg1->getSourceRange(); 2669 return true; 2670 } 2671 2672 if (Arg2Length > 0 && Arg0Length != Arg2Length) { 2673 S->Diag(TheCall->getBeginLoc(), 2674 diag::err_typecheck_vector_lengths_not_equal) 2675 << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange() 2676 << Arg2->getSourceRange(); 2677 return true; 2678 } 2679 2680 TheCall->setType( 2681 S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length)); 2682 return false; 2683 } 2684 2685 static bool CheckResourceHandle( 2686 Sema *S, CallExpr *TheCall, unsigned ArgIndex, 2687 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check = 2688 nullptr) { 2689 assert(TheCall->getNumArgs() >= ArgIndex); 2690 QualType ArgType = TheCall->getArg(ArgIndex)->getType(); 2691 const HLSLAttributedResourceType *ResTy = 2692 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>(); 2693 if (!ResTy) { 2694 S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(), 2695 diag::err_typecheck_expect_hlsl_resource) 2696 << ArgType; 2697 return true; 2698 } 2699 if (Check && Check(ResTy)) { 2700 S->Diag(TheCall->getArg(ArgIndex)->getExprLoc(), 2701 diag::err_invalid_hlsl_resource_type) 2702 << ArgType; 2703 return true; 2704 } 2705 return false; 2706 } 2707 2708 // Note: returning true in this case results in CheckBuiltinFunctionCall 2709 // returning an ExprError 2710 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { 2711 switch (BuiltinID) { 2712 case Builtin::BI__builtin_hlsl_adduint64: { 2713 if (SemaRef.checkArgCount(TheCall, 2)) 2714 return true; 2715 2716 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2717 CheckUnsignedIntVecRepresentation)) 2718 return true; 2719 2720 auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>(); 2721 // ensure arg integers are 32-bits 2722 uint64_t ElementBitCount = getASTContext() 2723 .getTypeSizeInChars(VTy->getElementType()) 2724 .getQuantity() * 2725 8; 2726 if (ElementBitCount != 32) { 2727 SemaRef.Diag(TheCall->getBeginLoc(), 2728 diag::err_integer_incorrect_bit_count) 2729 << 32 << ElementBitCount; 2730 return true; 2731 } 2732 2733 // ensure both args are vectors of total bit size of a multiple of 64 2734 int NumElementsArg = VTy->getNumElements(); 2735 if (NumElementsArg != 2 && NumElementsArg != 4) { 2736 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vector_incorrect_bit_count) 2737 << 1 /*a multiple of*/ << 64 << NumElementsArg * ElementBitCount; 2738 return true; 2739 } 2740 2741 // ensure first arg and second arg have the same type 2742 if (CheckAllArgsHaveSameType(&SemaRef, TheCall)) 2743 return true; 2744 2745 ExprResult A = TheCall->getArg(0); 2746 QualType ArgTyA = A.get()->getType(); 2747 // return type is the same as the input type 2748 TheCall->setType(ArgTyA); 2749 break; 2750 } 2751 case Builtin::BI__builtin_hlsl_resource_getpointer: { 2752 if (SemaRef.checkArgCount(TheCall, 2) || 2753 CheckResourceHandle(&SemaRef, TheCall, 0) || 2754 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), 2755 SemaRef.getASTContext().UnsignedIntTy)) 2756 return true; 2757 2758 auto *ResourceTy = 2759 TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>(); 2760 QualType ContainedTy = ResourceTy->getContainedType(); 2761 auto ReturnType = 2762 SemaRef.Context.getAddrSpaceQualType(ContainedTy, LangAS::hlsl_device); 2763 ReturnType = SemaRef.Context.getPointerType(ReturnType); 2764 TheCall->setType(ReturnType); 2765 TheCall->setValueKind(VK_LValue); 2766 2767 break; 2768 } 2769 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: { 2770 if (SemaRef.checkArgCount(TheCall, 1) || 2771 CheckResourceHandle(&SemaRef, TheCall, 0)) 2772 return true; 2773 // use the type of the handle (arg0) as a return type 2774 QualType ResourceTy = TheCall->getArg(0)->getType(); 2775 TheCall->setType(ResourceTy); 2776 break; 2777 } 2778 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: { 2779 ASTContext &AST = SemaRef.getASTContext(); 2780 if (SemaRef.checkArgCount(TheCall, 6) || 2781 CheckResourceHandle(&SemaRef, TheCall, 0) || 2782 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) || 2783 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.UnsignedIntTy) || 2784 CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.IntTy) || 2785 CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) || 2786 CheckArgTypeMatches(&SemaRef, TheCall->getArg(5), 2787 AST.getPointerType(AST.CharTy.withConst()))) 2788 return true; 2789 // use the type of the handle (arg0) as a return type 2790 QualType ResourceTy = TheCall->getArg(0)->getType(); 2791 TheCall->setType(ResourceTy); 2792 break; 2793 } 2794 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: { 2795 ASTContext &AST = SemaRef.getASTContext(); 2796 if (SemaRef.checkArgCount(TheCall, 6) || 2797 CheckResourceHandle(&SemaRef, TheCall, 0) || 2798 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) || 2799 CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.IntTy) || 2800 CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.UnsignedIntTy) || 2801 CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) || 2802 CheckArgTypeMatches(&SemaRef, TheCall->getArg(5), 2803 AST.getPointerType(AST.CharTy.withConst()))) 2804 return true; 2805 // use the type of the handle (arg0) as a return type 2806 QualType ResourceTy = TheCall->getArg(0)->getType(); 2807 TheCall->setType(ResourceTy); 2808 break; 2809 } 2810 case Builtin::BI__builtin_hlsl_and: 2811 case Builtin::BI__builtin_hlsl_or: { 2812 if (SemaRef.checkArgCount(TheCall, 2)) 2813 return true; 2814 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0)) 2815 return true; 2816 if (CheckAllArgsHaveSameType(&SemaRef, TheCall)) 2817 return true; 2818 2819 ExprResult A = TheCall->getArg(0); 2820 QualType ArgTyA = A.get()->getType(); 2821 // return type is the same as the input type 2822 TheCall->setType(ArgTyA); 2823 break; 2824 } 2825 case Builtin::BI__builtin_hlsl_all: 2826 case Builtin::BI__builtin_hlsl_any: { 2827 if (SemaRef.checkArgCount(TheCall, 1)) 2828 return true; 2829 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) 2830 return true; 2831 break; 2832 } 2833 case Builtin::BI__builtin_hlsl_asdouble: { 2834 if (SemaRef.checkArgCount(TheCall, 2)) 2835 return true; 2836 if (CheckScalarOrVector( 2837 &SemaRef, TheCall, 2838 /*only check for uint*/ SemaRef.Context.UnsignedIntTy, 2839 /* arg index */ 0)) 2840 return true; 2841 if (CheckScalarOrVector( 2842 &SemaRef, TheCall, 2843 /*only check for uint*/ SemaRef.Context.UnsignedIntTy, 2844 /* arg index */ 1)) 2845 return true; 2846 if (CheckAllArgsHaveSameType(&SemaRef, TheCall)) 2847 return true; 2848 2849 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy); 2850 break; 2851 } 2852 case Builtin::BI__builtin_hlsl_elementwise_clamp: { 2853 if (SemaRef.BuiltinElementwiseTernaryMath( 2854 TheCall, /*ArgTyRestr=*/ 2855 Sema::EltwiseBuiltinArgTyRestriction::None)) 2856 return true; 2857 break; 2858 } 2859 case Builtin::BI__builtin_hlsl_dot: { 2860 // arg count is checked by BuiltinVectorToScalarMath 2861 if (SemaRef.BuiltinVectorToScalarMath(TheCall)) 2862 return true; 2863 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors)) 2864 return true; 2865 break; 2866 } 2867 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: 2868 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: { 2869 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2870 return true; 2871 2872 const Expr *Arg = TheCall->getArg(0); 2873 QualType ArgTy = Arg->getType(); 2874 QualType EltTy = ArgTy; 2875 2876 QualType ResTy = SemaRef.Context.UnsignedIntTy; 2877 2878 if (auto *VecTy = EltTy->getAs<VectorType>()) { 2879 EltTy = VecTy->getElementType(); 2880 ResTy = SemaRef.Context.getExtVectorType(ResTy, VecTy->getNumElements()); 2881 } 2882 2883 if (!EltTy->isIntegerType()) { 2884 Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type) 2885 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1 2886 << /* no fp */ 0 << ArgTy; 2887 return true; 2888 } 2889 2890 TheCall->setType(ResTy); 2891 break; 2892 } 2893 case Builtin::BI__builtin_hlsl_select: { 2894 if (SemaRef.checkArgCount(TheCall, 3)) 2895 return true; 2896 if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0)) 2897 return true; 2898 QualType ArgTy = TheCall->getArg(0)->getType(); 2899 if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall)) 2900 return true; 2901 auto *VTy = ArgTy->getAs<VectorType>(); 2902 if (VTy && VTy->getElementType()->isBooleanType() && 2903 CheckVectorSelect(&SemaRef, TheCall)) 2904 return true; 2905 break; 2906 } 2907 case Builtin::BI__builtin_hlsl_elementwise_saturate: 2908 case Builtin::BI__builtin_hlsl_elementwise_rcp: { 2909 if (SemaRef.checkArgCount(TheCall, 1)) 2910 return true; 2911 if (!TheCall->getArg(0) 2912 ->getType() 2913 ->hasFloatingRepresentation()) // half or float or double 2914 return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), 2915 diag::err_builtin_invalid_arg_type) 2916 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0 2917 << /* fp */ 1 << TheCall->getArg(0)->getType(); 2918 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2919 return true; 2920 break; 2921 } 2922 case Builtin::BI__builtin_hlsl_elementwise_degrees: 2923 case Builtin::BI__builtin_hlsl_elementwise_radians: 2924 case Builtin::BI__builtin_hlsl_elementwise_rsqrt: 2925 case Builtin::BI__builtin_hlsl_elementwise_frac: { 2926 if (SemaRef.checkArgCount(TheCall, 1)) 2927 return true; 2928 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2929 CheckFloatOrHalfRepresentation)) 2930 return true; 2931 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2932 return true; 2933 break; 2934 } 2935 case Builtin::BI__builtin_hlsl_elementwise_isinf: { 2936 if (SemaRef.checkArgCount(TheCall, 1)) 2937 return true; 2938 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2939 CheckFloatOrHalfRepresentation)) 2940 return true; 2941 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2942 return true; 2943 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().BoolTy); 2944 break; 2945 } 2946 case Builtin::BI__builtin_hlsl_lerp: { 2947 if (SemaRef.checkArgCount(TheCall, 3)) 2948 return true; 2949 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2950 CheckFloatOrHalfRepresentation)) 2951 return true; 2952 if (CheckAllArgsHaveSameType(&SemaRef, TheCall)) 2953 return true; 2954 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall)) 2955 return true; 2956 break; 2957 } 2958 case Builtin::BI__builtin_hlsl_mad: { 2959 if (SemaRef.BuiltinElementwiseTernaryMath( 2960 TheCall, /*ArgTyRestr=*/ 2961 Sema::EltwiseBuiltinArgTyRestriction::None)) 2962 return true; 2963 break; 2964 } 2965 case Builtin::BI__builtin_hlsl_normalize: { 2966 if (SemaRef.checkArgCount(TheCall, 1)) 2967 return true; 2968 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2969 CheckFloatOrHalfRepresentation)) 2970 return true; 2971 ExprResult A = TheCall->getArg(0); 2972 QualType ArgTyA = A.get()->getType(); 2973 // return type is the same as the input type 2974 TheCall->setType(ArgTyA); 2975 break; 2976 } 2977 case Builtin::BI__builtin_hlsl_elementwise_sign: { 2978 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) 2979 return true; 2980 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2981 CheckFloatingOrIntRepresentation)) 2982 return true; 2983 SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy); 2984 break; 2985 } 2986 case Builtin::BI__builtin_hlsl_step: { 2987 if (SemaRef.checkArgCount(TheCall, 2)) 2988 return true; 2989 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2990 CheckFloatOrHalfRepresentation)) 2991 return true; 2992 2993 ExprResult A = TheCall->getArg(0); 2994 QualType ArgTyA = A.get()->getType(); 2995 // return type is the same as the input type 2996 TheCall->setType(ArgTyA); 2997 break; 2998 } 2999 case Builtin::BI__builtin_hlsl_wave_active_max: 3000 case Builtin::BI__builtin_hlsl_wave_active_sum: { 3001 if (SemaRef.checkArgCount(TheCall, 1)) 3002 return true; 3003 3004 // Ensure input expr type is a scalar/vector and the same as the return type 3005 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) 3006 return true; 3007 if (CheckWaveActive(&SemaRef, TheCall)) 3008 return true; 3009 ExprResult Expr = TheCall->getArg(0); 3010 QualType ArgTyExpr = Expr.get()->getType(); 3011 TheCall->setType(ArgTyExpr); 3012 break; 3013 } 3014 // Note these are llvm builtins that we want to catch invalid intrinsic 3015 // generation. Normal handling of these builitns will occur elsewhere. 3016 case Builtin::BI__builtin_elementwise_bitreverse: { 3017 // does not include a check for number of arguments 3018 // because that is done previously 3019 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 3020 CheckUnsignedIntRepresentation)) 3021 return true; 3022 break; 3023 } 3024 case Builtin::BI__builtin_hlsl_wave_read_lane_at: { 3025 if (SemaRef.checkArgCount(TheCall, 2)) 3026 return true; 3027 3028 // Ensure index parameter type can be interpreted as a uint 3029 ExprResult Index = TheCall->getArg(1); 3030 QualType ArgTyIndex = Index.get()->getType(); 3031 if (!ArgTyIndex->isIntegerType()) { 3032 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(), 3033 diag::err_typecheck_convert_incompatible) 3034 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0; 3035 return true; 3036 } 3037 3038 // Ensure input expr type is a scalar/vector and the same as the return type 3039 if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) 3040 return true; 3041 3042 ExprResult Expr = TheCall->getArg(0); 3043 QualType ArgTyExpr = Expr.get()->getType(); 3044 TheCall->setType(ArgTyExpr); 3045 break; 3046 } 3047 case Builtin::BI__builtin_hlsl_wave_get_lane_index: { 3048 if (SemaRef.checkArgCount(TheCall, 0)) 3049 return true; 3050 break; 3051 } 3052 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: { 3053 if (SemaRef.checkArgCount(TheCall, 3)) 3054 return true; 3055 3056 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) || 3057 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy, 3058 1) || 3059 CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy, 3060 2)) 3061 return true; 3062 3063 if (CheckModifiableLValue(&SemaRef, TheCall, 1) || 3064 CheckModifiableLValue(&SemaRef, TheCall, 2)) 3065 return true; 3066 break; 3067 } 3068 case Builtin::BI__builtin_hlsl_elementwise_clip: { 3069 if (SemaRef.checkArgCount(TheCall, 1)) 3070 return true; 3071 3072 if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.FloatTy, 0)) 3073 return true; 3074 break; 3075 } 3076 case Builtin::BI__builtin_elementwise_acos: 3077 case Builtin::BI__builtin_elementwise_asin: 3078 case Builtin::BI__builtin_elementwise_atan: 3079 case Builtin::BI__builtin_elementwise_atan2: 3080 case Builtin::BI__builtin_elementwise_ceil: 3081 case Builtin::BI__builtin_elementwise_cos: 3082 case Builtin::BI__builtin_elementwise_cosh: 3083 case Builtin::BI__builtin_elementwise_exp: 3084 case Builtin::BI__builtin_elementwise_exp2: 3085 case Builtin::BI__builtin_elementwise_exp10: 3086 case Builtin::BI__builtin_elementwise_floor: 3087 case Builtin::BI__builtin_elementwise_fmod: 3088 case Builtin::BI__builtin_elementwise_log: 3089 case Builtin::BI__builtin_elementwise_log2: 3090 case Builtin::BI__builtin_elementwise_log10: 3091 case Builtin::BI__builtin_elementwise_pow: 3092 case Builtin::BI__builtin_elementwise_roundeven: 3093 case Builtin::BI__builtin_elementwise_sin: 3094 case Builtin::BI__builtin_elementwise_sinh: 3095 case Builtin::BI__builtin_elementwise_sqrt: 3096 case Builtin::BI__builtin_elementwise_tan: 3097 case Builtin::BI__builtin_elementwise_tanh: 3098 case Builtin::BI__builtin_elementwise_trunc: { 3099 if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 3100 CheckFloatOrHalfRepresentation)) 3101 return true; 3102 break; 3103 } 3104 case Builtin::BI__builtin_hlsl_buffer_update_counter: { 3105 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool { 3106 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV && 3107 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType()); 3108 }; 3109 if (SemaRef.checkArgCount(TheCall, 2) || 3110 CheckResourceHandle(&SemaRef, TheCall, 0, checkResTy) || 3111 CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), 3112 SemaRef.getASTContext().IntTy)) 3113 return true; 3114 Expr *OffsetExpr = TheCall->getArg(1); 3115 std::optional<llvm::APSInt> Offset = 3116 OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext()); 3117 if (!Offset.has_value() || std::abs(Offset->getExtValue()) != 1) { 3118 SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(), 3119 diag::err_hlsl_expect_arg_const_int_one_or_neg_one) 3120 << 1; 3121 return true; 3122 } 3123 break; 3124 } 3125 } 3126 return false; 3127 } 3128 3129 static void BuildFlattenedTypeList(QualType BaseTy, 3130 llvm::SmallVectorImpl<QualType> &List) { 3131 llvm::SmallVector<QualType, 16> WorkList; 3132 WorkList.push_back(BaseTy); 3133 while (!WorkList.empty()) { 3134 QualType T = WorkList.pop_back_val(); 3135 T = T.getCanonicalType().getUnqualifiedType(); 3136 assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL"); 3137 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) { 3138 llvm::SmallVector<QualType, 16> ElementFields; 3139 // Generally I've avoided recursion in this algorithm, but arrays of 3140 // structs could be time-consuming to flatten and churn through on the 3141 // work list. Hopefully nesting arrays of structs containing arrays 3142 // of structs too many levels deep is unlikely. 3143 BuildFlattenedTypeList(AT->getElementType(), ElementFields); 3144 // Repeat the element's field list n times. 3145 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct) 3146 llvm::append_range(List, ElementFields); 3147 continue; 3148 } 3149 // Vectors can only have element types that are builtin types, so this can 3150 // add directly to the list instead of to the WorkList. 3151 if (const auto *VT = dyn_cast<VectorType>(T)) { 3152 List.insert(List.end(), VT->getNumElements(), VT->getElementType()); 3153 continue; 3154 } 3155 if (const auto *RT = dyn_cast<RecordType>(T)) { 3156 const CXXRecordDecl *RD = RT->getAsCXXRecordDecl(); 3157 assert(RD && "HLSL record types should all be CXXRecordDecls!"); 3158 3159 if (RD->isStandardLayout()) 3160 RD = RD->getStandardLayoutBaseWithFields(); 3161 3162 // For types that we shouldn't decompose (unions and non-aggregates), just 3163 // add the type itself to the list. 3164 if (RD->isUnion() || !RD->isAggregate()) { 3165 List.push_back(T); 3166 continue; 3167 } 3168 3169 llvm::SmallVector<QualType, 16> FieldTypes; 3170 for (const auto *FD : RD->fields()) 3171 FieldTypes.push_back(FD->getType()); 3172 // Reverse the newly added sub-range. 3173 std::reverse(FieldTypes.begin(), FieldTypes.end()); 3174 llvm::append_range(WorkList, FieldTypes); 3175 3176 // If this wasn't a standard layout type we may also have some base 3177 // classes to deal with. 3178 if (!RD->isStandardLayout()) { 3179 FieldTypes.clear(); 3180 for (const auto &Base : RD->bases()) 3181 FieldTypes.push_back(Base.getType()); 3182 std::reverse(FieldTypes.begin(), FieldTypes.end()); 3183 llvm::append_range(WorkList, FieldTypes); 3184 } 3185 continue; 3186 } 3187 List.push_back(T); 3188 } 3189 } 3190 3191 bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) { 3192 // null and array types are not allowed. 3193 if (QT.isNull() || QT->isArrayType()) 3194 return false; 3195 3196 // UDT types are not allowed 3197 if (QT->isRecordType()) 3198 return false; 3199 3200 if (QT->isBooleanType() || QT->isEnumeralType()) 3201 return false; 3202 3203 // the only other valid builtin types are scalars or vectors 3204 if (QT->isArithmeticType()) { 3205 if (SemaRef.Context.getTypeSize(QT) / 8 > 16) 3206 return false; 3207 return true; 3208 } 3209 3210 if (const VectorType *VT = QT->getAs<VectorType>()) { 3211 int ArraySize = VT->getNumElements(); 3212 3213 if (ArraySize > 4) 3214 return false; 3215 3216 QualType ElTy = VT->getElementType(); 3217 if (ElTy->isBooleanType()) 3218 return false; 3219 3220 if (SemaRef.Context.getTypeSize(QT) / 8 > 16) 3221 return false; 3222 return true; 3223 } 3224 3225 return false; 3226 } 3227 3228 bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const { 3229 if (T1.isNull() || T2.isNull()) 3230 return false; 3231 3232 T1 = T1.getCanonicalType().getUnqualifiedType(); 3233 T2 = T2.getCanonicalType().getUnqualifiedType(); 3234 3235 // If both types are the same canonical type, they're obviously compatible. 3236 if (SemaRef.getASTContext().hasSameType(T1, T2)) 3237 return true; 3238 3239 llvm::SmallVector<QualType, 16> T1Types; 3240 BuildFlattenedTypeList(T1, T1Types); 3241 llvm::SmallVector<QualType, 16> T2Types; 3242 BuildFlattenedTypeList(T2, T2Types); 3243 3244 // Check the flattened type list 3245 return llvm::equal(T1Types, T2Types, 3246 [this](QualType LHS, QualType RHS) -> bool { 3247 return SemaRef.IsLayoutCompatible(LHS, RHS); 3248 }); 3249 } 3250 3251 bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New, 3252 FunctionDecl *Old) { 3253 if (New->getNumParams() != Old->getNumParams()) 3254 return true; 3255 3256 bool HadError = false; 3257 3258 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) { 3259 ParmVarDecl *NewParam = New->getParamDecl(i); 3260 ParmVarDecl *OldParam = Old->getParamDecl(i); 3261 3262 // HLSL parameter declarations for inout and out must match between 3263 // declarations. In HLSL inout and out are ambiguous at the call site, 3264 // but have different calling behavior, so you cannot overload a 3265 // method based on a difference between inout and out annotations. 3266 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>(); 3267 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0); 3268 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>(); 3269 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0); 3270 3271 if (NSpellingIdx != OSpellingIdx) { 3272 SemaRef.Diag(NewParam->getLocation(), 3273 diag::err_hlsl_param_qualifier_mismatch) 3274 << NDAttr << NewParam; 3275 SemaRef.Diag(OldParam->getLocation(), diag::note_previous_declaration_as) 3276 << ODAttr; 3277 HadError = true; 3278 } 3279 } 3280 return HadError; 3281 } 3282 3283 // Generally follows PerformScalarCast, with cases reordered for 3284 // clarity of what types are supported 3285 bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) { 3286 3287 if (!SrcTy->isScalarType() || !DestTy->isScalarType()) 3288 return false; 3289 3290 if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy)) 3291 return true; 3292 3293 switch (SrcTy->getScalarTypeKind()) { 3294 case Type::STK_Bool: // casting from bool is like casting from an integer 3295 case Type::STK_Integral: 3296 switch (DestTy->getScalarTypeKind()) { 3297 case Type::STK_Bool: 3298 case Type::STK_Integral: 3299 case Type::STK_Floating: 3300 return true; 3301 case Type::STK_CPointer: 3302 case Type::STK_ObjCObjectPointer: 3303 case Type::STK_BlockPointer: 3304 case Type::STK_MemberPointer: 3305 llvm_unreachable("HLSL doesn't support pointers."); 3306 case Type::STK_IntegralComplex: 3307 case Type::STK_FloatingComplex: 3308 llvm_unreachable("HLSL doesn't support complex types."); 3309 case Type::STK_FixedPoint: 3310 llvm_unreachable("HLSL doesn't support fixed point types."); 3311 } 3312 llvm_unreachable("Should have returned before this"); 3313 3314 case Type::STK_Floating: 3315 switch (DestTy->getScalarTypeKind()) { 3316 case Type::STK_Floating: 3317 case Type::STK_Bool: 3318 case Type::STK_Integral: 3319 return true; 3320 case Type::STK_FloatingComplex: 3321 case Type::STK_IntegralComplex: 3322 llvm_unreachable("HLSL doesn't support complex types."); 3323 case Type::STK_FixedPoint: 3324 llvm_unreachable("HLSL doesn't support fixed point types."); 3325 case Type::STK_CPointer: 3326 case Type::STK_ObjCObjectPointer: 3327 case Type::STK_BlockPointer: 3328 case Type::STK_MemberPointer: 3329 llvm_unreachable("HLSL doesn't support pointers."); 3330 } 3331 llvm_unreachable("Should have returned before this"); 3332 3333 case Type::STK_MemberPointer: 3334 case Type::STK_CPointer: 3335 case Type::STK_BlockPointer: 3336 case Type::STK_ObjCObjectPointer: 3337 llvm_unreachable("HLSL doesn't support pointers."); 3338 3339 case Type::STK_FixedPoint: 3340 llvm_unreachable("HLSL doesn't support fixed point types."); 3341 3342 case Type::STK_FloatingComplex: 3343 case Type::STK_IntegralComplex: 3344 llvm_unreachable("HLSL doesn't support complex types."); 3345 } 3346 3347 llvm_unreachable("Unhandled scalar cast"); 3348 } 3349 3350 // Detect if a type contains a bitfield. Will be removed when 3351 // bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast 3352 bool SemaHLSL::ContainsBitField(QualType BaseTy) { 3353 llvm::SmallVector<QualType, 16> WorkList; 3354 WorkList.push_back(BaseTy); 3355 while (!WorkList.empty()) { 3356 QualType T = WorkList.pop_back_val(); 3357 T = T.getCanonicalType().getUnqualifiedType(); 3358 // only check aggregate types 3359 if (const auto *AT = dyn_cast<ConstantArrayType>(T)) { 3360 WorkList.push_back(AT->getElementType()); 3361 continue; 3362 } 3363 if (const auto *RT = dyn_cast<RecordType>(T)) { 3364 const RecordDecl *RD = RT->getDecl(); 3365 if (RD->isUnion()) 3366 continue; 3367 3368 const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(RD); 3369 3370 if (CXXD && CXXD->isStandardLayout()) 3371 RD = CXXD->getStandardLayoutBaseWithFields(); 3372 3373 for (const auto *FD : RD->fields()) { 3374 if (FD->isBitField()) 3375 return true; 3376 WorkList.push_back(FD->getType()); 3377 } 3378 continue; 3379 } 3380 } 3381 return false; 3382 } 3383 3384 // Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the 3385 // Src is a scalar or a vector of length 1 3386 // Or if Dest is a vector and Src is a vector of length 1 3387 bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) { 3388 3389 QualType SrcTy = Src->getType(); 3390 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is 3391 // going to be a vector splat from a scalar. 3392 if ((SrcTy->isScalarType() && DestTy->isVectorType()) || 3393 DestTy->isScalarType()) 3394 return false; 3395 3396 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>(); 3397 3398 // Src isn't a scalar or a vector of length 1 3399 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1)) 3400 return false; 3401 3402 if (SrcVecTy) 3403 SrcTy = SrcVecTy->getElementType(); 3404 3405 if (ContainsBitField(DestTy)) 3406 return false; 3407 3408 llvm::SmallVector<QualType> DestTypes; 3409 BuildFlattenedTypeList(DestTy, DestTypes); 3410 3411 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) { 3412 if (DestTypes[I]->isUnionType()) 3413 return false; 3414 if (!CanPerformScalarCast(SrcTy, DestTypes[I])) 3415 return false; 3416 } 3417 return true; 3418 } 3419 3420 // Can we perform an HLSL Elementwise cast? 3421 // TODO: update this code when matrices are added; see issue #88060 3422 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) { 3423 3424 // Don't handle casts where LHS and RHS are any combination of scalar/vector 3425 // There must be an aggregate somewhere 3426 QualType SrcTy = Src->getType(); 3427 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that 3428 return false; 3429 3430 if (SrcTy->isVectorType() && 3431 (DestTy->isScalarType() || DestTy->isVectorType())) 3432 return false; 3433 3434 if (ContainsBitField(DestTy) || ContainsBitField(SrcTy)) 3435 return false; 3436 3437 llvm::SmallVector<QualType> DestTypes; 3438 BuildFlattenedTypeList(DestTy, DestTypes); 3439 llvm::SmallVector<QualType> SrcTypes; 3440 BuildFlattenedTypeList(SrcTy, SrcTypes); 3441 3442 // Usually the size of SrcTypes must be greater than or equal to the size of 3443 // DestTypes. 3444 if (SrcTypes.size() < DestTypes.size()) 3445 return false; 3446 3447 unsigned SrcSize = SrcTypes.size(); 3448 unsigned DstSize = DestTypes.size(); 3449 unsigned I; 3450 for (I = 0; I < DstSize && I < SrcSize; I++) { 3451 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType()) 3452 return false; 3453 if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) { 3454 return false; 3455 } 3456 } 3457 3458 // check the rest of the source type for unions. 3459 for (; I < SrcSize; I++) { 3460 if (SrcTypes[I]->isUnionType()) 3461 return false; 3462 } 3463 return true; 3464 } 3465 3466 ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) { 3467 assert(Param->hasAttr<HLSLParamModifierAttr>() && 3468 "We should not get here without a parameter modifier expression"); 3469 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>(); 3470 if (Attr->getABI() == ParameterABI::Ordinary) 3471 return ExprResult(Arg); 3472 3473 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut; 3474 if (!Arg->isLValue()) { 3475 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue) 3476 << Arg << (IsInOut ? 1 : 0); 3477 return ExprError(); 3478 } 3479 3480 ASTContext &Ctx = SemaRef.getASTContext(); 3481 3482 QualType Ty = Param->getType().getNonLValueExprType(Ctx); 3483 3484 // HLSL allows implicit conversions from scalars to vectors, but not the 3485 // inverse, so we need to disallow `inout` with scalar->vector or 3486 // scalar->matrix conversions. 3487 if (Arg->getType()->isScalarType() != Ty->isScalarType()) { 3488 SemaRef.Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_scalar_extension) 3489 << Arg << (IsInOut ? 1 : 0); 3490 return ExprError(); 3491 } 3492 3493 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(), 3494 VK_LValue, OK_Ordinary, Arg); 3495 3496 // Parameters are initialized via copy initialization. This allows for 3497 // overload resolution of argument constructors. 3498 InitializedEntity Entity = 3499 InitializedEntity::InitializeParameter(Ctx, Ty, false); 3500 ExprResult Res = 3501 SemaRef.PerformCopyInitialization(Entity, Param->getBeginLoc(), ArgOpV); 3502 if (Res.isInvalid()) 3503 return ExprError(); 3504 Expr *Base = Res.get(); 3505 // After the cast, drop the reference type when creating the exprs. 3506 Ty = Ty.getNonLValueExprType(Ctx); 3507 auto *OpV = new (Ctx) 3508 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base); 3509 3510 // Writebacks are performed with `=` binary operator, which allows for 3511 // overload resolution on writeback result expressions. 3512 Res = SemaRef.ActOnBinOp(SemaRef.getCurScope(), Param->getBeginLoc(), 3513 tok::equal, ArgOpV, OpV); 3514 3515 if (Res.isInvalid()) 3516 return ExprError(); 3517 Expr *Writeback = Res.get(); 3518 auto *OutExpr = 3519 HLSLOutArgExpr::Create(Ctx, Ty, ArgOpV, OpV, Writeback, IsInOut); 3520 3521 return ExprResult(OutExpr); 3522 } 3523 3524 QualType SemaHLSL::getInoutParameterType(QualType Ty) { 3525 // If HLSL gains support for references, all the cites that use this will need 3526 // to be updated with semantic checking to produce errors for 3527 // pointers/references. 3528 assert(!Ty->isReferenceType() && 3529 "Pointer and reference types cannot be inout or out parameters"); 3530 Ty = SemaRef.getASTContext().getLValueReferenceType(Ty); 3531 Ty.addRestrict(); 3532 return Ty; 3533 } 3534 3535 static bool IsDefaultBufferConstantDecl(VarDecl *VD) { 3536 QualType QT = VD->getType(); 3537 return VD->getDeclContext()->isTranslationUnit() && 3538 QT.getAddressSpace() == LangAS::Default && 3539 VD->getStorageClass() != SC_Static && 3540 !VD->hasAttr<HLSLVkConstantIdAttr>() && 3541 !isInvalidConstantBufferLeafElementType(QT.getTypePtr()); 3542 } 3543 3544 void SemaHLSL::deduceAddressSpace(VarDecl *Decl) { 3545 // The variable already has an address space (groupshared for ex). 3546 if (Decl->getType().hasAddressSpace()) 3547 return; 3548 3549 if (Decl->getType()->isDependentType()) 3550 return; 3551 3552 QualType Type = Decl->getType(); 3553 3554 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) { 3555 LangAS ImplAS = LangAS::hlsl_input; 3556 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS); 3557 Decl->setType(Type); 3558 return; 3559 } 3560 3561 if (Type->isSamplerT() || Type->isVoidType()) 3562 return; 3563 3564 // Resource handles. 3565 if (isResourceRecordTypeOrArrayOf(Type->getUnqualifiedDesugaredType())) 3566 return; 3567 3568 // Only static globals belong to the Private address space. 3569 // Non-static globals belongs to the cbuffer. 3570 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember()) 3571 return; 3572 3573 LangAS ImplAS = LangAS::hlsl_private; 3574 Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS); 3575 Decl->setType(Type); 3576 } 3577 3578 void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) { 3579 if (VD->hasGlobalStorage()) { 3580 // make sure the declaration has a complete type 3581 if (SemaRef.RequireCompleteType( 3582 VD->getLocation(), 3583 SemaRef.getASTContext().getBaseElementType(VD->getType()), 3584 diag::err_typecheck_decl_incomplete_type)) { 3585 VD->setInvalidDecl(); 3586 deduceAddressSpace(VD); 3587 return; 3588 } 3589 3590 // Global variables outside a cbuffer block that are not a resource, static, 3591 // groupshared, or an empty array or struct belong to the default constant 3592 // buffer $Globals (to be created at the end of the translation unit). 3593 if (IsDefaultBufferConstantDecl(VD)) { 3594 // update address space to hlsl_constant 3595 QualType NewTy = getASTContext().getAddrSpaceQualType( 3596 VD->getType(), LangAS::hlsl_constant); 3597 VD->setType(NewTy); 3598 DefaultCBufferDecls.push_back(VD); 3599 } 3600 3601 // find all resources bindings on decl 3602 if (VD->getType()->isHLSLIntangibleType()) 3603 collectResourceBindingsOnVarDecl(VD); 3604 3605 const Type *VarType = VD->getType().getTypePtr(); 3606 while (VarType->isArrayType()) 3607 VarType = VarType->getArrayElementTypeNoTypeQual(); 3608 if (VarType->isHLSLResourceRecord() || 3609 VD->hasAttr<HLSLVkConstantIdAttr>()) { 3610 // Make the variable for resources static. The global externally visible 3611 // storage is accessed through the handle, which is a member. The variable 3612 // itself is not externally visible. 3613 VD->setStorageClass(StorageClass::SC_Static); 3614 } 3615 3616 // process explicit bindings 3617 processExplicitBindingsOnDecl(VD); 3618 } 3619 3620 deduceAddressSpace(VD); 3621 } 3622 3623 static bool initVarDeclWithCtor(Sema &S, VarDecl *VD, 3624 MutableArrayRef<Expr *> Args) { 3625 InitializedEntity Entity = InitializedEntity::InitializeVariable(VD); 3626 InitializationKind Kind = InitializationKind::CreateDirect( 3627 VD->getLocation(), SourceLocation(), SourceLocation()); 3628 3629 InitializationSequence InitSeq(S, Entity, Kind, Args); 3630 if (InitSeq.Failed()) 3631 return false; 3632 3633 ExprResult Init = InitSeq.Perform(S, Entity, Kind, Args); 3634 if (!Init.get()) 3635 return false; 3636 3637 VD->setInit(S.MaybeCreateExprWithCleanups(Init.get())); 3638 VD->setInitStyle(VarDecl::CallInit); 3639 S.CheckCompleteVariableDeclaration(VD); 3640 return true; 3641 } 3642 3643 bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) { 3644 std::optional<uint32_t> RegisterSlot; 3645 uint32_t SpaceNo = 0; 3646 HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>(); 3647 if (RBA) { 3648 if (RBA->hasRegisterSlot()) 3649 RegisterSlot = RBA->getSlotNumber(); 3650 SpaceNo = RBA->getSpaceNumber(); 3651 } 3652 3653 ASTContext &AST = SemaRef.getASTContext(); 3654 uint64_t UIntTySize = AST.getTypeSize(AST.UnsignedIntTy); 3655 uint64_t IntTySize = AST.getTypeSize(AST.IntTy); 3656 IntegerLiteral *RangeSize = IntegerLiteral::Create( 3657 AST, llvm::APInt(IntTySize, 1), AST.IntTy, SourceLocation()); 3658 IntegerLiteral *Index = IntegerLiteral::Create( 3659 AST, llvm::APInt(UIntTySize, 0), AST.UnsignedIntTy, SourceLocation()); 3660 IntegerLiteral *Space = 3661 IntegerLiteral::Create(AST, llvm::APInt(UIntTySize, SpaceNo), 3662 AST.UnsignedIntTy, SourceLocation()); 3663 StringRef VarName = VD->getName(); 3664 StringLiteral *Name = StringLiteral::Create( 3665 AST, VarName, StringLiteralKind::Ordinary, false, 3666 AST.getStringLiteralArrayType(AST.CharTy.withConst(), VarName.size()), 3667 SourceLocation()); 3668 3669 // resource with explicit binding 3670 if (RegisterSlot.has_value()) { 3671 IntegerLiteral *RegSlot = IntegerLiteral::Create( 3672 AST, llvm::APInt(UIntTySize, RegisterSlot.value()), AST.UnsignedIntTy, 3673 SourceLocation()); 3674 Expr *Args[] = {RegSlot, Space, RangeSize, Index, Name}; 3675 return initVarDeclWithCtor(SemaRef, VD, Args); 3676 } 3677 3678 // resource with implicit binding 3679 IntegerLiteral *OrderId = IntegerLiteral::Create( 3680 AST, llvm::APInt(UIntTySize, getNextImplicitBindingOrderID()), 3681 AST.UnsignedIntTy, SourceLocation()); 3682 Expr *Args[] = {Space, RangeSize, Index, OrderId, Name}; 3683 return initVarDeclWithCtor(SemaRef, VD, Args); 3684 } 3685 3686 // Returns true if the initialization has been handled. 3687 // Returns false to use default initialization. 3688 bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) { 3689 // Objects in the hlsl_constant address space are initialized 3690 // externally, so don't synthesize an implicit initializer. 3691 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant) 3692 return true; 3693 3694 // Initialize resources 3695 if (!isResourceRecordTypeOrArrayOf(VD)) 3696 return false; 3697 3698 // FIXME: We currectly support only simple resources - no arrays of resources 3699 // or resources in user defined structs. 3700 // (llvm/llvm-project#133835, llvm/llvm-project#133837) 3701 // Initialize resources at the global scope 3702 if (VD->hasGlobalStorage() && VD->getType()->isHLSLResourceRecord()) 3703 return initGlobalResourceDecl(VD); 3704 3705 return false; 3706 } 3707 3708 // Walks though the global variable declaration, collects all resource binding 3709 // requirements and adds them to Bindings 3710 void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) { 3711 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() && 3712 "expected global variable that contains HLSL resource"); 3713 3714 // Cbuffers and Tbuffers are HLSLBufferDecl types 3715 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) { 3716 Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer() 3717 ? ResourceClass::CBuffer 3718 : ResourceClass::SRV); 3719 return; 3720 } 3721 3722 // Unwrap arrays 3723 // FIXME: Calculate array size while unwrapping 3724 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); 3725 while (Ty->isConstantArrayType()) { 3726 const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); 3727 Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); 3728 } 3729 3730 // Resource (or array of resources) 3731 if (const HLSLAttributedResourceType *AttrResType = 3732 HLSLAttributedResourceType::findHandleTypeOnResource(Ty)) { 3733 Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass); 3734 return; 3735 } 3736 3737 // User defined record type 3738 if (const RecordType *RT = dyn_cast<RecordType>(Ty)) 3739 collectResourceBindingsOnUserRecordDecl(VD, RT); 3740 } 3741 3742 // Walks though the explicit resource binding attributes on the declaration, 3743 // and makes sure there is a resource that matched the binding and updates 3744 // DeclBindingInfoLists 3745 void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) { 3746 assert(VD->hasGlobalStorage() && "expected global variable"); 3747 3748 bool HasBinding = false; 3749 for (Attr *A : VD->attrs()) { 3750 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); 3751 if (!RBA || !RBA->hasRegisterSlot()) 3752 continue; 3753 HasBinding = true; 3754 3755 RegisterType RT = RBA->getRegisterType(); 3756 assert(RT != RegisterType::I && "invalid or obsolete register type should " 3757 "never have an attribute created"); 3758 3759 if (RT == RegisterType::C) { 3760 if (Bindings.hasBindingInfoForDecl(VD)) 3761 SemaRef.Diag(VD->getLocation(), 3762 diag::warn_hlsl_user_defined_type_missing_member) 3763 << static_cast<int>(RT); 3764 continue; 3765 } 3766 3767 // Find DeclBindingInfo for this binding and update it, or report error 3768 // if it does not exist (user type does to contain resources with the 3769 // expected resource class). 3770 ResourceClass RC = getResourceClass(RT); 3771 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) { 3772 // update binding info 3773 BI->setBindingAttribute(RBA, BindingType::Explicit); 3774 } else { 3775 SemaRef.Diag(VD->getLocation(), 3776 diag::warn_hlsl_user_defined_type_missing_member) 3777 << static_cast<int>(RT); 3778 } 3779 } 3780 3781 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD)) 3782 SemaRef.Diag(VD->getLocation(), diag::warn_hlsl_implicit_binding); 3783 } 3784 namespace { 3785 class InitListTransformer { 3786 Sema &S; 3787 ASTContext &Ctx; 3788 QualType InitTy; 3789 QualType *DstIt = nullptr; 3790 Expr **ArgIt = nullptr; 3791 // Is wrapping the destination type iterator required? This is only used for 3792 // incomplete array types where we loop over the destination type since we 3793 // don't know the full number of elements from the declaration. 3794 bool Wrap; 3795 3796 bool castInitializer(Expr *E) { 3797 assert(DstIt && "This should always be something!"); 3798 if (DstIt == DestTypes.end()) { 3799 if (!Wrap) { 3800 ArgExprs.push_back(E); 3801 // This is odd, but it isn't technically a failure due to conversion, we 3802 // handle mismatched counts of arguments differently. 3803 return true; 3804 } 3805 DstIt = DestTypes.begin(); 3806 } 3807 InitializedEntity Entity = InitializedEntity::InitializeParameter( 3808 Ctx, *DstIt, /* Consumed (ObjC) */ false); 3809 ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E); 3810 if (Res.isInvalid()) 3811 return false; 3812 Expr *Init = Res.get(); 3813 ArgExprs.push_back(Init); 3814 DstIt++; 3815 return true; 3816 } 3817 3818 bool buildInitializerListImpl(Expr *E) { 3819 // If this is an initialization list, traverse the sub initializers. 3820 if (auto *Init = dyn_cast<InitListExpr>(E)) { 3821 for (auto *SubInit : Init->inits()) 3822 if (!buildInitializerListImpl(SubInit)) 3823 return false; 3824 return true; 3825 } 3826 3827 // If this is a scalar type, just enqueue the expression. 3828 QualType Ty = E->getType(); 3829 3830 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType())) 3831 return castInitializer(E); 3832 3833 if (auto *VecTy = Ty->getAs<VectorType>()) { 3834 uint64_t Size = VecTy->getNumElements(); 3835 3836 QualType SizeTy = Ctx.getSizeType(); 3837 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy); 3838 for (uint64_t I = 0; I < Size; ++I) { 3839 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I), 3840 SizeTy, SourceLocation()); 3841 3842 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr( 3843 E, E->getBeginLoc(), Idx, E->getEndLoc()); 3844 if (ElExpr.isInvalid()) 3845 return false; 3846 if (!castInitializer(ElExpr.get())) 3847 return false; 3848 } 3849 return true; 3850 } 3851 3852 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) { 3853 uint64_t Size = ArrTy->getZExtSize(); 3854 QualType SizeTy = Ctx.getSizeType(); 3855 uint64_t SizeTySize = Ctx.getTypeSize(SizeTy); 3856 for (uint64_t I = 0; I < Size; ++I) { 3857 auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I), 3858 SizeTy, SourceLocation()); 3859 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr( 3860 E, E->getBeginLoc(), Idx, E->getEndLoc()); 3861 if (ElExpr.isInvalid()) 3862 return false; 3863 if (!buildInitializerListImpl(ElExpr.get())) 3864 return false; 3865 } 3866 return true; 3867 } 3868 3869 if (auto *RTy = Ty->getAs<RecordType>()) { 3870 llvm::SmallVector<const RecordType *> RecordTypes; 3871 RecordTypes.push_back(RTy); 3872 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) { 3873 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl(); 3874 assert(D->getNumBases() == 1 && 3875 "HLSL doesn't support multiple inheritance"); 3876 RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>()); 3877 } 3878 while (!RecordTypes.empty()) { 3879 const RecordType *RT = RecordTypes.pop_back_val(); 3880 for (auto *FD : RT->getDecl()->fields()) { 3881 DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess()); 3882 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc()); 3883 ExprResult Res = S.BuildFieldReferenceExpr( 3884 E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo); 3885 if (Res.isInvalid()) 3886 return false; 3887 if (!buildInitializerListImpl(Res.get())) 3888 return false; 3889 } 3890 } 3891 } 3892 return true; 3893 } 3894 3895 Expr *generateInitListsImpl(QualType Ty) { 3896 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!"); 3897 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType())) 3898 return *(ArgIt++); 3899 3900 llvm::SmallVector<Expr *> Inits; 3901 assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL"); 3902 Ty = Ty.getDesugaredType(Ctx); 3903 if (Ty->isVectorType() || Ty->isConstantArrayType()) { 3904 QualType ElTy; 3905 uint64_t Size = 0; 3906 if (auto *ATy = Ty->getAs<VectorType>()) { 3907 ElTy = ATy->getElementType(); 3908 Size = ATy->getNumElements(); 3909 } else { 3910 auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr()); 3911 ElTy = VTy->getElementType(); 3912 Size = VTy->getZExtSize(); 3913 } 3914 for (uint64_t I = 0; I < Size; ++I) 3915 Inits.push_back(generateInitListsImpl(ElTy)); 3916 } 3917 if (auto *RTy = Ty->getAs<RecordType>()) { 3918 llvm::SmallVector<const RecordType *> RecordTypes; 3919 RecordTypes.push_back(RTy); 3920 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) { 3921 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl(); 3922 assert(D->getNumBases() == 1 && 3923 "HLSL doesn't support multiple inheritance"); 3924 RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>()); 3925 } 3926 while (!RecordTypes.empty()) { 3927 const RecordType *RT = RecordTypes.pop_back_val(); 3928 for (auto *FD : RT->getDecl()->fields()) { 3929 Inits.push_back(generateInitListsImpl(FD->getType())); 3930 } 3931 } 3932 } 3933 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(), 3934 Inits, Inits.back()->getEndLoc()); 3935 NewInit->setType(Ty); 3936 return NewInit; 3937 } 3938 3939 public: 3940 llvm::SmallVector<QualType, 16> DestTypes; 3941 llvm::SmallVector<Expr *, 16> ArgExprs; 3942 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity) 3943 : S(SemaRef), Ctx(SemaRef.getASTContext()), 3944 Wrap(Entity.getType()->isIncompleteArrayType()) { 3945 InitTy = Entity.getType().getNonReferenceType(); 3946 // When we're generating initializer lists for incomplete array types we 3947 // need to wrap around both when building the initializers and when 3948 // generating the final initializer lists. 3949 if (Wrap) { 3950 assert(InitTy->isIncompleteArrayType()); 3951 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(InitTy); 3952 InitTy = IAT->getElementType(); 3953 } 3954 BuildFlattenedTypeList(InitTy, DestTypes); 3955 DstIt = DestTypes.begin(); 3956 } 3957 3958 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); } 3959 3960 Expr *generateInitLists() { 3961 assert(!ArgExprs.empty() && 3962 "Call buildInitializerList to generate argument expressions."); 3963 ArgIt = ArgExprs.begin(); 3964 if (!Wrap) 3965 return generateInitListsImpl(InitTy); 3966 llvm::SmallVector<Expr *> Inits; 3967 while (ArgIt != ArgExprs.end()) 3968 Inits.push_back(generateInitListsImpl(InitTy)); 3969 3970 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(), 3971 Inits, Inits.back()->getEndLoc()); 3972 llvm::APInt ArySize(64, Inits.size()); 3973 NewInit->setType(Ctx.getConstantArrayType(InitTy, ArySize, nullptr, 3974 ArraySizeModifier::Normal, 0)); 3975 return NewInit; 3976 } 3977 }; 3978 } // namespace 3979 3980 bool SemaHLSL::transformInitList(const InitializedEntity &Entity, 3981 InitListExpr *Init) { 3982 // If the initializer is a scalar, just return it. 3983 if (Init->getType()->isScalarType()) 3984 return true; 3985 ASTContext &Ctx = SemaRef.getASTContext(); 3986 InitListTransformer ILT(SemaRef, Entity); 3987 3988 for (unsigned I = 0; I < Init->getNumInits(); ++I) { 3989 Expr *E = Init->getInit(I); 3990 if (E->HasSideEffects(Ctx)) { 3991 QualType Ty = E->getType(); 3992 if (Ty->isRecordType()) 3993 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue()); 3994 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(), 3995 E->getObjectKind(), E); 3996 Init->setInit(I, E); 3997 } 3998 if (!ILT.buildInitializerList(E)) 3999 return false; 4000 } 4001 size_t ExpectedSize = ILT.DestTypes.size(); 4002 size_t ActualSize = ILT.ArgExprs.size(); 4003 // For incomplete arrays it is completely arbitrary to choose whether we think 4004 // the user intended fewer or more elements. This implementation assumes that 4005 // the user intended more, and errors that there are too few initializers to 4006 // complete the final element. 4007 if (Entity.getType()->isIncompleteArrayType()) 4008 ExpectedSize = 4009 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize; 4010 4011 // An initializer list might be attempting to initialize a reference or 4012 // rvalue-reference. When checking the initializer we should look through 4013 // the reference. 4014 QualType InitTy = Entity.getType().getNonReferenceType(); 4015 if (InitTy.hasAddressSpace()) 4016 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(InitTy); 4017 if (ExpectedSize != ActualSize) { 4018 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0; 4019 SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers) 4020 << TooManyOrFew << InitTy << ExpectedSize << ActualSize; 4021 return false; 4022 } 4023 4024 // generateInitListsImpl will always return an InitListExpr here, because the 4025 // scalar case is handled above. 4026 auto *NewInit = cast<InitListExpr>(ILT.generateInitLists()); 4027 Init->resizeInits(Ctx, NewInit->getNumInits()); 4028 for (unsigned I = 0; I < NewInit->getNumInits(); ++I) 4029 Init->updateInit(Ctx, I, NewInit->getInit(I)); 4030 return true; 4031 } 4032 4033 bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) { 4034 const HLSLVkConstantIdAttr *ConstIdAttr = 4035 VDecl->getAttr<HLSLVkConstantIdAttr>(); 4036 if (!ConstIdAttr) 4037 return true; 4038 4039 ASTContext &Context = SemaRef.getASTContext(); 4040 4041 APValue InitValue; 4042 if (!Init->isCXX11ConstantExpr(Context, &InitValue)) { 4043 Diag(VDecl->getLocation(), diag::err_specialization_const); 4044 VDecl->setInvalidDecl(); 4045 return false; 4046 } 4047 4048 Builtin::ID BID = 4049 getSpecConstBuiltinId(VDecl->getType()->getUnqualifiedDesugaredType()); 4050 4051 // Argument 1: The ID from the attribute 4052 int ConstantID = ConstIdAttr->getId(); 4053 llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID); 4054 Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy, 4055 ConstIdAttr->getLocation()); 4056 4057 SmallVector<Expr *, 2> Args = {IdExpr, Init}; 4058 Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args); 4059 if (C->getType()->getCanonicalTypeUnqualified() != 4060 VDecl->getType()->getCanonicalTypeUnqualified()) { 4061 C = SemaRef 4062 .BuildCStyleCastExpr(SourceLocation(), 4063 Context.getTrivialTypeSourceInfo( 4064 Init->getType(), Init->getExprLoc()), 4065 SourceLocation(), C) 4066 .get(); 4067 } 4068 Init = C; 4069 return true; 4070 } 4071