xref: /freebsd/contrib/llvm-project/clang/lib/Driver/ToolChains/HLSL.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===--- HLSL.cpp - HLSL ToolChain Implementations --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "HLSL.h"
10 #include "clang/Driver/CommonArgs.h"
11 #include "clang/Driver/Compilation.h"
12 #include "clang/Driver/Job.h"
13 #include "llvm/ADT/StringSwitch.h"
14 #include "llvm/TargetParser/Triple.h"
15 #include <regex>
16 
17 using namespace clang::driver;
18 using namespace clang::driver::tools;
19 using namespace clang::driver::toolchains;
20 using namespace clang;
21 using namespace llvm::opt;
22 using namespace llvm;
23 
24 namespace {
25 
26 const unsigned OfflineLibMinor = 0xF;
27 
isLegalShaderModel(Triple & T)28 bool isLegalShaderModel(Triple &T) {
29   if (T.getOS() != Triple::OSType::ShaderModel)
30     return false;
31 
32   auto Version = T.getOSVersion();
33   if (Version.getBuild())
34     return false;
35   if (Version.getSubminor())
36     return false;
37 
38   auto Kind = T.getEnvironment();
39 
40   switch (Kind) {
41   default:
42     return false;
43   case Triple::EnvironmentType::Vertex:
44   case Triple::EnvironmentType::Hull:
45   case Triple::EnvironmentType::Domain:
46   case Triple::EnvironmentType::Geometry:
47   case Triple::EnvironmentType::Pixel:
48   case Triple::EnvironmentType::Compute: {
49     VersionTuple MinVer(4, 0);
50     return MinVer <= Version;
51   } break;
52   case Triple::EnvironmentType::Library: {
53     VersionTuple SM6x(6, OfflineLibMinor);
54     if (Version == SM6x)
55       return true;
56 
57     VersionTuple MinVer(6, 3);
58     return MinVer <= Version;
59   } break;
60   case Triple::EnvironmentType::Amplification:
61   case Triple::EnvironmentType::Mesh: {
62     VersionTuple MinVer(6, 5);
63     return MinVer <= Version;
64   } break;
65   }
66   return false;
67 }
68 
tryParseProfile(StringRef Profile)69 std::optional<std::string> tryParseProfile(StringRef Profile) {
70   // [ps|vs|gs|hs|ds|cs|ms|as]_[major]_[minor]
71   SmallVector<StringRef, 3> Parts;
72   Profile.split(Parts, "_");
73   if (Parts.size() != 3)
74     return std::nullopt;
75 
76   Triple::EnvironmentType Kind =
77       StringSwitch<Triple::EnvironmentType>(Parts[0])
78           .Case("ps", Triple::EnvironmentType::Pixel)
79           .Case("vs", Triple::EnvironmentType::Vertex)
80           .Case("gs", Triple::EnvironmentType::Geometry)
81           .Case("hs", Triple::EnvironmentType::Hull)
82           .Case("ds", Triple::EnvironmentType::Domain)
83           .Case("cs", Triple::EnvironmentType::Compute)
84           .Case("lib", Triple::EnvironmentType::Library)
85           .Case("ms", Triple::EnvironmentType::Mesh)
86           .Case("as", Triple::EnvironmentType::Amplification)
87           .Default(Triple::EnvironmentType::UnknownEnvironment);
88   if (Kind == Triple::EnvironmentType::UnknownEnvironment)
89     return std::nullopt;
90 
91   unsigned long long Major = 0;
92   if (llvm::getAsUnsignedInteger(Parts[1], 0, Major))
93     return std::nullopt;
94 
95   unsigned long long Minor = 0;
96   if (Parts[2] == "x" && Kind == Triple::EnvironmentType::Library)
97     Minor = OfflineLibMinor;
98   else if (llvm::getAsUnsignedInteger(Parts[2], 0, Minor))
99     return std::nullopt;
100 
101   // Determine DXIL version using the minor version number of Shader
102   // Model version specified in target profile. Prior to decoupling DXIL version
103   // numbering from that of Shader Model DXIL version 1.Y corresponds to SM 6.Y.
104   // E.g., dxilv1.Y-unknown-shadermodelX.Y-hull
105   llvm::Triple T;
106   Triple::SubArchType SubArch = llvm::Triple::NoSubArch;
107   switch (Minor) {
108   case 0:
109     SubArch = llvm::Triple::DXILSubArch_v1_0;
110     break;
111   case 1:
112     SubArch = llvm::Triple::DXILSubArch_v1_1;
113     break;
114   case 2:
115     SubArch = llvm::Triple::DXILSubArch_v1_2;
116     break;
117   case 3:
118     SubArch = llvm::Triple::DXILSubArch_v1_3;
119     break;
120   case 4:
121     SubArch = llvm::Triple::DXILSubArch_v1_4;
122     break;
123   case 5:
124     SubArch = llvm::Triple::DXILSubArch_v1_5;
125     break;
126   case 6:
127     SubArch = llvm::Triple::DXILSubArch_v1_6;
128     break;
129   case 7:
130     SubArch = llvm::Triple::DXILSubArch_v1_7;
131     break;
132   case 8:
133     SubArch = llvm::Triple::DXILSubArch_v1_8;
134     break;
135   case OfflineLibMinor:
136     // Always consider minor version x as the latest supported DXIL version
137     SubArch = llvm::Triple::LatestDXILSubArch;
138     break;
139   default:
140     // No DXIL Version corresponding to specified Shader Model version found
141     return std::nullopt;
142   }
143   T.setArch(Triple::ArchType::dxil, SubArch);
144   T.setOSName(Triple::getOSTypeName(Triple::OSType::ShaderModel).str() +
145               VersionTuple(Major, Minor).getAsString());
146   T.setEnvironment(Kind);
147   if (isLegalShaderModel(T))
148     return T.getTriple();
149   else
150     return std::nullopt;
151 }
152 
isLegalValidatorVersion(StringRef ValVersionStr,const Driver & D)153 bool isLegalValidatorVersion(StringRef ValVersionStr, const Driver &D) {
154   VersionTuple Version;
155   if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
156       Version.getSubminor() || !Version.getMinor()) {
157     D.Diag(diag::err_drv_invalid_format_dxil_validator_version)
158         << ValVersionStr;
159     return false;
160   }
161 
162   uint64_t Major = Version.getMajor();
163   uint64_t Minor = *Version.getMinor();
164   if (Major == 0 && Minor != 0) {
165     D.Diag(diag::err_drv_invalid_empty_dxil_validator_version) << ValVersionStr;
166     return false;
167   }
168   VersionTuple MinVer(1, 0);
169   if (Version < MinVer) {
170     D.Diag(diag::err_drv_invalid_range_dxil_validator_version) << ValVersionStr;
171     return false;
172   }
173   return true;
174 }
175 
getSpirvExtArg(ArrayRef<std::string> SpvExtensionArgs)176 std::string getSpirvExtArg(ArrayRef<std::string> SpvExtensionArgs) {
177   if (SpvExtensionArgs.empty()) {
178     return "-spirv-ext=all";
179   }
180 
181   std::string LlvmOption =
182       (Twine("-spirv-ext=+") + SpvExtensionArgs.front()).str();
183   SpvExtensionArgs = SpvExtensionArgs.slice(1);
184   for (auto Extension : SpvExtensionArgs) {
185     if (Extension != "KHR")
186       Extension = (Twine("+") + Extension).str();
187     LlvmOption = (Twine(LlvmOption) + "," + Extension).str();
188   }
189   return LlvmOption;
190 }
191 
isValidSPIRVExtensionName(const std::string & str)192 bool isValidSPIRVExtensionName(const std::string &str) {
193   std::regex pattern("KHR|SPV_[a-zA-Z0-9_]+");
194   return std::regex_match(str, pattern);
195 }
196 
197 // SPIRV extension names are of the form `SPV_[a-zA-Z0-9_]+`. We want to
198 // disallow obviously invalid names to avoid issues when parsing `spirv-ext`.
checkExtensionArgsAreValid(ArrayRef<std::string> SpvExtensionArgs,const Driver & Driver)199 bool checkExtensionArgsAreValid(ArrayRef<std::string> SpvExtensionArgs,
200                                 const Driver &Driver) {
201   bool AllValid = true;
202   for (auto Extension : SpvExtensionArgs) {
203     if (!isValidSPIRVExtensionName(Extension)) {
204       Driver.Diag(diag::err_drv_invalid_value)
205           << "-fspv-extension" << Extension;
206       AllValid = false;
207     }
208   }
209   return AllValid;
210 }
211 } // namespace
212 
ConstructJob(Compilation & C,const JobAction & JA,const InputInfo & Output,const InputInfoList & Inputs,const ArgList & Args,const char * LinkingOutput) const213 void tools::hlsl::Validator::ConstructJob(Compilation &C, const JobAction &JA,
214                                           const InputInfo &Output,
215                                           const InputInfoList &Inputs,
216                                           const ArgList &Args,
217                                           const char *LinkingOutput) const {
218   std::string DxvPath = getToolChain().GetProgramPath("dxv");
219   assert(DxvPath != "dxv" && "cannot find dxv");
220 
221   ArgStringList CmdArgs;
222   assert(Inputs.size() == 1 && "Unable to handle multiple inputs.");
223   const InputInfo &Input = Inputs[0];
224   CmdArgs.push_back(Input.getFilename());
225   CmdArgs.push_back("-o");
226   CmdArgs.push_back(Output.getFilename());
227 
228   const char *Exec = Args.MakeArgString(DxvPath);
229   C.addCommand(std::make_unique<Command>(JA, *this, ResponseFileSupport::None(),
230                                          Exec, CmdArgs, Inputs, Input));
231 }
232 
ConstructJob(Compilation & C,const JobAction & JA,const InputInfo & Output,const InputInfoList & Inputs,const ArgList & Args,const char * LinkingOutput) const233 void tools::hlsl::MetalConverter::ConstructJob(
234     Compilation &C, const JobAction &JA, const InputInfo &Output,
235     const InputInfoList &Inputs, const ArgList &Args,
236     const char *LinkingOutput) const {
237   std::string MSCPath = getToolChain().GetProgramPath("metal-shaderconverter");
238   ArgStringList CmdArgs;
239   assert(Inputs.size() == 1 && "Unable to handle multiple inputs.");
240   const InputInfo &Input = Inputs[0];
241   CmdArgs.push_back(Input.getFilename());
242   CmdArgs.push_back("-o");
243   CmdArgs.push_back(Output.getFilename());
244 
245   const char *Exec = Args.MakeArgString(MSCPath);
246   C.addCommand(std::make_unique<Command>(JA, *this, ResponseFileSupport::None(),
247                                          Exec, CmdArgs, Inputs, Input));
248 }
249 
250 /// DirectX Toolchain
HLSLToolChain(const Driver & D,const llvm::Triple & Triple,const ArgList & Args)251 HLSLToolChain::HLSLToolChain(const Driver &D, const llvm::Triple &Triple,
252                              const ArgList &Args)
253     : ToolChain(D, Triple, Args) {
254   if (Args.hasArg(options::OPT_dxc_validator_path_EQ))
255     getProgramPaths().push_back(
256         Args.getLastArgValue(options::OPT_dxc_validator_path_EQ).str());
257 }
258 
getTool(Action::ActionClass AC) const259 Tool *clang::driver::toolchains::HLSLToolChain::getTool(
260     Action::ActionClass AC) const {
261   switch (AC) {
262   case Action::BinaryAnalyzeJobClass:
263     if (!Validator)
264       Validator.reset(new tools::hlsl::Validator(*this));
265     return Validator.get();
266   case Action::BinaryTranslatorJobClass:
267     if (!MetalConverter)
268       MetalConverter.reset(new tools::hlsl::MetalConverter(*this));
269     return MetalConverter.get();
270   default:
271     return ToolChain::getTool(AC);
272   }
273 }
274 
275 std::optional<std::string>
parseTargetProfile(StringRef TargetProfile)276 clang::driver::toolchains::HLSLToolChain::parseTargetProfile(
277     StringRef TargetProfile) {
278   return tryParseProfile(TargetProfile);
279 }
280 
281 DerivedArgList *
TranslateArgs(const DerivedArgList & Args,StringRef BoundArch,Action::OffloadKind DeviceOffloadKind) const282 HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,
283                              Action::OffloadKind DeviceOffloadKind) const {
284   DerivedArgList *DAL = new DerivedArgList(Args.getBaseArgs());
285 
286   const OptTable &Opts = getDriver().getOpts();
287 
288   for (Arg *A : Args) {
289     if (A->getOption().getID() == options::OPT_dxil_validator_version) {
290       StringRef ValVerStr = A->getValue();
291       if (!isLegalValidatorVersion(ValVerStr, getDriver()))
292         continue;
293     }
294     if (A->getOption().getID() == options::OPT_dxc_entrypoint) {
295       DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_hlsl_entrypoint),
296                           A->getValue());
297       A->claim();
298       continue;
299     }
300     if (A->getOption().getID() == options::OPT_dxc_rootsig_ver) {
301       DAL->AddJoinedArg(nullptr,
302                         Opts.getOption(options::OPT_fdx_rootsignature_version),
303                         A->getValue());
304       A->claim();
305       continue;
306     }
307     if (A->getOption().getID() == options::OPT__SLASH_O) {
308       StringRef OStr = A->getValue();
309       if (OStr == "d") {
310         DAL->AddFlagArg(nullptr, Opts.getOption(options::OPT_O0));
311         A->claim();
312         continue;
313       } else {
314         DAL->AddJoinedArg(nullptr, Opts.getOption(options::OPT_O), OStr);
315         A->claim();
316         continue;
317       }
318     }
319     if (A->getOption().getID() == options::OPT_emit_pristine_llvm) {
320       // Translate -fcgl into -emit-llvm and -disable-llvm-passes.
321       DAL->AddFlagArg(nullptr, Opts.getOption(options::OPT_emit_llvm));
322       DAL->AddFlagArg(nullptr,
323                       Opts.getOption(options::OPT_disable_llvm_passes));
324       A->claim();
325       continue;
326     }
327     if (A->getOption().getID() == options::OPT_dxc_hlsl_version) {
328       // Translate -HV into -std for llvm
329       // depending on the value given
330       LangStandard::Kind LangStd = LangStandard::getHLSLLangKind(A->getValue());
331       if (LangStd != LangStandard::lang_unspecified) {
332         LangStandard l = LangStandard::getLangStandardForKind(LangStd);
333         DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_std_EQ),
334                             l.getName());
335       } else {
336         getDriver().Diag(diag::err_drv_invalid_value) << "HV" << A->getValue();
337       }
338 
339       A->claim();
340       continue;
341     }
342     if (A->getOption().getID() == options::OPT_dxc_gis) {
343       // Translate -Gis into -ffp_model_EQ=strict
344       DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_ffp_model_EQ),
345                           "strict");
346       A->claim();
347       continue;
348     }
349     if (A->getOption().getID() == options::OPT_fvk_use_dx_layout) {
350       // This is the only implemented layout so far.
351       A->claim();
352       continue;
353     }
354 
355     if (A->getOption().getID() == options::OPT_fvk_use_scalar_layout) {
356       getDriver().Diag(diag::err_drv_clang_unsupported) << A->getAsString(Args);
357       A->claim();
358       continue;
359     }
360 
361     if (A->getOption().getID() == options::OPT_fvk_use_gl_layout) {
362       getDriver().Diag(diag::err_drv_clang_unsupported) << A->getAsString(Args);
363       A->claim();
364       continue;
365     }
366 
367     DAL->append(A);
368   }
369 
370   if (getArch() == llvm::Triple::spirv) {
371     std::vector<std::string> SpvExtensionArgs =
372         Args.getAllArgValues(options::OPT_fspv_extension_EQ);
373     if (checkExtensionArgsAreValid(SpvExtensionArgs, getDriver())) {
374       std::string LlvmOption = getSpirvExtArg(SpvExtensionArgs);
375       DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_mllvm),
376                           LlvmOption);
377     }
378     Args.claimAllArgs(options::OPT_fspv_extension_EQ);
379   }
380 
381   if (!DAL->hasArg(options::OPT_O_Group)) {
382     DAL->AddJoinedArg(nullptr, Opts.getOption(options::OPT_O), "3");
383   }
384 
385   return DAL;
386 }
387 
requiresValidation(DerivedArgList & Args) const388 bool HLSLToolChain::requiresValidation(DerivedArgList &Args) const {
389   if (!Args.hasArg(options::OPT_dxc_Fo))
390     return false;
391 
392   if (Args.getLastArg(options::OPT_dxc_disable_validation))
393     return false;
394 
395   std::string DxvPath = GetProgramPath("dxv");
396   if (DxvPath != "dxv")
397     return true;
398 
399   getDriver().Diag(diag::warn_drv_dxc_missing_dxv);
400   return false;
401 }
402 
requiresBinaryTranslation(DerivedArgList & Args) const403 bool HLSLToolChain::requiresBinaryTranslation(DerivedArgList &Args) const {
404   return Args.hasArg(options::OPT_metal) && Args.hasArg(options::OPT_dxc_Fo);
405 }
406 
isLastJob(DerivedArgList & Args,Action::ActionClass AC) const407 bool HLSLToolChain::isLastJob(DerivedArgList &Args,
408                               Action::ActionClass AC) const {
409   bool HasTranslation = requiresBinaryTranslation(Args);
410   bool HasValidation = requiresValidation(Args);
411   // If translation and validation are not required, we should only have one
412   // action.
413   if (!HasTranslation && !HasValidation)
414     return true;
415   if ((HasTranslation && AC == Action::BinaryTranslatorJobClass) ||
416       (!HasTranslation && HasValidation && AC == Action::BinaryAnalyzeJobClass))
417     return true;
418   return false;
419 }
420