xref: /freebsd/contrib/llvm-project/clang/lib/Parse/ParseHLSLRootSignature.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //=== ParseHLSLRootSignature.cpp - Parse Root Signature -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Parse/ParseHLSLRootSignature.h"
10 
11 #include "clang/Lex/LiteralSupport.h"
12 
13 using namespace llvm::hlsl::rootsig;
14 
15 namespace clang {
16 namespace hlsl {
17 
18 using TokenKind = RootSignatureToken::Kind;
19 
20 static const TokenKind RootElementKeywords[] = {
21     TokenKind::kw_RootFlags,
22     TokenKind::kw_CBV,
23     TokenKind::kw_UAV,
24     TokenKind::kw_SRV,
25     TokenKind::kw_DescriptorTable,
26     TokenKind::kw_StaticSampler,
27 };
28 
RootSignatureParser(llvm::dxbc::RootSignatureVersion Version,SmallVector<RootSignatureElement> & Elements,StringLiteral * Signature,Preprocessor & PP)29 RootSignatureParser::RootSignatureParser(
30     llvm::dxbc::RootSignatureVersion Version,
31     SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
32     Preprocessor &PP)
33     : Version(Version), Elements(Elements), Signature(Signature),
34       Lexer(Signature->getString()), PP(PP), CurToken(0) {}
35 
parse()36 bool RootSignatureParser::parse() {
37   // Iterate as many RootSignatureElements as possible, until we hit the
38   // end of the stream
39   bool HadError = false;
40   while (!peekExpectedToken(TokenKind::end_of_stream)) {
41     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
42       SourceLocation ElementLoc = getTokenLocation(CurToken);
43       auto Flags = parseRootFlags();
44       if (!Flags.has_value()) {
45         HadError = true;
46         skipUntilExpectedToken(RootElementKeywords);
47         continue;
48       }
49 
50       Elements.emplace_back(ElementLoc, *Flags);
51     } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
52       SourceLocation ElementLoc = getTokenLocation(CurToken);
53       auto Constants = parseRootConstants();
54       if (!Constants.has_value()) {
55         HadError = true;
56         skipUntilExpectedToken(RootElementKeywords);
57         continue;
58       }
59       Elements.emplace_back(ElementLoc, *Constants);
60     } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
61       SourceLocation ElementLoc = getTokenLocation(CurToken);
62       auto Table = parseDescriptorTable();
63       if (!Table.has_value()) {
64         HadError = true;
65         // We are within a DescriptorTable, we will do our best to recover
66         // by skipping until we encounter the expected closing ')'.
67         skipUntilClosedParens();
68         consumeNextToken();
69         skipUntilExpectedToken(RootElementKeywords);
70         continue;
71       }
72       Elements.emplace_back(ElementLoc, *Table);
73     } else if (tryConsumeExpectedToken(
74                    {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
75       SourceLocation ElementLoc = getTokenLocation(CurToken);
76       auto Descriptor = parseRootDescriptor();
77       if (!Descriptor.has_value()) {
78         HadError = true;
79         skipUntilExpectedToken(RootElementKeywords);
80         continue;
81       }
82       Elements.emplace_back(ElementLoc, *Descriptor);
83     } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
84       SourceLocation ElementLoc = getTokenLocation(CurToken);
85       auto Sampler = parseStaticSampler();
86       if (!Sampler.has_value()) {
87         HadError = true;
88         skipUntilExpectedToken(RootElementKeywords);
89         continue;
90       }
91       Elements.emplace_back(ElementLoc, *Sampler);
92     } else {
93       HadError = true;
94       consumeNextToken(); // let diagnostic be at the start of invalid token
95       reportDiag(diag::err_hlsl_invalid_token)
96           << /*parameter=*/0 << /*param of*/ TokenKind::kw_RootSignature;
97       skipUntilExpectedToken(RootElementKeywords);
98       continue;
99     }
100 
101     if (!tryConsumeExpectedToken(TokenKind::pu_comma)) {
102       // ',' denotes another element, otherwise, expected to be at end of stream
103       break;
104     }
105   }
106 
107   return HadError ||
108          consumeExpectedToken(TokenKind::end_of_stream,
109                               diag::err_expected_either, TokenKind::pu_comma);
110 }
111 
112 template <typename FlagType>
maybeOrFlag(std::optional<FlagType> Flags,FlagType Flag)113 static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
114   if (!Flags.has_value())
115     return Flag;
116 
117   return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
118                                llvm::to_underlying(Flag));
119 }
120 
parseRootFlags()121 std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
122   assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
123          "Expects to only be invoked starting at given keyword");
124 
125   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
126                            CurToken.TokKind))
127     return std::nullopt;
128 
129   std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
130 
131   // Handle valid empty case
132   if (tryConsumeExpectedToken(TokenKind::pu_r_paren))
133     return Flags;
134 
135   // Handle the edge-case of '0' to specify no flags set
136   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
137     if (!verifyZeroFlag()) {
138       reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
139       return std::nullopt;
140     }
141   } else {
142     // Otherwise, parse as many flags as possible
143     TokenKind Expected[] = {
144 #define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
145 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
146     };
147 
148     do {
149       if (tryConsumeExpectedToken(Expected)) {
150         switch (CurToken.TokKind) {
151 #define ROOT_FLAG_ENUM(NAME, LIT)                                              \
152   case TokenKind::en_##NAME:                                                   \
153     Flags = maybeOrFlag<llvm::dxbc::RootFlags>(Flags,                          \
154                                                llvm::dxbc::RootFlags::NAME);   \
155     break;
156 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
157         default:
158           llvm_unreachable("Switch for consumed enum token was not provided");
159         }
160       } else {
161         consumeNextToken(); // consume token to point at invalid token
162         reportDiag(diag::err_hlsl_invalid_token)
163             << /*value=*/1 << /*value of*/ TokenKind::kw_RootFlags;
164         return std::nullopt;
165       }
166     } while (tryConsumeExpectedToken(TokenKind::pu_or));
167   }
168 
169   if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_either,
170                            TokenKind::pu_comma))
171     return std::nullopt;
172 
173   return Flags;
174 }
175 
parseRootConstants()176 std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
177   assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
178          "Expects to only be invoked starting at given keyword");
179 
180   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
181                            CurToken.TokKind))
182     return std::nullopt;
183 
184   RootConstants Constants;
185 
186   auto Params = parseRootConstantParams();
187   if (!Params.has_value())
188     return std::nullopt;
189 
190   if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_either,
191                            TokenKind::pu_comma))
192     return std::nullopt;
193 
194   // Check mandatory parameters where provided
195   if (!Params->Num32BitConstants.has_value()) {
196     reportDiag(diag::err_hlsl_rootsig_missing_param)
197         << TokenKind::kw_num32BitConstants;
198     return std::nullopt;
199   }
200 
201   Constants.Num32BitConstants = Params->Num32BitConstants.value();
202 
203   if (!Params->Reg.has_value()) {
204     reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
205     return std::nullopt;
206   }
207 
208   Constants.Reg = Params->Reg.value();
209 
210   // Fill in optional parameters
211   if (Params->Visibility.has_value())
212     Constants.Visibility = Params->Visibility.value();
213 
214   if (Params->Space.has_value())
215     Constants.Space = Params->Space.value();
216 
217   return Constants;
218 }
219 
parseRootDescriptor()220 std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
221   assert((CurToken.TokKind == TokenKind::kw_CBV ||
222           CurToken.TokKind == TokenKind::kw_SRV ||
223           CurToken.TokKind == TokenKind::kw_UAV) &&
224          "Expects to only be invoked starting at given keyword");
225 
226   TokenKind DescriptorKind = CurToken.TokKind;
227 
228   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
229                            CurToken.TokKind))
230     return std::nullopt;
231 
232   RootDescriptor Descriptor;
233   TokenKind ExpectedReg;
234   switch (DescriptorKind) {
235   default:
236     llvm_unreachable("Switch for consumed token was not provided");
237   case TokenKind::kw_CBV:
238     Descriptor.Type = DescriptorType::CBuffer;
239     ExpectedReg = TokenKind::bReg;
240     break;
241   case TokenKind::kw_SRV:
242     Descriptor.Type = DescriptorType::SRV;
243     ExpectedReg = TokenKind::tReg;
244     break;
245   case TokenKind::kw_UAV:
246     Descriptor.Type = DescriptorType::UAV;
247     ExpectedReg = TokenKind::uReg;
248     break;
249   }
250   Descriptor.setDefaultFlags(Version);
251 
252   auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
253   if (!Params.has_value())
254     return std::nullopt;
255 
256   if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_either,
257                            TokenKind::pu_comma))
258     return std::nullopt;
259 
260   // Check mandatory parameters were provided
261   if (!Params->Reg.has_value()) {
262     reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
263     return std::nullopt;
264   }
265 
266   Descriptor.Reg = Params->Reg.value();
267 
268   // Fill in optional values
269   if (Params->Space.has_value())
270     Descriptor.Space = Params->Space.value();
271 
272   if (Params->Visibility.has_value())
273     Descriptor.Visibility = Params->Visibility.value();
274 
275   if (Params->Flags.has_value())
276     Descriptor.Flags = Params->Flags.value();
277 
278   return Descriptor;
279 }
280 
parseDescriptorTable()281 std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
282   assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
283          "Expects to only be invoked starting at given keyword");
284 
285   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
286                            CurToken.TokKind))
287     return std::nullopt;
288 
289   DescriptorTable Table;
290   std::optional<llvm::dxbc::ShaderVisibility> Visibility;
291 
292   // Iterate as many Clauses as possible, until we hit ')'
293   while (!peekExpectedToken(TokenKind::pu_r_paren)) {
294     if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
295                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
296       // DescriptorTableClause - CBV, SRV, UAV, or Sampler
297       SourceLocation ElementLoc = getTokenLocation(CurToken);
298       auto Clause = parseDescriptorTableClause();
299       if (!Clause.has_value()) {
300         // We are within a DescriptorTableClause, we will do our best to recover
301         // by skipping until we encounter the expected closing ')'
302         skipUntilExpectedToken(TokenKind::pu_r_paren);
303         consumeNextToken();
304         return std::nullopt;
305       }
306       Elements.emplace_back(ElementLoc, *Clause);
307       Table.NumClauses++;
308     } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
309       // visibility = SHADER_VISIBILITY
310       if (Visibility.has_value()) {
311         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
312         return std::nullopt;
313       }
314 
315       if (consumeExpectedToken(TokenKind::pu_equal))
316         return std::nullopt;
317 
318       Visibility = parseShaderVisibility(TokenKind::kw_visibility);
319       if (!Visibility.has_value())
320         return std::nullopt;
321     } else {
322       consumeNextToken(); // let diagnostic be at the start of invalid token
323       reportDiag(diag::err_hlsl_invalid_token)
324           << /*parameter=*/0 << /*param of*/ TokenKind::kw_DescriptorTable;
325       return std::nullopt;
326     }
327 
328     // ',' denotes another element, otherwise, expected to be at ')'
329     if (!tryConsumeExpectedToken(TokenKind::pu_comma))
330       break;
331   }
332 
333   if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_either,
334                            TokenKind::pu_comma))
335     return std::nullopt;
336 
337   // Fill in optional visibility
338   if (Visibility.has_value())
339     Table.Visibility = Visibility.value();
340 
341   return Table;
342 }
343 
344 std::optional<DescriptorTableClause>
parseDescriptorTableClause()345 RootSignatureParser::parseDescriptorTableClause() {
346   assert((CurToken.TokKind == TokenKind::kw_CBV ||
347           CurToken.TokKind == TokenKind::kw_SRV ||
348           CurToken.TokKind == TokenKind::kw_UAV ||
349           CurToken.TokKind == TokenKind::kw_Sampler) &&
350          "Expects to only be invoked starting at given keyword");
351 
352   TokenKind ParamKind = CurToken.TokKind;
353 
354   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
355                            CurToken.TokKind))
356     return std::nullopt;
357 
358   DescriptorTableClause Clause;
359   TokenKind ExpectedReg;
360   switch (ParamKind) {
361   default:
362     llvm_unreachable("Switch for consumed token was not provided");
363   case TokenKind::kw_CBV:
364     Clause.Type = ClauseType::CBuffer;
365     ExpectedReg = TokenKind::bReg;
366     break;
367   case TokenKind::kw_SRV:
368     Clause.Type = ClauseType::SRV;
369     ExpectedReg = TokenKind::tReg;
370     break;
371   case TokenKind::kw_UAV:
372     Clause.Type = ClauseType::UAV;
373     ExpectedReg = TokenKind::uReg;
374     break;
375   case TokenKind::kw_Sampler:
376     Clause.Type = ClauseType::Sampler;
377     ExpectedReg = TokenKind::sReg;
378     break;
379   }
380   Clause.setDefaultFlags(Version);
381 
382   auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
383   if (!Params.has_value())
384     return std::nullopt;
385 
386   if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_either,
387                            TokenKind::pu_comma))
388     return std::nullopt;
389 
390   // Check mandatory parameters were provided
391   if (!Params->Reg.has_value()) {
392     reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
393     return std::nullopt;
394   }
395 
396   Clause.Reg = Params->Reg.value();
397 
398   // Fill in optional values
399   if (Params->NumDescriptors.has_value())
400     Clause.NumDescriptors = Params->NumDescriptors.value();
401 
402   if (Params->Space.has_value())
403     Clause.Space = Params->Space.value();
404 
405   if (Params->Offset.has_value())
406     Clause.Offset = Params->Offset.value();
407 
408   if (Params->Flags.has_value())
409     Clause.Flags = Params->Flags.value();
410 
411   return Clause;
412 }
413 
parseStaticSampler()414 std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
415   assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
416          "Expects to only be invoked starting at given keyword");
417 
418   if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
419                            CurToken.TokKind))
420     return std::nullopt;
421 
422   StaticSampler Sampler;
423 
424   auto Params = parseStaticSamplerParams();
425   if (!Params.has_value())
426     return std::nullopt;
427 
428   if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_either,
429                            TokenKind::pu_comma))
430     return std::nullopt;
431 
432   // Check mandatory parameters were provided
433   if (!Params->Reg.has_value()) {
434     reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
435     return std::nullopt;
436   }
437 
438   Sampler.Reg = Params->Reg.value();
439 
440   // Fill in optional values
441   if (Params->Filter.has_value())
442     Sampler.Filter = Params->Filter.value();
443 
444   if (Params->AddressU.has_value())
445     Sampler.AddressU = Params->AddressU.value();
446 
447   if (Params->AddressV.has_value())
448     Sampler.AddressV = Params->AddressV.value();
449 
450   if (Params->AddressW.has_value())
451     Sampler.AddressW = Params->AddressW.value();
452 
453   if (Params->MipLODBias.has_value())
454     Sampler.MipLODBias = Params->MipLODBias.value();
455 
456   if (Params->MaxAnisotropy.has_value())
457     Sampler.MaxAnisotropy = Params->MaxAnisotropy.value();
458 
459   if (Params->CompFunc.has_value())
460     Sampler.CompFunc = Params->CompFunc.value();
461 
462   if (Params->BorderColor.has_value())
463     Sampler.BorderColor = Params->BorderColor.value();
464 
465   if (Params->MinLOD.has_value())
466     Sampler.MinLOD = Params->MinLOD.value();
467 
468   if (Params->MaxLOD.has_value())
469     Sampler.MaxLOD = Params->MaxLOD.value();
470 
471   if (Params->Space.has_value())
472     Sampler.Space = Params->Space.value();
473 
474   if (Params->Visibility.has_value())
475     Sampler.Visibility = Params->Visibility.value();
476 
477   return Sampler;
478 }
479 
480 // Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
481 // order and only exactly once. The following methods will parse through as
482 // many arguments as possible reporting an error if a duplicate is seen.
483 std::optional<RootSignatureParser::ParsedConstantParams>
parseRootConstantParams()484 RootSignatureParser::parseRootConstantParams() {
485   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
486          "Expects to only be invoked starting at given token");
487 
488   ParsedConstantParams Params;
489   while (!peekExpectedToken(TokenKind::pu_r_paren)) {
490     if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
491       // `num32BitConstants` `=` POS_INT
492       if (Params.Num32BitConstants.has_value()) {
493         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
494         return std::nullopt;
495       }
496 
497       if (consumeExpectedToken(TokenKind::pu_equal))
498         return std::nullopt;
499 
500       auto Num32BitConstants = parseUIntParam();
501       if (!Num32BitConstants.has_value())
502         return std::nullopt;
503       Params.Num32BitConstants = Num32BitConstants;
504     } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
505       // `b` POS_INT
506       if (Params.Reg.has_value()) {
507         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
508         return std::nullopt;
509       }
510       auto Reg = parseRegister();
511       if (!Reg.has_value())
512         return std::nullopt;
513       Params.Reg = Reg;
514     } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
515       // `space` `=` POS_INT
516       if (Params.Space.has_value()) {
517         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
518         return std::nullopt;
519       }
520 
521       if (consumeExpectedToken(TokenKind::pu_equal))
522         return std::nullopt;
523 
524       auto Space = parseUIntParam();
525       if (!Space.has_value())
526         return std::nullopt;
527       Params.Space = Space;
528     } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
529       // `visibility` `=` SHADER_VISIBILITY
530       if (Params.Visibility.has_value()) {
531         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
532         return std::nullopt;
533       }
534 
535       if (consumeExpectedToken(TokenKind::pu_equal))
536         return std::nullopt;
537 
538       auto Visibility = parseShaderVisibility(TokenKind::kw_visibility);
539       if (!Visibility.has_value())
540         return std::nullopt;
541       Params.Visibility = Visibility;
542     } else {
543       consumeNextToken(); // let diagnostic be at the start of invalid token
544       reportDiag(diag::err_hlsl_invalid_token)
545           << /*parameter=*/0 << /*param of*/ TokenKind::kw_RootConstants;
546       return std::nullopt;
547     }
548 
549     // ',' denotes another element, otherwise, expected to be at ')'
550     if (!tryConsumeExpectedToken(TokenKind::pu_comma))
551       break;
552   }
553 
554   return Params;
555 }
556 
557 std::optional<RootSignatureParser::ParsedRootDescriptorParams>
parseRootDescriptorParams(TokenKind DescKind,TokenKind RegType)558 RootSignatureParser::parseRootDescriptorParams(TokenKind DescKind,
559                                                TokenKind RegType) {
560   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
561          "Expects to only be invoked starting at given token");
562 
563   ParsedRootDescriptorParams Params;
564   while (!peekExpectedToken(TokenKind::pu_r_paren)) {
565     if (tryConsumeExpectedToken(RegType)) {
566       // ( `b` | `t` | `u`) POS_INT
567       if (Params.Reg.has_value()) {
568         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
569         return std::nullopt;
570       }
571       auto Reg = parseRegister();
572       if (!Reg.has_value())
573         return std::nullopt;
574       Params.Reg = Reg;
575     } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
576       // `space` `=` POS_INT
577       if (Params.Space.has_value()) {
578         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
579         return std::nullopt;
580       }
581 
582       if (consumeExpectedToken(TokenKind::pu_equal))
583         return std::nullopt;
584 
585       auto Space = parseUIntParam();
586       if (!Space.has_value())
587         return std::nullopt;
588       Params.Space = Space;
589     } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
590       // `visibility` `=` SHADER_VISIBILITY
591       if (Params.Visibility.has_value()) {
592         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
593         return std::nullopt;
594       }
595 
596       if (consumeExpectedToken(TokenKind::pu_equal))
597         return std::nullopt;
598 
599       auto Visibility = parseShaderVisibility(TokenKind::kw_visibility);
600       if (!Visibility.has_value())
601         return std::nullopt;
602       Params.Visibility = Visibility;
603     } else if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
604       // `flags` `=` ROOT_DESCRIPTOR_FLAGS
605       if (Params.Flags.has_value()) {
606         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
607         return std::nullopt;
608       }
609 
610       if (consumeExpectedToken(TokenKind::pu_equal))
611         return std::nullopt;
612 
613       auto Flags = parseRootDescriptorFlags(TokenKind::kw_flags);
614       if (!Flags.has_value())
615         return std::nullopt;
616       Params.Flags = Flags;
617     } else {
618       consumeNextToken(); // let diagnostic be at the start of invalid token
619       reportDiag(diag::err_hlsl_invalid_token)
620           << /*parameter=*/0 << /*param of*/ DescKind;
621       return std::nullopt;
622     }
623 
624     // ',' denotes another element, otherwise, expected to be at ')'
625     if (!tryConsumeExpectedToken(TokenKind::pu_comma))
626       break;
627   }
628 
629   return Params;
630 }
631 
632 std::optional<RootSignatureParser::ParsedClauseParams>
parseDescriptorTableClauseParams(TokenKind ClauseKind,TokenKind RegType)633 RootSignatureParser::parseDescriptorTableClauseParams(TokenKind ClauseKind,
634                                                       TokenKind RegType) {
635   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
636          "Expects to only be invoked starting at given token");
637 
638   ParsedClauseParams Params;
639   while (!peekExpectedToken(TokenKind::pu_r_paren)) {
640     if (tryConsumeExpectedToken(RegType)) {
641       // ( `b` | `t` | `u` | `s`) POS_INT
642       if (Params.Reg.has_value()) {
643         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
644         return std::nullopt;
645       }
646       auto Reg = parseRegister();
647       if (!Reg.has_value())
648         return std::nullopt;
649       Params.Reg = Reg;
650     } else if (tryConsumeExpectedToken(TokenKind::kw_numDescriptors)) {
651       // `numDescriptors` `=` POS_INT | unbounded
652       if (Params.NumDescriptors.has_value()) {
653         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
654         return std::nullopt;
655       }
656 
657       if (consumeExpectedToken(TokenKind::pu_equal))
658         return std::nullopt;
659 
660       std::optional<uint32_t> NumDescriptors;
661       if (tryConsumeExpectedToken(TokenKind::en_unbounded))
662         NumDescriptors = NumDescriptorsUnbounded;
663       else {
664         NumDescriptors = parseUIntParam();
665         if (!NumDescriptors.has_value())
666           return std::nullopt;
667       }
668 
669       Params.NumDescriptors = NumDescriptors;
670     } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
671       // `space` `=` POS_INT
672       if (Params.Space.has_value()) {
673         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
674         return std::nullopt;
675       }
676 
677       if (consumeExpectedToken(TokenKind::pu_equal))
678         return std::nullopt;
679 
680       auto Space = parseUIntParam();
681       if (!Space.has_value())
682         return std::nullopt;
683       Params.Space = Space;
684     } else if (tryConsumeExpectedToken(TokenKind::kw_offset)) {
685       // `offset` `=` POS_INT | DESCRIPTOR_RANGE_OFFSET_APPEND
686       if (Params.Offset.has_value()) {
687         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
688         return std::nullopt;
689       }
690 
691       if (consumeExpectedToken(TokenKind::pu_equal))
692         return std::nullopt;
693 
694       std::optional<uint32_t> Offset;
695       if (tryConsumeExpectedToken(TokenKind::en_DescriptorRangeOffsetAppend))
696         Offset = DescriptorTableOffsetAppend;
697       else {
698         Offset = parseUIntParam();
699         if (!Offset.has_value())
700           return std::nullopt;
701       }
702 
703       Params.Offset = Offset;
704     } else if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
705       // `flags` `=` DESCRIPTOR_RANGE_FLAGS
706       if (Params.Flags.has_value()) {
707         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
708         return std::nullopt;
709       }
710 
711       if (consumeExpectedToken(TokenKind::pu_equal))
712         return std::nullopt;
713 
714       auto Flags = parseDescriptorRangeFlags(TokenKind::kw_flags);
715       if (!Flags.has_value())
716         return std::nullopt;
717       Params.Flags = Flags;
718     } else {
719       consumeNextToken(); // let diagnostic be at the start of invalid token
720       reportDiag(diag::err_hlsl_invalid_token)
721           << /*parameter=*/0 << /*param of*/ ClauseKind;
722       return std::nullopt;
723     }
724 
725     // ',' denotes another element, otherwise, expected to be at ')'
726     if (!tryConsumeExpectedToken(TokenKind::pu_comma))
727       break;
728   }
729 
730   return Params;
731 }
732 
733 std::optional<RootSignatureParser::ParsedStaticSamplerParams>
parseStaticSamplerParams()734 RootSignatureParser::parseStaticSamplerParams() {
735   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
736          "Expects to only be invoked starting at given token");
737 
738   ParsedStaticSamplerParams Params;
739   while (!peekExpectedToken(TokenKind::pu_r_paren)) {
740     if (tryConsumeExpectedToken(TokenKind::sReg)) {
741       // `s` POS_INT
742       if (Params.Reg.has_value()) {
743         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
744         return std::nullopt;
745       }
746       auto Reg = parseRegister();
747       if (!Reg.has_value())
748         return std::nullopt;
749       Params.Reg = Reg;
750     } else if (tryConsumeExpectedToken(TokenKind::kw_filter)) {
751       // `filter` `=` FILTER
752       if (Params.Filter.has_value()) {
753         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
754         return std::nullopt;
755       }
756 
757       if (consumeExpectedToken(TokenKind::pu_equal))
758         return std::nullopt;
759 
760       auto Filter = parseSamplerFilter(TokenKind::kw_filter);
761       if (!Filter.has_value())
762         return std::nullopt;
763       Params.Filter = Filter;
764     } else if (tryConsumeExpectedToken(TokenKind::kw_addressU)) {
765       // `addressU` `=` TEXTURE_ADDRESS
766       if (Params.AddressU.has_value()) {
767         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
768         return std::nullopt;
769       }
770 
771       if (consumeExpectedToken(TokenKind::pu_equal))
772         return std::nullopt;
773 
774       auto AddressU = parseTextureAddressMode(TokenKind::kw_addressU);
775       if (!AddressU.has_value())
776         return std::nullopt;
777       Params.AddressU = AddressU;
778     } else if (tryConsumeExpectedToken(TokenKind::kw_addressV)) {
779       // `addressV` `=` TEXTURE_ADDRESS
780       if (Params.AddressV.has_value()) {
781         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
782         return std::nullopt;
783       }
784 
785       if (consumeExpectedToken(TokenKind::pu_equal))
786         return std::nullopt;
787 
788       auto AddressV = parseTextureAddressMode(TokenKind::kw_addressV);
789       if (!AddressV.has_value())
790         return std::nullopt;
791       Params.AddressV = AddressV;
792     } else if (tryConsumeExpectedToken(TokenKind::kw_addressW)) {
793       // `addressW` `=` TEXTURE_ADDRESS
794       if (Params.AddressW.has_value()) {
795         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
796         return std::nullopt;
797       }
798 
799       if (consumeExpectedToken(TokenKind::pu_equal))
800         return std::nullopt;
801 
802       auto AddressW = parseTextureAddressMode(TokenKind::kw_addressW);
803       if (!AddressW.has_value())
804         return std::nullopt;
805       Params.AddressW = AddressW;
806     } else if (tryConsumeExpectedToken(TokenKind::kw_mipLODBias)) {
807       // `mipLODBias` `=` NUMBER
808       if (Params.MipLODBias.has_value()) {
809         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
810         return std::nullopt;
811       }
812 
813       if (consumeExpectedToken(TokenKind::pu_equal))
814         return std::nullopt;
815 
816       auto MipLODBias = parseFloatParam();
817       if (!MipLODBias.has_value())
818         return std::nullopt;
819       Params.MipLODBias = MipLODBias;
820     } else if (tryConsumeExpectedToken(TokenKind::kw_maxAnisotropy)) {
821       // `maxAnisotropy` `=` POS_INT
822       if (Params.MaxAnisotropy.has_value()) {
823         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
824         return std::nullopt;
825       }
826 
827       if (consumeExpectedToken(TokenKind::pu_equal))
828         return std::nullopt;
829 
830       auto MaxAnisotropy = parseUIntParam();
831       if (!MaxAnisotropy.has_value())
832         return std::nullopt;
833       Params.MaxAnisotropy = MaxAnisotropy;
834     } else if (tryConsumeExpectedToken(TokenKind::kw_comparisonFunc)) {
835       // `comparisonFunc` `=` COMPARISON_FUNC
836       if (Params.CompFunc.has_value()) {
837         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
838         return std::nullopt;
839       }
840 
841       if (consumeExpectedToken(TokenKind::pu_equal))
842         return std::nullopt;
843 
844       auto CompFunc = parseComparisonFunc(TokenKind::kw_comparisonFunc);
845       if (!CompFunc.has_value())
846         return std::nullopt;
847       Params.CompFunc = CompFunc;
848     } else if (tryConsumeExpectedToken(TokenKind::kw_borderColor)) {
849       // `borderColor` `=` STATIC_BORDER_COLOR
850       if (Params.BorderColor.has_value()) {
851         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
852         return std::nullopt;
853       }
854 
855       if (consumeExpectedToken(TokenKind::pu_equal))
856         return std::nullopt;
857 
858       auto BorderColor = parseStaticBorderColor(TokenKind::kw_borderColor);
859       if (!BorderColor.has_value())
860         return std::nullopt;
861       Params.BorderColor = BorderColor;
862     } else if (tryConsumeExpectedToken(TokenKind::kw_minLOD)) {
863       // `minLOD` `=` NUMBER
864       if (Params.MinLOD.has_value()) {
865         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
866         return std::nullopt;
867       }
868 
869       if (consumeExpectedToken(TokenKind::pu_equal))
870         return std::nullopt;
871 
872       auto MinLOD = parseFloatParam();
873       if (!MinLOD.has_value())
874         return std::nullopt;
875       Params.MinLOD = MinLOD;
876     } else if (tryConsumeExpectedToken(TokenKind::kw_maxLOD)) {
877       // `maxLOD` `=` NUMBER
878       if (Params.MaxLOD.has_value()) {
879         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
880         return std::nullopt;
881       }
882 
883       if (consumeExpectedToken(TokenKind::pu_equal))
884         return std::nullopt;
885 
886       auto MaxLOD = parseFloatParam();
887       if (!MaxLOD.has_value())
888         return std::nullopt;
889       Params.MaxLOD = MaxLOD;
890     } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
891       // `space` `=` POS_INT
892       if (Params.Space.has_value()) {
893         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
894         return std::nullopt;
895       }
896 
897       if (consumeExpectedToken(TokenKind::pu_equal))
898         return std::nullopt;
899 
900       auto Space = parseUIntParam();
901       if (!Space.has_value())
902         return std::nullopt;
903       Params.Space = Space;
904     } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
905       // `visibility` `=` SHADER_VISIBILITY
906       if (Params.Visibility.has_value()) {
907         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
908         return std::nullopt;
909       }
910 
911       if (consumeExpectedToken(TokenKind::pu_equal))
912         return std::nullopt;
913 
914       auto Visibility = parseShaderVisibility(TokenKind::kw_visibility);
915       if (!Visibility.has_value())
916         return std::nullopt;
917       Params.Visibility = Visibility;
918     } else {
919       consumeNextToken(); // let diagnostic be at the start of invalid token
920       reportDiag(diag::err_hlsl_invalid_token)
921           << /*parameter=*/0 << /*param of*/ TokenKind::kw_StaticSampler;
922       return std::nullopt;
923     }
924 
925     // ',' denotes another element, otherwise, expected to be at ')'
926     if (!tryConsumeExpectedToken(TokenKind::pu_comma))
927       break;
928   }
929 
930   return Params;
931 }
932 
parseUIntParam()933 std::optional<uint32_t> RootSignatureParser::parseUIntParam() {
934   assert(CurToken.TokKind == TokenKind::pu_equal &&
935          "Expects to only be invoked starting at given keyword");
936   tryConsumeExpectedToken(TokenKind::pu_plus);
937   if (consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
938                            CurToken.TokKind))
939     return std::nullopt;
940   return handleUIntLiteral();
941 }
942 
parseRegister()943 std::optional<Register> RootSignatureParser::parseRegister() {
944   assert((CurToken.TokKind == TokenKind::bReg ||
945           CurToken.TokKind == TokenKind::tReg ||
946           CurToken.TokKind == TokenKind::uReg ||
947           CurToken.TokKind == TokenKind::sReg) &&
948          "Expects to only be invoked starting at given keyword");
949 
950   Register Reg;
951   switch (CurToken.TokKind) {
952   default:
953     llvm_unreachable("Switch for consumed token was not provided");
954   case TokenKind::bReg:
955     Reg.ViewType = RegisterType::BReg;
956     break;
957   case TokenKind::tReg:
958     Reg.ViewType = RegisterType::TReg;
959     break;
960   case TokenKind::uReg:
961     Reg.ViewType = RegisterType::UReg;
962     break;
963   case TokenKind::sReg:
964     Reg.ViewType = RegisterType::SReg;
965     break;
966   }
967 
968   auto Number = handleUIntLiteral();
969   if (!Number.has_value())
970     return std::nullopt; // propogate NumericLiteralParser error
971 
972   Reg.Number = *Number;
973   return Reg;
974 }
975 
parseFloatParam()976 std::optional<float> RootSignatureParser::parseFloatParam() {
977   assert(CurToken.TokKind == TokenKind::pu_equal &&
978          "Expects to only be invoked starting at given keyword");
979   // Consume sign modifier
980   bool Signed =
981       tryConsumeExpectedToken({TokenKind::pu_plus, TokenKind::pu_minus});
982   bool Negated = Signed && CurToken.TokKind == TokenKind::pu_minus;
983 
984   // DXC will treat a postive signed integer as unsigned
985   if (!Negated && tryConsumeExpectedToken(TokenKind::int_literal)) {
986     std::optional<uint32_t> UInt = handleUIntLiteral();
987     if (!UInt.has_value())
988       return std::nullopt;
989     return float(UInt.value());
990   }
991 
992   if (Negated && tryConsumeExpectedToken(TokenKind::int_literal)) {
993     std::optional<int32_t> Int = handleIntLiteral(Negated);
994     if (!Int.has_value())
995       return std::nullopt;
996     return float(Int.value());
997   }
998 
999   if (tryConsumeExpectedToken(TokenKind::float_literal)) {
1000     std::optional<float> Float = handleFloatLiteral(Negated);
1001     if (!Float.has_value())
1002       return std::nullopt;
1003     return Float.value();
1004   }
1005 
1006   return std::nullopt;
1007 }
1008 
1009 std::optional<llvm::dxbc::ShaderVisibility>
parseShaderVisibility(TokenKind Context)1010 RootSignatureParser::parseShaderVisibility(TokenKind Context) {
1011   assert(CurToken.TokKind == TokenKind::pu_equal &&
1012          "Expects to only be invoked starting at given keyword");
1013 
1014   TokenKind Expected[] = {
1015 #define SHADER_VISIBILITY_ENUM(NAME, LIT) TokenKind::en_##NAME,
1016 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1017   };
1018 
1019   if (!tryConsumeExpectedToken(Expected)) {
1020     consumeNextToken(); // consume token to point at invalid token
1021     reportDiag(diag::err_hlsl_invalid_token)
1022         << /*value=*/1 << /*value of*/ Context;
1023     return std::nullopt;
1024   }
1025 
1026   switch (CurToken.TokKind) {
1027 #define SHADER_VISIBILITY_ENUM(NAME, LIT)                                      \
1028   case TokenKind::en_##NAME:                                                   \
1029     return llvm::dxbc::ShaderVisibility::NAME;                                 \
1030     break;
1031 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1032   default:
1033     llvm_unreachable("Switch for consumed enum token was not provided");
1034   }
1035 
1036   return std::nullopt;
1037 }
1038 
1039 std::optional<llvm::dxbc::SamplerFilter>
parseSamplerFilter(TokenKind Context)1040 RootSignatureParser::parseSamplerFilter(TokenKind Context) {
1041   assert(CurToken.TokKind == TokenKind::pu_equal &&
1042          "Expects to only be invoked starting at given keyword");
1043 
1044   TokenKind Expected[] = {
1045 #define FILTER_ENUM(NAME, LIT) TokenKind::en_##NAME,
1046 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1047   };
1048 
1049   if (!tryConsumeExpectedToken(Expected)) {
1050     consumeNextToken(); // consume token to point at invalid token
1051     reportDiag(diag::err_hlsl_invalid_token)
1052         << /*value=*/1 << /*value of*/ Context;
1053     return std::nullopt;
1054   }
1055 
1056   switch (CurToken.TokKind) {
1057 #define FILTER_ENUM(NAME, LIT)                                                 \
1058   case TokenKind::en_##NAME:                                                   \
1059     return llvm::dxbc::SamplerFilter::NAME;                                    \
1060     break;
1061 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1062   default:
1063     llvm_unreachable("Switch for consumed enum token was not provided");
1064   }
1065 
1066   return std::nullopt;
1067 }
1068 
1069 std::optional<llvm::dxbc::TextureAddressMode>
parseTextureAddressMode(TokenKind Context)1070 RootSignatureParser::parseTextureAddressMode(TokenKind Context) {
1071   assert(CurToken.TokKind == TokenKind::pu_equal &&
1072          "Expects to only be invoked starting at given keyword");
1073 
1074   TokenKind Expected[] = {
1075 #define TEXTURE_ADDRESS_MODE_ENUM(NAME, LIT) TokenKind::en_##NAME,
1076 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1077   };
1078 
1079   if (!tryConsumeExpectedToken(Expected)) {
1080     consumeNextToken(); // consume token to point at invalid token
1081     reportDiag(diag::err_hlsl_invalid_token)
1082         << /*value=*/1 << /*value of*/ Context;
1083     return std::nullopt;
1084   }
1085 
1086   switch (CurToken.TokKind) {
1087 #define TEXTURE_ADDRESS_MODE_ENUM(NAME, LIT)                                   \
1088   case TokenKind::en_##NAME:                                                   \
1089     return llvm::dxbc::TextureAddressMode::NAME;                               \
1090     break;
1091 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1092   default:
1093     llvm_unreachable("Switch for consumed enum token was not provided");
1094   }
1095 
1096   return std::nullopt;
1097 }
1098 
1099 std::optional<llvm::dxbc::ComparisonFunc>
parseComparisonFunc(TokenKind Context)1100 RootSignatureParser::parseComparisonFunc(TokenKind Context) {
1101   assert(CurToken.TokKind == TokenKind::pu_equal &&
1102          "Expects to only be invoked starting at given keyword");
1103 
1104   TokenKind Expected[] = {
1105 #define COMPARISON_FUNC_ENUM(NAME, LIT) TokenKind::en_##NAME,
1106 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1107   };
1108 
1109   if (!tryConsumeExpectedToken(Expected)) {
1110     consumeNextToken(); // consume token to point at invalid token
1111     reportDiag(diag::err_hlsl_invalid_token)
1112         << /*value=*/1 << /*value of*/ Context;
1113     return std::nullopt;
1114   }
1115 
1116   switch (CurToken.TokKind) {
1117 #define COMPARISON_FUNC_ENUM(NAME, LIT)                                        \
1118   case TokenKind::en_##NAME:                                                   \
1119     return llvm::dxbc::ComparisonFunc::NAME;                                   \
1120     break;
1121 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1122   default:
1123     llvm_unreachable("Switch for consumed enum token was not provided");
1124   }
1125 
1126   return std::nullopt;
1127 }
1128 
1129 std::optional<llvm::dxbc::StaticBorderColor>
parseStaticBorderColor(TokenKind Context)1130 RootSignatureParser::parseStaticBorderColor(TokenKind Context) {
1131   assert(CurToken.TokKind == TokenKind::pu_equal &&
1132          "Expects to only be invoked starting at given keyword");
1133 
1134   TokenKind Expected[] = {
1135 #define STATIC_BORDER_COLOR_ENUM(NAME, LIT) TokenKind::en_##NAME,
1136 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1137   };
1138 
1139   if (!tryConsumeExpectedToken(Expected)) {
1140     consumeNextToken(); // consume token to point at invalid token
1141     reportDiag(diag::err_hlsl_invalid_token)
1142         << /*value=*/1 << /*value of*/ Context;
1143     return std::nullopt;
1144   }
1145 
1146   switch (CurToken.TokKind) {
1147 #define STATIC_BORDER_COLOR_ENUM(NAME, LIT)                                    \
1148   case TokenKind::en_##NAME:                                                   \
1149     return llvm::dxbc::StaticBorderColor::NAME;                                \
1150     break;
1151 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1152   default:
1153     llvm_unreachable("Switch for consumed enum token was not provided");
1154   }
1155 
1156   return std::nullopt;
1157 }
1158 
1159 std::optional<llvm::dxbc::RootDescriptorFlags>
parseRootDescriptorFlags(TokenKind Context)1160 RootSignatureParser::parseRootDescriptorFlags(TokenKind Context) {
1161   assert(CurToken.TokKind == TokenKind::pu_equal &&
1162          "Expects to only be invoked starting at given keyword");
1163 
1164   // Handle the edge-case of '0' to specify no flags set
1165   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
1166     if (!verifyZeroFlag()) {
1167       reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
1168       return std::nullopt;
1169     }
1170     return llvm::dxbc::RootDescriptorFlags::None;
1171   }
1172 
1173   TokenKind Expected[] = {
1174 #define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
1175 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1176   };
1177 
1178   std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
1179 
1180   do {
1181     if (tryConsumeExpectedToken(Expected)) {
1182       switch (CurToken.TokKind) {
1183 #define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT)                                   \
1184   case TokenKind::en_##NAME:                                                   \
1185     Flags = maybeOrFlag<llvm::dxbc::RootDescriptorFlags>(                      \
1186         Flags, llvm::dxbc::RootDescriptorFlags::NAME);                         \
1187     break;
1188 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1189       default:
1190         llvm_unreachable("Switch for consumed enum token was not provided");
1191       }
1192     } else {
1193       consumeNextToken(); // consume token to point at invalid token
1194       reportDiag(diag::err_hlsl_invalid_token)
1195           << /*value=*/1 << /*value of*/ Context;
1196       return std::nullopt;
1197     }
1198   } while (tryConsumeExpectedToken(TokenKind::pu_or));
1199 
1200   return Flags;
1201 }
1202 
1203 std::optional<llvm::dxbc::DescriptorRangeFlags>
parseDescriptorRangeFlags(TokenKind Context)1204 RootSignatureParser::parseDescriptorRangeFlags(TokenKind Context) {
1205   assert(CurToken.TokKind == TokenKind::pu_equal &&
1206          "Expects to only be invoked starting at given keyword");
1207 
1208   // Handle the edge-case of '0' to specify no flags set
1209   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
1210     if (!verifyZeroFlag()) {
1211       reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
1212       return std::nullopt;
1213     }
1214     return llvm::dxbc::DescriptorRangeFlags::None;
1215   }
1216 
1217   TokenKind Expected[] = {
1218 #define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
1219 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1220   };
1221 
1222   std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
1223 
1224   do {
1225     if (tryConsumeExpectedToken(Expected)) {
1226       switch (CurToken.TokKind) {
1227 #define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON)                              \
1228   case TokenKind::en_##NAME:                                                   \
1229     Flags = maybeOrFlag<llvm::dxbc::DescriptorRangeFlags>(                     \
1230         Flags, llvm::dxbc::DescriptorRangeFlags::NAME);                        \
1231     break;
1232 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1233       default:
1234         llvm_unreachable("Switch for consumed enum token was not provided");
1235       }
1236     } else {
1237       consumeNextToken(); // consume token to point at invalid token
1238       reportDiag(diag::err_hlsl_invalid_token)
1239           << /*value=*/1 << /*value of*/ Context;
1240       return std::nullopt;
1241     }
1242   } while (tryConsumeExpectedToken(TokenKind::pu_or));
1243 
1244   return Flags;
1245 }
1246 
handleUIntLiteral()1247 std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
1248   // Parse the numeric value and do semantic checks on its specification
1249   clang::NumericLiteralParser Literal(
1250       CurToken.NumSpelling, getTokenLocation(CurToken), PP.getSourceManager(),
1251       PP.getLangOpts(), PP.getTargetInfo(), PP.getDiagnostics());
1252   if (Literal.hadError)
1253     return std::nullopt; // Error has already been reported so just return
1254 
1255   assert(Literal.isIntegerLiteral() &&
1256          "NumSpelling can only consist of digits");
1257 
1258   llvm::APSInt Val(32, /*IsUnsigned=*/true);
1259   if (Literal.GetIntegerValue(Val)) {
1260     // Report that the value has overflowed
1261     reportDiag(diag::err_hlsl_number_literal_overflow)
1262         << /*integer type*/ 0 << /*is signed*/ 0;
1263     return std::nullopt;
1264   }
1265 
1266   return Val.getExtValue();
1267 }
1268 
handleIntLiteral(bool Negated)1269 std::optional<int32_t> RootSignatureParser::handleIntLiteral(bool Negated) {
1270   // Parse the numeric value and do semantic checks on its specification
1271   clang::NumericLiteralParser Literal(
1272       CurToken.NumSpelling, getTokenLocation(CurToken), PP.getSourceManager(),
1273       PP.getLangOpts(), PP.getTargetInfo(), PP.getDiagnostics());
1274   if (Literal.hadError)
1275     return std::nullopt; // Error has already been reported so just return
1276 
1277   assert(Literal.isIntegerLiteral() &&
1278          "NumSpelling can only consist of digits");
1279 
1280   llvm::APSInt Val(32, /*IsUnsigned=*/true);
1281   // GetIntegerValue will overwrite Val from the parsed Literal and return
1282   // true if it overflows as a 32-bit unsigned int
1283   bool Overflowed = Literal.GetIntegerValue(Val);
1284 
1285   // So we then need to check that it doesn't overflow as a 32-bit signed int:
1286   int64_t MaxNegativeMagnitude = -int64_t(std::numeric_limits<int32_t>::min());
1287   Overflowed |= (Negated && MaxNegativeMagnitude < Val.getExtValue());
1288 
1289   int64_t MaxPositiveMagnitude = int64_t(std::numeric_limits<int32_t>::max());
1290   Overflowed |= (!Negated && MaxPositiveMagnitude < Val.getExtValue());
1291 
1292   if (Overflowed) {
1293     // Report that the value has overflowed
1294     reportDiag(diag::err_hlsl_number_literal_overflow)
1295         << /*integer type*/ 0 << /*is signed*/ 1;
1296     return std::nullopt;
1297   }
1298 
1299   if (Negated)
1300     Val = -Val;
1301 
1302   return int32_t(Val.getExtValue());
1303 }
1304 
handleFloatLiteral(bool Negated)1305 std::optional<float> RootSignatureParser::handleFloatLiteral(bool Negated) {
1306   // Parse the numeric value and do semantic checks on its specification
1307   clang::NumericLiteralParser Literal(
1308       CurToken.NumSpelling, getTokenLocation(CurToken), PP.getSourceManager(),
1309       PP.getLangOpts(), PP.getTargetInfo(), PP.getDiagnostics());
1310   if (Literal.hadError)
1311     return std::nullopt; // Error has already been reported so just return
1312 
1313   assert(Literal.isFloatingLiteral() &&
1314          "NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
1315          "will be caught and reported by NumericLiteralParser.");
1316 
1317   // DXC used `strtod` to convert the token string to a float which corresponds
1318   // to:
1319   auto DXCSemantics = llvm::APFloat::Semantics::S_IEEEdouble;
1320   auto DXCRoundingMode = llvm::RoundingMode::NearestTiesToEven;
1321 
1322   llvm::APFloat Val(llvm::APFloat::EnumToSemantics(DXCSemantics));
1323   llvm::APFloat::opStatus Status(Literal.GetFloatValue(Val, DXCRoundingMode));
1324 
1325   // Note: we do not error when opStatus::opInexact by itself as this just
1326   // denotes that rounding occured but not that it is invalid
1327   assert(!(Status & llvm::APFloat::opStatus::opInvalidOp) &&
1328          "NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
1329          "will be caught and reported by NumericLiteralParser.");
1330 
1331   assert(!(Status & llvm::APFloat::opStatus::opDivByZero) &&
1332          "It is not possible for a division to be performed when "
1333          "constructing an APFloat from a string");
1334 
1335   if (Status & llvm::APFloat::opStatus::opUnderflow) {
1336     // Report that the value has underflowed
1337     reportDiag(diag::err_hlsl_number_literal_underflow);
1338     return std::nullopt;
1339   }
1340 
1341   if (Status & llvm::APFloat::opStatus::opOverflow) {
1342     // Report that the value has overflowed
1343     reportDiag(diag::err_hlsl_number_literal_overflow) << /*float type*/ 1;
1344     return std::nullopt;
1345   }
1346 
1347   if (Negated)
1348     Val = -Val;
1349 
1350   double DoubleVal = Val.convertToDouble();
1351   double FloatMax = double(std::numeric_limits<float>::max());
1352   if (FloatMax < DoubleVal || DoubleVal < -FloatMax) {
1353     // Report that the value has overflowed
1354     reportDiag(diag::err_hlsl_number_literal_overflow) << /*float type*/ 1;
1355     return std::nullopt;
1356   }
1357 
1358   return static_cast<float>(DoubleVal);
1359 }
1360 
verifyZeroFlag()1361 bool RootSignatureParser::verifyZeroFlag() {
1362   assert(CurToken.TokKind == TokenKind::int_literal);
1363   auto X = handleUIntLiteral();
1364   return X.has_value() && X.value() == 0;
1365 }
1366 
peekExpectedToken(TokenKind Expected)1367 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
1368   return peekExpectedToken(ArrayRef{Expected});
1369 }
1370 
peekExpectedToken(ArrayRef<TokenKind> AnyExpected)1371 bool RootSignatureParser::peekExpectedToken(ArrayRef<TokenKind> AnyExpected) {
1372   RootSignatureToken Result = Lexer.peekNextToken();
1373   return llvm::is_contained(AnyExpected, Result.TokKind);
1374 }
1375 
consumeExpectedToken(TokenKind Expected,unsigned DiagID,TokenKind Context)1376 bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
1377                                                unsigned DiagID,
1378                                                TokenKind Context) {
1379   if (tryConsumeExpectedToken(Expected))
1380     return false;
1381 
1382   // Report unexpected token kind error
1383   DiagnosticBuilder DB = reportDiag(DiagID);
1384   switch (DiagID) {
1385   case diag::err_expected:
1386     DB << Expected;
1387     break;
1388   case diag::err_expected_either:
1389     DB << Expected << Context;
1390     break;
1391   case diag::err_expected_after:
1392     DB << Context << Expected;
1393     break;
1394   default:
1395     break;
1396   }
1397   return true;
1398 }
1399 
tryConsumeExpectedToken(TokenKind Expected)1400 bool RootSignatureParser::tryConsumeExpectedToken(TokenKind Expected) {
1401   return tryConsumeExpectedToken(ArrayRef{Expected});
1402 }
1403 
tryConsumeExpectedToken(ArrayRef<TokenKind> AnyExpected)1404 bool RootSignatureParser::tryConsumeExpectedToken(
1405     ArrayRef<TokenKind> AnyExpected) {
1406   // If not the expected token just return
1407   if (!peekExpectedToken(AnyExpected))
1408     return false;
1409   consumeNextToken();
1410   return true;
1411 }
1412 
skipUntilExpectedToken(TokenKind Expected)1413 bool RootSignatureParser::skipUntilExpectedToken(TokenKind Expected) {
1414   return skipUntilExpectedToken(ArrayRef{Expected});
1415 }
1416 
skipUntilExpectedToken(ArrayRef<TokenKind> AnyExpected)1417 bool RootSignatureParser::skipUntilExpectedToken(
1418     ArrayRef<TokenKind> AnyExpected) {
1419 
1420   while (!peekExpectedToken(AnyExpected)) {
1421     if (peekExpectedToken(TokenKind::end_of_stream))
1422       return false;
1423     consumeNextToken();
1424   }
1425 
1426   return true;
1427 }
1428 
skipUntilClosedParens(uint32_t NumParens)1429 bool RootSignatureParser::skipUntilClosedParens(uint32_t NumParens) {
1430   TokenKind ParenKinds[] = {
1431       TokenKind::pu_l_paren,
1432       TokenKind::pu_r_paren,
1433   };
1434   while (skipUntilExpectedToken(ParenKinds)) {
1435     consumeNextToken();
1436     if (CurToken.TokKind == TokenKind::pu_r_paren)
1437       NumParens--;
1438     else
1439       NumParens++;
1440     if (NumParens == 0)
1441       return true;
1442   }
1443 
1444   return false;
1445 }
1446 
getTokenLocation(RootSignatureToken Tok)1447 SourceLocation RootSignatureParser::getTokenLocation(RootSignatureToken Tok) {
1448   return Signature->getLocationOfByte(Tok.LocOffset, PP.getSourceManager(),
1449                                       PP.getLangOpts(), PP.getTargetInfo());
1450 }
1451 
1452 } // namespace hlsl
1453 } // namespace clang
1454