xref: /freebsd/contrib/llvm-project/clang/lib/Driver/OffloadBundler.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- OffloadBundler.cpp - File Bundling and Unbundling ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements an offload bundling API that bundles different files
11 /// that relate with the same source code but different targets into a single
12 /// one. Also the implements the opposite functionality, i.e. unbundle files
13 /// previous created by this API.
14 ///
15 //===----------------------------------------------------------------------===//
16 
17 #include "clang/Driver/OffloadBundler.h"
18 #include "clang/Basic/Cuda.h"
19 #include "clang/Basic/TargetID.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/StringMap.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/BinaryFormat/Magic.h"
27 #include "llvm/Object/Archive.h"
28 #include "llvm/Object/ArchiveWriter.h"
29 #include "llvm/Object/Binary.h"
30 #include "llvm/Object/ObjectFile.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/Compiler.h"
33 #include "llvm/Support/Compression.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/EndianStream.h"
36 #include "llvm/Support/Errc.h"
37 #include "llvm/Support/Error.h"
38 #include "llvm/Support/ErrorOr.h"
39 #include "llvm/Support/FileSystem.h"
40 #include "llvm/Support/MD5.h"
41 #include "llvm/Support/ManagedStatic.h"
42 #include "llvm/Support/MemoryBuffer.h"
43 #include "llvm/Support/Path.h"
44 #include "llvm/Support/Program.h"
45 #include "llvm/Support/Signals.h"
46 #include "llvm/Support/StringSaver.h"
47 #include "llvm/Support/Timer.h"
48 #include "llvm/Support/WithColor.h"
49 #include "llvm/Support/raw_ostream.h"
50 #include "llvm/TargetParser/Host.h"
51 #include "llvm/TargetParser/Triple.h"
52 #include <algorithm>
53 #include <cassert>
54 #include <cstddef>
55 #include <cstdint>
56 #include <forward_list>
57 #include <llvm/Support/Process.h>
58 #include <memory>
59 #include <set>
60 #include <string>
61 #include <system_error>
62 #include <utility>
63 
64 using namespace llvm;
65 using namespace llvm::object;
66 using namespace clang;
67 
68 namespace {
69 struct CreateClangOffloadBundlerTimerGroup {
call__anoncaeeea5c0111::CreateClangOffloadBundlerTimerGroup70   static void *call() {
71     return new TimerGroup("Clang Offload Bundler Timer Group",
72                           "Timer group for clang offload bundler");
73   }
74 };
75 } // namespace
76 static llvm::ManagedStatic<llvm::TimerGroup,
77                            CreateClangOffloadBundlerTimerGroup>
78     ClangOffloadBundlerTimerGroup;
79 
80 /// Magic string that marks the existence of offloading data.
81 #define OFFLOAD_BUNDLER_MAGIC_STR "__CLANG_OFFLOAD_BUNDLE__"
82 
OffloadTargetInfo(const StringRef Target,const OffloadBundlerConfig & BC)83 OffloadTargetInfo::OffloadTargetInfo(const StringRef Target,
84                                      const OffloadBundlerConfig &BC)
85     : BundlerConfig(BC) {
86 
87   // <kind>-<triple>[-<target id>[:target features]]
88   // <triple> := <arch>-<vendor>-<os>-<env>
89   SmallVector<StringRef, 6> Components;
90   Target.split(Components, '-', /*MaxSplit=*/5);
91   assert((Components.size() == 5 || Components.size() == 6) &&
92          "malformed target string");
93 
94   StringRef TargetIdWithFeature =
95       Components.size() == 6 ? Components.back() : "";
96   StringRef TargetId = TargetIdWithFeature.split(':').first;
97   if (!TargetId.empty() &&
98       clang::StringToOffloadArch(TargetId) != clang::OffloadArch::UNKNOWN)
99     this->TargetID = TargetIdWithFeature;
100   else
101     this->TargetID = "";
102 
103   this->OffloadKind = Components.front();
104   ArrayRef<StringRef> TripleSlice{&Components[1], /*length=*/4};
105   llvm::Triple T = llvm::Triple(llvm::join(TripleSlice, "-"));
106   this->Triple = llvm::Triple(T.getArchName(), T.getVendorName(), T.getOSName(),
107                               T.getEnvironmentName());
108 }
109 
hasHostKind() const110 bool OffloadTargetInfo::hasHostKind() const {
111   return this->OffloadKind == "host";
112 }
113 
isOffloadKindValid() const114 bool OffloadTargetInfo::isOffloadKindValid() const {
115   return OffloadKind == "host" || OffloadKind == "openmp" ||
116          OffloadKind == "hip" || OffloadKind == "hipv4";
117 }
118 
isOffloadKindCompatible(const StringRef TargetOffloadKind) const119 bool OffloadTargetInfo::isOffloadKindCompatible(
120     const StringRef TargetOffloadKind) const {
121   if ((OffloadKind == TargetOffloadKind) ||
122       (OffloadKind == "hip" && TargetOffloadKind == "hipv4") ||
123       (OffloadKind == "hipv4" && TargetOffloadKind == "hip"))
124     return true;
125 
126   if (BundlerConfig.HipOpenmpCompatible) {
127     bool HIPCompatibleWithOpenMP = OffloadKind.starts_with_insensitive("hip") &&
128                                    TargetOffloadKind == "openmp";
129     bool OpenMPCompatibleWithHIP =
130         OffloadKind == "openmp" &&
131         TargetOffloadKind.starts_with_insensitive("hip");
132     return HIPCompatibleWithOpenMP || OpenMPCompatibleWithHIP;
133   }
134   return false;
135 }
136 
isTripleValid() const137 bool OffloadTargetInfo::isTripleValid() const {
138   return !Triple.str().empty() && Triple.getArch() != Triple::UnknownArch;
139 }
140 
operator ==(const OffloadTargetInfo & Target) const141 bool OffloadTargetInfo::operator==(const OffloadTargetInfo &Target) const {
142   return OffloadKind == Target.OffloadKind &&
143          Triple.isCompatibleWith(Target.Triple) && TargetID == Target.TargetID;
144 }
145 
str() const146 std::string OffloadTargetInfo::str() const {
147   std::string NormalizedTriple;
148   // Unfortunately we need some special sauce for AMDHSA because all the runtime
149   // assumes the triple to be "amdgcn/spirv64-amd-amdhsa-" (empty environment)
150   // instead of "amdgcn/spirv64-amd-amdhsa-unknown". It's gonna be very tricky
151   // to patch different layers of runtime.
152   if (Triple.getOS() == Triple::OSType::AMDHSA) {
153     NormalizedTriple = Triple.normalize(Triple::CanonicalForm::THREE_IDENT);
154     NormalizedTriple.push_back('-');
155   } else {
156     NormalizedTriple = Triple.normalize(Triple::CanonicalForm::FOUR_IDENT);
157   }
158   return Twine(OffloadKind + "-" + NormalizedTriple + "-" + TargetID).str();
159 }
160 
getDeviceFileExtension(StringRef Device,StringRef BundleFileName)161 static StringRef getDeviceFileExtension(StringRef Device,
162                                         StringRef BundleFileName) {
163   if (Device.contains("gfx"))
164     return ".bc";
165   if (Device.contains("sm_"))
166     return ".cubin";
167   return sys::path::extension(BundleFileName);
168 }
169 
getDeviceLibraryFileName(StringRef BundleFileName,StringRef Device)170 static std::string getDeviceLibraryFileName(StringRef BundleFileName,
171                                             StringRef Device) {
172   StringRef LibName = sys::path::stem(BundleFileName);
173   StringRef Extension = getDeviceFileExtension(Device, BundleFileName);
174 
175   std::string Result;
176   Result += LibName;
177   Result += Extension;
178   return Result;
179 }
180 
181 namespace {
182 /// Generic file handler interface.
183 class FileHandler {
184 public:
185   struct BundleInfo {
186     StringRef BundleID;
187   };
188 
FileHandler()189   FileHandler() {}
190 
~FileHandler()191   virtual ~FileHandler() {}
192 
193   /// Update the file handler with information from the header of the bundled
194   /// file.
195   virtual Error ReadHeader(MemoryBuffer &Input) = 0;
196 
197   /// Read the marker of the next bundled to be read in the file. The bundle
198   /// name is returned if there is one in the file, or `std::nullopt` if there
199   /// are no more bundles to be read.
200   virtual Expected<std::optional<StringRef>>
201   ReadBundleStart(MemoryBuffer &Input) = 0;
202 
203   /// Read the marker that closes the current bundle.
204   virtual Error ReadBundleEnd(MemoryBuffer &Input) = 0;
205 
206   /// Read the current bundle and write the result into the stream \a OS.
207   virtual Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) = 0;
208 
209   /// Write the header of the bundled file to \a OS based on the information
210   /// gathered from \a Inputs.
211   virtual Error WriteHeader(raw_ostream &OS,
212                             ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) = 0;
213 
214   /// Write the marker that initiates a bundle for the triple \a TargetTriple to
215   /// \a OS.
216   virtual Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) = 0;
217 
218   /// Write the marker that closes a bundle for the triple \a TargetTriple to \a
219   /// OS.
220   virtual Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) = 0;
221 
222   /// Write the bundle from \a Input into \a OS.
223   virtual Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) = 0;
224 
225   /// Finalize output file.
finalizeOutputFile()226   virtual Error finalizeOutputFile() { return Error::success(); }
227 
228   /// List bundle IDs in \a Input.
listBundleIDs(MemoryBuffer & Input)229   virtual Error listBundleIDs(MemoryBuffer &Input) {
230     if (Error Err = ReadHeader(Input))
231       return Err;
232     return forEachBundle(Input, [&](const BundleInfo &Info) -> Error {
233       llvm::outs() << Info.BundleID << '\n';
234       Error Err = listBundleIDsCallback(Input, Info);
235       if (Err)
236         return Err;
237       return Error::success();
238     });
239   }
240 
241   /// Get bundle IDs in \a Input in \a BundleIds.
getBundleIDs(MemoryBuffer & Input,std::set<StringRef> & BundleIds)242   virtual Error getBundleIDs(MemoryBuffer &Input,
243                              std::set<StringRef> &BundleIds) {
244     if (Error Err = ReadHeader(Input))
245       return Err;
246     return forEachBundle(Input, [&](const BundleInfo &Info) -> Error {
247       BundleIds.insert(Info.BundleID);
248       Error Err = listBundleIDsCallback(Input, Info);
249       if (Err)
250         return Err;
251       return Error::success();
252     });
253   }
254 
255   /// For each bundle in \a Input, do \a Func.
forEachBundle(MemoryBuffer & Input,std::function<Error (const BundleInfo &)> Func)256   Error forEachBundle(MemoryBuffer &Input,
257                       std::function<Error(const BundleInfo &)> Func) {
258     while (true) {
259       Expected<std::optional<StringRef>> CurTripleOrErr =
260           ReadBundleStart(Input);
261       if (!CurTripleOrErr)
262         return CurTripleOrErr.takeError();
263 
264       // No more bundles.
265       if (!*CurTripleOrErr)
266         break;
267 
268       StringRef CurTriple = **CurTripleOrErr;
269       assert(!CurTriple.empty());
270 
271       BundleInfo Info{CurTriple};
272       if (Error Err = Func(Info))
273         return Err;
274     }
275     return Error::success();
276   }
277 
278 protected:
listBundleIDsCallback(MemoryBuffer & Input,const BundleInfo & Info)279   virtual Error listBundleIDsCallback(MemoryBuffer &Input,
280                                       const BundleInfo &Info) {
281     return Error::success();
282   }
283 };
284 
285 /// Handler for binary files. The bundled file will have the following format
286 /// (all integers are stored in little-endian format):
287 ///
288 /// "OFFLOAD_BUNDLER_MAGIC_STR" (ASCII encoding of the string)
289 ///
290 /// NumberOfOffloadBundles (8-byte integer)
291 ///
292 /// OffsetOfBundle1 (8-byte integer)
293 /// SizeOfBundle1 (8-byte integer)
294 /// NumberOfBytesInTripleOfBundle1 (8-byte integer)
295 /// TripleOfBundle1 (byte length defined before)
296 ///
297 /// ...
298 ///
299 /// OffsetOfBundleN (8-byte integer)
300 /// SizeOfBundleN (8-byte integer)
301 /// NumberOfBytesInTripleOfBundleN (8-byte integer)
302 /// TripleOfBundleN (byte length defined before)
303 ///
304 /// Bundle1
305 /// ...
306 /// BundleN
307 
308 /// Read 8-byte integers from a buffer in little-endian format.
Read8byteIntegerFromBuffer(StringRef Buffer,size_t pos)309 static uint64_t Read8byteIntegerFromBuffer(StringRef Buffer, size_t pos) {
310   return llvm::support::endian::read64le(Buffer.data() + pos);
311 }
312 
313 /// Write 8-byte integers to a buffer in little-endian format.
Write8byteIntegerToBuffer(raw_ostream & OS,uint64_t Val)314 static void Write8byteIntegerToBuffer(raw_ostream &OS, uint64_t Val) {
315   llvm::support::endian::write(OS, Val, llvm::endianness::little);
316 }
317 
318 class BinaryFileHandler final : public FileHandler {
319   /// Information about the bundles extracted from the header.
320   struct BinaryBundleInfo final : public BundleInfo {
321     /// Size of the bundle.
322     uint64_t Size = 0u;
323     /// Offset at which the bundle starts in the bundled file.
324     uint64_t Offset = 0u;
325 
BinaryBundleInfo__anoncaeeea5c0211::BinaryFileHandler::BinaryBundleInfo326     BinaryBundleInfo() {}
BinaryBundleInfo__anoncaeeea5c0211::BinaryFileHandler::BinaryBundleInfo327     BinaryBundleInfo(uint64_t Size, uint64_t Offset)
328         : Size(Size), Offset(Offset) {}
329   };
330 
331   /// Map between a triple and the corresponding bundle information.
332   StringMap<BinaryBundleInfo> BundlesInfo;
333 
334   /// Iterator for the bundle information that is being read.
335   StringMap<BinaryBundleInfo>::iterator CurBundleInfo;
336   StringMap<BinaryBundleInfo>::iterator NextBundleInfo;
337 
338   /// Current bundle target to be written.
339   std::string CurWriteBundleTarget;
340 
341   /// Configuration options and arrays for this bundler job
342   const OffloadBundlerConfig &BundlerConfig;
343 
344 public:
345   // TODO: Add error checking from ClangOffloadBundler.cpp
BinaryFileHandler(const OffloadBundlerConfig & BC)346   BinaryFileHandler(const OffloadBundlerConfig &BC) : BundlerConfig(BC) {}
347 
~BinaryFileHandler()348   ~BinaryFileHandler() final {}
349 
ReadHeader(MemoryBuffer & Input)350   Error ReadHeader(MemoryBuffer &Input) final {
351     StringRef FC = Input.getBuffer();
352 
353     // Initialize the current bundle with the end of the container.
354     CurBundleInfo = BundlesInfo.end();
355 
356     // Check if buffer is smaller than magic string.
357     size_t ReadChars = sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1;
358     if (ReadChars > FC.size())
359       return Error::success();
360 
361     // Check if no magic was found.
362     if (llvm::identify_magic(FC) != llvm::file_magic::offload_bundle)
363       return Error::success();
364 
365     // Read number of bundles.
366     if (ReadChars + 8 > FC.size())
367       return Error::success();
368 
369     uint64_t NumberOfBundles = Read8byteIntegerFromBuffer(FC, ReadChars);
370     ReadChars += 8;
371 
372     // Read bundle offsets, sizes and triples.
373     for (uint64_t i = 0; i < NumberOfBundles; ++i) {
374 
375       // Read offset.
376       if (ReadChars + 8 > FC.size())
377         return Error::success();
378 
379       uint64_t Offset = Read8byteIntegerFromBuffer(FC, ReadChars);
380       ReadChars += 8;
381 
382       // Read size.
383       if (ReadChars + 8 > FC.size())
384         return Error::success();
385 
386       uint64_t Size = Read8byteIntegerFromBuffer(FC, ReadChars);
387       ReadChars += 8;
388 
389       // Read triple size.
390       if (ReadChars + 8 > FC.size())
391         return Error::success();
392 
393       uint64_t TripleSize = Read8byteIntegerFromBuffer(FC, ReadChars);
394       ReadChars += 8;
395 
396       // Read triple.
397       if (ReadChars + TripleSize > FC.size())
398         return Error::success();
399 
400       StringRef Triple(&FC.data()[ReadChars], TripleSize);
401       ReadChars += TripleSize;
402 
403       // Check if the offset and size make sense.
404       if (!Offset || Offset + Size > FC.size())
405         return Error::success();
406 
407       assert(!BundlesInfo.contains(Triple) && "Triple is duplicated??");
408       BundlesInfo[Triple] = BinaryBundleInfo(Size, Offset);
409     }
410     // Set the iterator to where we will start to read.
411     CurBundleInfo = BundlesInfo.end();
412     NextBundleInfo = BundlesInfo.begin();
413     return Error::success();
414   }
415 
416   Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer & Input)417   ReadBundleStart(MemoryBuffer &Input) final {
418     if (NextBundleInfo == BundlesInfo.end())
419       return std::nullopt;
420     CurBundleInfo = NextBundleInfo++;
421     return CurBundleInfo->first();
422   }
423 
ReadBundleEnd(MemoryBuffer & Input)424   Error ReadBundleEnd(MemoryBuffer &Input) final {
425     assert(CurBundleInfo != BundlesInfo.end() && "Invalid reader info!");
426     return Error::success();
427   }
428 
ReadBundle(raw_ostream & OS,MemoryBuffer & Input)429   Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
430     assert(CurBundleInfo != BundlesInfo.end() && "Invalid reader info!");
431     StringRef FC = Input.getBuffer();
432     OS.write(FC.data() + CurBundleInfo->second.Offset,
433              CurBundleInfo->second.Size);
434     return Error::success();
435   }
436 
WriteHeader(raw_ostream & OS,ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs)437   Error WriteHeader(raw_ostream &OS,
438                     ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
439 
440     // Compute size of the header.
441     uint64_t HeaderSize = 0;
442 
443     HeaderSize += sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1;
444     HeaderSize += 8; // Number of Bundles
445 
446     for (auto &T : BundlerConfig.TargetNames) {
447       HeaderSize += 3 * 8; // Bundle offset, Size of bundle and size of triple.
448       HeaderSize += T.size(); // The triple.
449     }
450 
451     // Write to the buffer the header.
452     OS << OFFLOAD_BUNDLER_MAGIC_STR;
453 
454     Write8byteIntegerToBuffer(OS, BundlerConfig.TargetNames.size());
455 
456     unsigned Idx = 0;
457     for (auto &T : BundlerConfig.TargetNames) {
458       MemoryBuffer &MB = *Inputs[Idx++];
459       HeaderSize = alignTo(HeaderSize, BundlerConfig.BundleAlignment);
460       // Bundle offset.
461       Write8byteIntegerToBuffer(OS, HeaderSize);
462       // Size of the bundle (adds to the next bundle's offset)
463       Write8byteIntegerToBuffer(OS, MB.getBufferSize());
464       BundlesInfo[T] = BinaryBundleInfo(MB.getBufferSize(), HeaderSize);
465       HeaderSize += MB.getBufferSize();
466       // Size of the triple
467       Write8byteIntegerToBuffer(OS, T.size());
468       // Triple
469       OS << T;
470     }
471     return Error::success();
472   }
473 
WriteBundleStart(raw_ostream & OS,StringRef TargetTriple)474   Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
475     CurWriteBundleTarget = TargetTriple.str();
476     return Error::success();
477   }
478 
WriteBundleEnd(raw_ostream & OS,StringRef TargetTriple)479   Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
480     return Error::success();
481   }
482 
WriteBundle(raw_ostream & OS,MemoryBuffer & Input)483   Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
484     auto BI = BundlesInfo[CurWriteBundleTarget];
485 
486     // Pad with 0 to reach specified offset.
487     size_t CurrentPos = OS.tell();
488     size_t PaddingSize = BI.Offset > CurrentPos ? BI.Offset - CurrentPos : 0;
489     for (size_t I = 0; I < PaddingSize; ++I)
490       OS.write('\0');
491     assert(OS.tell() == BI.Offset);
492 
493     OS.write(Input.getBufferStart(), Input.getBufferSize());
494 
495     return Error::success();
496   }
497 };
498 
499 // This class implements a list of temporary files that are removed upon
500 // object destruction.
501 class TempFileHandlerRAII {
502 public:
~TempFileHandlerRAII()503   ~TempFileHandlerRAII() {
504     for (const auto &File : Files)
505       sys::fs::remove(File);
506   }
507 
508   // Creates temporary file with given contents.
Create(std::optional<ArrayRef<char>> Contents)509   Expected<StringRef> Create(std::optional<ArrayRef<char>> Contents) {
510     SmallString<128u> File;
511     if (std::error_code EC =
512             sys::fs::createTemporaryFile("clang-offload-bundler", "tmp", File))
513       return createFileError(File, EC);
514     Files.push_front(File);
515 
516     if (Contents) {
517       std::error_code EC;
518       raw_fd_ostream OS(File, EC);
519       if (EC)
520         return createFileError(File, EC);
521       OS.write(Contents->data(), Contents->size());
522     }
523     return Files.front().str();
524   }
525 
526 private:
527   std::forward_list<SmallString<128u>> Files;
528 };
529 
530 /// Handler for object files. The bundles are organized by sections with a
531 /// designated name.
532 ///
533 /// To unbundle, we just copy the contents of the designated section.
534 class ObjectFileHandler final : public FileHandler {
535 
536   /// The object file we are currently dealing with.
537   std::unique_ptr<ObjectFile> Obj;
538 
539   /// Return the input file contents.
getInputFileContents() const540   StringRef getInputFileContents() const { return Obj->getData(); }
541 
542   /// Return bundle name (<kind>-<triple>) if the provided section is an offload
543   /// section.
544   static Expected<std::optional<StringRef>>
IsOffloadSection(SectionRef CurSection)545   IsOffloadSection(SectionRef CurSection) {
546     Expected<StringRef> NameOrErr = CurSection.getName();
547     if (!NameOrErr)
548       return NameOrErr.takeError();
549 
550     // If it does not start with the reserved suffix, just skip this section.
551     if (llvm::identify_magic(*NameOrErr) != llvm::file_magic::offload_bundle)
552       return std::nullopt;
553 
554     // Return the triple that is right after the reserved prefix.
555     return NameOrErr->substr(sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1);
556   }
557 
558   /// Total number of inputs.
559   unsigned NumberOfInputs = 0;
560 
561   /// Total number of processed inputs, i.e, inputs that were already
562   /// read from the buffers.
563   unsigned NumberOfProcessedInputs = 0;
564 
565   /// Iterator of the current and next section.
566   section_iterator CurrentSection;
567   section_iterator NextSection;
568 
569   /// Configuration options and arrays for this bundler job
570   const OffloadBundlerConfig &BundlerConfig;
571 
572 public:
573   // TODO: Add error checking from ClangOffloadBundler.cpp
ObjectFileHandler(std::unique_ptr<ObjectFile> ObjIn,const OffloadBundlerConfig & BC)574   ObjectFileHandler(std::unique_ptr<ObjectFile> ObjIn,
575                     const OffloadBundlerConfig &BC)
576       : Obj(std::move(ObjIn)), CurrentSection(Obj->section_begin()),
577         NextSection(Obj->section_begin()), BundlerConfig(BC) {}
578 
~ObjectFileHandler()579   ~ObjectFileHandler() final {}
580 
ReadHeader(MemoryBuffer & Input)581   Error ReadHeader(MemoryBuffer &Input) final { return Error::success(); }
582 
583   Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer & Input)584   ReadBundleStart(MemoryBuffer &Input) final {
585     while (NextSection != Obj->section_end()) {
586       CurrentSection = NextSection;
587       ++NextSection;
588 
589       // Check if the current section name starts with the reserved prefix. If
590       // so, return the triple.
591       Expected<std::optional<StringRef>> TripleOrErr =
592           IsOffloadSection(*CurrentSection);
593       if (!TripleOrErr)
594         return TripleOrErr.takeError();
595       if (*TripleOrErr)
596         return **TripleOrErr;
597     }
598     return std::nullopt;
599   }
600 
ReadBundleEnd(MemoryBuffer & Input)601   Error ReadBundleEnd(MemoryBuffer &Input) final { return Error::success(); }
602 
ReadBundle(raw_ostream & OS,MemoryBuffer & Input)603   Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
604     Expected<StringRef> ContentOrErr = CurrentSection->getContents();
605     if (!ContentOrErr)
606       return ContentOrErr.takeError();
607     StringRef Content = *ContentOrErr;
608 
609     // Copy fat object contents to the output when extracting host bundle.
610     std::string ModifiedContent;
611     if (Content.size() == 1u && Content.front() == 0) {
612       auto HostBundleOrErr = getHostBundle(
613           StringRef(Input.getBufferStart(), Input.getBufferSize()));
614       if (!HostBundleOrErr)
615         return HostBundleOrErr.takeError();
616 
617       ModifiedContent = std::move(*HostBundleOrErr);
618       Content = ModifiedContent;
619     }
620 
621     OS.write(Content.data(), Content.size());
622     return Error::success();
623   }
624 
WriteHeader(raw_ostream & OS,ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs)625   Error WriteHeader(raw_ostream &OS,
626                     ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
627     assert(BundlerConfig.HostInputIndex != ~0u &&
628            "Host input index not defined.");
629 
630     // Record number of inputs.
631     NumberOfInputs = Inputs.size();
632     return Error::success();
633   }
634 
WriteBundleStart(raw_ostream & OS,StringRef TargetTriple)635   Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
636     ++NumberOfProcessedInputs;
637     return Error::success();
638   }
639 
WriteBundleEnd(raw_ostream & OS,StringRef TargetTriple)640   Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
641     return Error::success();
642   }
643 
finalizeOutputFile()644   Error finalizeOutputFile() final {
645     assert(NumberOfProcessedInputs <= NumberOfInputs &&
646            "Processing more inputs that actually exist!");
647     assert(BundlerConfig.HostInputIndex != ~0u &&
648            "Host input index not defined.");
649 
650     // If this is not the last output, we don't have to do anything.
651     if (NumberOfProcessedInputs != NumberOfInputs)
652       return Error::success();
653 
654     // We will use llvm-objcopy to add target objects sections to the output
655     // fat object. These sections should have 'exclude' flag set which tells
656     // link editor to remove them from linker inputs when linking executable or
657     // shared library.
658 
659     assert(BundlerConfig.ObjcopyPath != "" &&
660            "llvm-objcopy path not specified");
661 
662     // Temporary files that need to be removed.
663     TempFileHandlerRAII TempFiles;
664 
665     // Compose llvm-objcopy command line for add target objects' sections with
666     // appropriate flags.
667     BumpPtrAllocator Alloc;
668     StringSaver SS{Alloc};
669     SmallVector<StringRef, 8u> ObjcopyArgs{"llvm-objcopy"};
670 
671     for (unsigned I = 0; I < NumberOfInputs; ++I) {
672       StringRef InputFile = BundlerConfig.InputFileNames[I];
673       if (I == BundlerConfig.HostInputIndex) {
674         // Special handling for the host bundle. We do not need to add a
675         // standard bundle for the host object since we are going to use fat
676         // object as a host object. Therefore use dummy contents (one zero byte)
677         // when creating section for the host bundle.
678         Expected<StringRef> TempFileOrErr = TempFiles.Create(ArrayRef<char>(0));
679         if (!TempFileOrErr)
680           return TempFileOrErr.takeError();
681         InputFile = *TempFileOrErr;
682       }
683 
684       ObjcopyArgs.push_back(
685           SS.save(Twine("--add-section=") + OFFLOAD_BUNDLER_MAGIC_STR +
686                   BundlerConfig.TargetNames[I] + "=" + InputFile));
687       ObjcopyArgs.push_back(
688           SS.save(Twine("--set-section-flags=") + OFFLOAD_BUNDLER_MAGIC_STR +
689                   BundlerConfig.TargetNames[I] + "=readonly,exclude"));
690     }
691     ObjcopyArgs.push_back("--");
692     ObjcopyArgs.push_back(
693         BundlerConfig.InputFileNames[BundlerConfig.HostInputIndex]);
694     ObjcopyArgs.push_back(BundlerConfig.OutputFileNames.front());
695 
696     if (Error Err = executeObjcopy(BundlerConfig.ObjcopyPath, ObjcopyArgs))
697       return Err;
698 
699     return Error::success();
700   }
701 
WriteBundle(raw_ostream & OS,MemoryBuffer & Input)702   Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
703     return Error::success();
704   }
705 
706 private:
executeObjcopy(StringRef Objcopy,ArrayRef<StringRef> Args)707   Error executeObjcopy(StringRef Objcopy, ArrayRef<StringRef> Args) {
708     // If the user asked for the commands to be printed out, we do that
709     // instead of executing it.
710     if (BundlerConfig.PrintExternalCommands) {
711       errs() << "\"" << Objcopy << "\"";
712       for (StringRef Arg : drop_begin(Args, 1))
713         errs() << " \"" << Arg << "\"";
714       errs() << "\n";
715     } else {
716       if (sys::ExecuteAndWait(Objcopy, Args))
717         return createStringError(inconvertibleErrorCode(),
718                                  "'llvm-objcopy' tool failed");
719     }
720     return Error::success();
721   }
722 
getHostBundle(StringRef Input)723   Expected<std::string> getHostBundle(StringRef Input) {
724     TempFileHandlerRAII TempFiles;
725 
726     auto ModifiedObjPathOrErr = TempFiles.Create(std::nullopt);
727     if (!ModifiedObjPathOrErr)
728       return ModifiedObjPathOrErr.takeError();
729     StringRef ModifiedObjPath = *ModifiedObjPathOrErr;
730 
731     BumpPtrAllocator Alloc;
732     StringSaver SS{Alloc};
733     SmallVector<StringRef, 16> ObjcopyArgs{"llvm-objcopy"};
734 
735     ObjcopyArgs.push_back("--regex");
736     ObjcopyArgs.push_back("--remove-section=__CLANG_OFFLOAD_BUNDLE__.*");
737     ObjcopyArgs.push_back("--");
738 
739     StringRef ObjcopyInputFileName;
740     // When unbundling an archive, the content of each object file in the
741     // archive is passed to this function by parameter Input, which is different
742     // from the content of the original input archive file, therefore it needs
743     // to be saved to a temporary file before passed to llvm-objcopy. Otherwise,
744     // Input is the same as the content of the original input file, therefore
745     // temporary file is not needed.
746     if (StringRef(BundlerConfig.FilesType).starts_with("a")) {
747       auto InputFileOrErr = TempFiles.Create(ArrayRef<char>(Input));
748       if (!InputFileOrErr)
749         return InputFileOrErr.takeError();
750       ObjcopyInputFileName = *InputFileOrErr;
751     } else
752       ObjcopyInputFileName = BundlerConfig.InputFileNames.front();
753 
754     ObjcopyArgs.push_back(ObjcopyInputFileName);
755     ObjcopyArgs.push_back(ModifiedObjPath);
756 
757     if (Error Err = executeObjcopy(BundlerConfig.ObjcopyPath, ObjcopyArgs))
758       return std::move(Err);
759 
760     auto BufOrErr = MemoryBuffer::getFile(ModifiedObjPath);
761     if (!BufOrErr)
762       return createStringError(BufOrErr.getError(),
763                                "Failed to read back the modified object file");
764 
765     return BufOrErr->get()->getBuffer().str();
766   }
767 };
768 
769 /// Handler for text files. The bundled file will have the following format.
770 ///
771 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__START__ triple"
772 /// Bundle 1
773 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__END__ triple"
774 /// ...
775 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__START__ triple"
776 /// Bundle N
777 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__END__ triple"
778 class TextFileHandler final : public FileHandler {
779   /// String that begins a line comment.
780   StringRef Comment;
781 
782   /// String that initiates a bundle.
783   std::string BundleStartString;
784 
785   /// String that closes a bundle.
786   std::string BundleEndString;
787 
788   /// Number of chars read from input.
789   size_t ReadChars = 0u;
790 
791 protected:
ReadHeader(MemoryBuffer & Input)792   Error ReadHeader(MemoryBuffer &Input) final { return Error::success(); }
793 
794   Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer & Input)795   ReadBundleStart(MemoryBuffer &Input) final {
796     StringRef FC = Input.getBuffer();
797 
798     // Find start of the bundle.
799     ReadChars = FC.find(BundleStartString, ReadChars);
800     if (ReadChars == FC.npos)
801       return std::nullopt;
802 
803     // Get position of the triple.
804     size_t TripleStart = ReadChars = ReadChars + BundleStartString.size();
805 
806     // Get position that closes the triple.
807     size_t TripleEnd = ReadChars = FC.find("\n", ReadChars);
808     if (TripleEnd == FC.npos)
809       return std::nullopt;
810 
811     // Next time we read after the new line.
812     ++ReadChars;
813 
814     return StringRef(&FC.data()[TripleStart], TripleEnd - TripleStart);
815   }
816 
ReadBundleEnd(MemoryBuffer & Input)817   Error ReadBundleEnd(MemoryBuffer &Input) final {
818     StringRef FC = Input.getBuffer();
819 
820     // Read up to the next new line.
821     assert(FC[ReadChars] == '\n' && "The bundle should end with a new line.");
822 
823     size_t TripleEnd = ReadChars = FC.find("\n", ReadChars + 1);
824     if (TripleEnd != FC.npos)
825       // Next time we read after the new line.
826       ++ReadChars;
827 
828     return Error::success();
829   }
830 
ReadBundle(raw_ostream & OS,MemoryBuffer & Input)831   Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
832     StringRef FC = Input.getBuffer();
833     size_t BundleStart = ReadChars;
834 
835     // Find end of the bundle.
836     size_t BundleEnd = ReadChars = FC.find(BundleEndString, ReadChars);
837 
838     StringRef Bundle(&FC.data()[BundleStart], BundleEnd - BundleStart);
839     OS << Bundle;
840 
841     return Error::success();
842   }
843 
WriteHeader(raw_ostream & OS,ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs)844   Error WriteHeader(raw_ostream &OS,
845                     ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
846     return Error::success();
847   }
848 
WriteBundleStart(raw_ostream & OS,StringRef TargetTriple)849   Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
850     OS << BundleStartString << TargetTriple << "\n";
851     return Error::success();
852   }
853 
WriteBundleEnd(raw_ostream & OS,StringRef TargetTriple)854   Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
855     OS << BundleEndString << TargetTriple << "\n";
856     return Error::success();
857   }
858 
WriteBundle(raw_ostream & OS,MemoryBuffer & Input)859   Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
860     OS << Input.getBuffer();
861     return Error::success();
862   }
863 
864 public:
TextFileHandler(StringRef Comment)865   TextFileHandler(StringRef Comment) : Comment(Comment), ReadChars(0) {
866     BundleStartString =
867         "\n" + Comment.str() + " " OFFLOAD_BUNDLER_MAGIC_STR "__START__ ";
868     BundleEndString =
869         "\n" + Comment.str() + " " OFFLOAD_BUNDLER_MAGIC_STR "__END__ ";
870   }
871 
listBundleIDsCallback(MemoryBuffer & Input,const BundleInfo & Info)872   Error listBundleIDsCallback(MemoryBuffer &Input,
873                               const BundleInfo &Info) final {
874     // TODO: To list bundle IDs in a bundled text file we need to go through
875     // all bundles. The format of bundled text file may need to include a
876     // header if the performance of listing bundle IDs of bundled text file is
877     // important.
878     ReadChars = Input.getBuffer().find(BundleEndString, ReadChars);
879     if (Error Err = ReadBundleEnd(Input))
880       return Err;
881     return Error::success();
882   }
883 };
884 } // namespace
885 
886 /// Return an appropriate object file handler. We use the specific object
887 /// handler if we know how to deal with that format, otherwise we use a default
888 /// binary file handler.
889 static std::unique_ptr<FileHandler>
CreateObjectFileHandler(MemoryBuffer & FirstInput,const OffloadBundlerConfig & BundlerConfig)890 CreateObjectFileHandler(MemoryBuffer &FirstInput,
891                         const OffloadBundlerConfig &BundlerConfig) {
892   // Check if the input file format is one that we know how to deal with.
893   Expected<std::unique_ptr<Binary>> BinaryOrErr = createBinary(FirstInput);
894 
895   // We only support regular object files. If failed to open the input as a
896   // known binary or this is not an object file use the default binary handler.
897   if (errorToBool(BinaryOrErr.takeError()) || !isa<ObjectFile>(*BinaryOrErr))
898     return std::make_unique<BinaryFileHandler>(BundlerConfig);
899 
900   // Otherwise create an object file handler. The handler will be owned by the
901   // client of this function.
902   return std::make_unique<ObjectFileHandler>(
903       std::unique_ptr<ObjectFile>(cast<ObjectFile>(BinaryOrErr->release())),
904       BundlerConfig);
905 }
906 
907 /// Return an appropriate handler given the input files and options.
908 static Expected<std::unique_ptr<FileHandler>>
CreateFileHandler(MemoryBuffer & FirstInput,const OffloadBundlerConfig & BundlerConfig)909 CreateFileHandler(MemoryBuffer &FirstInput,
910                   const OffloadBundlerConfig &BundlerConfig) {
911   std::string FilesType = BundlerConfig.FilesType;
912 
913   if (FilesType == "i")
914     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
915   if (FilesType == "ii")
916     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
917   if (FilesType == "cui")
918     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
919   if (FilesType == "hipi")
920     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
921   // TODO: `.d` should be eventually removed once `-M` and its variants are
922   // handled properly in offload compilation.
923   if (FilesType == "d")
924     return std::make_unique<TextFileHandler>(/*Comment=*/"#");
925   if (FilesType == "ll")
926     return std::make_unique<TextFileHandler>(/*Comment=*/";");
927   if (FilesType == "bc")
928     return std::make_unique<BinaryFileHandler>(BundlerConfig);
929   if (FilesType == "s")
930     return std::make_unique<TextFileHandler>(/*Comment=*/"#");
931   if (FilesType == "o")
932     return CreateObjectFileHandler(FirstInput, BundlerConfig);
933   if (FilesType == "a")
934     return CreateObjectFileHandler(FirstInput, BundlerConfig);
935   if (FilesType == "gch")
936     return std::make_unique<BinaryFileHandler>(BundlerConfig);
937   if (FilesType == "ast")
938     return std::make_unique<BinaryFileHandler>(BundlerConfig);
939 
940   return createStringError(errc::invalid_argument,
941                            "'" + FilesType + "': invalid file type specified");
942 }
943 
OffloadBundlerConfig()944 OffloadBundlerConfig::OffloadBundlerConfig()
945     : CompressedBundleVersion(CompressedOffloadBundle::DefaultVersion) {
946   if (llvm::compression::zstd::isAvailable()) {
947     CompressionFormat = llvm::compression::Format::Zstd;
948     // Compression level 3 is usually sufficient for zstd since long distance
949     // matching is enabled.
950     CompressionLevel = 3;
951   } else if (llvm::compression::zlib::isAvailable()) {
952     CompressionFormat = llvm::compression::Format::Zlib;
953     // Use default level for zlib since higher level does not have significant
954     // improvement.
955     CompressionLevel = llvm::compression::zlib::DefaultCompression;
956   }
957   auto IgnoreEnvVarOpt =
958       llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_IGNORE_ENV_VAR");
959   if (IgnoreEnvVarOpt.has_value() && IgnoreEnvVarOpt.value() == "1")
960     return;
961   auto VerboseEnvVarOpt = llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_VERBOSE");
962   if (VerboseEnvVarOpt.has_value())
963     Verbose = VerboseEnvVarOpt.value() == "1";
964   auto CompressEnvVarOpt =
965       llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_COMPRESS");
966   if (CompressEnvVarOpt.has_value())
967     Compress = CompressEnvVarOpt.value() == "1";
968   auto CompressionLevelEnvVarOpt =
969       llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_COMPRESSION_LEVEL");
970   if (CompressionLevelEnvVarOpt.has_value()) {
971     llvm::StringRef CompressionLevelStr = CompressionLevelEnvVarOpt.value();
972     int Level;
973     if (!CompressionLevelStr.getAsInteger(10, Level))
974       CompressionLevel = Level;
975     else
976       llvm::errs()
977           << "Warning: Invalid value for OFFLOAD_BUNDLER_COMPRESSION_LEVEL: "
978           << CompressionLevelStr.str() << ". Ignoring it.\n";
979   }
980   auto CompressedBundleFormatVersionOpt =
981       llvm::sys::Process::GetEnv("COMPRESSED_BUNDLE_FORMAT_VERSION");
982   if (CompressedBundleFormatVersionOpt.has_value()) {
983     llvm::StringRef VersionStr = CompressedBundleFormatVersionOpt.value();
984     uint16_t Version;
985     if (!VersionStr.getAsInteger(10, Version)) {
986       if (Version >= 2 && Version <= 3)
987         CompressedBundleVersion = Version;
988       else
989         llvm::errs()
990             << "Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
991             << VersionStr.str()
992             << ". Valid values are 2 or 3. Using default version "
993             << CompressedBundleVersion << ".\n";
994     } else
995       llvm::errs()
996           << "Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
997           << VersionStr.str() << ". Using default version "
998           << CompressedBundleVersion << ".\n";
999   }
1000 }
1001 
1002 // Utility function to format numbers with commas
formatWithCommas(unsigned long long Value)1003 static std::string formatWithCommas(unsigned long long Value) {
1004   std::string Num = std::to_string(Value);
1005   int InsertPosition = Num.length() - 3;
1006   while (InsertPosition > 0) {
1007     Num.insert(InsertPosition, ",");
1008     InsertPosition -= 3;
1009   }
1010   return Num;
1011 }
1012 
1013 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
compress(llvm::compression::Params P,const llvm::MemoryBuffer & Input,uint16_t Version,bool Verbose)1014 CompressedOffloadBundle::compress(llvm::compression::Params P,
1015                                   const llvm::MemoryBuffer &Input,
1016                                   uint16_t Version, bool Verbose) {
1017   if (!llvm::compression::zstd::isAvailable() &&
1018       !llvm::compression::zlib::isAvailable())
1019     return createStringError(llvm::inconvertibleErrorCode(),
1020                              "Compression not supported");
1021   llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
1022                         *ClangOffloadBundlerTimerGroup);
1023   if (Verbose)
1024     HashTimer.startTimer();
1025   llvm::MD5 Hash;
1026   llvm::MD5::MD5Result Result;
1027   Hash.update(Input.getBuffer());
1028   Hash.final(Result);
1029   uint64_t TruncatedHash = Result.low();
1030   if (Verbose)
1031     HashTimer.stopTimer();
1032 
1033   SmallVector<uint8_t, 0> CompressedBuffer;
1034   auto BufferUint8 = llvm::ArrayRef<uint8_t>(
1035       reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
1036       Input.getBuffer().size());
1037   llvm::Timer CompressTimer("Compression Timer", "Compression time",
1038                             *ClangOffloadBundlerTimerGroup);
1039   if (Verbose)
1040     CompressTimer.startTimer();
1041   llvm::compression::compress(P, BufferUint8, CompressedBuffer);
1042   if (Verbose)
1043     CompressTimer.stopTimer();
1044 
1045   uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
1046 
1047   // Store sizes in 64-bit variables first
1048   uint64_t UncompressedSize64 = Input.getBuffer().size();
1049   uint64_t TotalFileSize64;
1050 
1051   // Calculate total file size based on version
1052   if (Version == 2) {
1053     // For V2, ensure the sizes don't exceed 32-bit limit
1054     if (UncompressedSize64 > std::numeric_limits<uint32_t>::max())
1055       return createStringError(llvm::inconvertibleErrorCode(),
1056                                "Uncompressed size exceeds version 2 limit");
1057     if ((MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
1058          sizeof(CompressionMethod) + sizeof(uint32_t) + sizeof(TruncatedHash) +
1059          CompressedBuffer.size()) > std::numeric_limits<uint32_t>::max())
1060       return createStringError(llvm::inconvertibleErrorCode(),
1061                                "Total file size exceeds version 2 limit");
1062 
1063     TotalFileSize64 = MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
1064                       sizeof(CompressionMethod) + sizeof(uint32_t) +
1065                       sizeof(TruncatedHash) + CompressedBuffer.size();
1066   } else { // Version 3
1067     TotalFileSize64 = MagicNumber.size() + sizeof(uint64_t) + sizeof(Version) +
1068                       sizeof(CompressionMethod) + sizeof(uint64_t) +
1069                       sizeof(TruncatedHash) + CompressedBuffer.size();
1070   }
1071 
1072   SmallVector<char, 0> FinalBuffer;
1073   llvm::raw_svector_ostream OS(FinalBuffer);
1074   OS << MagicNumber;
1075   OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
1076   OS.write(reinterpret_cast<const char *>(&CompressionMethod),
1077            sizeof(CompressionMethod));
1078 
1079   // Write size fields according to version
1080   if (Version == 2) {
1081     uint32_t TotalFileSize32 = static_cast<uint32_t>(TotalFileSize64);
1082     uint32_t UncompressedSize32 = static_cast<uint32_t>(UncompressedSize64);
1083     OS.write(reinterpret_cast<const char *>(&TotalFileSize32),
1084              sizeof(TotalFileSize32));
1085     OS.write(reinterpret_cast<const char *>(&UncompressedSize32),
1086              sizeof(UncompressedSize32));
1087   } else { // Version 3
1088     OS.write(reinterpret_cast<const char *>(&TotalFileSize64),
1089              sizeof(TotalFileSize64));
1090     OS.write(reinterpret_cast<const char *>(&UncompressedSize64),
1091              sizeof(UncompressedSize64));
1092   }
1093 
1094   OS.write(reinterpret_cast<const char *>(&TruncatedHash),
1095            sizeof(TruncatedHash));
1096   OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
1097            CompressedBuffer.size());
1098 
1099   if (Verbose) {
1100     auto MethodUsed =
1101         P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
1102     double CompressionRate =
1103         static_cast<double>(UncompressedSize64) / CompressedBuffer.size();
1104     double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
1105     double CompressionSpeedMBs =
1106         (UncompressedSize64 / (1024.0 * 1024.0)) / CompressionTimeSeconds;
1107     llvm::errs() << "Compressed bundle format version: " << Version << "\n"
1108                  << "Total file size (including headers): "
1109                  << formatWithCommas(TotalFileSize64) << " bytes\n"
1110                  << "Compression method used: " << MethodUsed << "\n"
1111                  << "Compression level: " << P.level << "\n"
1112                  << "Binary size before compression: "
1113                  << formatWithCommas(UncompressedSize64) << " bytes\n"
1114                  << "Binary size after compression: "
1115                  << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
1116                  << "Compression rate: "
1117                  << llvm::format("%.2lf", CompressionRate) << "\n"
1118                  << "Compression ratio: "
1119                  << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
1120                  << "Compression speed: "
1121                  << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
1122                  << "Truncated MD5 hash: "
1123                  << llvm::format_hex(TruncatedHash, 16) << "\n";
1124   }
1125 
1126   return llvm::MemoryBuffer::getMemBufferCopy(
1127       llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
1128 }
1129 
1130 // Use packed structs to avoid padding, such that the structs map the serialized
1131 // format.
1132 LLVM_PACKED_START
1133 union RawCompressedBundleHeader {
1134   struct CommonFields {
1135     uint32_t Magic;
1136     uint16_t Version;
1137     uint16_t Method;
1138   };
1139 
1140   struct V1Header {
1141     CommonFields Common;
1142     uint32_t UncompressedFileSize;
1143     uint64_t Hash;
1144   };
1145 
1146   struct V2Header {
1147     CommonFields Common;
1148     uint32_t FileSize;
1149     uint32_t UncompressedFileSize;
1150     uint64_t Hash;
1151   };
1152 
1153   struct V3Header {
1154     CommonFields Common;
1155     uint64_t FileSize;
1156     uint64_t UncompressedFileSize;
1157     uint64_t Hash;
1158   };
1159 
1160   CommonFields Common;
1161   V1Header V1;
1162   V2Header V2;
1163   V3Header V3;
1164 };
1165 LLVM_PACKED_END
1166 
1167 // Helper method to get header size based on version
getHeaderSize(uint16_t Version)1168 static size_t getHeaderSize(uint16_t Version) {
1169   switch (Version) {
1170   case 1:
1171     return sizeof(RawCompressedBundleHeader::V1Header);
1172   case 2:
1173     return sizeof(RawCompressedBundleHeader::V2Header);
1174   case 3:
1175     return sizeof(RawCompressedBundleHeader::V3Header);
1176   default:
1177     llvm_unreachable("Unsupported version");
1178   }
1179 }
1180 
1181 Expected<CompressedOffloadBundle::CompressedBundleHeader>
tryParse(StringRef Blob)1182 CompressedOffloadBundle::CompressedBundleHeader::tryParse(StringRef Blob) {
1183   assert(Blob.size() >= sizeof(RawCompressedBundleHeader::CommonFields));
1184   assert(llvm::identify_magic(Blob) ==
1185          llvm::file_magic::offload_bundle_compressed);
1186 
1187   RawCompressedBundleHeader Header;
1188   memcpy(&Header, Blob.data(), std::min(Blob.size(), sizeof(Header)));
1189 
1190   CompressedBundleHeader Normalized;
1191   Normalized.Version = Header.Common.Version;
1192 
1193   size_t RequiredSize = getHeaderSize(Normalized.Version);
1194   if (Blob.size() < RequiredSize)
1195     return createStringError(inconvertibleErrorCode(),
1196                              "Compressed bundle header size too small");
1197 
1198   switch (Normalized.Version) {
1199   case 1:
1200     Normalized.UncompressedFileSize = Header.V1.UncompressedFileSize;
1201     Normalized.Hash = Header.V1.Hash;
1202     break;
1203   case 2:
1204     Normalized.FileSize = Header.V2.FileSize;
1205     Normalized.UncompressedFileSize = Header.V2.UncompressedFileSize;
1206     Normalized.Hash = Header.V2.Hash;
1207     break;
1208   case 3:
1209     Normalized.FileSize = Header.V3.FileSize;
1210     Normalized.UncompressedFileSize = Header.V3.UncompressedFileSize;
1211     Normalized.Hash = Header.V3.Hash;
1212     break;
1213   default:
1214     return createStringError(inconvertibleErrorCode(),
1215                              "Unknown compressed bundle version");
1216   }
1217 
1218   // Determine compression format
1219   switch (Header.Common.Method) {
1220   case static_cast<uint16_t>(compression::Format::Zlib):
1221   case static_cast<uint16_t>(compression::Format::Zstd):
1222     Normalized.CompressionFormat =
1223         static_cast<compression::Format>(Header.Common.Method);
1224     break;
1225   default:
1226     return createStringError(inconvertibleErrorCode(),
1227                              "Unknown compressing method");
1228   }
1229 
1230   return Normalized;
1231 }
1232 
1233 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
decompress(const llvm::MemoryBuffer & Input,bool Verbose)1234 CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
1235                                     bool Verbose) {
1236   StringRef Blob = Input.getBuffer();
1237 
1238   // Check minimum header size (using V1 as it's the smallest)
1239   if (Blob.size() < sizeof(RawCompressedBundleHeader::CommonFields))
1240     return llvm::MemoryBuffer::getMemBufferCopy(Blob);
1241 
1242   if (llvm::identify_magic(Blob) !=
1243       llvm::file_magic::offload_bundle_compressed) {
1244     if (Verbose)
1245       llvm::errs() << "Uncompressed bundle.\n";
1246     return llvm::MemoryBuffer::getMemBufferCopy(Blob);
1247   }
1248 
1249   Expected<CompressedBundleHeader> HeaderOrErr =
1250       CompressedBundleHeader::tryParse(Blob);
1251   if (!HeaderOrErr)
1252     return HeaderOrErr.takeError();
1253 
1254   const CompressedBundleHeader &Normalized = *HeaderOrErr;
1255   unsigned ThisVersion = Normalized.Version;
1256   size_t HeaderSize = getHeaderSize(ThisVersion);
1257 
1258   llvm::compression::Format CompressionFormat = Normalized.CompressionFormat;
1259 
1260   size_t TotalFileSize = Normalized.FileSize.value_or(0);
1261   size_t UncompressedSize = Normalized.UncompressedFileSize;
1262   auto StoredHash = Normalized.Hash;
1263 
1264   llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
1265                               *ClangOffloadBundlerTimerGroup);
1266   if (Verbose)
1267     DecompressTimer.startTimer();
1268 
1269   SmallVector<uint8_t, 0> DecompressedData;
1270   StringRef CompressedData = Blob.substr(HeaderSize);
1271   if (llvm::Error DecompressionError = llvm::compression::decompress(
1272           CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),
1273           DecompressedData, UncompressedSize))
1274     return createStringError(inconvertibleErrorCode(),
1275                              "Could not decompress embedded file contents: " +
1276                                  llvm::toString(std::move(DecompressionError)));
1277 
1278   if (Verbose) {
1279     DecompressTimer.stopTimer();
1280 
1281     double DecompressionTimeSeconds =
1282         DecompressTimer.getTotalTime().getWallTime();
1283 
1284     // Recalculate MD5 hash for integrity check
1285     llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
1286                                 "Hash recalculation time",
1287                                 *ClangOffloadBundlerTimerGroup);
1288     HashRecalcTimer.startTimer();
1289     llvm::MD5 Hash;
1290     llvm::MD5::MD5Result Result;
1291     Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData));
1292     Hash.final(Result);
1293     uint64_t RecalculatedHash = Result.low();
1294     HashRecalcTimer.stopTimer();
1295     bool HashMatch = (StoredHash == RecalculatedHash);
1296 
1297     double CompressionRate =
1298         static_cast<double>(UncompressedSize) / CompressedData.size();
1299     double DecompressionSpeedMBs =
1300         (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
1301 
1302     llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
1303     if (ThisVersion >= 2)
1304       llvm::errs() << "Total file size (from header): "
1305                    << formatWithCommas(TotalFileSize) << " bytes\n";
1306     llvm::errs() << "Decompression method: "
1307                  << (CompressionFormat == llvm::compression::Format::Zlib
1308                          ? "zlib"
1309                          : "zstd")
1310                  << "\n"
1311                  << "Size before decompression: "
1312                  << formatWithCommas(CompressedData.size()) << " bytes\n"
1313                  << "Size after decompression: "
1314                  << formatWithCommas(UncompressedSize) << " bytes\n"
1315                  << "Compression rate: "
1316                  << llvm::format("%.2lf", CompressionRate) << "\n"
1317                  << "Compression ratio: "
1318                  << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
1319                  << "Decompression speed: "
1320                  << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
1321                  << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"
1322                  << "Recalculated hash: "
1323                  << llvm::format_hex(RecalculatedHash, 16) << "\n"
1324                  << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
1325   }
1326 
1327   return llvm::MemoryBuffer::getMemBufferCopy(
1328       llvm::toStringRef(DecompressedData));
1329 }
1330 
1331 // List bundle IDs. Return true if an error was found.
ListBundleIDsInFile(StringRef InputFileName,const OffloadBundlerConfig & BundlerConfig)1332 Error OffloadBundler::ListBundleIDsInFile(
1333     StringRef InputFileName, const OffloadBundlerConfig &BundlerConfig) {
1334   // Open Input file.
1335   ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
1336       MemoryBuffer::getFileOrSTDIN(InputFileName, /*IsText=*/true);
1337   if (std::error_code EC = CodeOrErr.getError())
1338     return createFileError(InputFileName, EC);
1339 
1340   // Decompress the input if necessary.
1341   Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
1342       CompressedOffloadBundle::decompress(**CodeOrErr, BundlerConfig.Verbose);
1343   if (!DecompressedBufferOrErr)
1344     return createStringError(
1345         inconvertibleErrorCode(),
1346         "Failed to decompress input: " +
1347             llvm::toString(DecompressedBufferOrErr.takeError()));
1348 
1349   MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
1350 
1351   // Select the right files handler.
1352   Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1353       CreateFileHandler(DecompressedInput, BundlerConfig);
1354   if (!FileHandlerOrErr)
1355     return FileHandlerOrErr.takeError();
1356 
1357   std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
1358   assert(FH);
1359   return FH->listBundleIDs(DecompressedInput);
1360 }
1361 
1362 /// @brief Checks if a code object \p CodeObjectInfo is compatible with a given
1363 /// target \p TargetInfo.
1364 /// @link https://clang.llvm.org/docs/ClangOffloadBundler.html#bundle-entry-id
isCodeObjectCompatible(const OffloadTargetInfo & CodeObjectInfo,const OffloadTargetInfo & TargetInfo)1365 bool isCodeObjectCompatible(const OffloadTargetInfo &CodeObjectInfo,
1366                             const OffloadTargetInfo &TargetInfo) {
1367 
1368   // Compatible in case of exact match.
1369   if (CodeObjectInfo == TargetInfo) {
1370     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1371                     dbgs() << "Compatible: Exact match: \t[CodeObject: "
1372                            << CodeObjectInfo.str()
1373                            << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1374     return true;
1375   }
1376 
1377   // Incompatible if Kinds or Triples mismatch.
1378   if (!CodeObjectInfo.isOffloadKindCompatible(TargetInfo.OffloadKind) ||
1379       !CodeObjectInfo.Triple.isCompatibleWith(TargetInfo.Triple)) {
1380     DEBUG_WITH_TYPE(
1381         "CodeObjectCompatibility",
1382         dbgs() << "Incompatible: Kind/Triple mismatch \t[CodeObject: "
1383                << CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
1384                << "]\n");
1385     return false;
1386   }
1387 
1388   // Incompatible if Processors mismatch.
1389   llvm::StringMap<bool> CodeObjectFeatureMap, TargetFeatureMap;
1390   std::optional<StringRef> CodeObjectProc = clang::parseTargetID(
1391       CodeObjectInfo.Triple, CodeObjectInfo.TargetID, &CodeObjectFeatureMap);
1392   std::optional<StringRef> TargetProc = clang::parseTargetID(
1393       TargetInfo.Triple, TargetInfo.TargetID, &TargetFeatureMap);
1394 
1395   // Both TargetProc and CodeObjectProc can't be empty here.
1396   if (!TargetProc || !CodeObjectProc ||
1397       CodeObjectProc.value() != TargetProc.value()) {
1398     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1399                     dbgs() << "Incompatible: Processor mismatch \t[CodeObject: "
1400                            << CodeObjectInfo.str()
1401                            << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1402     return false;
1403   }
1404 
1405   // Incompatible if CodeObject has more features than Target, irrespective of
1406   // type or sign of features.
1407   if (CodeObjectFeatureMap.getNumItems() > TargetFeatureMap.getNumItems()) {
1408     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1409                     dbgs() << "Incompatible: CodeObject has more features "
1410                               "than target \t[CodeObject: "
1411                            << CodeObjectInfo.str()
1412                            << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1413     return false;
1414   }
1415 
1416   // Compatible if each target feature specified by target is compatible with
1417   // target feature of code object. The target feature is compatible if the
1418   // code object does not specify it (meaning Any), or if it specifies it
1419   // with the same value (meaning On or Off).
1420   for (const auto &CodeObjectFeature : CodeObjectFeatureMap) {
1421     auto TargetFeature = TargetFeatureMap.find(CodeObjectFeature.getKey());
1422     if (TargetFeature == TargetFeatureMap.end()) {
1423       DEBUG_WITH_TYPE(
1424           "CodeObjectCompatibility",
1425           dbgs()
1426               << "Incompatible: Value of CodeObject's non-ANY feature is "
1427                  "not matching with Target feature's ANY value \t[CodeObject: "
1428               << CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
1429               << "]\n");
1430       return false;
1431     } else if (TargetFeature->getValue() != CodeObjectFeature.getValue()) {
1432       DEBUG_WITH_TYPE(
1433           "CodeObjectCompatibility",
1434           dbgs() << "Incompatible: Value of CodeObject's non-ANY feature is "
1435                     "not matching with Target feature's non-ANY value "
1436                     "\t[CodeObject: "
1437                  << CodeObjectInfo.str()
1438                  << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1439       return false;
1440     }
1441   }
1442 
1443   // CodeObject is compatible if all features of Target are:
1444   //   - either, present in the Code Object's features map with the same sign,
1445   //   - or, the feature is missing from CodeObjects's features map i.e. it is
1446   //   set to ANY
1447   DEBUG_WITH_TYPE(
1448       "CodeObjectCompatibility",
1449       dbgs() << "Compatible: Target IDs are compatible \t[CodeObject: "
1450              << CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
1451              << "]\n");
1452   return true;
1453 }
1454 
1455 /// Bundle the files. Return true if an error was found.
BundleFiles()1456 Error OffloadBundler::BundleFiles() {
1457   std::error_code EC;
1458 
1459   // Create a buffer to hold the content before compressing.
1460   SmallVector<char, 0> Buffer;
1461   llvm::raw_svector_ostream BufferStream(Buffer);
1462 
1463   // Open input files.
1464   SmallVector<std::unique_ptr<MemoryBuffer>, 8u> InputBuffers;
1465   InputBuffers.reserve(BundlerConfig.InputFileNames.size());
1466   for (auto &I : BundlerConfig.InputFileNames) {
1467     ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
1468         MemoryBuffer::getFileOrSTDIN(I, /*IsText=*/true);
1469     if (std::error_code EC = CodeOrErr.getError())
1470       return createFileError(I, EC);
1471     InputBuffers.emplace_back(std::move(*CodeOrErr));
1472   }
1473 
1474   // Get the file handler. We use the host buffer as reference.
1475   assert((BundlerConfig.HostInputIndex != ~0u || BundlerConfig.AllowNoHost) &&
1476          "Host input index undefined??");
1477   Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr = CreateFileHandler(
1478       *InputBuffers[BundlerConfig.AllowNoHost ? 0
1479                                               : BundlerConfig.HostInputIndex],
1480       BundlerConfig);
1481   if (!FileHandlerOrErr)
1482     return FileHandlerOrErr.takeError();
1483 
1484   std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
1485   assert(FH);
1486 
1487   // Write header.
1488   if (Error Err = FH->WriteHeader(BufferStream, InputBuffers))
1489     return Err;
1490 
1491   // Write all bundles along with the start/end markers. If an error was found
1492   // writing the end of the bundle component, abort the bundle writing.
1493   auto Input = InputBuffers.begin();
1494   for (auto &Triple : BundlerConfig.TargetNames) {
1495     if (Error Err = FH->WriteBundleStart(BufferStream, Triple))
1496       return Err;
1497     if (Error Err = FH->WriteBundle(BufferStream, **Input))
1498       return Err;
1499     if (Error Err = FH->WriteBundleEnd(BufferStream, Triple))
1500       return Err;
1501     ++Input;
1502   }
1503 
1504   raw_fd_ostream OutputFile(BundlerConfig.OutputFileNames.front(), EC,
1505                             sys::fs::OF_None);
1506   if (EC)
1507     return createFileError(BundlerConfig.OutputFileNames.front(), EC);
1508 
1509   SmallVector<char, 0> CompressedBuffer;
1510   if (BundlerConfig.Compress) {
1511     std::unique_ptr<llvm::MemoryBuffer> BufferMemory =
1512         llvm::MemoryBuffer::getMemBufferCopy(
1513             llvm::StringRef(Buffer.data(), Buffer.size()));
1514     auto CompressionResult = CompressedOffloadBundle::compress(
1515         {BundlerConfig.CompressionFormat, BundlerConfig.CompressionLevel,
1516          /*zstdEnableLdm=*/true},
1517         *BufferMemory, BundlerConfig.CompressedBundleVersion,
1518         BundlerConfig.Verbose);
1519     if (auto Error = CompressionResult.takeError())
1520       return Error;
1521 
1522     auto CompressedMemBuffer = std::move(CompressionResult.get());
1523     CompressedBuffer.assign(CompressedMemBuffer->getBufferStart(),
1524                             CompressedMemBuffer->getBufferEnd());
1525   } else
1526     CompressedBuffer = Buffer;
1527 
1528   OutputFile.write(CompressedBuffer.data(), CompressedBuffer.size());
1529 
1530   return FH->finalizeOutputFile();
1531 }
1532 
1533 // Unbundle the files. Return true if an error was found.
UnbundleFiles()1534 Error OffloadBundler::UnbundleFiles() {
1535   // Open Input file.
1536   ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
1537       MemoryBuffer::getFileOrSTDIN(BundlerConfig.InputFileNames.front(),
1538                                    /*IsText=*/true);
1539   if (std::error_code EC = CodeOrErr.getError())
1540     return createFileError(BundlerConfig.InputFileNames.front(), EC);
1541 
1542   // Decompress the input if necessary.
1543   Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
1544       CompressedOffloadBundle::decompress(**CodeOrErr, BundlerConfig.Verbose);
1545   if (!DecompressedBufferOrErr)
1546     return createStringError(
1547         inconvertibleErrorCode(),
1548         "Failed to decompress input: " +
1549             llvm::toString(DecompressedBufferOrErr.takeError()));
1550 
1551   MemoryBuffer &Input = **DecompressedBufferOrErr;
1552 
1553   // Select the right files handler.
1554   Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1555       CreateFileHandler(Input, BundlerConfig);
1556   if (!FileHandlerOrErr)
1557     return FileHandlerOrErr.takeError();
1558 
1559   std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
1560   assert(FH);
1561 
1562   // Read the header of the bundled file.
1563   if (Error Err = FH->ReadHeader(Input))
1564     return Err;
1565 
1566   // Create a work list that consist of the map triple/output file.
1567   StringMap<StringRef> Worklist;
1568   auto Output = BundlerConfig.OutputFileNames.begin();
1569   for (auto &Triple : BundlerConfig.TargetNames) {
1570     if (!checkOffloadBundleID(Triple))
1571       return createStringError(errc::invalid_argument,
1572                                "invalid bundle id from bundle config");
1573     Worklist[Triple] = *Output;
1574     ++Output;
1575   }
1576 
1577   // Read all the bundles that are in the work list. If we find no bundles we
1578   // assume the file is meant for the host target.
1579   bool FoundHostBundle = false;
1580   while (!Worklist.empty()) {
1581     Expected<std::optional<StringRef>> CurTripleOrErr =
1582         FH->ReadBundleStart(Input);
1583     if (!CurTripleOrErr)
1584       return CurTripleOrErr.takeError();
1585 
1586     // We don't have more bundles.
1587     if (!*CurTripleOrErr)
1588       break;
1589 
1590     StringRef CurTriple = **CurTripleOrErr;
1591     assert(!CurTriple.empty());
1592     if (!checkOffloadBundleID(CurTriple))
1593       return createStringError(errc::invalid_argument,
1594                                "invalid bundle id read from the bundle");
1595 
1596     auto Output = Worklist.begin();
1597     for (auto E = Worklist.end(); Output != E; Output++) {
1598       if (isCodeObjectCompatible(
1599               OffloadTargetInfo(CurTriple, BundlerConfig),
1600               OffloadTargetInfo((*Output).first(), BundlerConfig))) {
1601         break;
1602       }
1603     }
1604 
1605     if (Output == Worklist.end())
1606       continue;
1607     // Check if the output file can be opened and copy the bundle to it.
1608     std::error_code EC;
1609     raw_fd_ostream OutputFile((*Output).second, EC, sys::fs::OF_None);
1610     if (EC)
1611       return createFileError((*Output).second, EC);
1612     if (Error Err = FH->ReadBundle(OutputFile, Input))
1613       return Err;
1614     if (Error Err = FH->ReadBundleEnd(Input))
1615       return Err;
1616     Worklist.erase(Output);
1617 
1618     // Record if we found the host bundle.
1619     auto OffloadInfo = OffloadTargetInfo(CurTriple, BundlerConfig);
1620     if (OffloadInfo.hasHostKind())
1621       FoundHostBundle = true;
1622   }
1623 
1624   if (!BundlerConfig.AllowMissingBundles && !Worklist.empty()) {
1625     std::string ErrMsg = "Can't find bundles for";
1626     std::set<StringRef> Sorted;
1627     for (auto &E : Worklist)
1628       Sorted.insert(E.first());
1629     unsigned I = 0;
1630     unsigned Last = Sorted.size() - 1;
1631     for (auto &E : Sorted) {
1632       if (I != 0 && Last > 1)
1633         ErrMsg += ",";
1634       ErrMsg += " ";
1635       if (I == Last && I != 0)
1636         ErrMsg += "and ";
1637       ErrMsg += E.str();
1638       ++I;
1639     }
1640     return createStringError(inconvertibleErrorCode(), ErrMsg);
1641   }
1642 
1643   // If no bundles were found, assume the input file is the host bundle and
1644   // create empty files for the remaining targets.
1645   if (Worklist.size() == BundlerConfig.TargetNames.size()) {
1646     for (auto &E : Worklist) {
1647       std::error_code EC;
1648       raw_fd_ostream OutputFile(E.second, EC, sys::fs::OF_None);
1649       if (EC)
1650         return createFileError(E.second, EC);
1651 
1652       // If this entry has a host kind, copy the input file to the output file.
1653       // We don't need to check E.getKey() here through checkOffloadBundleID
1654       // because the entire WorkList has been checked above.
1655       auto OffloadInfo = OffloadTargetInfo(E.getKey(), BundlerConfig);
1656       if (OffloadInfo.hasHostKind())
1657         OutputFile.write(Input.getBufferStart(), Input.getBufferSize());
1658     }
1659     return Error::success();
1660   }
1661 
1662   // If we found elements, we emit an error if none of those were for the host
1663   // in case host bundle name was provided in command line.
1664   if (!(FoundHostBundle || BundlerConfig.HostInputIndex == ~0u ||
1665         BundlerConfig.AllowMissingBundles))
1666     return createStringError(inconvertibleErrorCode(),
1667                              "Can't find bundle for the host target");
1668 
1669   // If we still have any elements in the worklist, create empty files for them.
1670   for (auto &E : Worklist) {
1671     std::error_code EC;
1672     raw_fd_ostream OutputFile(E.second, EC, sys::fs::OF_None);
1673     if (EC)
1674       return createFileError(E.second, EC);
1675   }
1676 
1677   return Error::success();
1678 }
1679 
getDefaultArchiveKindForHost()1680 static Archive::Kind getDefaultArchiveKindForHost() {
1681   return Triple(sys::getDefaultTargetTriple()).isOSDarwin() ? Archive::K_DARWIN
1682                                                             : Archive::K_GNU;
1683 }
1684 
1685 /// @brief Computes a list of targets among all given targets which are
1686 /// compatible with this code object
1687 /// @param [in] CodeObjectInfo Code Object
1688 /// @param [out] CompatibleTargets List of all compatible targets among all
1689 /// given targets
1690 /// @return false, if no compatible target is found.
1691 static bool
getCompatibleOffloadTargets(OffloadTargetInfo & CodeObjectInfo,SmallVectorImpl<StringRef> & CompatibleTargets,const OffloadBundlerConfig & BundlerConfig)1692 getCompatibleOffloadTargets(OffloadTargetInfo &CodeObjectInfo,
1693                             SmallVectorImpl<StringRef> &CompatibleTargets,
1694                             const OffloadBundlerConfig &BundlerConfig) {
1695   if (!CompatibleTargets.empty()) {
1696     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1697                     dbgs() << "CompatibleTargets list should be empty\n");
1698     return false;
1699   }
1700   for (auto &Target : BundlerConfig.TargetNames) {
1701     auto TargetInfo = OffloadTargetInfo(Target, BundlerConfig);
1702     if (isCodeObjectCompatible(CodeObjectInfo, TargetInfo))
1703       CompatibleTargets.push_back(Target);
1704   }
1705   return !CompatibleTargets.empty();
1706 }
1707 
1708 // Check that each code object file in the input archive conforms to following
1709 // rule: for a specific processor, a feature either shows up in all target IDs,
1710 // or does not show up in any target IDs. Otherwise the target ID combination is
1711 // invalid.
1712 static Error
CheckHeterogeneousArchive(StringRef ArchiveName,const OffloadBundlerConfig & BundlerConfig)1713 CheckHeterogeneousArchive(StringRef ArchiveName,
1714                           const OffloadBundlerConfig &BundlerConfig) {
1715   std::vector<std::unique_ptr<MemoryBuffer>> ArchiveBuffers;
1716   ErrorOr<std::unique_ptr<MemoryBuffer>> BufOrErr =
1717       MemoryBuffer::getFileOrSTDIN(ArchiveName, true, false);
1718   if (std::error_code EC = BufOrErr.getError())
1719     return createFileError(ArchiveName, EC);
1720 
1721   ArchiveBuffers.push_back(std::move(*BufOrErr));
1722   Expected<std::unique_ptr<llvm::object::Archive>> LibOrErr =
1723       Archive::create(ArchiveBuffers.back()->getMemBufferRef());
1724   if (!LibOrErr)
1725     return LibOrErr.takeError();
1726 
1727   auto Archive = std::move(*LibOrErr);
1728 
1729   Error ArchiveErr = Error::success();
1730   auto ChildEnd = Archive->child_end();
1731 
1732   /// Iterate over all bundled code object files in the input archive.
1733   for (auto ArchiveIter = Archive->child_begin(ArchiveErr);
1734        ArchiveIter != ChildEnd; ++ArchiveIter) {
1735     if (ArchiveErr)
1736       return ArchiveErr;
1737     auto ArchiveChildNameOrErr = (*ArchiveIter).getName();
1738     if (!ArchiveChildNameOrErr)
1739       return ArchiveChildNameOrErr.takeError();
1740 
1741     auto CodeObjectBufferRefOrErr = (*ArchiveIter).getMemoryBufferRef();
1742     if (!CodeObjectBufferRefOrErr)
1743       return CodeObjectBufferRefOrErr.takeError();
1744 
1745     auto CodeObjectBuffer =
1746         MemoryBuffer::getMemBuffer(*CodeObjectBufferRefOrErr, false);
1747 
1748     Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1749         CreateFileHandler(*CodeObjectBuffer, BundlerConfig);
1750     if (!FileHandlerOrErr)
1751       return FileHandlerOrErr.takeError();
1752 
1753     std::unique_ptr<FileHandler> &FileHandler = *FileHandlerOrErr;
1754     assert(FileHandler);
1755 
1756     std::set<StringRef> BundleIds;
1757     auto CodeObjectFileError =
1758         FileHandler->getBundleIDs(*CodeObjectBuffer, BundleIds);
1759     if (CodeObjectFileError)
1760       return CodeObjectFileError;
1761 
1762     auto &&ConflictingArchs = clang::getConflictTargetIDCombination(BundleIds);
1763     if (ConflictingArchs) {
1764       std::string ErrMsg =
1765           Twine("conflicting TargetIDs [" + ConflictingArchs.value().first +
1766                 ", " + ConflictingArchs.value().second + "] found in " +
1767                 ArchiveChildNameOrErr.get() + " of " + ArchiveName)
1768               .str();
1769       return createStringError(inconvertibleErrorCode(), ErrMsg);
1770     }
1771   }
1772 
1773   return ArchiveErr;
1774 }
1775 
1776 /// UnbundleArchive takes an archive file (".a") as input containing bundled
1777 /// code object files, and a list of offload targets (not host), and extracts
1778 /// the code objects into a new archive file for each offload target. Each
1779 /// resulting archive file contains all code object files corresponding to that
1780 /// particular offload target. The created archive file does not
1781 /// contain an index of the symbols and code object files are named as
1782 /// <<Parent Bundle Name>-<CodeObject's TargetID>>, with ':' replaced with '_'.
UnbundleArchive()1783 Error OffloadBundler::UnbundleArchive() {
1784   std::vector<std::unique_ptr<MemoryBuffer>> ArchiveBuffers;
1785 
1786   /// Map of target names with list of object files that will form the device
1787   /// specific archive for that target
1788   StringMap<std::vector<NewArchiveMember>> OutputArchivesMap;
1789 
1790   // Map of target names and output archive filenames
1791   StringMap<StringRef> TargetOutputFileNameMap;
1792 
1793   auto Output = BundlerConfig.OutputFileNames.begin();
1794   for (auto &Target : BundlerConfig.TargetNames) {
1795     TargetOutputFileNameMap[Target] = *Output;
1796     ++Output;
1797   }
1798 
1799   StringRef IFName = BundlerConfig.InputFileNames.front();
1800 
1801   if (BundlerConfig.CheckInputArchive) {
1802     // For a specific processor, a feature either shows up in all target IDs, or
1803     // does not show up in any target IDs. Otherwise the target ID combination
1804     // is invalid.
1805     auto ArchiveError = CheckHeterogeneousArchive(IFName, BundlerConfig);
1806     if (ArchiveError) {
1807       return ArchiveError;
1808     }
1809   }
1810 
1811   ErrorOr<std::unique_ptr<MemoryBuffer>> BufOrErr =
1812       MemoryBuffer::getFileOrSTDIN(IFName, true, false);
1813   if (std::error_code EC = BufOrErr.getError())
1814     return createFileError(BundlerConfig.InputFileNames.front(), EC);
1815 
1816   ArchiveBuffers.push_back(std::move(*BufOrErr));
1817   Expected<std::unique_ptr<llvm::object::Archive>> LibOrErr =
1818       Archive::create(ArchiveBuffers.back()->getMemBufferRef());
1819   if (!LibOrErr)
1820     return LibOrErr.takeError();
1821 
1822   auto Archive = std::move(*LibOrErr);
1823 
1824   Error ArchiveErr = Error::success();
1825   auto ChildEnd = Archive->child_end();
1826 
1827   /// Iterate over all bundled code object files in the input archive.
1828   for (auto ArchiveIter = Archive->child_begin(ArchiveErr);
1829        ArchiveIter != ChildEnd; ++ArchiveIter) {
1830     if (ArchiveErr)
1831       return ArchiveErr;
1832     auto ArchiveChildNameOrErr = (*ArchiveIter).getName();
1833     if (!ArchiveChildNameOrErr)
1834       return ArchiveChildNameOrErr.takeError();
1835 
1836     StringRef BundledObjectFile = sys::path::filename(*ArchiveChildNameOrErr);
1837 
1838     auto CodeObjectBufferRefOrErr = (*ArchiveIter).getMemoryBufferRef();
1839     if (!CodeObjectBufferRefOrErr)
1840       return CodeObjectBufferRefOrErr.takeError();
1841 
1842     auto TempCodeObjectBuffer =
1843         MemoryBuffer::getMemBuffer(*CodeObjectBufferRefOrErr, false);
1844 
1845     // Decompress the buffer if necessary.
1846     Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
1847         CompressedOffloadBundle::decompress(*TempCodeObjectBuffer,
1848                                             BundlerConfig.Verbose);
1849     if (!DecompressedBufferOrErr)
1850       return createStringError(
1851           inconvertibleErrorCode(),
1852           "Failed to decompress code object: " +
1853               llvm::toString(DecompressedBufferOrErr.takeError()));
1854 
1855     MemoryBuffer &CodeObjectBuffer = **DecompressedBufferOrErr;
1856 
1857     Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1858         CreateFileHandler(CodeObjectBuffer, BundlerConfig);
1859     if (!FileHandlerOrErr)
1860       return FileHandlerOrErr.takeError();
1861 
1862     std::unique_ptr<FileHandler> &FileHandler = *FileHandlerOrErr;
1863     assert(FileHandler &&
1864            "FileHandle creation failed for file in the archive!");
1865 
1866     if (Error ReadErr = FileHandler->ReadHeader(CodeObjectBuffer))
1867       return ReadErr;
1868 
1869     Expected<std::optional<StringRef>> CurBundleIDOrErr =
1870         FileHandler->ReadBundleStart(CodeObjectBuffer);
1871     if (!CurBundleIDOrErr)
1872       return CurBundleIDOrErr.takeError();
1873 
1874     std::optional<StringRef> OptionalCurBundleID = *CurBundleIDOrErr;
1875     // No device code in this child, skip.
1876     if (!OptionalCurBundleID)
1877       continue;
1878     StringRef CodeObject = *OptionalCurBundleID;
1879 
1880     // Process all bundle entries (CodeObjects) found in this child of input
1881     // archive.
1882     while (!CodeObject.empty()) {
1883       SmallVector<StringRef> CompatibleTargets;
1884       if (!checkOffloadBundleID(CodeObject)) {
1885         return createStringError(errc::invalid_argument,
1886                                  "Invalid bundle id read from code object");
1887       }
1888       auto CodeObjectInfo = OffloadTargetInfo(CodeObject, BundlerConfig);
1889       if (getCompatibleOffloadTargets(CodeObjectInfo, CompatibleTargets,
1890                                       BundlerConfig)) {
1891         std::string BundleData;
1892         raw_string_ostream DataStream(BundleData);
1893         if (Error Err = FileHandler->ReadBundle(DataStream, CodeObjectBuffer))
1894           return Err;
1895 
1896         for (auto &CompatibleTarget : CompatibleTargets) {
1897           SmallString<128> BundledObjectFileName;
1898           BundledObjectFileName.assign(BundledObjectFile);
1899           auto OutputBundleName =
1900               Twine(llvm::sys::path::stem(BundledObjectFileName) + "-" +
1901                     CodeObject +
1902                     getDeviceLibraryFileName(BundledObjectFileName,
1903                                              CodeObjectInfo.TargetID))
1904                   .str();
1905           // Replace ':' in optional target feature list with '_' to ensure
1906           // cross-platform validity.
1907           llvm::replace(OutputBundleName, ':', '_');
1908 
1909           std::unique_ptr<MemoryBuffer> MemBuf = MemoryBuffer::getMemBufferCopy(
1910               DataStream.str(), OutputBundleName);
1911           ArchiveBuffers.push_back(std::move(MemBuf));
1912           llvm::MemoryBufferRef MemBufRef =
1913               MemoryBufferRef(*(ArchiveBuffers.back()));
1914 
1915           // For inserting <CompatibleTarget, list<CodeObject>> entry in
1916           // OutputArchivesMap.
1917           OutputArchivesMap[CompatibleTarget].push_back(
1918               NewArchiveMember(MemBufRef));
1919         }
1920       }
1921 
1922       if (Error Err = FileHandler->ReadBundleEnd(CodeObjectBuffer))
1923         return Err;
1924 
1925       Expected<std::optional<StringRef>> NextTripleOrErr =
1926           FileHandler->ReadBundleStart(CodeObjectBuffer);
1927       if (!NextTripleOrErr)
1928         return NextTripleOrErr.takeError();
1929 
1930       CodeObject = ((*NextTripleOrErr).has_value()) ? **NextTripleOrErr : "";
1931     } // End of processing of all bundle entries of this child of input archive.
1932   }   // End of while over children of input archive.
1933 
1934   assert(!ArchiveErr && "Error occurred while reading archive!");
1935 
1936   /// Write out an archive for each target
1937   for (auto &Target : BundlerConfig.TargetNames) {
1938     StringRef FileName = TargetOutputFileNameMap[Target];
1939     StringMapIterator<std::vector<llvm::NewArchiveMember>> CurArchiveMembers =
1940         OutputArchivesMap.find(Target);
1941     if (CurArchiveMembers != OutputArchivesMap.end()) {
1942       if (Error WriteErr = writeArchive(FileName, CurArchiveMembers->getValue(),
1943                                         SymtabWritingMode::NormalSymtab,
1944                                         getDefaultArchiveKindForHost(), true,
1945                                         false, nullptr))
1946         return WriteErr;
1947     } else if (!BundlerConfig.AllowMissingBundles) {
1948       std::string ErrMsg =
1949           Twine("no compatible code object found for the target '" + Target +
1950                 "' in heterogeneous archive library: " + IFName)
1951               .str();
1952       return createStringError(inconvertibleErrorCode(), ErrMsg);
1953     } else { // Create an empty archive file if no compatible code object is
1954              // found and "allow-missing-bundles" is enabled. It ensures that
1955              // the linker using output of this step doesn't complain about
1956              // the missing input file.
1957       std::vector<llvm::NewArchiveMember> EmptyArchive;
1958       EmptyArchive.clear();
1959       if (Error WriteErr = writeArchive(
1960               FileName, EmptyArchive, SymtabWritingMode::NormalSymtab,
1961               getDefaultArchiveKindForHost(), true, false, nullptr))
1962         return WriteErr;
1963     }
1964   }
1965 
1966   return Error::success();
1967 }
1968 
checkOffloadBundleID(const llvm::StringRef Str)1969 bool clang::checkOffloadBundleID(const llvm::StringRef Str) {
1970   // <kind>-<triple>[-<target id>[:target features]]
1971   // <triple> := <arch>-<vendor>-<os>-<env>
1972   SmallVector<StringRef, 6> Components;
1973   Str.split(Components, '-', /*MaxSplit=*/5);
1974   return Components.size() == 5 || Components.size() == 6;
1975 }
1976