xref: /freebsd/contrib/llvm-project/clang/lib/Sema/SemaHLSL.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This implements Semantic Analysis for HLSL constructs.
9 //===----------------------------------------------------------------------===//
10 
11 #include "clang/Sema/SemaHLSL.h"
12 #include "clang/AST/Decl.h"
13 #include "clang/AST/Expr.h"
14 #include "clang/AST/RecursiveASTVisitor.h"
15 #include "clang/Basic/DiagnosticSema.h"
16 #include "clang/Basic/LLVM.h"
17 #include "clang/Basic/TargetInfo.h"
18 #include "clang/Sema/ParsedAttr.h"
19 #include "clang/Sema/Sema.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include "llvm/TargetParser/Triple.h"
26 #include <iterator>
27 
28 using namespace clang;
29 
SemaHLSL(Sema & S)30 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
31 
ActOnStartBuffer(Scope * BufferScope,bool CBuffer,SourceLocation KwLoc,IdentifierInfo * Ident,SourceLocation IdentLoc,SourceLocation LBrace)32 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
33                                  SourceLocation KwLoc, IdentifierInfo *Ident,
34                                  SourceLocation IdentLoc,
35                                  SourceLocation LBrace) {
36   // For anonymous namespace, take the location of the left brace.
37   DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
38   HLSLBufferDecl *Result = HLSLBufferDecl::Create(
39       getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);
40 
41   SemaRef.PushOnScopeChains(Result, BufferScope);
42   SemaRef.PushDeclContext(BufferScope, Result);
43 
44   return Result;
45 }
46 
47 // Calculate the size of a legacy cbuffer type based on
48 // https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
calculateLegacyCbufferSize(const ASTContext & Context,QualType T)49 static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
50                                            QualType T) {
51   unsigned Size = 0;
52   constexpr unsigned CBufferAlign = 128;
53   if (const RecordType *RT = T->getAs<RecordType>()) {
54     const RecordDecl *RD = RT->getDecl();
55     for (const FieldDecl *Field : RD->fields()) {
56       QualType Ty = Field->getType();
57       unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
58       unsigned FieldAlign = 32;
59       if (Ty->isAggregateType())
60         FieldAlign = CBufferAlign;
61       Size = llvm::alignTo(Size, FieldAlign);
62       Size += FieldSize;
63     }
64   } else if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
65     if (unsigned ElementCount = AT->getSize().getZExtValue()) {
66       unsigned ElementSize =
67           calculateLegacyCbufferSize(Context, AT->getElementType());
68       unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);
69       Size = AlignedElementSize * (ElementCount - 1) + ElementSize;
70     }
71   } else if (const VectorType *VT = T->getAs<VectorType>()) {
72     unsigned ElementCount = VT->getNumElements();
73     unsigned ElementSize =
74         calculateLegacyCbufferSize(Context, VT->getElementType());
75     Size = ElementSize * ElementCount;
76   } else {
77     Size = Context.getTypeSize(T);
78   }
79   return Size;
80 }
81 
ActOnFinishBuffer(Decl * Dcl,SourceLocation RBrace)82 void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
83   auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
84   BufDecl->setRBraceLoc(RBrace);
85 
86   // Validate packoffset.
87   llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
88   bool HasPackOffset = false;
89   bool HasNonPackOffset = false;
90   for (auto *Field : BufDecl->decls()) {
91     VarDecl *Var = dyn_cast<VarDecl>(Field);
92     if (!Var)
93       continue;
94     if (Field->hasAttr<HLSLPackOffsetAttr>()) {
95       PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
96       HasPackOffset = true;
97     } else {
98       HasNonPackOffset = true;
99     }
100   }
101 
102   if (HasPackOffset && HasNonPackOffset)
103     Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);
104 
105   if (HasPackOffset) {
106     ASTContext &Context = getASTContext();
107     // Make sure no overlap in packoffset.
108     // Sort PackOffsetVec by offset.
109     std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
110               [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
111                  const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
112                 return LHS.second->getOffset() < RHS.second->getOffset();
113               });
114 
115     for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
116       VarDecl *Var = PackOffsetVec[i].first;
117       HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
118       unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());
119       unsigned Begin = Attr->getOffset() * 32;
120       unsigned End = Begin + Size;
121       unsigned NextBegin = PackOffsetVec[i + 1].second->getOffset() * 32;
122       if (End > NextBegin) {
123         VarDecl *NextVar = PackOffsetVec[i + 1].first;
124         Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
125             << NextVar << Var;
126       }
127     }
128   }
129 
130   SemaRef.PopDeclContext();
131 }
132 
mergeNumThreadsAttr(Decl * D,const AttributeCommonInfo & AL,int X,int Y,int Z)133 HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
134                                                   const AttributeCommonInfo &AL,
135                                                   int X, int Y, int Z) {
136   if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
137     if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
138       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
139       Diag(AL.getLoc(), diag::note_conflicting_attribute);
140     }
141     return nullptr;
142   }
143   return ::new (getASTContext())
144       HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
145 }
146 
147 HLSLShaderAttr *
mergeShaderAttr(Decl * D,const AttributeCommonInfo & AL,llvm::Triple::EnvironmentType ShaderType)148 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
149                           llvm::Triple::EnvironmentType ShaderType) {
150   if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
151     if (NT->getType() != ShaderType) {
152       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
153       Diag(AL.getLoc(), diag::note_conflicting_attribute);
154     }
155     return nullptr;
156   }
157   return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
158 }
159 
160 HLSLParamModifierAttr *
mergeParamModifierAttr(Decl * D,const AttributeCommonInfo & AL,HLSLParamModifierAttr::Spelling Spelling)161 SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
162                                  HLSLParamModifierAttr::Spelling Spelling) {
163   // We can only merge an `in` attribute with an `out` attribute. All other
164   // combinations of duplicated attributes are ill-formed.
165   if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
166     if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
167         (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
168       D->dropAttr<HLSLParamModifierAttr>();
169       SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
170       return HLSLParamModifierAttr::Create(
171           getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
172           HLSLParamModifierAttr::Keyword_inout);
173     }
174     Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
175     Diag(PA->getLocation(), diag::note_conflicting_attribute);
176     return nullptr;
177   }
178   return HLSLParamModifierAttr::Create(getASTContext(), AL);
179 }
180 
ActOnTopLevelFunction(FunctionDecl * FD)181 void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
182   auto &TargetInfo = getASTContext().getTargetInfo();
183 
184   if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
185     return;
186 
187   llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
188   if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
189     if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
190       // The entry point is already annotated - check that it matches the
191       // triple.
192       if (Shader->getType() != Env) {
193         Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
194             << Shader;
195         FD->setInvalidDecl();
196       }
197     } else {
198       // Implicitly add the shader attribute if the entry function isn't
199       // explicitly annotated.
200       FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
201                                                  FD->getBeginLoc()));
202     }
203   } else {
204     switch (Env) {
205     case llvm::Triple::UnknownEnvironment:
206     case llvm::Triple::Library:
207       break;
208     default:
209       llvm_unreachable("Unhandled environment in triple");
210     }
211   }
212 }
213 
CheckEntryPoint(FunctionDecl * FD)214 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
215   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
216   assert(ShaderAttr && "Entry point has no shader attribute");
217   llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
218 
219   switch (ST) {
220   case llvm::Triple::Pixel:
221   case llvm::Triple::Vertex:
222   case llvm::Triple::Geometry:
223   case llvm::Triple::Hull:
224   case llvm::Triple::Domain:
225   case llvm::Triple::RayGeneration:
226   case llvm::Triple::Intersection:
227   case llvm::Triple::AnyHit:
228   case llvm::Triple::ClosestHit:
229   case llvm::Triple::Miss:
230   case llvm::Triple::Callable:
231     if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
232       DiagnoseAttrStageMismatch(NT, ST,
233                                 {llvm::Triple::Compute,
234                                  llvm::Triple::Amplification,
235                                  llvm::Triple::Mesh});
236       FD->setInvalidDecl();
237     }
238     break;
239 
240   case llvm::Triple::Compute:
241   case llvm::Triple::Amplification:
242   case llvm::Triple::Mesh:
243     if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
244       Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
245           << llvm::Triple::getEnvironmentTypeName(ST);
246       FD->setInvalidDecl();
247     }
248     break;
249   default:
250     llvm_unreachable("Unhandled environment in triple");
251   }
252 
253   for (ParmVarDecl *Param : FD->parameters()) {
254     if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
255       CheckSemanticAnnotation(FD, Param, AnnotationAttr);
256     } else {
257       // FIXME: Handle struct parameters where annotations are on struct fields.
258       // See: https://github.com/llvm/llvm-project/issues/57875
259       Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
260       Diag(Param->getLocation(), diag::note_previous_decl) << Param;
261       FD->setInvalidDecl();
262     }
263   }
264   // FIXME: Verify return type semantic annotation.
265 }
266 
CheckSemanticAnnotation(FunctionDecl * EntryPoint,const Decl * Param,const HLSLAnnotationAttr * AnnotationAttr)267 void SemaHLSL::CheckSemanticAnnotation(
268     FunctionDecl *EntryPoint, const Decl *Param,
269     const HLSLAnnotationAttr *AnnotationAttr) {
270   auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
271   assert(ShaderAttr && "Entry point has no shader attribute");
272   llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
273 
274   switch (AnnotationAttr->getKind()) {
275   case attr::HLSLSV_DispatchThreadID:
276   case attr::HLSLSV_GroupIndex:
277     if (ST == llvm::Triple::Compute)
278       return;
279     DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
280     break;
281   default:
282     llvm_unreachable("Unknown HLSLAnnotationAttr");
283   }
284 }
285 
DiagnoseAttrStageMismatch(const Attr * A,llvm::Triple::EnvironmentType Stage,std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages)286 void SemaHLSL::DiagnoseAttrStageMismatch(
287     const Attr *A, llvm::Triple::EnvironmentType Stage,
288     std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
289   SmallVector<StringRef, 8> StageStrings;
290   llvm::transform(AllowedStages, std::back_inserter(StageStrings),
291                   [](llvm::Triple::EnvironmentType ST) {
292                     return StringRef(
293                         HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
294                   });
295   Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
296       << A << llvm::Triple::getEnvironmentTypeName(Stage)
297       << (AllowedStages.size() != 1) << join(StageStrings, ", ");
298 }
299 
handleNumThreadsAttr(Decl * D,const ParsedAttr & AL)300 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
301   llvm::VersionTuple SMVersion =
302       getASTContext().getTargetInfo().getTriple().getOSVersion();
303   uint32_t ZMax = 1024;
304   uint32_t ThreadMax = 1024;
305   if (SMVersion.getMajor() <= 4) {
306     ZMax = 1;
307     ThreadMax = 768;
308   } else if (SMVersion.getMajor() == 5) {
309     ZMax = 64;
310     ThreadMax = 1024;
311   }
312 
313   uint32_t X;
314   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))
315     return;
316   if (X > 1024) {
317     Diag(AL.getArgAsExpr(0)->getExprLoc(),
318          diag::err_hlsl_numthreads_argument_oor)
319         << 0 << 1024;
320     return;
321   }
322   uint32_t Y;
323   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))
324     return;
325   if (Y > 1024) {
326     Diag(AL.getArgAsExpr(1)->getExprLoc(),
327          diag::err_hlsl_numthreads_argument_oor)
328         << 1 << 1024;
329     return;
330   }
331   uint32_t Z;
332   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))
333     return;
334   if (Z > ZMax) {
335     SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),
336                  diag::err_hlsl_numthreads_argument_oor)
337         << 2 << ZMax;
338     return;
339   }
340 
341   if (X * Y * Z > ThreadMax) {
342     Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
343     return;
344   }
345 
346   HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
347   if (NewAttr)
348     D->addAttr(NewAttr);
349 }
350 
isLegalTypeForHLSLSV_DispatchThreadID(QualType T)351 static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
352   if (!T->hasUnsignedIntegerRepresentation())
353     return false;
354   if (const auto *VT = T->getAs<VectorType>())
355     return VT->getNumElements() <= 3;
356   return true;
357 }
358 
handleSV_DispatchThreadIDAttr(Decl * D,const ParsedAttr & AL)359 void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
360   auto *VD = cast<ValueDecl>(D);
361   if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
362     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
363         << AL << "uint/uint2/uint3";
364     return;
365   }
366 
367   D->addAttr(::new (getASTContext())
368                  HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
369 }
370 
handlePackOffsetAttr(Decl * D,const ParsedAttr & AL)371 void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
372   if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
373     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
374         << AL << "shader constant in a constant buffer";
375     return;
376   }
377 
378   uint32_t SubComponent;
379   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))
380     return;
381   uint32_t Component;
382   if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))
383     return;
384 
385   QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
386   // Check if T is an array or struct type.
387   // TODO: mark matrix type as aggregate type.
388   bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
389 
390   // Check Component is valid for T.
391   if (Component) {
392     unsigned Size = getASTContext().getTypeSize(T);
393     if (IsAggregateTy || Size > 128) {
394       Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
395       return;
396     } else {
397       // Make sure Component + sizeof(T) <= 4.
398       if ((Component * 32 + Size) > 128) {
399         Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
400         return;
401       }
402       QualType EltTy = T;
403       if (const auto *VT = T->getAs<VectorType>())
404         EltTy = VT->getElementType();
405       unsigned Align = getASTContext().getTypeAlign(EltTy);
406       if (Align > 32 && Component == 1) {
407         // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
408         // So we only need to check Component 1 here.
409         Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)
410             << Align << EltTy;
411         return;
412       }
413     }
414   }
415 
416   D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(
417       getASTContext(), AL, SubComponent, Component));
418 }
419 
handleShaderAttr(Decl * D,const ParsedAttr & AL)420 void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
421   StringRef Str;
422   SourceLocation ArgLoc;
423   if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
424     return;
425 
426   llvm::Triple::EnvironmentType ShaderType;
427   if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
428     Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
429         << AL << Str << ArgLoc;
430     return;
431   }
432 
433   // FIXME: check function match the shader stage.
434 
435   HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
436   if (NewAttr)
437     D->addAttr(NewAttr);
438 }
439 
handleResourceClassAttr(Decl * D,const ParsedAttr & AL)440 void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) {
441   if (!AL.isArgIdent(0)) {
442     Diag(AL.getLoc(), diag::err_attribute_argument_type)
443         << AL << AANT_ArgumentIdentifier;
444     return;
445   }
446 
447   IdentifierLoc *Loc = AL.getArgAsIdent(0);
448   StringRef Identifier = Loc->Ident->getName();
449   SourceLocation ArgLoc = Loc->Loc;
450 
451   // Validate.
452   llvm::dxil::ResourceClass RC;
453   if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
454     Diag(ArgLoc, diag::warn_attribute_type_not_supported)
455         << "ResourceClass" << Identifier;
456     return;
457   }
458 
459   D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc));
460 }
461 
handleResourceBindingAttr(Decl * D,const ParsedAttr & AL)462 void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) {
463   StringRef Space = "space0";
464   StringRef Slot = "";
465 
466   if (!AL.isArgIdent(0)) {
467     Diag(AL.getLoc(), diag::err_attribute_argument_type)
468         << AL << AANT_ArgumentIdentifier;
469     return;
470   }
471 
472   IdentifierLoc *Loc = AL.getArgAsIdent(0);
473   StringRef Str = Loc->Ident->getName();
474   SourceLocation ArgLoc = Loc->Loc;
475 
476   SourceLocation SpaceArgLoc;
477   if (AL.getNumArgs() == 2) {
478     Slot = Str;
479     if (!AL.isArgIdent(1)) {
480       Diag(AL.getLoc(), diag::err_attribute_argument_type)
481           << AL << AANT_ArgumentIdentifier;
482       return;
483     }
484 
485     IdentifierLoc *Loc = AL.getArgAsIdent(1);
486     Space = Loc->Ident->getName();
487     SpaceArgLoc = Loc->Loc;
488   } else {
489     Slot = Str;
490   }
491 
492   // Validate.
493   if (!Slot.empty()) {
494     switch (Slot[0]) {
495     case 'u':
496     case 'b':
497     case 's':
498     case 't':
499       break;
500     default:
501       Diag(ArgLoc, diag::err_hlsl_unsupported_register_type)
502           << Slot.substr(0, 1);
503       return;
504     }
505 
506     StringRef SlotNum = Slot.substr(1);
507     unsigned Num = 0;
508     if (SlotNum.getAsInteger(10, Num)) {
509       Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);
510       return;
511     }
512   }
513 
514   if (!Space.starts_with("space")) {
515     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
516     return;
517   }
518   StringRef SpaceNum = Space.substr(5);
519   unsigned Num = 0;
520   if (SpaceNum.getAsInteger(10, Num)) {
521     Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;
522     return;
523   }
524 
525   // FIXME: check reg type match decl. Issue
526   // https://github.com/llvm/llvm-project/issues/57886.
527   HLSLResourceBindingAttr *NewAttr =
528       HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);
529   if (NewAttr)
530     D->addAttr(NewAttr);
531 }
532 
handleParamModifierAttr(Decl * D,const ParsedAttr & AL)533 void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
534   HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
535       D, AL,
536       static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
537   if (NewAttr)
538     D->addAttr(NewAttr);
539 }
540 
541 namespace {
542 
543 /// This class implements HLSL availability diagnostics for default
544 /// and relaxed mode
545 ///
546 /// The goal of this diagnostic is to emit an error or warning when an
547 /// unavailable API is found in code that is reachable from the shader
548 /// entry function or from an exported function (when compiling a shader
549 /// library).
550 ///
551 /// This is done by traversing the AST of all shader entry point functions
552 /// and of all exported functions, and any functions that are referenced
553 /// from this AST. In other words, any functions that are reachable from
554 /// the entry points.
555 class DiagnoseHLSLAvailability
556     : public RecursiveASTVisitor<DiagnoseHLSLAvailability> {
557 
558   Sema &SemaRef;
559 
560   // Stack of functions to be scaned
561   llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
562 
563   // Tracks which environments functions have been scanned in.
564   //
565   // Maps FunctionDecl to an unsigned number that represents the set of shader
566   // environments the function has been scanned for.
567   // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
568   // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
569   // (verified by static_asserts in Triple.cpp), we can use it to index
570   // individual bits in the set, as long as we shift the values to start with 0
571   // by subtracting the value of llvm::Triple::Pixel first.
572   //
573   // The N'th bit in the set will be set if the function has been scanned
574   // in shader environment whose llvm::Triple::EnvironmentType integer value
575   // equals (llvm::Triple::Pixel + N).
576   //
577   // For example, if a function has been scanned in compute and pixel stage
578   // environment, the value will be 0x21 (100001 binary) because:
579   //
580   //   (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
581   //   (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
582   //
583   // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
584   // been scanned in any environment.
585   llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
586 
587   // Do not access these directly, use the get/set methods below to make
588   // sure the values are in sync
589   llvm::Triple::EnvironmentType CurrentShaderEnvironment;
590   unsigned CurrentShaderStageBit;
591 
592   // True if scanning a function that was already scanned in a different
593   // shader stage context, and therefore we should not report issues that
594   // depend only on shader model version because they would be duplicate.
595   bool ReportOnlyShaderStageIssues;
596 
597   // Helper methods for dealing with current stage context / environment
SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType)598   void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
599     static_assert(sizeof(unsigned) >= 4);
600     assert(HLSLShaderAttr::isValidShaderType(ShaderType));
601     assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
602            "ShaderType is too big for this bitmap"); // 31 is reserved for
603                                                      // "unknown"
604 
605     unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
606     CurrentShaderEnvironment = ShaderType;
607     CurrentShaderStageBit = (1 << bitmapIndex);
608   }
609 
SetUnknownShaderStageContext()610   void SetUnknownShaderStageContext() {
611     CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
612     CurrentShaderStageBit = (1 << 31);
613   }
614 
GetCurrentShaderEnvironment() const615   llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
616     return CurrentShaderEnvironment;
617   }
618 
InUnknownShaderStageContext() const619   bool InUnknownShaderStageContext() const {
620     return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
621   }
622 
623   // Helper methods for dealing with shader stage bitmap
AddToScannedFunctions(const FunctionDecl * FD)624   void AddToScannedFunctions(const FunctionDecl *FD) {
625     unsigned &ScannedStages = ScannedDecls.getOrInsertDefault(FD);
626     ScannedStages |= CurrentShaderStageBit;
627   }
628 
GetScannedStages(const FunctionDecl * FD)629   unsigned GetScannedStages(const FunctionDecl *FD) {
630     return ScannedDecls.getOrInsertDefault(FD);
631   }
632 
WasAlreadyScannedInCurrentStage(const FunctionDecl * FD)633   bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
634     return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));
635   }
636 
WasAlreadyScannedInCurrentStage(unsigned ScannerStages)637   bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
638     return ScannerStages & CurrentShaderStageBit;
639   }
640 
NeverBeenScanned(unsigned ScannedStages)641   static bool NeverBeenScanned(unsigned ScannedStages) {
642     return ScannedStages == 0;
643   }
644 
645   // Scanning methods
646   void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
647   void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
648                              SourceRange Range);
649   const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
650   bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
651 
652 public:
DiagnoseHLSLAvailability(Sema & SemaRef)653   DiagnoseHLSLAvailability(Sema &SemaRef) : SemaRef(SemaRef) {}
654 
655   // AST traversal methods
656   void RunOnTranslationUnit(const TranslationUnitDecl *TU);
657   void RunOnFunction(const FunctionDecl *FD);
658 
VisitDeclRefExpr(DeclRefExpr * DRE)659   bool VisitDeclRefExpr(DeclRefExpr *DRE) {
660     FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());
661     if (FD)
662       HandleFunctionOrMethodRef(FD, DRE);
663     return true;
664   }
665 
VisitMemberExpr(MemberExpr * ME)666   bool VisitMemberExpr(MemberExpr *ME) {
667     FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());
668     if (FD)
669       HandleFunctionOrMethodRef(FD, ME);
670     return true;
671   }
672 };
673 
HandleFunctionOrMethodRef(FunctionDecl * FD,Expr * RefExpr)674 void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
675                                                          Expr *RefExpr) {
676   assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
677          "expected DeclRefExpr or MemberExpr");
678 
679   // has a definition -> add to stack to be scanned
680   const FunctionDecl *FDWithBody = nullptr;
681   if (FD->hasBody(FDWithBody)) {
682     if (!WasAlreadyScannedInCurrentStage(FDWithBody))
683       DeclsToScan.push_back(FDWithBody);
684     return;
685   }
686 
687   // no body -> diagnose availability
688   const AvailabilityAttr *AA = FindAvailabilityAttr(FD);
689   if (AA)
690     CheckDeclAvailability(
691         FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
692 }
693 
RunOnTranslationUnit(const TranslationUnitDecl * TU)694 void DiagnoseHLSLAvailability::RunOnTranslationUnit(
695     const TranslationUnitDecl *TU) {
696 
697   // Iterate over all shader entry functions and library exports, and for those
698   // that have a body (definiton), run diag scan on each, setting appropriate
699   // shader environment context based on whether it is a shader entry function
700   // or an exported function. Exported functions can be in namespaces and in
701   // export declarations so we need to scan those declaration contexts as well.
702   llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
703   DeclContextsToScan.push_back(TU);
704 
705   while (!DeclContextsToScan.empty()) {
706     const DeclContext *DC = DeclContextsToScan.pop_back_val();
707     for (auto &D : DC->decls()) {
708       // do not scan implicit declaration generated by the implementation
709       if (D->isImplicit())
710         continue;
711 
712       // for namespace or export declaration add the context to the list to be
713       // scanned later
714       if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {
715         DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));
716         continue;
717       }
718 
719       // skip over other decls or function decls without body
720       const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);
721       if (!FD || !FD->isThisDeclarationADefinition())
722         continue;
723 
724       // shader entry point
725       if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
726         SetShaderStageContext(ShaderAttr->getType());
727         RunOnFunction(FD);
728         continue;
729       }
730       // exported library function
731       // FIXME: replace this loop with external linkage check once issue #92071
732       // is resolved
733       bool isExport = FD->isInExportDeclContext();
734       if (!isExport) {
735         for (const auto *Redecl : FD->redecls()) {
736           if (Redecl->isInExportDeclContext()) {
737             isExport = true;
738             break;
739           }
740         }
741       }
742       if (isExport) {
743         SetUnknownShaderStageContext();
744         RunOnFunction(FD);
745         continue;
746       }
747     }
748   }
749 }
750 
RunOnFunction(const FunctionDecl * FD)751 void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
752   assert(DeclsToScan.empty() && "DeclsToScan should be empty");
753   DeclsToScan.push_back(FD);
754 
755   while (!DeclsToScan.empty()) {
756     // Take one decl from the stack and check it by traversing its AST.
757     // For any CallExpr found during the traversal add it's callee to the top of
758     // the stack to be processed next. Functions already processed are stored in
759     // ScannedDecls.
760     const FunctionDecl *FD = DeclsToScan.pop_back_val();
761 
762     // Decl was already scanned
763     const unsigned ScannedStages = GetScannedStages(FD);
764     if (WasAlreadyScannedInCurrentStage(ScannedStages))
765       continue;
766 
767     ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
768 
769     AddToScannedFunctions(FD);
770     TraverseStmt(FD->getBody());
771   }
772 }
773 
HasMatchingEnvironmentOrNone(const AvailabilityAttr * AA)774 bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
775     const AvailabilityAttr *AA) {
776   IdentifierInfo *IIEnvironment = AA->getEnvironment();
777   if (!IIEnvironment)
778     return true;
779 
780   llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
781   if (CurrentEnv == llvm::Triple::UnknownEnvironment)
782     return false;
783 
784   llvm::Triple::EnvironmentType AttrEnv =
785       AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());
786 
787   return CurrentEnv == AttrEnv;
788 }
789 
790 const AvailabilityAttr *
FindAvailabilityAttr(const Decl * D)791 DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
792   AvailabilityAttr const *PartialMatch = nullptr;
793   // Check each AvailabilityAttr to find the one for this platform.
794   // For multiple attributes with the same platform try to find one for this
795   // environment.
796   for (const auto *A : D->attrs()) {
797     if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {
798       StringRef AttrPlatform = Avail->getPlatform()->getName();
799       StringRef TargetPlatform =
800           SemaRef.getASTContext().getTargetInfo().getPlatformName();
801 
802       // Match the platform name.
803       if (AttrPlatform == TargetPlatform) {
804         // Find the best matching attribute for this environment
805         if (HasMatchingEnvironmentOrNone(Avail))
806           return Avail;
807         PartialMatch = Avail;
808       }
809     }
810   }
811   return PartialMatch;
812 }
813 
814 // Check availability against target shader model version and current shader
815 // stage and emit diagnostic
CheckDeclAvailability(NamedDecl * D,const AvailabilityAttr * AA,SourceRange Range)816 void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
817                                                      const AvailabilityAttr *AA,
818                                                      SourceRange Range) {
819 
820   IdentifierInfo *IIEnv = AA->getEnvironment();
821 
822   if (!IIEnv) {
823     // The availability attribute does not have environment -> it depends only
824     // on shader model version and not on specific the shader stage.
825 
826     // Skip emitting the diagnostics if the diagnostic mode is set to
827     // strict (-fhlsl-strict-availability) because all relevant diagnostics
828     // were already emitted in the DiagnoseUnguardedAvailability scan
829     // (SemaAvailability.cpp).
830     if (SemaRef.getLangOpts().HLSLStrictAvailability)
831       return;
832 
833     // Do not report shader-stage-independent issues if scanning a function
834     // that was already scanned in a different shader stage context (they would
835     // be duplicate)
836     if (ReportOnlyShaderStageIssues)
837       return;
838 
839   } else {
840     // The availability attribute has environment -> we need to know
841     // the current stage context to property diagnose it.
842     if (InUnknownShaderStageContext())
843       return;
844   }
845 
846   // Check introduced version and if environment matches
847   bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
848   VersionTuple Introduced = AA->getIntroduced();
849   VersionTuple TargetVersion =
850       SemaRef.Context.getTargetInfo().getPlatformMinVersion();
851 
852   if (TargetVersion >= Introduced && EnvironmentMatches)
853     return;
854 
855   // Emit diagnostic message
856   const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
857   llvm::StringRef PlatformName(
858       AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));
859 
860   llvm::StringRef CurrentEnvStr =
861       llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());
862 
863   llvm::StringRef AttrEnvStr =
864       AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
865   bool UseEnvironment = !AttrEnvStr.empty();
866 
867   if (EnvironmentMatches) {
868     SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)
869         << Range << D << PlatformName << Introduced.getAsString()
870         << UseEnvironment << CurrentEnvStr;
871   } else {
872     SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)
873         << Range << D;
874   }
875 
876   SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)
877       << D << PlatformName << Introduced.getAsString()
878       << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
879       << UseEnvironment << AttrEnvStr << CurrentEnvStr;
880 }
881 
882 } // namespace
883 
DiagnoseAvailabilityViolations(TranslationUnitDecl * TU)884 void SemaHLSL::DiagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
885   // Skip running the diagnostics scan if the diagnostic mode is
886   // strict (-fhlsl-strict-availability) and the target shader stage is known
887   // because all relevant diagnostics were already emitted in the
888   // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
889   const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
890   if (SemaRef.getLangOpts().HLSLStrictAvailability &&
891       TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
892     return;
893 
894   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
895 }
896 
897 // Helper function for CheckHLSLBuiltinFunctionCall
CheckVectorElementCallArgs(Sema * S,CallExpr * TheCall)898 bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
899   assert(TheCall->getNumArgs() > 1);
900   ExprResult A = TheCall->getArg(0);
901 
902   QualType ArgTyA = A.get()->getType();
903 
904   auto *VecTyA = ArgTyA->getAs<VectorType>();
905   SourceLocation BuiltinLoc = TheCall->getBeginLoc();
906 
907   for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
908     ExprResult B = TheCall->getArg(i);
909     QualType ArgTyB = B.get()->getType();
910     auto *VecTyB = ArgTyB->getAs<VectorType>();
911     if (VecTyA == nullptr && VecTyB == nullptr)
912       return false;
913 
914     if (VecTyA && VecTyB) {
915       bool retValue = false;
916       if (VecTyA->getElementType() != VecTyB->getElementType()) {
917         // Note: type promotion is intended to be handeled via the intrinsics
918         //  and not the builtin itself.
919         S->Diag(TheCall->getBeginLoc(),
920                 diag::err_vec_builtin_incompatible_vector)
921             << TheCall->getDirectCallee() << /*useAllTerminology*/ true
922             << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
923         retValue = true;
924       }
925       if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
926         // You should only be hitting this case if you are calling the builtin
927         // directly. HLSL intrinsics should avoid this case via a
928         // HLSLVectorTruncation.
929         S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
930             << TheCall->getDirectCallee() << /*useAllTerminology*/ true
931             << SourceRange(TheCall->getArg(0)->getBeginLoc(),
932                            TheCall->getArg(1)->getEndLoc());
933         retValue = true;
934       }
935       return retValue;
936     }
937   }
938 
939   // Note: if we get here one of the args is a scalar which
940   // requires a VectorSplat on Arg0 or Arg1
941   S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
942       << TheCall->getDirectCallee() << /*useAllTerminology*/ true
943       << SourceRange(TheCall->getArg(0)->getBeginLoc(),
944                      TheCall->getArg(1)->getEndLoc());
945   return true;
946 }
947 
CheckArgsTypesAreCorrect(Sema * S,CallExpr * TheCall,QualType ExpectedType,llvm::function_ref<bool (clang::QualType PassedType)> Check)948 bool CheckArgsTypesAreCorrect(
949     Sema *S, CallExpr *TheCall, QualType ExpectedType,
950     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
951   for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
952     QualType PassedType = TheCall->getArg(i)->getType();
953     if (Check(PassedType)) {
954       if (auto *VecTyA = PassedType->getAs<VectorType>())
955         ExpectedType = S->Context.getVectorType(
956             ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
957       S->Diag(TheCall->getArg(0)->getBeginLoc(),
958               diag::err_typecheck_convert_incompatible)
959           << PassedType << ExpectedType << 1 << 0 << 0;
960       return true;
961     }
962   }
963   return false;
964 }
965 
CheckAllArgsHaveFloatRepresentation(Sema * S,CallExpr * TheCall)966 bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
967   auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
968     return !PassedType->hasFloatingRepresentation();
969   };
970   return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
971                                   checkAllFloatTypes);
972 }
973 
CheckFloatOrHalfRepresentations(Sema * S,CallExpr * TheCall)974 bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
975   auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
976     clang::QualType BaseType =
977         PassedType->isVectorType()
978             ? PassedType->getAs<clang::VectorType>()->getElementType()
979             : PassedType;
980     return !BaseType->isHalfType() && !BaseType->isFloat32Type();
981   };
982   return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
983                                   checkFloatorHalf);
984 }
985 
CheckNoDoubleVectors(Sema * S,CallExpr * TheCall)986 bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
987   auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
988     if (const auto *VecTy = PassedType->getAs<VectorType>())
989       return VecTy->getElementType()->isDoubleType();
990     return false;
991   };
992   return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
993                                   checkDoubleVector);
994 }
995 
CheckUnsignedIntRepresentation(Sema * S,CallExpr * TheCall)996 bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
997   auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
998     return !PassedType->hasUnsignedIntegerRepresentation();
999   };
1000   return CheckArgsTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
1001                                   checkAllUnsignedTypes);
1002 }
1003 
SetElementTypeAsReturnType(Sema * S,CallExpr * TheCall,QualType ReturnType)1004 void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
1005                                 QualType ReturnType) {
1006   auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
1007   if (VecTyA)
1008     ReturnType = S->Context.getVectorType(ReturnType, VecTyA->getNumElements(),
1009                                           VectorKind::Generic);
1010   TheCall->setType(ReturnType);
1011 }
1012 
1013 // Note: returning true in this case results in CheckBuiltinFunctionCall
1014 // returning an ExprError
CheckBuiltinFunctionCall(unsigned BuiltinID,CallExpr * TheCall)1015 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
1016   switch (BuiltinID) {
1017   case Builtin::BI__builtin_hlsl_elementwise_all:
1018   case Builtin::BI__builtin_hlsl_elementwise_any: {
1019     if (SemaRef.checkArgCount(TheCall, 1))
1020       return true;
1021     break;
1022   }
1023   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
1024     if (SemaRef.checkArgCount(TheCall, 3))
1025       return true;
1026     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
1027       return true;
1028     if (SemaRef.BuiltinElementwiseTernaryMath(
1029             TheCall, /*CheckForFloatArgs*/
1030             TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
1031       return true;
1032     break;
1033   }
1034   case Builtin::BI__builtin_hlsl_dot: {
1035     if (SemaRef.checkArgCount(TheCall, 2))
1036       return true;
1037     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
1038       return true;
1039     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
1040       return true;
1041     if (CheckNoDoubleVectors(&SemaRef, TheCall))
1042       return true;
1043     break;
1044   }
1045   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
1046     if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
1047       return true;
1048     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
1049       return true;
1050     break;
1051   }
1052   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
1053   case Builtin::BI__builtin_hlsl_elementwise_frac: {
1054     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
1055       return true;
1056     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
1057       return true;
1058     break;
1059   }
1060   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1061     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
1062       return true;
1063     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
1064       return true;
1065     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().BoolTy);
1066     break;
1067   }
1068   case Builtin::BI__builtin_hlsl_lerp: {
1069     if (SemaRef.checkArgCount(TheCall, 3))
1070       return true;
1071     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
1072       return true;
1073     if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
1074       return true;
1075     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
1076       return true;
1077     break;
1078   }
1079   case Builtin::BI__builtin_hlsl_mad: {
1080     if (SemaRef.checkArgCount(TheCall, 3))
1081       return true;
1082     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
1083       return true;
1084     if (SemaRef.BuiltinElementwiseTernaryMath(
1085             TheCall, /*CheckForFloatArgs*/
1086             TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
1087       return true;
1088     break;
1089   }
1090   // Note these are llvm builtins that we want to catch invalid intrinsic
1091   // generation. Normal handling of these builitns will occur elsewhere.
1092   case Builtin::BI__builtin_elementwise_bitreverse: {
1093     if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
1094       return true;
1095     break;
1096   }
1097   case Builtin::BI__builtin_elementwise_acos:
1098   case Builtin::BI__builtin_elementwise_asin:
1099   case Builtin::BI__builtin_elementwise_atan:
1100   case Builtin::BI__builtin_elementwise_ceil:
1101   case Builtin::BI__builtin_elementwise_cos:
1102   case Builtin::BI__builtin_elementwise_cosh:
1103   case Builtin::BI__builtin_elementwise_exp:
1104   case Builtin::BI__builtin_elementwise_exp2:
1105   case Builtin::BI__builtin_elementwise_floor:
1106   case Builtin::BI__builtin_elementwise_log:
1107   case Builtin::BI__builtin_elementwise_log2:
1108   case Builtin::BI__builtin_elementwise_log10:
1109   case Builtin::BI__builtin_elementwise_pow:
1110   case Builtin::BI__builtin_elementwise_roundeven:
1111   case Builtin::BI__builtin_elementwise_sin:
1112   case Builtin::BI__builtin_elementwise_sinh:
1113   case Builtin::BI__builtin_elementwise_sqrt:
1114   case Builtin::BI__builtin_elementwise_tan:
1115   case Builtin::BI__builtin_elementwise_tanh:
1116   case Builtin::BI__builtin_elementwise_trunc: {
1117     if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
1118       return true;
1119     break;
1120   }
1121   }
1122   return false;
1123 }
1124