1 //===--- HLSL.cpp - HLSL ToolChain Implementations --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "HLSL.h"
10 #include "clang/Driver/CommonArgs.h"
11 #include "clang/Driver/Compilation.h"
12 #include "clang/Driver/Job.h"
13 #include "llvm/ADT/StringSwitch.h"
14 #include "llvm/TargetParser/Triple.h"
15 #include <regex>
16
17 using namespace clang::driver;
18 using namespace clang::driver::tools;
19 using namespace clang::driver::toolchains;
20 using namespace clang;
21 using namespace llvm::opt;
22 using namespace llvm;
23
24 namespace {
25
26 const unsigned OfflineLibMinor = 0xF;
27
isLegalShaderModel(Triple & T)28 bool isLegalShaderModel(Triple &T) {
29 if (T.getOS() != Triple::OSType::ShaderModel)
30 return false;
31
32 auto Version = T.getOSVersion();
33 if (Version.getBuild())
34 return false;
35 if (Version.getSubminor())
36 return false;
37
38 auto Kind = T.getEnvironment();
39
40 switch (Kind) {
41 default:
42 return false;
43 case Triple::EnvironmentType::Vertex:
44 case Triple::EnvironmentType::Hull:
45 case Triple::EnvironmentType::Domain:
46 case Triple::EnvironmentType::Geometry:
47 case Triple::EnvironmentType::Pixel:
48 case Triple::EnvironmentType::Compute: {
49 VersionTuple MinVer(4, 0);
50 return MinVer <= Version;
51 } break;
52 case Triple::EnvironmentType::Library: {
53 VersionTuple SM6x(6, OfflineLibMinor);
54 if (Version == SM6x)
55 return true;
56
57 VersionTuple MinVer(6, 3);
58 return MinVer <= Version;
59 } break;
60 case Triple::EnvironmentType::Amplification:
61 case Triple::EnvironmentType::Mesh: {
62 VersionTuple MinVer(6, 5);
63 return MinVer <= Version;
64 } break;
65 }
66 return false;
67 }
68
tryParseProfile(StringRef Profile)69 std::optional<std::string> tryParseProfile(StringRef Profile) {
70 // [ps|vs|gs|hs|ds|cs|ms|as]_[major]_[minor]
71 SmallVector<StringRef, 3> Parts;
72 Profile.split(Parts, "_");
73 if (Parts.size() != 3)
74 return std::nullopt;
75
76 Triple::EnvironmentType Kind =
77 StringSwitch<Triple::EnvironmentType>(Parts[0])
78 .Case("ps", Triple::EnvironmentType::Pixel)
79 .Case("vs", Triple::EnvironmentType::Vertex)
80 .Case("gs", Triple::EnvironmentType::Geometry)
81 .Case("hs", Triple::EnvironmentType::Hull)
82 .Case("ds", Triple::EnvironmentType::Domain)
83 .Case("cs", Triple::EnvironmentType::Compute)
84 .Case("lib", Triple::EnvironmentType::Library)
85 .Case("ms", Triple::EnvironmentType::Mesh)
86 .Case("as", Triple::EnvironmentType::Amplification)
87 .Default(Triple::EnvironmentType::UnknownEnvironment);
88 if (Kind == Triple::EnvironmentType::UnknownEnvironment)
89 return std::nullopt;
90
91 unsigned long long Major = 0;
92 if (llvm::getAsUnsignedInteger(Parts[1], 0, Major))
93 return std::nullopt;
94
95 unsigned long long Minor = 0;
96 if (Parts[2] == "x" && Kind == Triple::EnvironmentType::Library)
97 Minor = OfflineLibMinor;
98 else if (llvm::getAsUnsignedInteger(Parts[2], 0, Minor))
99 return std::nullopt;
100
101 // Determine DXIL version using the minor version number of Shader
102 // Model version specified in target profile. Prior to decoupling DXIL version
103 // numbering from that of Shader Model DXIL version 1.Y corresponds to SM 6.Y.
104 // E.g., dxilv1.Y-unknown-shadermodelX.Y-hull
105 llvm::Triple T;
106 Triple::SubArchType SubArch = llvm::Triple::NoSubArch;
107 switch (Minor) {
108 case 0:
109 SubArch = llvm::Triple::DXILSubArch_v1_0;
110 break;
111 case 1:
112 SubArch = llvm::Triple::DXILSubArch_v1_1;
113 break;
114 case 2:
115 SubArch = llvm::Triple::DXILSubArch_v1_2;
116 break;
117 case 3:
118 SubArch = llvm::Triple::DXILSubArch_v1_3;
119 break;
120 case 4:
121 SubArch = llvm::Triple::DXILSubArch_v1_4;
122 break;
123 case 5:
124 SubArch = llvm::Triple::DXILSubArch_v1_5;
125 break;
126 case 6:
127 SubArch = llvm::Triple::DXILSubArch_v1_6;
128 break;
129 case 7:
130 SubArch = llvm::Triple::DXILSubArch_v1_7;
131 break;
132 case 8:
133 SubArch = llvm::Triple::DXILSubArch_v1_8;
134 break;
135 case OfflineLibMinor:
136 // Always consider minor version x as the latest supported DXIL version
137 SubArch = llvm::Triple::LatestDXILSubArch;
138 break;
139 default:
140 // No DXIL Version corresponding to specified Shader Model version found
141 return std::nullopt;
142 }
143 T.setArch(Triple::ArchType::dxil, SubArch);
144 T.setOSName(Triple::getOSTypeName(Triple::OSType::ShaderModel).str() +
145 VersionTuple(Major, Minor).getAsString());
146 T.setEnvironment(Kind);
147 if (isLegalShaderModel(T))
148 return T.getTriple();
149 else
150 return std::nullopt;
151 }
152
isLegalValidatorVersion(StringRef ValVersionStr,const Driver & D)153 bool isLegalValidatorVersion(StringRef ValVersionStr, const Driver &D) {
154 VersionTuple Version;
155 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
156 Version.getSubminor() || !Version.getMinor()) {
157 D.Diag(diag::err_drv_invalid_format_dxil_validator_version)
158 << ValVersionStr;
159 return false;
160 }
161
162 uint64_t Major = Version.getMajor();
163 uint64_t Minor = *Version.getMinor();
164 if (Major == 0 && Minor != 0) {
165 D.Diag(diag::err_drv_invalid_empty_dxil_validator_version) << ValVersionStr;
166 return false;
167 }
168 VersionTuple MinVer(1, 0);
169 if (Version < MinVer) {
170 D.Diag(diag::err_drv_invalid_range_dxil_validator_version) << ValVersionStr;
171 return false;
172 }
173 return true;
174 }
175
getSpirvExtArg(ArrayRef<std::string> SpvExtensionArgs)176 std::string getSpirvExtArg(ArrayRef<std::string> SpvExtensionArgs) {
177 if (SpvExtensionArgs.empty()) {
178 return "-spirv-ext=all";
179 }
180
181 std::string LlvmOption =
182 (Twine("-spirv-ext=+") + SpvExtensionArgs.front()).str();
183 SpvExtensionArgs = SpvExtensionArgs.slice(1);
184 for (auto Extension : SpvExtensionArgs) {
185 if (Extension != "KHR")
186 Extension = (Twine("+") + Extension).str();
187 LlvmOption = (Twine(LlvmOption) + "," + Extension).str();
188 }
189 return LlvmOption;
190 }
191
isValidSPIRVExtensionName(const std::string & str)192 bool isValidSPIRVExtensionName(const std::string &str) {
193 std::regex pattern("KHR|SPV_[a-zA-Z0-9_]+");
194 return std::regex_match(str, pattern);
195 }
196
197 // SPIRV extension names are of the form `SPV_[a-zA-Z0-9_]+`. We want to
198 // disallow obviously invalid names to avoid issues when parsing `spirv-ext`.
checkExtensionArgsAreValid(ArrayRef<std::string> SpvExtensionArgs,const Driver & Driver)199 bool checkExtensionArgsAreValid(ArrayRef<std::string> SpvExtensionArgs,
200 const Driver &Driver) {
201 bool AllValid = true;
202 for (auto Extension : SpvExtensionArgs) {
203 if (!isValidSPIRVExtensionName(Extension)) {
204 Driver.Diag(diag::err_drv_invalid_value)
205 << "-fspv-extension" << Extension;
206 AllValid = false;
207 }
208 }
209 return AllValid;
210 }
211 } // namespace
212
ConstructJob(Compilation & C,const JobAction & JA,const InputInfo & Output,const InputInfoList & Inputs,const ArgList & Args,const char * LinkingOutput) const213 void tools::hlsl::Validator::ConstructJob(Compilation &C, const JobAction &JA,
214 const InputInfo &Output,
215 const InputInfoList &Inputs,
216 const ArgList &Args,
217 const char *LinkingOutput) const {
218 std::string DxvPath = getToolChain().GetProgramPath("dxv");
219 assert(DxvPath != "dxv" && "cannot find dxv");
220
221 ArgStringList CmdArgs;
222 assert(Inputs.size() == 1 && "Unable to handle multiple inputs.");
223 const InputInfo &Input = Inputs[0];
224 CmdArgs.push_back(Input.getFilename());
225 CmdArgs.push_back("-o");
226 CmdArgs.push_back(Output.getFilename());
227
228 const char *Exec = Args.MakeArgString(DxvPath);
229 C.addCommand(std::make_unique<Command>(JA, *this, ResponseFileSupport::None(),
230 Exec, CmdArgs, Inputs, Input));
231 }
232
ConstructJob(Compilation & C,const JobAction & JA,const InputInfo & Output,const InputInfoList & Inputs,const ArgList & Args,const char * LinkingOutput) const233 void tools::hlsl::MetalConverter::ConstructJob(
234 Compilation &C, const JobAction &JA, const InputInfo &Output,
235 const InputInfoList &Inputs, const ArgList &Args,
236 const char *LinkingOutput) const {
237 std::string MSCPath = getToolChain().GetProgramPath("metal-shaderconverter");
238 ArgStringList CmdArgs;
239 assert(Inputs.size() == 1 && "Unable to handle multiple inputs.");
240 const InputInfo &Input = Inputs[0];
241 CmdArgs.push_back(Input.getFilename());
242 CmdArgs.push_back("-o");
243 CmdArgs.push_back(Output.getFilename());
244
245 const char *Exec = Args.MakeArgString(MSCPath);
246 C.addCommand(std::make_unique<Command>(JA, *this, ResponseFileSupport::None(),
247 Exec, CmdArgs, Inputs, Input));
248 }
249
250 /// DirectX Toolchain
HLSLToolChain(const Driver & D,const llvm::Triple & Triple,const ArgList & Args)251 HLSLToolChain::HLSLToolChain(const Driver &D, const llvm::Triple &Triple,
252 const ArgList &Args)
253 : ToolChain(D, Triple, Args) {
254 if (Args.hasArg(options::OPT_dxc_validator_path_EQ))
255 getProgramPaths().push_back(
256 Args.getLastArgValue(options::OPT_dxc_validator_path_EQ).str());
257 }
258
getTool(Action::ActionClass AC) const259 Tool *clang::driver::toolchains::HLSLToolChain::getTool(
260 Action::ActionClass AC) const {
261 switch (AC) {
262 case Action::BinaryAnalyzeJobClass:
263 if (!Validator)
264 Validator.reset(new tools::hlsl::Validator(*this));
265 return Validator.get();
266 case Action::BinaryTranslatorJobClass:
267 if (!MetalConverter)
268 MetalConverter.reset(new tools::hlsl::MetalConverter(*this));
269 return MetalConverter.get();
270 default:
271 return ToolChain::getTool(AC);
272 }
273 }
274
275 std::optional<std::string>
parseTargetProfile(StringRef TargetProfile)276 clang::driver::toolchains::HLSLToolChain::parseTargetProfile(
277 StringRef TargetProfile) {
278 return tryParseProfile(TargetProfile);
279 }
280
281 DerivedArgList *
TranslateArgs(const DerivedArgList & Args,StringRef BoundArch,Action::OffloadKind DeviceOffloadKind) const282 HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,
283 Action::OffloadKind DeviceOffloadKind) const {
284 DerivedArgList *DAL = new DerivedArgList(Args.getBaseArgs());
285
286 const OptTable &Opts = getDriver().getOpts();
287
288 for (Arg *A : Args) {
289 if (A->getOption().getID() == options::OPT_dxil_validator_version) {
290 StringRef ValVerStr = A->getValue();
291 if (!isLegalValidatorVersion(ValVerStr, getDriver()))
292 continue;
293 }
294 if (A->getOption().getID() == options::OPT_dxc_entrypoint) {
295 DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_hlsl_entrypoint),
296 A->getValue());
297 A->claim();
298 continue;
299 }
300 if (A->getOption().getID() == options::OPT_dxc_rootsig_ver) {
301 DAL->AddJoinedArg(nullptr,
302 Opts.getOption(options::OPT_fdx_rootsignature_version),
303 A->getValue());
304 A->claim();
305 continue;
306 }
307 if (A->getOption().getID() == options::OPT__SLASH_O) {
308 StringRef OStr = A->getValue();
309 if (OStr == "d") {
310 DAL->AddFlagArg(nullptr, Opts.getOption(options::OPT_O0));
311 A->claim();
312 continue;
313 } else {
314 DAL->AddJoinedArg(nullptr, Opts.getOption(options::OPT_O), OStr);
315 A->claim();
316 continue;
317 }
318 }
319 if (A->getOption().getID() == options::OPT_emit_pristine_llvm) {
320 // Translate -fcgl into -emit-llvm and -disable-llvm-passes.
321 DAL->AddFlagArg(nullptr, Opts.getOption(options::OPT_emit_llvm));
322 DAL->AddFlagArg(nullptr,
323 Opts.getOption(options::OPT_disable_llvm_passes));
324 A->claim();
325 continue;
326 }
327 if (A->getOption().getID() == options::OPT_dxc_hlsl_version) {
328 // Translate -HV into -std for llvm
329 // depending on the value given
330 LangStandard::Kind LangStd = LangStandard::getHLSLLangKind(A->getValue());
331 if (LangStd != LangStandard::lang_unspecified) {
332 LangStandard l = LangStandard::getLangStandardForKind(LangStd);
333 DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_std_EQ),
334 l.getName());
335 } else {
336 getDriver().Diag(diag::err_drv_invalid_value) << "HV" << A->getValue();
337 }
338
339 A->claim();
340 continue;
341 }
342 if (A->getOption().getID() == options::OPT_dxc_gis) {
343 // Translate -Gis into -ffp_model_EQ=strict
344 DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_ffp_model_EQ),
345 "strict");
346 A->claim();
347 continue;
348 }
349 if (A->getOption().getID() == options::OPT_fvk_use_dx_layout) {
350 // This is the only implemented layout so far.
351 A->claim();
352 continue;
353 }
354
355 if (A->getOption().getID() == options::OPT_fvk_use_scalar_layout) {
356 getDriver().Diag(diag::err_drv_clang_unsupported) << A->getAsString(Args);
357 A->claim();
358 continue;
359 }
360
361 if (A->getOption().getID() == options::OPT_fvk_use_gl_layout) {
362 getDriver().Diag(diag::err_drv_clang_unsupported) << A->getAsString(Args);
363 A->claim();
364 continue;
365 }
366
367 DAL->append(A);
368 }
369
370 if (getArch() == llvm::Triple::spirv) {
371 std::vector<std::string> SpvExtensionArgs =
372 Args.getAllArgValues(options::OPT_fspv_extension_EQ);
373 if (checkExtensionArgsAreValid(SpvExtensionArgs, getDriver())) {
374 std::string LlvmOption = getSpirvExtArg(SpvExtensionArgs);
375 DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_mllvm),
376 LlvmOption);
377 }
378 Args.claimAllArgs(options::OPT_fspv_extension_EQ);
379 }
380
381 if (!DAL->hasArg(options::OPT_O_Group)) {
382 DAL->AddJoinedArg(nullptr, Opts.getOption(options::OPT_O), "3");
383 }
384
385 return DAL;
386 }
387
requiresValidation(DerivedArgList & Args) const388 bool HLSLToolChain::requiresValidation(DerivedArgList &Args) const {
389 if (!Args.hasArg(options::OPT_dxc_Fo))
390 return false;
391
392 if (Args.getLastArg(options::OPT_dxc_disable_validation))
393 return false;
394
395 std::string DxvPath = GetProgramPath("dxv");
396 if (DxvPath != "dxv")
397 return true;
398
399 getDriver().Diag(diag::warn_drv_dxc_missing_dxv);
400 return false;
401 }
402
requiresBinaryTranslation(DerivedArgList & Args) const403 bool HLSLToolChain::requiresBinaryTranslation(DerivedArgList &Args) const {
404 return Args.hasArg(options::OPT_metal) && Args.hasArg(options::OPT_dxc_Fo);
405 }
406
isLastJob(DerivedArgList & Args,Action::ActionClass AC) const407 bool HLSLToolChain::isLastJob(DerivedArgList &Args,
408 Action::ActionClass AC) const {
409 bool HasTranslation = requiresBinaryTranslation(Args);
410 bool HasValidation = requiresValidation(Args);
411 // If translation and validation are not required, we should only have one
412 // action.
413 if (!HasTranslation && !HasValidation)
414 return true;
415 if ((HasTranslation && AC == Action::BinaryTranslatorJobClass) ||
416 (!HasTranslation && HasValidation && AC == Action::BinaryAnalyzeJobClass))
417 return true;
418 return false;
419 }
420