xref: /freebsd/contrib/llvm-project/clang/lib/Driver/ToolChains/ROCm.h (revision 6132212808e8dccedc9e5d85fea4390c2f38059a)
1 //===--- ROCm.h - ROCm installation detector --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H
10 #define LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H
11 
12 #include "clang/Basic/Cuda.h"
13 #include "clang/Basic/LLVM.h"
14 #include "clang/Driver/Driver.h"
15 #include "clang/Driver/Options.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/StringMap.h"
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/Option/ArgList.h"
21 #include "llvm/Support/VersionTuple.h"
22 
23 namespace clang {
24 namespace driver {
25 
26 /// A class to find a viable ROCM installation
27 /// TODO: Generalize to handle libclc.
28 class RocmInstallationDetector {
29 private:
30   struct ConditionalLibrary {
31     SmallString<0> On;
32     SmallString<0> Off;
33 
34     bool isValid() const { return !On.empty() && !Off.empty(); }
35 
36     StringRef get(bool Enabled) const {
37       assert(isValid());
38       return Enabled ? On : Off;
39     }
40   };
41 
42   // Installation path candidate.
43   struct Candidate {
44     llvm::SmallString<0> Path;
45     bool StrictChecking;
46 
47     Candidate(std::string Path, bool StrictChecking = false)
48         : Path(Path), StrictChecking(StrictChecking) {}
49   };
50 
51   const Driver &D;
52   bool HasHIPRuntime = false;
53   bool HasDeviceLibrary = false;
54 
55   // Default version if not detected or specified.
56   const unsigned DefaultVersionMajor = 3;
57   const unsigned DefaultVersionMinor = 5;
58   const char *DefaultVersionPatch = "0";
59 
60   // The version string in Major.Minor.Patch format.
61   std::string DetectedVersion;
62   // Version containing major and minor.
63   llvm::VersionTuple VersionMajorMinor;
64   // Version containing patch.
65   std::string VersionPatch;
66 
67   // ROCm path specified by --rocm-path.
68   StringRef RocmPathArg;
69   // ROCm device library paths specified by --rocm-device-lib-path.
70   std::vector<std::string> RocmDeviceLibPathArg;
71   // HIP version specified by --hip-version.
72   StringRef HIPVersionArg;
73   // Wheter -nogpulib is specified.
74   bool NoBuiltinLibs = false;
75 
76   // Paths
77   SmallString<0> InstallPath;
78   SmallString<0> BinPath;
79   SmallString<0> LibPath;
80   SmallString<0> LibDevicePath;
81   SmallString<0> IncludePath;
82   llvm::StringMap<std::string> LibDeviceMap;
83 
84   // Libraries that are always linked.
85   SmallString<0> OCML;
86   SmallString<0> OCKL;
87 
88   // Libraries that are always linked depending on the language
89   SmallString<0> OpenCL;
90   SmallString<0> HIP;
91 
92   // Libraries swapped based on compile flags.
93   ConditionalLibrary WavefrontSize64;
94   ConditionalLibrary FiniteOnly;
95   ConditionalLibrary UnsafeMath;
96   ConditionalLibrary DenormalsAreZero;
97   ConditionalLibrary CorrectlyRoundedSqrt;
98 
99   bool allGenericLibsValid() const {
100     return !OCML.empty() && !OCKL.empty() && !OpenCL.empty() && !HIP.empty() &&
101            WavefrontSize64.isValid() && FiniteOnly.isValid() &&
102            UnsafeMath.isValid() && DenormalsAreZero.isValid() &&
103            CorrectlyRoundedSqrt.isValid();
104   }
105 
106   // GPU architectures for which we have raised an error in
107   // CheckRocmVersionSupportsArch.
108   mutable llvm::SmallSet<CudaArch, 4> ArchsWithBadVersion;
109 
110   void scanLibDevicePath(llvm::StringRef Path);
111   void ParseHIPVersionFile(llvm::StringRef V);
112   SmallVector<Candidate, 4> getInstallationPathCandidates();
113 
114 public:
115   RocmInstallationDetector(const Driver &D, const llvm::Triple &HostTriple,
116                            const llvm::opt::ArgList &Args,
117                            bool DetectHIPRuntime = true,
118                            bool DetectDeviceLib = false);
119 
120   /// Add arguments needed to link default bitcode libraries.
121   void addCommonBitcodeLibCC1Args(const llvm::opt::ArgList &DriverArgs,
122                                   llvm::opt::ArgStringList &CC1Args,
123                                   StringRef LibDeviceFile, bool Wave64,
124                                   bool DAZ, bool FiniteOnly, bool UnsafeMathOpt,
125                                   bool FastRelaxedMath, bool CorrectSqrt) const;
126 
127   /// Emit an error if Version does not support the given Arch.
128   ///
129   /// If either Version or Arch is unknown, does not emit an error.  Emits at
130   /// most one error per Arch.
131   void CheckRocmVersionSupportsArch(CudaArch Arch) const;
132 
133   /// Check whether we detected a valid HIP runtime.
134   bool hasHIPRuntime() const { return HasHIPRuntime; }
135 
136   /// Check whether we detected a valid ROCm device library.
137   bool hasDeviceLibrary() const { return HasDeviceLibrary; }
138 
139   /// Print information about the detected ROCm installation.
140   void print(raw_ostream &OS) const;
141 
142   /// Get the detected Rocm install's version.
143   // RocmVersion version() const { return Version; }
144 
145   /// Get the detected Rocm installation path.
146   StringRef getInstallPath() const { return InstallPath; }
147 
148   /// Get the detected path to Rocm's bin directory.
149   // StringRef getBinPath() const { return BinPath; }
150 
151   /// Get the detected Rocm Include path.
152   StringRef getIncludePath() const { return IncludePath; }
153 
154   /// Get the detected Rocm library path.
155   StringRef getLibPath() const { return LibPath; }
156 
157   /// Get the detected Rocm device library path.
158   StringRef getLibDevicePath() const { return LibDevicePath; }
159 
160   StringRef getOCMLPath() const {
161     assert(!OCML.empty());
162     return OCML;
163   }
164 
165   StringRef getOCKLPath() const {
166     assert(!OCKL.empty());
167     return OCKL;
168   }
169 
170   StringRef getOpenCLPath() const {
171     assert(!OpenCL.empty());
172     return OpenCL;
173   }
174 
175   StringRef getHIPPath() const {
176     assert(!HIP.empty());
177     return HIP;
178   }
179 
180   StringRef getWavefrontSize64Path(bool Enabled) const {
181     return WavefrontSize64.get(Enabled);
182   }
183 
184   StringRef getFiniteOnlyPath(bool Enabled) const {
185     return FiniteOnly.get(Enabled);
186   }
187 
188   StringRef getUnsafeMathPath(bool Enabled) const {
189     return UnsafeMath.get(Enabled);
190   }
191 
192   StringRef getDenormalsAreZeroPath(bool Enabled) const {
193     return DenormalsAreZero.get(Enabled);
194   }
195 
196   StringRef getCorrectlyRoundedSqrtPath(bool Enabled) const {
197     return CorrectlyRoundedSqrt.get(Enabled);
198   }
199 
200   /// Get libdevice file for given architecture
201   std::string getLibDeviceFile(StringRef Gpu) const {
202     return LibDeviceMap.lookup(Gpu);
203   }
204 
205   void AddHIPIncludeArgs(const llvm::opt::ArgList &DriverArgs,
206                          llvm::opt::ArgStringList &CC1Args) const;
207 
208   void detectDeviceLibrary();
209   void detectHIPRuntime();
210 
211   /// Get the values for --rocm-device-lib-path arguments
212   std::vector<std::string> getRocmDeviceLibPathArg() const {
213     return RocmDeviceLibPathArg;
214   }
215 
216   /// Get the value for --rocm-path argument
217   StringRef getRocmPathArg() const { return RocmPathArg; }
218 
219   /// Get the value for --hip-version argument
220   StringRef getHIPVersionArg() const { return HIPVersionArg; }
221 
222   std::string getHIPVersion() const { return DetectedVersion; }
223 };
224 
225 } // end namespace driver
226 } // end namespace clang
227 
228 #endif // LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H
229