xref: /freebsd/contrib/llvm-project/clang/lib/Driver/OffloadBundler.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- OffloadBundler.cpp - File Bundling and Unbundling ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements an offload bundling API that bundles different files
11 /// that relate with the same source code but different targets into a single
12 /// one. Also the implements the opposite functionality, i.e. unbundle files
13 /// previous created by this API.
14 ///
15 //===----------------------------------------------------------------------===//
16 
17 #include "clang/Driver/OffloadBundler.h"
18 #include "clang/Basic/Cuda.h"
19 #include "clang/Basic/TargetID.h"
20 #include "clang/Basic/Version.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/BinaryFormat/Magic.h"
28 #include "llvm/Object/Archive.h"
29 #include "llvm/Object/ArchiveWriter.h"
30 #include "llvm/Object/Binary.h"
31 #include "llvm/Object/ObjectFile.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Compression.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/EndianStream.h"
36 #include "llvm/Support/Errc.h"
37 #include "llvm/Support/Error.h"
38 #include "llvm/Support/ErrorOr.h"
39 #include "llvm/Support/FileSystem.h"
40 #include "llvm/Support/MD5.h"
41 #include "llvm/Support/MemoryBuffer.h"
42 #include "llvm/Support/Path.h"
43 #include "llvm/Support/Program.h"
44 #include "llvm/Support/Signals.h"
45 #include "llvm/Support/StringSaver.h"
46 #include "llvm/Support/Timer.h"
47 #include "llvm/Support/WithColor.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "llvm/TargetParser/Host.h"
50 #include "llvm/TargetParser/Triple.h"
51 #include <algorithm>
52 #include <cassert>
53 #include <cstddef>
54 #include <cstdint>
55 #include <forward_list>
56 #include <llvm/Support/Process.h>
57 #include <memory>
58 #include <set>
59 #include <string>
60 #include <system_error>
61 #include <utility>
62 
63 using namespace llvm;
64 using namespace llvm::object;
65 using namespace clang;
66 
67 static llvm::TimerGroup
68     ClangOffloadBundlerTimerGroup("Clang Offload Bundler Timer Group",
69                                   "Timer group for clang offload bundler");
70 
71 /// Magic string that marks the existence of offloading data.
72 #define OFFLOAD_BUNDLER_MAGIC_STR "__CLANG_OFFLOAD_BUNDLE__"
73 
OffloadTargetInfo(const StringRef Target,const OffloadBundlerConfig & BC)74 OffloadTargetInfo::OffloadTargetInfo(const StringRef Target,
75                                      const OffloadBundlerConfig &BC)
76     : BundlerConfig(BC) {
77 
78   // TODO: Add error checking from ClangOffloadBundler.cpp
79   auto TargetFeatures = Target.split(':');
80   auto TripleOrGPU = TargetFeatures.first.rsplit('-');
81 
82   if (clang::StringToOffloadArch(TripleOrGPU.second) !=
83       clang::OffloadArch::UNKNOWN) {
84     auto KindTriple = TripleOrGPU.first.split('-');
85     this->OffloadKind = KindTriple.first;
86 
87     // Enforce optional env field to standardize bundles
88     llvm::Triple t = llvm::Triple(KindTriple.second);
89     this->Triple = llvm::Triple(t.getArchName(), t.getVendorName(),
90                                 t.getOSName(), t.getEnvironmentName());
91 
92     this->TargetID = Target.substr(Target.find(TripleOrGPU.second));
93   } else {
94     auto KindTriple = TargetFeatures.first.split('-');
95     this->OffloadKind = KindTriple.first;
96 
97     // Enforce optional env field to standardize bundles
98     llvm::Triple t = llvm::Triple(KindTriple.second);
99     this->Triple = llvm::Triple(t.getArchName(), t.getVendorName(),
100                                 t.getOSName(), t.getEnvironmentName());
101 
102     this->TargetID = "";
103   }
104 }
105 
hasHostKind() const106 bool OffloadTargetInfo::hasHostKind() const {
107   return this->OffloadKind == "host";
108 }
109 
isOffloadKindValid() const110 bool OffloadTargetInfo::isOffloadKindValid() const {
111   return OffloadKind == "host" || OffloadKind == "openmp" ||
112          OffloadKind == "hip" || OffloadKind == "hipv4";
113 }
114 
isOffloadKindCompatible(const StringRef TargetOffloadKind) const115 bool OffloadTargetInfo::isOffloadKindCompatible(
116     const StringRef TargetOffloadKind) const {
117   if ((OffloadKind == TargetOffloadKind) ||
118       (OffloadKind == "hip" && TargetOffloadKind == "hipv4") ||
119       (OffloadKind == "hipv4" && TargetOffloadKind == "hip"))
120     return true;
121 
122   if (BundlerConfig.HipOpenmpCompatible) {
123     bool HIPCompatibleWithOpenMP = OffloadKind.starts_with_insensitive("hip") &&
124                                    TargetOffloadKind == "openmp";
125     bool OpenMPCompatibleWithHIP =
126         OffloadKind == "openmp" &&
127         TargetOffloadKind.starts_with_insensitive("hip");
128     return HIPCompatibleWithOpenMP || OpenMPCompatibleWithHIP;
129   }
130   return false;
131 }
132 
isTripleValid() const133 bool OffloadTargetInfo::isTripleValid() const {
134   return !Triple.str().empty() && Triple.getArch() != Triple::UnknownArch;
135 }
136 
operator ==(const OffloadTargetInfo & Target) const137 bool OffloadTargetInfo::operator==(const OffloadTargetInfo &Target) const {
138   return OffloadKind == Target.OffloadKind &&
139          Triple.isCompatibleWith(Target.Triple) && TargetID == Target.TargetID;
140 }
141 
str() const142 std::string OffloadTargetInfo::str() const {
143   return Twine(OffloadKind + "-" + Triple.str() + "-" + TargetID).str();
144 }
145 
getDeviceFileExtension(StringRef Device,StringRef BundleFileName)146 static StringRef getDeviceFileExtension(StringRef Device,
147                                         StringRef BundleFileName) {
148   if (Device.contains("gfx"))
149     return ".bc";
150   if (Device.contains("sm_"))
151     return ".cubin";
152   return sys::path::extension(BundleFileName);
153 }
154 
getDeviceLibraryFileName(StringRef BundleFileName,StringRef Device)155 static std::string getDeviceLibraryFileName(StringRef BundleFileName,
156                                             StringRef Device) {
157   StringRef LibName = sys::path::stem(BundleFileName);
158   StringRef Extension = getDeviceFileExtension(Device, BundleFileName);
159 
160   std::string Result;
161   Result += LibName;
162   Result += Extension;
163   return Result;
164 }
165 
166 namespace {
167 /// Generic file handler interface.
168 class FileHandler {
169 public:
170   struct BundleInfo {
171     StringRef BundleID;
172   };
173 
FileHandler()174   FileHandler() {}
175 
~FileHandler()176   virtual ~FileHandler() {}
177 
178   /// Update the file handler with information from the header of the bundled
179   /// file.
180   virtual Error ReadHeader(MemoryBuffer &Input) = 0;
181 
182   /// Read the marker of the next bundled to be read in the file. The bundle
183   /// name is returned if there is one in the file, or `std::nullopt` if there
184   /// are no more bundles to be read.
185   virtual Expected<std::optional<StringRef>>
186   ReadBundleStart(MemoryBuffer &Input) = 0;
187 
188   /// Read the marker that closes the current bundle.
189   virtual Error ReadBundleEnd(MemoryBuffer &Input) = 0;
190 
191   /// Read the current bundle and write the result into the stream \a OS.
192   virtual Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) = 0;
193 
194   /// Write the header of the bundled file to \a OS based on the information
195   /// gathered from \a Inputs.
196   virtual Error WriteHeader(raw_ostream &OS,
197                             ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) = 0;
198 
199   /// Write the marker that initiates a bundle for the triple \a TargetTriple to
200   /// \a OS.
201   virtual Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) = 0;
202 
203   /// Write the marker that closes a bundle for the triple \a TargetTriple to \a
204   /// OS.
205   virtual Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) = 0;
206 
207   /// Write the bundle from \a Input into \a OS.
208   virtual Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) = 0;
209 
210   /// Finalize output file.
finalizeOutputFile()211   virtual Error finalizeOutputFile() { return Error::success(); }
212 
213   /// List bundle IDs in \a Input.
listBundleIDs(MemoryBuffer & Input)214   virtual Error listBundleIDs(MemoryBuffer &Input) {
215     if (Error Err = ReadHeader(Input))
216       return Err;
217     return forEachBundle(Input, [&](const BundleInfo &Info) -> Error {
218       llvm::outs() << Info.BundleID << '\n';
219       Error Err = listBundleIDsCallback(Input, Info);
220       if (Err)
221         return Err;
222       return Error::success();
223     });
224   }
225 
226   /// Get bundle IDs in \a Input in \a BundleIds.
getBundleIDs(MemoryBuffer & Input,std::set<StringRef> & BundleIds)227   virtual Error getBundleIDs(MemoryBuffer &Input,
228                              std::set<StringRef> &BundleIds) {
229     if (Error Err = ReadHeader(Input))
230       return Err;
231     return forEachBundle(Input, [&](const BundleInfo &Info) -> Error {
232       BundleIds.insert(Info.BundleID);
233       Error Err = listBundleIDsCallback(Input, Info);
234       if (Err)
235         return Err;
236       return Error::success();
237     });
238   }
239 
240   /// For each bundle in \a Input, do \a Func.
forEachBundle(MemoryBuffer & Input,std::function<Error (const BundleInfo &)> Func)241   Error forEachBundle(MemoryBuffer &Input,
242                       std::function<Error(const BundleInfo &)> Func) {
243     while (true) {
244       Expected<std::optional<StringRef>> CurTripleOrErr =
245           ReadBundleStart(Input);
246       if (!CurTripleOrErr)
247         return CurTripleOrErr.takeError();
248 
249       // No more bundles.
250       if (!*CurTripleOrErr)
251         break;
252 
253       StringRef CurTriple = **CurTripleOrErr;
254       assert(!CurTriple.empty());
255 
256       BundleInfo Info{CurTriple};
257       if (Error Err = Func(Info))
258         return Err;
259     }
260     return Error::success();
261   }
262 
263 protected:
listBundleIDsCallback(MemoryBuffer & Input,const BundleInfo & Info)264   virtual Error listBundleIDsCallback(MemoryBuffer &Input,
265                                       const BundleInfo &Info) {
266     return Error::success();
267   }
268 };
269 
270 /// Handler for binary files. The bundled file will have the following format
271 /// (all integers are stored in little-endian format):
272 ///
273 /// "OFFLOAD_BUNDLER_MAGIC_STR" (ASCII encoding of the string)
274 ///
275 /// NumberOfOffloadBundles (8-byte integer)
276 ///
277 /// OffsetOfBundle1 (8-byte integer)
278 /// SizeOfBundle1 (8-byte integer)
279 /// NumberOfBytesInTripleOfBundle1 (8-byte integer)
280 /// TripleOfBundle1 (byte length defined before)
281 ///
282 /// ...
283 ///
284 /// OffsetOfBundleN (8-byte integer)
285 /// SizeOfBundleN (8-byte integer)
286 /// NumberOfBytesInTripleOfBundleN (8-byte integer)
287 /// TripleOfBundleN (byte length defined before)
288 ///
289 /// Bundle1
290 /// ...
291 /// BundleN
292 
293 /// Read 8-byte integers from a buffer in little-endian format.
Read8byteIntegerFromBuffer(StringRef Buffer,size_t pos)294 static uint64_t Read8byteIntegerFromBuffer(StringRef Buffer, size_t pos) {
295   return llvm::support::endian::read64le(Buffer.data() + pos);
296 }
297 
298 /// Write 8-byte integers to a buffer in little-endian format.
Write8byteIntegerToBuffer(raw_ostream & OS,uint64_t Val)299 static void Write8byteIntegerToBuffer(raw_ostream &OS, uint64_t Val) {
300   llvm::support::endian::write(OS, Val, llvm::endianness::little);
301 }
302 
303 class BinaryFileHandler final : public FileHandler {
304   /// Information about the bundles extracted from the header.
305   struct BinaryBundleInfo final : public BundleInfo {
306     /// Size of the bundle.
307     uint64_t Size = 0u;
308     /// Offset at which the bundle starts in the bundled file.
309     uint64_t Offset = 0u;
310 
BinaryBundleInfo__anoncaeeea5c0111::BinaryFileHandler::BinaryBundleInfo311     BinaryBundleInfo() {}
BinaryBundleInfo__anoncaeeea5c0111::BinaryFileHandler::BinaryBundleInfo312     BinaryBundleInfo(uint64_t Size, uint64_t Offset)
313         : Size(Size), Offset(Offset) {}
314   };
315 
316   /// Map between a triple and the corresponding bundle information.
317   StringMap<BinaryBundleInfo> BundlesInfo;
318 
319   /// Iterator for the bundle information that is being read.
320   StringMap<BinaryBundleInfo>::iterator CurBundleInfo;
321   StringMap<BinaryBundleInfo>::iterator NextBundleInfo;
322 
323   /// Current bundle target to be written.
324   std::string CurWriteBundleTarget;
325 
326   /// Configuration options and arrays for this bundler job
327   const OffloadBundlerConfig &BundlerConfig;
328 
329 public:
330   // TODO: Add error checking from ClangOffloadBundler.cpp
BinaryFileHandler(const OffloadBundlerConfig & BC)331   BinaryFileHandler(const OffloadBundlerConfig &BC) : BundlerConfig(BC) {}
332 
~BinaryFileHandler()333   ~BinaryFileHandler() final {}
334 
ReadHeader(MemoryBuffer & Input)335   Error ReadHeader(MemoryBuffer &Input) final {
336     StringRef FC = Input.getBuffer();
337 
338     // Initialize the current bundle with the end of the container.
339     CurBundleInfo = BundlesInfo.end();
340 
341     // Check if buffer is smaller than magic string.
342     size_t ReadChars = sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1;
343     if (ReadChars > FC.size())
344       return Error::success();
345 
346     // Check if no magic was found.
347     if (llvm::identify_magic(FC) != llvm::file_magic::offload_bundle)
348       return Error::success();
349 
350     // Read number of bundles.
351     if (ReadChars + 8 > FC.size())
352       return Error::success();
353 
354     uint64_t NumberOfBundles = Read8byteIntegerFromBuffer(FC, ReadChars);
355     ReadChars += 8;
356 
357     // Read bundle offsets, sizes and triples.
358     for (uint64_t i = 0; i < NumberOfBundles; ++i) {
359 
360       // Read offset.
361       if (ReadChars + 8 > FC.size())
362         return Error::success();
363 
364       uint64_t Offset = Read8byteIntegerFromBuffer(FC, ReadChars);
365       ReadChars += 8;
366 
367       // Read size.
368       if (ReadChars + 8 > FC.size())
369         return Error::success();
370 
371       uint64_t Size = Read8byteIntegerFromBuffer(FC, ReadChars);
372       ReadChars += 8;
373 
374       // Read triple size.
375       if (ReadChars + 8 > FC.size())
376         return Error::success();
377 
378       uint64_t TripleSize = Read8byteIntegerFromBuffer(FC, ReadChars);
379       ReadChars += 8;
380 
381       // Read triple.
382       if (ReadChars + TripleSize > FC.size())
383         return Error::success();
384 
385       StringRef Triple(&FC.data()[ReadChars], TripleSize);
386       ReadChars += TripleSize;
387 
388       // Check if the offset and size make sense.
389       if (!Offset || Offset + Size > FC.size())
390         return Error::success();
391 
392       assert(!BundlesInfo.contains(Triple) && "Triple is duplicated??");
393       BundlesInfo[Triple] = BinaryBundleInfo(Size, Offset);
394     }
395     // Set the iterator to where we will start to read.
396     CurBundleInfo = BundlesInfo.end();
397     NextBundleInfo = BundlesInfo.begin();
398     return Error::success();
399   }
400 
401   Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer & Input)402   ReadBundleStart(MemoryBuffer &Input) final {
403     if (NextBundleInfo == BundlesInfo.end())
404       return std::nullopt;
405     CurBundleInfo = NextBundleInfo++;
406     return CurBundleInfo->first();
407   }
408 
ReadBundleEnd(MemoryBuffer & Input)409   Error ReadBundleEnd(MemoryBuffer &Input) final {
410     assert(CurBundleInfo != BundlesInfo.end() && "Invalid reader info!");
411     return Error::success();
412   }
413 
ReadBundle(raw_ostream & OS,MemoryBuffer & Input)414   Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
415     assert(CurBundleInfo != BundlesInfo.end() && "Invalid reader info!");
416     StringRef FC = Input.getBuffer();
417     OS.write(FC.data() + CurBundleInfo->second.Offset,
418              CurBundleInfo->second.Size);
419     return Error::success();
420   }
421 
WriteHeader(raw_ostream & OS,ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs)422   Error WriteHeader(raw_ostream &OS,
423                     ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
424 
425     // Compute size of the header.
426     uint64_t HeaderSize = 0;
427 
428     HeaderSize += sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1;
429     HeaderSize += 8; // Number of Bundles
430 
431     for (auto &T : BundlerConfig.TargetNames) {
432       HeaderSize += 3 * 8; // Bundle offset, Size of bundle and size of triple.
433       HeaderSize += T.size(); // The triple.
434     }
435 
436     // Write to the buffer the header.
437     OS << OFFLOAD_BUNDLER_MAGIC_STR;
438 
439     Write8byteIntegerToBuffer(OS, BundlerConfig.TargetNames.size());
440 
441     unsigned Idx = 0;
442     for (auto &T : BundlerConfig.TargetNames) {
443       MemoryBuffer &MB = *Inputs[Idx++];
444       HeaderSize = alignTo(HeaderSize, BundlerConfig.BundleAlignment);
445       // Bundle offset.
446       Write8byteIntegerToBuffer(OS, HeaderSize);
447       // Size of the bundle (adds to the next bundle's offset)
448       Write8byteIntegerToBuffer(OS, MB.getBufferSize());
449       BundlesInfo[T] = BinaryBundleInfo(MB.getBufferSize(), HeaderSize);
450       HeaderSize += MB.getBufferSize();
451       // Size of the triple
452       Write8byteIntegerToBuffer(OS, T.size());
453       // Triple
454       OS << T;
455     }
456     return Error::success();
457   }
458 
WriteBundleStart(raw_ostream & OS,StringRef TargetTriple)459   Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
460     CurWriteBundleTarget = TargetTriple.str();
461     return Error::success();
462   }
463 
WriteBundleEnd(raw_ostream & OS,StringRef TargetTriple)464   Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
465     return Error::success();
466   }
467 
WriteBundle(raw_ostream & OS,MemoryBuffer & Input)468   Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
469     auto BI = BundlesInfo[CurWriteBundleTarget];
470 
471     // Pad with 0 to reach specified offset.
472     size_t CurrentPos = OS.tell();
473     size_t PaddingSize = BI.Offset > CurrentPos ? BI.Offset - CurrentPos : 0;
474     for (size_t I = 0; I < PaddingSize; ++I)
475       OS.write('\0');
476     assert(OS.tell() == BI.Offset);
477 
478     OS.write(Input.getBufferStart(), Input.getBufferSize());
479 
480     return Error::success();
481   }
482 };
483 
484 // This class implements a list of temporary files that are removed upon
485 // object destruction.
486 class TempFileHandlerRAII {
487 public:
~TempFileHandlerRAII()488   ~TempFileHandlerRAII() {
489     for (const auto &File : Files)
490       sys::fs::remove(File);
491   }
492 
493   // Creates temporary file with given contents.
Create(std::optional<ArrayRef<char>> Contents)494   Expected<StringRef> Create(std::optional<ArrayRef<char>> Contents) {
495     SmallString<128u> File;
496     if (std::error_code EC =
497             sys::fs::createTemporaryFile("clang-offload-bundler", "tmp", File))
498       return createFileError(File, EC);
499     Files.push_front(File);
500 
501     if (Contents) {
502       std::error_code EC;
503       raw_fd_ostream OS(File, EC);
504       if (EC)
505         return createFileError(File, EC);
506       OS.write(Contents->data(), Contents->size());
507     }
508     return Files.front().str();
509   }
510 
511 private:
512   std::forward_list<SmallString<128u>> Files;
513 };
514 
515 /// Handler for object files. The bundles are organized by sections with a
516 /// designated name.
517 ///
518 /// To unbundle, we just copy the contents of the designated section.
519 class ObjectFileHandler final : public FileHandler {
520 
521   /// The object file we are currently dealing with.
522   std::unique_ptr<ObjectFile> Obj;
523 
524   /// Return the input file contents.
getInputFileContents() const525   StringRef getInputFileContents() const { return Obj->getData(); }
526 
527   /// Return bundle name (<kind>-<triple>) if the provided section is an offload
528   /// section.
529   static Expected<std::optional<StringRef>>
IsOffloadSection(SectionRef CurSection)530   IsOffloadSection(SectionRef CurSection) {
531     Expected<StringRef> NameOrErr = CurSection.getName();
532     if (!NameOrErr)
533       return NameOrErr.takeError();
534 
535     // If it does not start with the reserved suffix, just skip this section.
536     if (llvm::identify_magic(*NameOrErr) != llvm::file_magic::offload_bundle)
537       return std::nullopt;
538 
539     // Return the triple that is right after the reserved prefix.
540     return NameOrErr->substr(sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1);
541   }
542 
543   /// Total number of inputs.
544   unsigned NumberOfInputs = 0;
545 
546   /// Total number of processed inputs, i.e, inputs that were already
547   /// read from the buffers.
548   unsigned NumberOfProcessedInputs = 0;
549 
550   /// Iterator of the current and next section.
551   section_iterator CurrentSection;
552   section_iterator NextSection;
553 
554   /// Configuration options and arrays for this bundler job
555   const OffloadBundlerConfig &BundlerConfig;
556 
557 public:
558   // TODO: Add error checking from ClangOffloadBundler.cpp
ObjectFileHandler(std::unique_ptr<ObjectFile> ObjIn,const OffloadBundlerConfig & BC)559   ObjectFileHandler(std::unique_ptr<ObjectFile> ObjIn,
560                     const OffloadBundlerConfig &BC)
561       : Obj(std::move(ObjIn)), CurrentSection(Obj->section_begin()),
562         NextSection(Obj->section_begin()), BundlerConfig(BC) {}
563 
~ObjectFileHandler()564   ~ObjectFileHandler() final {}
565 
ReadHeader(MemoryBuffer & Input)566   Error ReadHeader(MemoryBuffer &Input) final { return Error::success(); }
567 
568   Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer & Input)569   ReadBundleStart(MemoryBuffer &Input) final {
570     while (NextSection != Obj->section_end()) {
571       CurrentSection = NextSection;
572       ++NextSection;
573 
574       // Check if the current section name starts with the reserved prefix. If
575       // so, return the triple.
576       Expected<std::optional<StringRef>> TripleOrErr =
577           IsOffloadSection(*CurrentSection);
578       if (!TripleOrErr)
579         return TripleOrErr.takeError();
580       if (*TripleOrErr)
581         return **TripleOrErr;
582     }
583     return std::nullopt;
584   }
585 
ReadBundleEnd(MemoryBuffer & Input)586   Error ReadBundleEnd(MemoryBuffer &Input) final { return Error::success(); }
587 
ReadBundle(raw_ostream & OS,MemoryBuffer & Input)588   Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
589     Expected<StringRef> ContentOrErr = CurrentSection->getContents();
590     if (!ContentOrErr)
591       return ContentOrErr.takeError();
592     StringRef Content = *ContentOrErr;
593 
594     // Copy fat object contents to the output when extracting host bundle.
595     std::string ModifiedContent;
596     if (Content.size() == 1u && Content.front() == 0) {
597       auto HostBundleOrErr = getHostBundle(
598           StringRef(Input.getBufferStart(), Input.getBufferSize()));
599       if (!HostBundleOrErr)
600         return HostBundleOrErr.takeError();
601 
602       ModifiedContent = std::move(*HostBundleOrErr);
603       Content = ModifiedContent;
604     }
605 
606     OS.write(Content.data(), Content.size());
607     return Error::success();
608   }
609 
WriteHeader(raw_ostream & OS,ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs)610   Error WriteHeader(raw_ostream &OS,
611                     ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
612     assert(BundlerConfig.HostInputIndex != ~0u &&
613            "Host input index not defined.");
614 
615     // Record number of inputs.
616     NumberOfInputs = Inputs.size();
617     return Error::success();
618   }
619 
WriteBundleStart(raw_ostream & OS,StringRef TargetTriple)620   Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
621     ++NumberOfProcessedInputs;
622     return Error::success();
623   }
624 
WriteBundleEnd(raw_ostream & OS,StringRef TargetTriple)625   Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
626     return Error::success();
627   }
628 
finalizeOutputFile()629   Error finalizeOutputFile() final {
630     assert(NumberOfProcessedInputs <= NumberOfInputs &&
631            "Processing more inputs that actually exist!");
632     assert(BundlerConfig.HostInputIndex != ~0u &&
633            "Host input index not defined.");
634 
635     // If this is not the last output, we don't have to do anything.
636     if (NumberOfProcessedInputs != NumberOfInputs)
637       return Error::success();
638 
639     // We will use llvm-objcopy to add target objects sections to the output
640     // fat object. These sections should have 'exclude' flag set which tells
641     // link editor to remove them from linker inputs when linking executable or
642     // shared library.
643 
644     assert(BundlerConfig.ObjcopyPath != "" &&
645            "llvm-objcopy path not specified");
646 
647     // Temporary files that need to be removed.
648     TempFileHandlerRAII TempFiles;
649 
650     // Compose llvm-objcopy command line for add target objects' sections with
651     // appropriate flags.
652     BumpPtrAllocator Alloc;
653     StringSaver SS{Alloc};
654     SmallVector<StringRef, 8u> ObjcopyArgs{"llvm-objcopy"};
655 
656     for (unsigned I = 0; I < NumberOfInputs; ++I) {
657       StringRef InputFile = BundlerConfig.InputFileNames[I];
658       if (I == BundlerConfig.HostInputIndex) {
659         // Special handling for the host bundle. We do not need to add a
660         // standard bundle for the host object since we are going to use fat
661         // object as a host object. Therefore use dummy contents (one zero byte)
662         // when creating section for the host bundle.
663         Expected<StringRef> TempFileOrErr = TempFiles.Create(ArrayRef<char>(0));
664         if (!TempFileOrErr)
665           return TempFileOrErr.takeError();
666         InputFile = *TempFileOrErr;
667       }
668 
669       ObjcopyArgs.push_back(
670           SS.save(Twine("--add-section=") + OFFLOAD_BUNDLER_MAGIC_STR +
671                   BundlerConfig.TargetNames[I] + "=" + InputFile));
672       ObjcopyArgs.push_back(
673           SS.save(Twine("--set-section-flags=") + OFFLOAD_BUNDLER_MAGIC_STR +
674                   BundlerConfig.TargetNames[I] + "=readonly,exclude"));
675     }
676     ObjcopyArgs.push_back("--");
677     ObjcopyArgs.push_back(
678         BundlerConfig.InputFileNames[BundlerConfig.HostInputIndex]);
679     ObjcopyArgs.push_back(BundlerConfig.OutputFileNames.front());
680 
681     if (Error Err = executeObjcopy(BundlerConfig.ObjcopyPath, ObjcopyArgs))
682       return Err;
683 
684     return Error::success();
685   }
686 
WriteBundle(raw_ostream & OS,MemoryBuffer & Input)687   Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
688     return Error::success();
689   }
690 
691 private:
executeObjcopy(StringRef Objcopy,ArrayRef<StringRef> Args)692   Error executeObjcopy(StringRef Objcopy, ArrayRef<StringRef> Args) {
693     // If the user asked for the commands to be printed out, we do that
694     // instead of executing it.
695     if (BundlerConfig.PrintExternalCommands) {
696       errs() << "\"" << Objcopy << "\"";
697       for (StringRef Arg : drop_begin(Args, 1))
698         errs() << " \"" << Arg << "\"";
699       errs() << "\n";
700     } else {
701       if (sys::ExecuteAndWait(Objcopy, Args))
702         return createStringError(inconvertibleErrorCode(),
703                                  "'llvm-objcopy' tool failed");
704     }
705     return Error::success();
706   }
707 
getHostBundle(StringRef Input)708   Expected<std::string> getHostBundle(StringRef Input) {
709     TempFileHandlerRAII TempFiles;
710 
711     auto ModifiedObjPathOrErr = TempFiles.Create(std::nullopt);
712     if (!ModifiedObjPathOrErr)
713       return ModifiedObjPathOrErr.takeError();
714     StringRef ModifiedObjPath = *ModifiedObjPathOrErr;
715 
716     BumpPtrAllocator Alloc;
717     StringSaver SS{Alloc};
718     SmallVector<StringRef, 16> ObjcopyArgs{"llvm-objcopy"};
719 
720     ObjcopyArgs.push_back("--regex");
721     ObjcopyArgs.push_back("--remove-section=__CLANG_OFFLOAD_BUNDLE__.*");
722     ObjcopyArgs.push_back("--");
723 
724     StringRef ObjcopyInputFileName;
725     // When unbundling an archive, the content of each object file in the
726     // archive is passed to this function by parameter Input, which is different
727     // from the content of the original input archive file, therefore it needs
728     // to be saved to a temporary file before passed to llvm-objcopy. Otherwise,
729     // Input is the same as the content of the original input file, therefore
730     // temporary file is not needed.
731     if (StringRef(BundlerConfig.FilesType).starts_with("a")) {
732       auto InputFileOrErr =
733           TempFiles.Create(ArrayRef<char>(Input.data(), Input.size()));
734       if (!InputFileOrErr)
735         return InputFileOrErr.takeError();
736       ObjcopyInputFileName = *InputFileOrErr;
737     } else
738       ObjcopyInputFileName = BundlerConfig.InputFileNames.front();
739 
740     ObjcopyArgs.push_back(ObjcopyInputFileName);
741     ObjcopyArgs.push_back(ModifiedObjPath);
742 
743     if (Error Err = executeObjcopy(BundlerConfig.ObjcopyPath, ObjcopyArgs))
744       return std::move(Err);
745 
746     auto BufOrErr = MemoryBuffer::getFile(ModifiedObjPath);
747     if (!BufOrErr)
748       return createStringError(BufOrErr.getError(),
749                                "Failed to read back the modified object file");
750 
751     return BufOrErr->get()->getBuffer().str();
752   }
753 };
754 
755 /// Handler for text files. The bundled file will have the following format.
756 ///
757 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__START__ triple"
758 /// Bundle 1
759 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__END__ triple"
760 /// ...
761 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__START__ triple"
762 /// Bundle N
763 /// "Comment OFFLOAD_BUNDLER_MAGIC_STR__END__ triple"
764 class TextFileHandler final : public FileHandler {
765   /// String that begins a line comment.
766   StringRef Comment;
767 
768   /// String that initiates a bundle.
769   std::string BundleStartString;
770 
771   /// String that closes a bundle.
772   std::string BundleEndString;
773 
774   /// Number of chars read from input.
775   size_t ReadChars = 0u;
776 
777 protected:
ReadHeader(MemoryBuffer & Input)778   Error ReadHeader(MemoryBuffer &Input) final { return Error::success(); }
779 
780   Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer & Input)781   ReadBundleStart(MemoryBuffer &Input) final {
782     StringRef FC = Input.getBuffer();
783 
784     // Find start of the bundle.
785     ReadChars = FC.find(BundleStartString, ReadChars);
786     if (ReadChars == FC.npos)
787       return std::nullopt;
788 
789     // Get position of the triple.
790     size_t TripleStart = ReadChars = ReadChars + BundleStartString.size();
791 
792     // Get position that closes the triple.
793     size_t TripleEnd = ReadChars = FC.find("\n", ReadChars);
794     if (TripleEnd == FC.npos)
795       return std::nullopt;
796 
797     // Next time we read after the new line.
798     ++ReadChars;
799 
800     return StringRef(&FC.data()[TripleStart], TripleEnd - TripleStart);
801   }
802 
ReadBundleEnd(MemoryBuffer & Input)803   Error ReadBundleEnd(MemoryBuffer &Input) final {
804     StringRef FC = Input.getBuffer();
805 
806     // Read up to the next new line.
807     assert(FC[ReadChars] == '\n' && "The bundle should end with a new line.");
808 
809     size_t TripleEnd = ReadChars = FC.find("\n", ReadChars + 1);
810     if (TripleEnd != FC.npos)
811       // Next time we read after the new line.
812       ++ReadChars;
813 
814     return Error::success();
815   }
816 
ReadBundle(raw_ostream & OS,MemoryBuffer & Input)817   Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
818     StringRef FC = Input.getBuffer();
819     size_t BundleStart = ReadChars;
820 
821     // Find end of the bundle.
822     size_t BundleEnd = ReadChars = FC.find(BundleEndString, ReadChars);
823 
824     StringRef Bundle(&FC.data()[BundleStart], BundleEnd - BundleStart);
825     OS << Bundle;
826 
827     return Error::success();
828   }
829 
WriteHeader(raw_ostream & OS,ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs)830   Error WriteHeader(raw_ostream &OS,
831                     ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
832     return Error::success();
833   }
834 
WriteBundleStart(raw_ostream & OS,StringRef TargetTriple)835   Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
836     OS << BundleStartString << TargetTriple << "\n";
837     return Error::success();
838   }
839 
WriteBundleEnd(raw_ostream & OS,StringRef TargetTriple)840   Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
841     OS << BundleEndString << TargetTriple << "\n";
842     return Error::success();
843   }
844 
WriteBundle(raw_ostream & OS,MemoryBuffer & Input)845   Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
846     OS << Input.getBuffer();
847     return Error::success();
848   }
849 
850 public:
TextFileHandler(StringRef Comment)851   TextFileHandler(StringRef Comment) : Comment(Comment), ReadChars(0) {
852     BundleStartString =
853         "\n" + Comment.str() + " " OFFLOAD_BUNDLER_MAGIC_STR "__START__ ";
854     BundleEndString =
855         "\n" + Comment.str() + " " OFFLOAD_BUNDLER_MAGIC_STR "__END__ ";
856   }
857 
listBundleIDsCallback(MemoryBuffer & Input,const BundleInfo & Info)858   Error listBundleIDsCallback(MemoryBuffer &Input,
859                               const BundleInfo &Info) final {
860     // TODO: To list bundle IDs in a bundled text file we need to go through
861     // all bundles. The format of bundled text file may need to include a
862     // header if the performance of listing bundle IDs of bundled text file is
863     // important.
864     ReadChars = Input.getBuffer().find(BundleEndString, ReadChars);
865     if (Error Err = ReadBundleEnd(Input))
866       return Err;
867     return Error::success();
868   }
869 };
870 } // namespace
871 
872 /// Return an appropriate object file handler. We use the specific object
873 /// handler if we know how to deal with that format, otherwise we use a default
874 /// binary file handler.
875 static std::unique_ptr<FileHandler>
CreateObjectFileHandler(MemoryBuffer & FirstInput,const OffloadBundlerConfig & BundlerConfig)876 CreateObjectFileHandler(MemoryBuffer &FirstInput,
877                         const OffloadBundlerConfig &BundlerConfig) {
878   // Check if the input file format is one that we know how to deal with.
879   Expected<std::unique_ptr<Binary>> BinaryOrErr = createBinary(FirstInput);
880 
881   // We only support regular object files. If failed to open the input as a
882   // known binary or this is not an object file use the default binary handler.
883   if (errorToBool(BinaryOrErr.takeError()) || !isa<ObjectFile>(*BinaryOrErr))
884     return std::make_unique<BinaryFileHandler>(BundlerConfig);
885 
886   // Otherwise create an object file handler. The handler will be owned by the
887   // client of this function.
888   return std::make_unique<ObjectFileHandler>(
889       std::unique_ptr<ObjectFile>(cast<ObjectFile>(BinaryOrErr->release())),
890       BundlerConfig);
891 }
892 
893 /// Return an appropriate handler given the input files and options.
894 static Expected<std::unique_ptr<FileHandler>>
CreateFileHandler(MemoryBuffer & FirstInput,const OffloadBundlerConfig & BundlerConfig)895 CreateFileHandler(MemoryBuffer &FirstInput,
896                   const OffloadBundlerConfig &BundlerConfig) {
897   std::string FilesType = BundlerConfig.FilesType;
898 
899   if (FilesType == "i")
900     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
901   if (FilesType == "ii")
902     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
903   if (FilesType == "cui")
904     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
905   if (FilesType == "hipi")
906     return std::make_unique<TextFileHandler>(/*Comment=*/"//");
907   // TODO: `.d` should be eventually removed once `-M` and its variants are
908   // handled properly in offload compilation.
909   if (FilesType == "d")
910     return std::make_unique<TextFileHandler>(/*Comment=*/"#");
911   if (FilesType == "ll")
912     return std::make_unique<TextFileHandler>(/*Comment=*/";");
913   if (FilesType == "bc")
914     return std::make_unique<BinaryFileHandler>(BundlerConfig);
915   if (FilesType == "s")
916     return std::make_unique<TextFileHandler>(/*Comment=*/"#");
917   if (FilesType == "o")
918     return CreateObjectFileHandler(FirstInput, BundlerConfig);
919   if (FilesType == "a")
920     return CreateObjectFileHandler(FirstInput, BundlerConfig);
921   if (FilesType == "gch")
922     return std::make_unique<BinaryFileHandler>(BundlerConfig);
923   if (FilesType == "ast")
924     return std::make_unique<BinaryFileHandler>(BundlerConfig);
925 
926   return createStringError(errc::invalid_argument,
927                            "'" + FilesType + "': invalid file type specified");
928 }
929 
OffloadBundlerConfig()930 OffloadBundlerConfig::OffloadBundlerConfig() {
931   if (llvm::compression::zstd::isAvailable()) {
932     CompressionFormat = llvm::compression::Format::Zstd;
933     // Compression level 3 is usually sufficient for zstd since long distance
934     // matching is enabled.
935     CompressionLevel = 3;
936   } else if (llvm::compression::zlib::isAvailable()) {
937     CompressionFormat = llvm::compression::Format::Zlib;
938     // Use default level for zlib since higher level does not have significant
939     // improvement.
940     CompressionLevel = llvm::compression::zlib::DefaultCompression;
941   }
942   auto IgnoreEnvVarOpt =
943       llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_IGNORE_ENV_VAR");
944   if (IgnoreEnvVarOpt.has_value() && IgnoreEnvVarOpt.value() == "1")
945     return;
946 
947   auto VerboseEnvVarOpt = llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_VERBOSE");
948   if (VerboseEnvVarOpt.has_value())
949     Verbose = VerboseEnvVarOpt.value() == "1";
950 
951   auto CompressEnvVarOpt =
952       llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_COMPRESS");
953   if (CompressEnvVarOpt.has_value())
954     Compress = CompressEnvVarOpt.value() == "1";
955 
956   auto CompressionLevelEnvVarOpt =
957       llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_COMPRESSION_LEVEL");
958   if (CompressionLevelEnvVarOpt.has_value()) {
959     llvm::StringRef CompressionLevelStr = CompressionLevelEnvVarOpt.value();
960     int Level;
961     if (!CompressionLevelStr.getAsInteger(10, Level))
962       CompressionLevel = Level;
963     else
964       llvm::errs()
965           << "Warning: Invalid value for OFFLOAD_BUNDLER_COMPRESSION_LEVEL: "
966           << CompressionLevelStr.str() << ". Ignoring it.\n";
967   }
968 }
969 
970 // Utility function to format numbers with commas
formatWithCommas(unsigned long long Value)971 static std::string formatWithCommas(unsigned long long Value) {
972   std::string Num = std::to_string(Value);
973   int InsertPosition = Num.length() - 3;
974   while (InsertPosition > 0) {
975     Num.insert(InsertPosition, ",");
976     InsertPosition -= 3;
977   }
978   return Num;
979 }
980 
981 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
compress(llvm::compression::Params P,const llvm::MemoryBuffer & Input,bool Verbose)982 CompressedOffloadBundle::compress(llvm::compression::Params P,
983                                   const llvm::MemoryBuffer &Input,
984                                   bool Verbose) {
985   if (!llvm::compression::zstd::isAvailable() &&
986       !llvm::compression::zlib::isAvailable())
987     return createStringError(llvm::inconvertibleErrorCode(),
988                              "Compression not supported");
989 
990   llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
991                         ClangOffloadBundlerTimerGroup);
992   if (Verbose)
993     HashTimer.startTimer();
994   llvm::MD5 Hash;
995   llvm::MD5::MD5Result Result;
996   Hash.update(Input.getBuffer());
997   Hash.final(Result);
998   uint64_t TruncatedHash = Result.low();
999   if (Verbose)
1000     HashTimer.stopTimer();
1001 
1002   SmallVector<uint8_t, 0> CompressedBuffer;
1003   auto BufferUint8 = llvm::ArrayRef<uint8_t>(
1004       reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
1005       Input.getBuffer().size());
1006 
1007   llvm::Timer CompressTimer("Compression Timer", "Compression time",
1008                             ClangOffloadBundlerTimerGroup);
1009   if (Verbose)
1010     CompressTimer.startTimer();
1011   llvm::compression::compress(P, BufferUint8, CompressedBuffer);
1012   if (Verbose)
1013     CompressTimer.stopTimer();
1014 
1015   uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
1016   uint32_t UncompressedSize = Input.getBuffer().size();
1017   uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +
1018                            sizeof(Version) + sizeof(CompressionMethod) +
1019                            sizeof(UncompressedSize) + sizeof(TruncatedHash) +
1020                            CompressedBuffer.size();
1021 
1022   SmallVector<char, 0> FinalBuffer;
1023   llvm::raw_svector_ostream OS(FinalBuffer);
1024   OS << MagicNumber;
1025   OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
1026   OS.write(reinterpret_cast<const char *>(&CompressionMethod),
1027            sizeof(CompressionMethod));
1028   OS.write(reinterpret_cast<const char *>(&TotalFileSize),
1029            sizeof(TotalFileSize));
1030   OS.write(reinterpret_cast<const char *>(&UncompressedSize),
1031            sizeof(UncompressedSize));
1032   OS.write(reinterpret_cast<const char *>(&TruncatedHash),
1033            sizeof(TruncatedHash));
1034   OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
1035            CompressedBuffer.size());
1036 
1037   if (Verbose) {
1038     auto MethodUsed =
1039         P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
1040     double CompressionRate =
1041         static_cast<double>(UncompressedSize) / CompressedBuffer.size();
1042     double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
1043     double CompressionSpeedMBs =
1044         (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;
1045 
1046     llvm::errs() << "Compressed bundle format version: " << Version << "\n"
1047                  << "Total file size (including headers): "
1048                  << formatWithCommas(TotalFileSize) << " bytes\n"
1049                  << "Compression method used: " << MethodUsed << "\n"
1050                  << "Compression level: " << P.level << "\n"
1051                  << "Binary size before compression: "
1052                  << formatWithCommas(UncompressedSize) << " bytes\n"
1053                  << "Binary size after compression: "
1054                  << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
1055                  << "Compression rate: "
1056                  << llvm::format("%.2lf", CompressionRate) << "\n"
1057                  << "Compression ratio: "
1058                  << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
1059                  << "Compression speed: "
1060                  << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
1061                  << "Truncated MD5 hash: "
1062                  << llvm::format_hex(TruncatedHash, 16) << "\n";
1063   }
1064   return llvm::MemoryBuffer::getMemBufferCopy(
1065       llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
1066 }
1067 
1068 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
decompress(const llvm::MemoryBuffer & Input,bool Verbose)1069 CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
1070                                     bool Verbose) {
1071 
1072   StringRef Blob = Input.getBuffer();
1073 
1074   if (Blob.size() < V1HeaderSize)
1075     return llvm::MemoryBuffer::getMemBufferCopy(Blob);
1076 
1077   if (llvm::identify_magic(Blob) !=
1078       llvm::file_magic::offload_bundle_compressed) {
1079     if (Verbose)
1080       llvm::errs() << "Uncompressed bundle.\n";
1081     return llvm::MemoryBuffer::getMemBufferCopy(Blob);
1082   }
1083 
1084   size_t CurrentOffset = MagicSize;
1085 
1086   uint16_t ThisVersion;
1087   memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));
1088   CurrentOffset += VersionFieldSize;
1089 
1090   uint16_t CompressionMethod;
1091   memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t));
1092   CurrentOffset += MethodFieldSize;
1093 
1094   uint32_t TotalFileSize;
1095   if (ThisVersion >= 2) {
1096     if (Blob.size() < V2HeaderSize)
1097       return createStringError(inconvertibleErrorCode(),
1098                                "Compressed bundle header size too small");
1099     memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
1100     CurrentOffset += FileSizeFieldSize;
1101   }
1102 
1103   uint32_t UncompressedSize;
1104   memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
1105   CurrentOffset += UncompressedSizeFieldSize;
1106 
1107   uint64_t StoredHash;
1108   memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t));
1109   CurrentOffset += HashFieldSize;
1110 
1111   llvm::compression::Format CompressionFormat;
1112   if (CompressionMethod ==
1113       static_cast<uint16_t>(llvm::compression::Format::Zlib))
1114     CompressionFormat = llvm::compression::Format::Zlib;
1115   else if (CompressionMethod ==
1116            static_cast<uint16_t>(llvm::compression::Format::Zstd))
1117     CompressionFormat = llvm::compression::Format::Zstd;
1118   else
1119     return createStringError(inconvertibleErrorCode(),
1120                              "Unknown compressing method");
1121 
1122   llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
1123                               ClangOffloadBundlerTimerGroup);
1124   if (Verbose)
1125     DecompressTimer.startTimer();
1126 
1127   SmallVector<uint8_t, 0> DecompressedData;
1128   StringRef CompressedData = Blob.substr(CurrentOffset);
1129   if (llvm::Error DecompressionError = llvm::compression::decompress(
1130           CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),
1131           DecompressedData, UncompressedSize))
1132     return createStringError(inconvertibleErrorCode(),
1133                              "Could not decompress embedded file contents: " +
1134                                  llvm::toString(std::move(DecompressionError)));
1135 
1136   if (Verbose) {
1137     DecompressTimer.stopTimer();
1138 
1139     double DecompressionTimeSeconds =
1140         DecompressTimer.getTotalTime().getWallTime();
1141 
1142     // Recalculate MD5 hash for integrity check
1143     llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
1144                                 "Hash recalculation time",
1145                                 ClangOffloadBundlerTimerGroup);
1146     HashRecalcTimer.startTimer();
1147     llvm::MD5 Hash;
1148     llvm::MD5::MD5Result Result;
1149     Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData.data(),
1150                                         DecompressedData.size()));
1151     Hash.final(Result);
1152     uint64_t RecalculatedHash = Result.low();
1153     HashRecalcTimer.stopTimer();
1154     bool HashMatch = (StoredHash == RecalculatedHash);
1155 
1156     double CompressionRate =
1157         static_cast<double>(UncompressedSize) / CompressedData.size();
1158     double DecompressionSpeedMBs =
1159         (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
1160 
1161     llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
1162     if (ThisVersion >= 2)
1163       llvm::errs() << "Total file size (from header): "
1164                    << formatWithCommas(TotalFileSize) << " bytes\n";
1165     llvm::errs() << "Decompression method: "
1166                  << (CompressionFormat == llvm::compression::Format::Zlib
1167                          ? "zlib"
1168                          : "zstd")
1169                  << "\n"
1170                  << "Size before decompression: "
1171                  << formatWithCommas(CompressedData.size()) << " bytes\n"
1172                  << "Size after decompression: "
1173                  << formatWithCommas(UncompressedSize) << " bytes\n"
1174                  << "Compression rate: "
1175                  << llvm::format("%.2lf", CompressionRate) << "\n"
1176                  << "Compression ratio: "
1177                  << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
1178                  << "Decompression speed: "
1179                  << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
1180                  << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"
1181                  << "Recalculated hash: "
1182                  << llvm::format_hex(RecalculatedHash, 16) << "\n"
1183                  << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
1184   }
1185 
1186   return llvm::MemoryBuffer::getMemBufferCopy(
1187       llvm::toStringRef(DecompressedData));
1188 }
1189 
1190 // List bundle IDs. Return true if an error was found.
ListBundleIDsInFile(StringRef InputFileName,const OffloadBundlerConfig & BundlerConfig)1191 Error OffloadBundler::ListBundleIDsInFile(
1192     StringRef InputFileName, const OffloadBundlerConfig &BundlerConfig) {
1193   // Open Input file.
1194   ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
1195       MemoryBuffer::getFileOrSTDIN(InputFileName);
1196   if (std::error_code EC = CodeOrErr.getError())
1197     return createFileError(InputFileName, EC);
1198 
1199   // Decompress the input if necessary.
1200   Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
1201       CompressedOffloadBundle::decompress(**CodeOrErr, BundlerConfig.Verbose);
1202   if (!DecompressedBufferOrErr)
1203     return createStringError(
1204         inconvertibleErrorCode(),
1205         "Failed to decompress input: " +
1206             llvm::toString(DecompressedBufferOrErr.takeError()));
1207 
1208   MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
1209 
1210   // Select the right files handler.
1211   Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1212       CreateFileHandler(DecompressedInput, BundlerConfig);
1213   if (!FileHandlerOrErr)
1214     return FileHandlerOrErr.takeError();
1215 
1216   std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
1217   assert(FH);
1218   return FH->listBundleIDs(DecompressedInput);
1219 }
1220 
1221 /// @brief Checks if a code object \p CodeObjectInfo is compatible with a given
1222 /// target \p TargetInfo.
1223 /// @link https://clang.llvm.org/docs/ClangOffloadBundler.html#bundle-entry-id
isCodeObjectCompatible(const OffloadTargetInfo & CodeObjectInfo,const OffloadTargetInfo & TargetInfo)1224 bool isCodeObjectCompatible(const OffloadTargetInfo &CodeObjectInfo,
1225                             const OffloadTargetInfo &TargetInfo) {
1226 
1227   // Compatible in case of exact match.
1228   if (CodeObjectInfo == TargetInfo) {
1229     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1230                     dbgs() << "Compatible: Exact match: \t[CodeObject: "
1231                            << CodeObjectInfo.str()
1232                            << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1233     return true;
1234   }
1235 
1236   // Incompatible if Kinds or Triples mismatch.
1237   if (!CodeObjectInfo.isOffloadKindCompatible(TargetInfo.OffloadKind) ||
1238       !CodeObjectInfo.Triple.isCompatibleWith(TargetInfo.Triple)) {
1239     DEBUG_WITH_TYPE(
1240         "CodeObjectCompatibility",
1241         dbgs() << "Incompatible: Kind/Triple mismatch \t[CodeObject: "
1242                << CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
1243                << "]\n");
1244     return false;
1245   }
1246 
1247   // Incompatible if Processors mismatch.
1248   llvm::StringMap<bool> CodeObjectFeatureMap, TargetFeatureMap;
1249   std::optional<StringRef> CodeObjectProc = clang::parseTargetID(
1250       CodeObjectInfo.Triple, CodeObjectInfo.TargetID, &CodeObjectFeatureMap);
1251   std::optional<StringRef> TargetProc = clang::parseTargetID(
1252       TargetInfo.Triple, TargetInfo.TargetID, &TargetFeatureMap);
1253 
1254   // Both TargetProc and CodeObjectProc can't be empty here.
1255   if (!TargetProc || !CodeObjectProc ||
1256       CodeObjectProc.value() != TargetProc.value()) {
1257     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1258                     dbgs() << "Incompatible: Processor mismatch \t[CodeObject: "
1259                            << CodeObjectInfo.str()
1260                            << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1261     return false;
1262   }
1263 
1264   // Incompatible if CodeObject has more features than Target, irrespective of
1265   // type or sign of features.
1266   if (CodeObjectFeatureMap.getNumItems() > TargetFeatureMap.getNumItems()) {
1267     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1268                     dbgs() << "Incompatible: CodeObject has more features "
1269                               "than target \t[CodeObject: "
1270                            << CodeObjectInfo.str()
1271                            << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1272     return false;
1273   }
1274 
1275   // Compatible if each target feature specified by target is compatible with
1276   // target feature of code object. The target feature is compatible if the
1277   // code object does not specify it (meaning Any), or if it specifies it
1278   // with the same value (meaning On or Off).
1279   for (const auto &CodeObjectFeature : CodeObjectFeatureMap) {
1280     auto TargetFeature = TargetFeatureMap.find(CodeObjectFeature.getKey());
1281     if (TargetFeature == TargetFeatureMap.end()) {
1282       DEBUG_WITH_TYPE(
1283           "CodeObjectCompatibility",
1284           dbgs()
1285               << "Incompatible: Value of CodeObject's non-ANY feature is "
1286                  "not matching with Target feature's ANY value \t[CodeObject: "
1287               << CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
1288               << "]\n");
1289       return false;
1290     } else if (TargetFeature->getValue() != CodeObjectFeature.getValue()) {
1291       DEBUG_WITH_TYPE(
1292           "CodeObjectCompatibility",
1293           dbgs() << "Incompatible: Value of CodeObject's non-ANY feature is "
1294                     "not matching with Target feature's non-ANY value "
1295                     "\t[CodeObject: "
1296                  << CodeObjectInfo.str()
1297                  << "]\t:\t[Target: " << TargetInfo.str() << "]\n");
1298       return false;
1299     }
1300   }
1301 
1302   // CodeObject is compatible if all features of Target are:
1303   //   - either, present in the Code Object's features map with the same sign,
1304   //   - or, the feature is missing from CodeObjects's features map i.e. it is
1305   //   set to ANY
1306   DEBUG_WITH_TYPE(
1307       "CodeObjectCompatibility",
1308       dbgs() << "Compatible: Target IDs are compatible \t[CodeObject: "
1309              << CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
1310              << "]\n");
1311   return true;
1312 }
1313 
1314 /// Bundle the files. Return true if an error was found.
BundleFiles()1315 Error OffloadBundler::BundleFiles() {
1316   std::error_code EC;
1317 
1318   // Create a buffer to hold the content before compressing.
1319   SmallVector<char, 0> Buffer;
1320   llvm::raw_svector_ostream BufferStream(Buffer);
1321 
1322   // Open input files.
1323   SmallVector<std::unique_ptr<MemoryBuffer>, 8u> InputBuffers;
1324   InputBuffers.reserve(BundlerConfig.InputFileNames.size());
1325   for (auto &I : BundlerConfig.InputFileNames) {
1326     ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
1327         MemoryBuffer::getFileOrSTDIN(I);
1328     if (std::error_code EC = CodeOrErr.getError())
1329       return createFileError(I, EC);
1330     InputBuffers.emplace_back(std::move(*CodeOrErr));
1331   }
1332 
1333   // Get the file handler. We use the host buffer as reference.
1334   assert((BundlerConfig.HostInputIndex != ~0u || BundlerConfig.AllowNoHost) &&
1335          "Host input index undefined??");
1336   Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr = CreateFileHandler(
1337       *InputBuffers[BundlerConfig.AllowNoHost ? 0
1338                                               : BundlerConfig.HostInputIndex],
1339       BundlerConfig);
1340   if (!FileHandlerOrErr)
1341     return FileHandlerOrErr.takeError();
1342 
1343   std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
1344   assert(FH);
1345 
1346   // Write header.
1347   if (Error Err = FH->WriteHeader(BufferStream, InputBuffers))
1348     return Err;
1349 
1350   // Write all bundles along with the start/end markers. If an error was found
1351   // writing the end of the bundle component, abort the bundle writing.
1352   auto Input = InputBuffers.begin();
1353   for (auto &Triple : BundlerConfig.TargetNames) {
1354     if (Error Err = FH->WriteBundleStart(BufferStream, Triple))
1355       return Err;
1356     if (Error Err = FH->WriteBundle(BufferStream, **Input))
1357       return Err;
1358     if (Error Err = FH->WriteBundleEnd(BufferStream, Triple))
1359       return Err;
1360     ++Input;
1361   }
1362 
1363   raw_fd_ostream OutputFile(BundlerConfig.OutputFileNames.front(), EC,
1364                             sys::fs::OF_None);
1365   if (EC)
1366     return createFileError(BundlerConfig.OutputFileNames.front(), EC);
1367 
1368   SmallVector<char, 0> CompressedBuffer;
1369   if (BundlerConfig.Compress) {
1370     std::unique_ptr<llvm::MemoryBuffer> BufferMemory =
1371         llvm::MemoryBuffer::getMemBufferCopy(
1372             llvm::StringRef(Buffer.data(), Buffer.size()));
1373     auto CompressionResult = CompressedOffloadBundle::compress(
1374         {BundlerConfig.CompressionFormat, BundlerConfig.CompressionLevel,
1375          /*zstdEnableLdm=*/true},
1376         *BufferMemory, BundlerConfig.Verbose);
1377     if (auto Error = CompressionResult.takeError())
1378       return Error;
1379 
1380     auto CompressedMemBuffer = std::move(CompressionResult.get());
1381     CompressedBuffer.assign(CompressedMemBuffer->getBufferStart(),
1382                             CompressedMemBuffer->getBufferEnd());
1383   } else
1384     CompressedBuffer = Buffer;
1385 
1386   OutputFile.write(CompressedBuffer.data(), CompressedBuffer.size());
1387 
1388   return FH->finalizeOutputFile();
1389 }
1390 
1391 // Unbundle the files. Return true if an error was found.
UnbundleFiles()1392 Error OffloadBundler::UnbundleFiles() {
1393   // Open Input file.
1394   ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
1395       MemoryBuffer::getFileOrSTDIN(BundlerConfig.InputFileNames.front());
1396   if (std::error_code EC = CodeOrErr.getError())
1397     return createFileError(BundlerConfig.InputFileNames.front(), EC);
1398 
1399   // Decompress the input if necessary.
1400   Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
1401       CompressedOffloadBundle::decompress(**CodeOrErr, BundlerConfig.Verbose);
1402   if (!DecompressedBufferOrErr)
1403     return createStringError(
1404         inconvertibleErrorCode(),
1405         "Failed to decompress input: " +
1406             llvm::toString(DecompressedBufferOrErr.takeError()));
1407 
1408   MemoryBuffer &Input = **DecompressedBufferOrErr;
1409 
1410   // Select the right files handler.
1411   Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1412       CreateFileHandler(Input, BundlerConfig);
1413   if (!FileHandlerOrErr)
1414     return FileHandlerOrErr.takeError();
1415 
1416   std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
1417   assert(FH);
1418 
1419   // Read the header of the bundled file.
1420   if (Error Err = FH->ReadHeader(Input))
1421     return Err;
1422 
1423   // Create a work list that consist of the map triple/output file.
1424   StringMap<StringRef> Worklist;
1425   auto Output = BundlerConfig.OutputFileNames.begin();
1426   for (auto &Triple : BundlerConfig.TargetNames) {
1427     Worklist[Triple] = *Output;
1428     ++Output;
1429   }
1430 
1431   // Read all the bundles that are in the work list. If we find no bundles we
1432   // assume the file is meant for the host target.
1433   bool FoundHostBundle = false;
1434   while (!Worklist.empty()) {
1435     Expected<std::optional<StringRef>> CurTripleOrErr =
1436         FH->ReadBundleStart(Input);
1437     if (!CurTripleOrErr)
1438       return CurTripleOrErr.takeError();
1439 
1440     // We don't have more bundles.
1441     if (!*CurTripleOrErr)
1442       break;
1443 
1444     StringRef CurTriple = **CurTripleOrErr;
1445     assert(!CurTriple.empty());
1446 
1447     auto Output = Worklist.begin();
1448     for (auto E = Worklist.end(); Output != E; Output++) {
1449       if (isCodeObjectCompatible(
1450               OffloadTargetInfo(CurTriple, BundlerConfig),
1451               OffloadTargetInfo((*Output).first(), BundlerConfig))) {
1452         break;
1453       }
1454     }
1455 
1456     if (Output == Worklist.end())
1457       continue;
1458     // Check if the output file can be opened and copy the bundle to it.
1459     std::error_code EC;
1460     raw_fd_ostream OutputFile((*Output).second, EC, sys::fs::OF_None);
1461     if (EC)
1462       return createFileError((*Output).second, EC);
1463     if (Error Err = FH->ReadBundle(OutputFile, Input))
1464       return Err;
1465     if (Error Err = FH->ReadBundleEnd(Input))
1466       return Err;
1467     Worklist.erase(Output);
1468 
1469     // Record if we found the host bundle.
1470     auto OffloadInfo = OffloadTargetInfo(CurTriple, BundlerConfig);
1471     if (OffloadInfo.hasHostKind())
1472       FoundHostBundle = true;
1473   }
1474 
1475   if (!BundlerConfig.AllowMissingBundles && !Worklist.empty()) {
1476     std::string ErrMsg = "Can't find bundles for";
1477     std::set<StringRef> Sorted;
1478     for (auto &E : Worklist)
1479       Sorted.insert(E.first());
1480     unsigned I = 0;
1481     unsigned Last = Sorted.size() - 1;
1482     for (auto &E : Sorted) {
1483       if (I != 0 && Last > 1)
1484         ErrMsg += ",";
1485       ErrMsg += " ";
1486       if (I == Last && I != 0)
1487         ErrMsg += "and ";
1488       ErrMsg += E.str();
1489       ++I;
1490     }
1491     return createStringError(inconvertibleErrorCode(), ErrMsg);
1492   }
1493 
1494   // If no bundles were found, assume the input file is the host bundle and
1495   // create empty files for the remaining targets.
1496   if (Worklist.size() == BundlerConfig.TargetNames.size()) {
1497     for (auto &E : Worklist) {
1498       std::error_code EC;
1499       raw_fd_ostream OutputFile(E.second, EC, sys::fs::OF_None);
1500       if (EC)
1501         return createFileError(E.second, EC);
1502 
1503       // If this entry has a host kind, copy the input file to the output file.
1504       auto OffloadInfo = OffloadTargetInfo(E.getKey(), BundlerConfig);
1505       if (OffloadInfo.hasHostKind())
1506         OutputFile.write(Input.getBufferStart(), Input.getBufferSize());
1507     }
1508     return Error::success();
1509   }
1510 
1511   // If we found elements, we emit an error if none of those were for the host
1512   // in case host bundle name was provided in command line.
1513   if (!(FoundHostBundle || BundlerConfig.HostInputIndex == ~0u ||
1514         BundlerConfig.AllowMissingBundles))
1515     return createStringError(inconvertibleErrorCode(),
1516                              "Can't find bundle for the host target");
1517 
1518   // If we still have any elements in the worklist, create empty files for them.
1519   for (auto &E : Worklist) {
1520     std::error_code EC;
1521     raw_fd_ostream OutputFile(E.second, EC, sys::fs::OF_None);
1522     if (EC)
1523       return createFileError(E.second, EC);
1524   }
1525 
1526   return Error::success();
1527 }
1528 
getDefaultArchiveKindForHost()1529 static Archive::Kind getDefaultArchiveKindForHost() {
1530   return Triple(sys::getDefaultTargetTriple()).isOSDarwin() ? Archive::K_DARWIN
1531                                                             : Archive::K_GNU;
1532 }
1533 
1534 /// @brief Computes a list of targets among all given targets which are
1535 /// compatible with this code object
1536 /// @param [in] CodeObjectInfo Code Object
1537 /// @param [out] CompatibleTargets List of all compatible targets among all
1538 /// given targets
1539 /// @return false, if no compatible target is found.
1540 static bool
getCompatibleOffloadTargets(OffloadTargetInfo & CodeObjectInfo,SmallVectorImpl<StringRef> & CompatibleTargets,const OffloadBundlerConfig & BundlerConfig)1541 getCompatibleOffloadTargets(OffloadTargetInfo &CodeObjectInfo,
1542                             SmallVectorImpl<StringRef> &CompatibleTargets,
1543                             const OffloadBundlerConfig &BundlerConfig) {
1544   if (!CompatibleTargets.empty()) {
1545     DEBUG_WITH_TYPE("CodeObjectCompatibility",
1546                     dbgs() << "CompatibleTargets list should be empty\n");
1547     return false;
1548   }
1549   for (auto &Target : BundlerConfig.TargetNames) {
1550     auto TargetInfo = OffloadTargetInfo(Target, BundlerConfig);
1551     if (isCodeObjectCompatible(CodeObjectInfo, TargetInfo))
1552       CompatibleTargets.push_back(Target);
1553   }
1554   return !CompatibleTargets.empty();
1555 }
1556 
1557 // Check that each code object file in the input archive conforms to following
1558 // rule: for a specific processor, a feature either shows up in all target IDs,
1559 // or does not show up in any target IDs. Otherwise the target ID combination is
1560 // invalid.
1561 static Error
CheckHeterogeneousArchive(StringRef ArchiveName,const OffloadBundlerConfig & BundlerConfig)1562 CheckHeterogeneousArchive(StringRef ArchiveName,
1563                           const OffloadBundlerConfig &BundlerConfig) {
1564   std::vector<std::unique_ptr<MemoryBuffer>> ArchiveBuffers;
1565   ErrorOr<std::unique_ptr<MemoryBuffer>> BufOrErr =
1566       MemoryBuffer::getFileOrSTDIN(ArchiveName, true, false);
1567   if (std::error_code EC = BufOrErr.getError())
1568     return createFileError(ArchiveName, EC);
1569 
1570   ArchiveBuffers.push_back(std::move(*BufOrErr));
1571   Expected<std::unique_ptr<llvm::object::Archive>> LibOrErr =
1572       Archive::create(ArchiveBuffers.back()->getMemBufferRef());
1573   if (!LibOrErr)
1574     return LibOrErr.takeError();
1575 
1576   auto Archive = std::move(*LibOrErr);
1577 
1578   Error ArchiveErr = Error::success();
1579   auto ChildEnd = Archive->child_end();
1580 
1581   /// Iterate over all bundled code object files in the input archive.
1582   for (auto ArchiveIter = Archive->child_begin(ArchiveErr);
1583        ArchiveIter != ChildEnd; ++ArchiveIter) {
1584     if (ArchiveErr)
1585       return ArchiveErr;
1586     auto ArchiveChildNameOrErr = (*ArchiveIter).getName();
1587     if (!ArchiveChildNameOrErr)
1588       return ArchiveChildNameOrErr.takeError();
1589 
1590     auto CodeObjectBufferRefOrErr = (*ArchiveIter).getMemoryBufferRef();
1591     if (!CodeObjectBufferRefOrErr)
1592       return CodeObjectBufferRefOrErr.takeError();
1593 
1594     auto CodeObjectBuffer =
1595         MemoryBuffer::getMemBuffer(*CodeObjectBufferRefOrErr, false);
1596 
1597     Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1598         CreateFileHandler(*CodeObjectBuffer, BundlerConfig);
1599     if (!FileHandlerOrErr)
1600       return FileHandlerOrErr.takeError();
1601 
1602     std::unique_ptr<FileHandler> &FileHandler = *FileHandlerOrErr;
1603     assert(FileHandler);
1604 
1605     std::set<StringRef> BundleIds;
1606     auto CodeObjectFileError =
1607         FileHandler->getBundleIDs(*CodeObjectBuffer, BundleIds);
1608     if (CodeObjectFileError)
1609       return CodeObjectFileError;
1610 
1611     auto &&ConflictingArchs = clang::getConflictTargetIDCombination(BundleIds);
1612     if (ConflictingArchs) {
1613       std::string ErrMsg =
1614           Twine("conflicting TargetIDs [" + ConflictingArchs.value().first +
1615                 ", " + ConflictingArchs.value().second + "] found in " +
1616                 ArchiveChildNameOrErr.get() + " of " + ArchiveName)
1617               .str();
1618       return createStringError(inconvertibleErrorCode(), ErrMsg);
1619     }
1620   }
1621 
1622   return ArchiveErr;
1623 }
1624 
1625 /// UnbundleArchive takes an archive file (".a") as input containing bundled
1626 /// code object files, and a list of offload targets (not host), and extracts
1627 /// the code objects into a new archive file for each offload target. Each
1628 /// resulting archive file contains all code object files corresponding to that
1629 /// particular offload target. The created archive file does not
1630 /// contain an index of the symbols and code object files are named as
1631 /// <<Parent Bundle Name>-<CodeObject's TargetID>>, with ':' replaced with '_'.
UnbundleArchive()1632 Error OffloadBundler::UnbundleArchive() {
1633   std::vector<std::unique_ptr<MemoryBuffer>> ArchiveBuffers;
1634 
1635   /// Map of target names with list of object files that will form the device
1636   /// specific archive for that target
1637   StringMap<std::vector<NewArchiveMember>> OutputArchivesMap;
1638 
1639   // Map of target names and output archive filenames
1640   StringMap<StringRef> TargetOutputFileNameMap;
1641 
1642   auto Output = BundlerConfig.OutputFileNames.begin();
1643   for (auto &Target : BundlerConfig.TargetNames) {
1644     TargetOutputFileNameMap[Target] = *Output;
1645     ++Output;
1646   }
1647 
1648   StringRef IFName = BundlerConfig.InputFileNames.front();
1649 
1650   if (BundlerConfig.CheckInputArchive) {
1651     // For a specific processor, a feature either shows up in all target IDs, or
1652     // does not show up in any target IDs. Otherwise the target ID combination
1653     // is invalid.
1654     auto ArchiveError = CheckHeterogeneousArchive(IFName, BundlerConfig);
1655     if (ArchiveError) {
1656       return ArchiveError;
1657     }
1658   }
1659 
1660   ErrorOr<std::unique_ptr<MemoryBuffer>> BufOrErr =
1661       MemoryBuffer::getFileOrSTDIN(IFName, true, false);
1662   if (std::error_code EC = BufOrErr.getError())
1663     return createFileError(BundlerConfig.InputFileNames.front(), EC);
1664 
1665   ArchiveBuffers.push_back(std::move(*BufOrErr));
1666   Expected<std::unique_ptr<llvm::object::Archive>> LibOrErr =
1667       Archive::create(ArchiveBuffers.back()->getMemBufferRef());
1668   if (!LibOrErr)
1669     return LibOrErr.takeError();
1670 
1671   auto Archive = std::move(*LibOrErr);
1672 
1673   Error ArchiveErr = Error::success();
1674   auto ChildEnd = Archive->child_end();
1675 
1676   /// Iterate over all bundled code object files in the input archive.
1677   for (auto ArchiveIter = Archive->child_begin(ArchiveErr);
1678        ArchiveIter != ChildEnd; ++ArchiveIter) {
1679     if (ArchiveErr)
1680       return ArchiveErr;
1681     auto ArchiveChildNameOrErr = (*ArchiveIter).getName();
1682     if (!ArchiveChildNameOrErr)
1683       return ArchiveChildNameOrErr.takeError();
1684 
1685     StringRef BundledObjectFile = sys::path::filename(*ArchiveChildNameOrErr);
1686 
1687     auto CodeObjectBufferRefOrErr = (*ArchiveIter).getMemoryBufferRef();
1688     if (!CodeObjectBufferRefOrErr)
1689       return CodeObjectBufferRefOrErr.takeError();
1690 
1691     auto TempCodeObjectBuffer =
1692         MemoryBuffer::getMemBuffer(*CodeObjectBufferRefOrErr, false);
1693 
1694     // Decompress the buffer if necessary.
1695     Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
1696         CompressedOffloadBundle::decompress(*TempCodeObjectBuffer,
1697                                             BundlerConfig.Verbose);
1698     if (!DecompressedBufferOrErr)
1699       return createStringError(
1700           inconvertibleErrorCode(),
1701           "Failed to decompress code object: " +
1702               llvm::toString(DecompressedBufferOrErr.takeError()));
1703 
1704     MemoryBuffer &CodeObjectBuffer = **DecompressedBufferOrErr;
1705 
1706     Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
1707         CreateFileHandler(CodeObjectBuffer, BundlerConfig);
1708     if (!FileHandlerOrErr)
1709       return FileHandlerOrErr.takeError();
1710 
1711     std::unique_ptr<FileHandler> &FileHandler = *FileHandlerOrErr;
1712     assert(FileHandler &&
1713            "FileHandle creation failed for file in the archive!");
1714 
1715     if (Error ReadErr = FileHandler->ReadHeader(CodeObjectBuffer))
1716       return ReadErr;
1717 
1718     Expected<std::optional<StringRef>> CurBundleIDOrErr =
1719         FileHandler->ReadBundleStart(CodeObjectBuffer);
1720     if (!CurBundleIDOrErr)
1721       return CurBundleIDOrErr.takeError();
1722 
1723     std::optional<StringRef> OptionalCurBundleID = *CurBundleIDOrErr;
1724     // No device code in this child, skip.
1725     if (!OptionalCurBundleID)
1726       continue;
1727     StringRef CodeObject = *OptionalCurBundleID;
1728 
1729     // Process all bundle entries (CodeObjects) found in this child of input
1730     // archive.
1731     while (!CodeObject.empty()) {
1732       SmallVector<StringRef> CompatibleTargets;
1733       auto CodeObjectInfo = OffloadTargetInfo(CodeObject, BundlerConfig);
1734       if (getCompatibleOffloadTargets(CodeObjectInfo, CompatibleTargets,
1735                                       BundlerConfig)) {
1736         std::string BundleData;
1737         raw_string_ostream DataStream(BundleData);
1738         if (Error Err = FileHandler->ReadBundle(DataStream, CodeObjectBuffer))
1739           return Err;
1740 
1741         for (auto &CompatibleTarget : CompatibleTargets) {
1742           SmallString<128> BundledObjectFileName;
1743           BundledObjectFileName.assign(BundledObjectFile);
1744           auto OutputBundleName =
1745               Twine(llvm::sys::path::stem(BundledObjectFileName) + "-" +
1746                     CodeObject +
1747                     getDeviceLibraryFileName(BundledObjectFileName,
1748                                              CodeObjectInfo.TargetID))
1749                   .str();
1750           // Replace ':' in optional target feature list with '_' to ensure
1751           // cross-platform validity.
1752           std::replace(OutputBundleName.begin(), OutputBundleName.end(), ':',
1753                        '_');
1754 
1755           std::unique_ptr<MemoryBuffer> MemBuf = MemoryBuffer::getMemBufferCopy(
1756               DataStream.str(), OutputBundleName);
1757           ArchiveBuffers.push_back(std::move(MemBuf));
1758           llvm::MemoryBufferRef MemBufRef =
1759               MemoryBufferRef(*(ArchiveBuffers.back()));
1760 
1761           // For inserting <CompatibleTarget, list<CodeObject>> entry in
1762           // OutputArchivesMap.
1763           if (!OutputArchivesMap.contains(CompatibleTarget)) {
1764 
1765             std::vector<NewArchiveMember> ArchiveMembers;
1766             ArchiveMembers.push_back(NewArchiveMember(MemBufRef));
1767             OutputArchivesMap.insert_or_assign(CompatibleTarget,
1768                                                std::move(ArchiveMembers));
1769           } else {
1770             OutputArchivesMap[CompatibleTarget].push_back(
1771                 NewArchiveMember(MemBufRef));
1772           }
1773         }
1774       }
1775 
1776       if (Error Err = FileHandler->ReadBundleEnd(CodeObjectBuffer))
1777         return Err;
1778 
1779       Expected<std::optional<StringRef>> NextTripleOrErr =
1780           FileHandler->ReadBundleStart(CodeObjectBuffer);
1781       if (!NextTripleOrErr)
1782         return NextTripleOrErr.takeError();
1783 
1784       CodeObject = ((*NextTripleOrErr).has_value()) ? **NextTripleOrErr : "";
1785     } // End of processing of all bundle entries of this child of input archive.
1786   }   // End of while over children of input archive.
1787 
1788   assert(!ArchiveErr && "Error occurred while reading archive!");
1789 
1790   /// Write out an archive for each target
1791   for (auto &Target : BundlerConfig.TargetNames) {
1792     StringRef FileName = TargetOutputFileNameMap[Target];
1793     StringMapIterator<std::vector<llvm::NewArchiveMember>> CurArchiveMembers =
1794         OutputArchivesMap.find(Target);
1795     if (CurArchiveMembers != OutputArchivesMap.end()) {
1796       if (Error WriteErr = writeArchive(FileName, CurArchiveMembers->getValue(),
1797                                         SymtabWritingMode::NormalSymtab,
1798                                         getDefaultArchiveKindForHost(), true,
1799                                         false, nullptr))
1800         return WriteErr;
1801     } else if (!BundlerConfig.AllowMissingBundles) {
1802       std::string ErrMsg =
1803           Twine("no compatible code object found for the target '" + Target +
1804                 "' in heterogeneous archive library: " + IFName)
1805               .str();
1806       return createStringError(inconvertibleErrorCode(), ErrMsg);
1807     } else { // Create an empty archive file if no compatible code object is
1808              // found and "allow-missing-bundles" is enabled. It ensures that
1809              // the linker using output of this step doesn't complain about
1810              // the missing input file.
1811       std::vector<llvm::NewArchiveMember> EmptyArchive;
1812       EmptyArchive.clear();
1813       if (Error WriteErr = writeArchive(
1814               FileName, EmptyArchive, SymtabWritingMode::NormalSymtab,
1815               getDefaultArchiveKindForHost(), true, false, nullptr))
1816         return WriteErr;
1817     }
1818   }
1819 
1820   return Error::success();
1821 }
1822