xref: /freebsd/contrib/llvm-project/llvm/lib/Object/OffloadBundle.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- OffloadBundle.cpp - Utilities for offload bundles---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------===//
8 
9 #include "llvm/Object/OffloadBundle.h"
10 #include "llvm/BinaryFormat/Magic.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IRReader/IRReader.h"
13 #include "llvm/MC/StringTableBuilder.h"
14 #include "llvm/Object/Archive.h"
15 #include "llvm/Object/Binary.h"
16 #include "llvm/Object/COFF.h"
17 #include "llvm/Object/ELFObjectFile.h"
18 #include "llvm/Object/Error.h"
19 #include "llvm/Object/IRObjectFile.h"
20 #include "llvm/Object/ObjectFile.h"
21 #include "llvm/Support/BinaryStreamReader.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include "llvm/Support/Timer.h"
24 
25 using namespace llvm;
26 using namespace llvm::object;
27 
28 static llvm::TimerGroup
29     OffloadBundlerTimerGroup("Offload Bundler Timer Group",
30                              "Timer group for offload bundler");
31 
32 // Extract an Offload bundle (usually a Offload Bundle) from a fat_bin
33 // section
34 Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset,
35                            StringRef FileName,
36                            SmallVectorImpl<OffloadBundleFatBin> &Bundles) {
37 
38   size_t Offset = 0;
39   size_t NextbundleStart = 0;
40 
41   // There could be multiple offloading bundles stored at this section.
42   while (NextbundleStart != StringRef::npos) {
43     std::unique_ptr<MemoryBuffer> Buffer =
44         MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",
45                                    /*RequiresNullTerminator=*/false);
46 
47     // Create the FatBinBindle object. This will also create the Bundle Entry
48     // list info.
49     auto FatBundleOrErr =
50         OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName);
51     if (!FatBundleOrErr)
52       return FatBundleOrErr.takeError();
53 
54     // Add current Bundle to list.
55     Bundles.emplace_back(std::move(**FatBundleOrErr));
56 
57     // Find the next bundle by searching for the magic string
58     StringRef Str = Buffer->getBuffer();
59     NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24);
60 
61     if (NextbundleStart != StringRef::npos)
62       Offset += NextbundleStart;
63   }
64 
65   return Error::success();
66 }
67 
68 Error OffloadBundleFatBin::readEntries(StringRef Buffer,
69                                        uint64_t SectionOffset) {
70   uint64_t NumOfEntries = 0;
71 
72   BinaryStreamReader Reader(Buffer, llvm::endianness::little);
73 
74   // Read the Magic String first.
75   StringRef Magic;
76   if (auto EC = Reader.readFixedString(Magic, 24))
77     return errorCodeToError(object_error::parse_failed);
78 
79   // Read the number of Code Objects (Entries) in the current Bundle.
80   if (auto EC = Reader.readInteger(NumOfEntries))
81     return errorCodeToError(object_error::parse_failed);
82 
83   NumberOfEntries = NumOfEntries;
84 
85   // For each Bundle Entry (code object)
86   for (uint64_t I = 0; I < NumOfEntries; I++) {
87     uint64_t EntrySize;
88     uint64_t EntryOffset;
89     uint64_t EntryIDSize;
90     StringRef EntryID;
91 
92     if (auto EC = Reader.readInteger(EntryOffset))
93       return errorCodeToError(object_error::parse_failed);
94 
95     if (auto EC = Reader.readInteger(EntrySize))
96       return errorCodeToError(object_error::parse_failed);
97 
98     if (auto EC = Reader.readInteger(EntryIDSize))
99       return errorCodeToError(object_error::parse_failed);
100 
101     if (auto EC = Reader.readFixedString(EntryID, EntryIDSize))
102       return errorCodeToError(object_error::parse_failed);
103 
104     auto Entry = std::make_unique<OffloadBundleEntry>(
105         EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID);
106 
107     Entries.push_back(*Entry);
108   }
109 
110   return Error::success();
111 }
112 
113 Expected<std::unique_ptr<OffloadBundleFatBin>>
114 OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset,
115                             StringRef FileName) {
116   if (Buf.getBufferSize() < 24)
117     return errorCodeToError(object_error::parse_failed);
118 
119   // Check for magic bytes.
120   if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle)
121     return errorCodeToError(object_error::parse_failed);
122 
123   OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName);
124 
125   // Read the Bundle Entries
126   Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset);
127   if (Err)
128     return errorCodeToError(object_error::parse_failed);
129 
130   return std::unique_ptr<OffloadBundleFatBin>(TheBundle);
131 }
132 
133 Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) {
134   // This will extract all entries in the Bundle
135   for (OffloadBundleEntry &Entry : Entries) {
136 
137     if (Entry.Size == 0)
138       continue;
139 
140     // create output file name. Which should be
141     // <fileName>-offset<Offset>-size<Size>.co"
142     std::string Str = getFileName().str() + "-offset" + itostr(Entry.Offset) +
143                       "-size" + itostr(Entry.Size) + ".co";
144     if (Error Err = object::extractCodeObject(Source, Entry.Offset, Entry.Size,
145                                               StringRef(Str)))
146       return Err;
147   }
148 
149   return Error::success();
150 }
151 
152 Error object::extractOffloadBundleFatBinary(
153     const ObjectFile &Obj, SmallVectorImpl<OffloadBundleFatBin> &Bundles) {
154   assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type");
155 
156   // Iterate through Sections until we find an offload_bundle section.
157   for (SectionRef Sec : Obj.sections()) {
158     Expected<StringRef> Buffer = Sec.getContents();
159     if (!Buffer)
160       return Buffer.takeError();
161 
162     // If it does not start with the reserved suffix, just skip this section.
163     if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) ||
164         (llvm::identify_magic(*Buffer) ==
165          llvm::file_magic::offload_bundle_compressed)) {
166 
167       uint64_t SectionOffset = 0;
168       if (Obj.isELF()) {
169         SectionOffset = ELFSectionRef(Sec).getOffset();
170       } else if (Obj.isCOFF()) // TODO: add COFF Support
171         return createStringError(object_error::parse_failed,
172                                  "COFF object files not supported.\n");
173 
174       MemoryBufferRef Contents(*Buffer, Obj.getFileName());
175 
176       if (llvm::identify_magic(*Buffer) ==
177           llvm::file_magic::offload_bundle_compressed) {
178         // Decompress the input if necessary.
179         Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
180             CompressedOffloadBundle::decompress(Contents, false);
181 
182         if (!DecompressedBufferOrErr)
183           return createStringError(
184               inconvertibleErrorCode(),
185               "Failed to decompress input: " +
186                   llvm::toString(DecompressedBufferOrErr.takeError()));
187 
188         MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
189         if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset,
190                                              Obj.getFileName(), Bundles))
191           return Err;
192       } else {
193         if (Error Err = extractOffloadBundle(Contents, SectionOffset,
194                                              Obj.getFileName(), Bundles))
195           return Err;
196       }
197     }
198   }
199   return Error::success();
200 }
201 
202 Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset,
203                                 int64_t Size, StringRef OutputFileName) {
204   Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr =
205       FileOutputBuffer::create(OutputFileName, Size);
206 
207   if (!BufferOrErr)
208     return BufferOrErr.takeError();
209 
210   Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef();
211   if (Error Err = InputBuffOrErr.takeError())
212     return Err;
213 
214   std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr);
215   std::copy(InputBuffOrErr->getBufferStart() + Offset,
216             InputBuffOrErr->getBufferStart() + Offset + Size,
217             Buf->getBufferStart());
218   if (Error E = Buf->commit())
219     return E;
220 
221   return Error::success();
222 }
223 
224 // given a file name, offset, and size, extract data into a code object file,
225 // into file <SourceFile>-offset<Offset>-size<Size>.co
226 Error object::extractOffloadBundleByURI(StringRef URIstr) {
227   // create a URI object
228   Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr(
229       OffloadBundleURI::createOffloadBundleURI(URIstr, FILE_URI));
230   if (!UriOrErr)
231     return UriOrErr.takeError();
232 
233   OffloadBundleURI &Uri = **UriOrErr;
234   std::string OutputFile = Uri.FileName.str();
235   OutputFile +=
236       "-offset" + itostr(Uri.Offset) + "-size" + itostr(Uri.Size) + ".co";
237 
238   // Create an ObjectFile object from uri.file_uri
239   auto ObjOrErr = ObjectFile::createObjectFile(Uri.FileName);
240   if (!ObjOrErr)
241     return ObjOrErr.takeError();
242 
243   auto Obj = ObjOrErr->getBinary();
244   if (Error Err =
245           object::extractCodeObject(*Obj, Uri.Offset, Uri.Size, OutputFile))
246     return Err;
247 
248   return Error::success();
249 }
250 
251 // Utility function to format numbers with commas
252 static std::string formatWithCommas(unsigned long long Value) {
253   std::string Num = std::to_string(Value);
254   int InsertPosition = Num.length() - 3;
255   while (InsertPosition > 0) {
256     Num.insert(InsertPosition, ",");
257     InsertPosition -= 3;
258   }
259   return Num;
260 }
261 
262 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
263 CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
264                                     bool Verbose) {
265   StringRef Blob = Input.getBuffer();
266 
267   if (Blob.size() < V1HeaderSize)
268     return llvm::MemoryBuffer::getMemBufferCopy(Blob);
269 
270   if (llvm::identify_magic(Blob) !=
271       llvm::file_magic::offload_bundle_compressed) {
272     if (Verbose)
273       llvm::errs() << "Uncompressed bundle.\n";
274     return llvm::MemoryBuffer::getMemBufferCopy(Blob);
275   }
276 
277   size_t CurrentOffset = MagicSize;
278 
279   uint16_t ThisVersion;
280   memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));
281   CurrentOffset += VersionFieldSize;
282 
283   uint16_t CompressionMethod;
284   memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t));
285   CurrentOffset += MethodFieldSize;
286 
287   uint32_t TotalFileSize;
288   if (ThisVersion >= 2) {
289     if (Blob.size() < V2HeaderSize)
290       return createStringError(inconvertibleErrorCode(),
291                                "Compressed bundle header size too small");
292     memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
293     CurrentOffset += FileSizeFieldSize;
294   }
295 
296   uint32_t UncompressedSize;
297   memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
298   CurrentOffset += UncompressedSizeFieldSize;
299 
300   uint64_t StoredHash;
301   memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t));
302   CurrentOffset += HashFieldSize;
303 
304   llvm::compression::Format CompressionFormat;
305   if (CompressionMethod ==
306       static_cast<uint16_t>(llvm::compression::Format::Zlib))
307     CompressionFormat = llvm::compression::Format::Zlib;
308   else if (CompressionMethod ==
309            static_cast<uint16_t>(llvm::compression::Format::Zstd))
310     CompressionFormat = llvm::compression::Format::Zstd;
311   else
312     return createStringError(inconvertibleErrorCode(),
313                              "Unknown compressing method");
314 
315   llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
316                               OffloadBundlerTimerGroup);
317   if (Verbose)
318     DecompressTimer.startTimer();
319 
320   SmallVector<uint8_t, 0> DecompressedData;
321   StringRef CompressedData = Blob.substr(CurrentOffset);
322   if (llvm::Error DecompressionError = llvm::compression::decompress(
323           CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),
324           DecompressedData, UncompressedSize))
325     return createStringError(inconvertibleErrorCode(),
326                              "Could not decompress embedded file contents: " +
327                                  llvm::toString(std::move(DecompressionError)));
328 
329   if (Verbose) {
330     DecompressTimer.stopTimer();
331 
332     double DecompressionTimeSeconds =
333         DecompressTimer.getTotalTime().getWallTime();
334 
335     // Recalculate MD5 hash for integrity check.
336     llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
337                                 "Hash recalculation time",
338                                 OffloadBundlerTimerGroup);
339     HashRecalcTimer.startTimer();
340     llvm::MD5 Hash;
341     llvm::MD5::MD5Result Result;
342     Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData));
343     Hash.final(Result);
344     uint64_t RecalculatedHash = Result.low();
345     HashRecalcTimer.stopTimer();
346     bool HashMatch = (StoredHash == RecalculatedHash);
347 
348     double CompressionRate =
349         static_cast<double>(UncompressedSize) / CompressedData.size();
350     double DecompressionSpeedMBs =
351         (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
352 
353     llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
354     if (ThisVersion >= 2)
355       llvm::errs() << "Total file size (from header): "
356                    << formatWithCommas(TotalFileSize) << " bytes\n";
357     llvm::errs() << "Decompression method: "
358                  << (CompressionFormat == llvm::compression::Format::Zlib
359                          ? "zlib"
360                          : "zstd")
361                  << "\n"
362                  << "Size before decompression: "
363                  << formatWithCommas(CompressedData.size()) << " bytes\n"
364                  << "Size after decompression: "
365                  << formatWithCommas(UncompressedSize) << " bytes\n"
366                  << "Compression rate: "
367                  << llvm::format("%.2lf", CompressionRate) << "\n"
368                  << "Compression ratio: "
369                  << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
370                  << "Decompression speed: "
371                  << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
372                  << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"
373                  << "Recalculated hash: "
374                  << llvm::format_hex(RecalculatedHash, 16) << "\n"
375                  << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
376   }
377 
378   return llvm::MemoryBuffer::getMemBufferCopy(
379       llvm::toStringRef(DecompressedData));
380 }
381 
382 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
383 CompressedOffloadBundle::compress(llvm::compression::Params P,
384                                   const llvm::MemoryBuffer &Input,
385                                   bool Verbose) {
386   if (!llvm::compression::zstd::isAvailable() &&
387       !llvm::compression::zlib::isAvailable())
388     return createStringError(llvm::inconvertibleErrorCode(),
389                              "Compression not supported");
390 
391   llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
392                         OffloadBundlerTimerGroup);
393   if (Verbose)
394     HashTimer.startTimer();
395   llvm::MD5 Hash;
396   llvm::MD5::MD5Result Result;
397   Hash.update(Input.getBuffer());
398   Hash.final(Result);
399   uint64_t TruncatedHash = Result.low();
400   if (Verbose)
401     HashTimer.stopTimer();
402 
403   SmallVector<uint8_t, 0> CompressedBuffer;
404   auto BufferUint8 = llvm::ArrayRef<uint8_t>(
405       reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
406       Input.getBuffer().size());
407 
408   llvm::Timer CompressTimer("Compression Timer", "Compression time",
409                             OffloadBundlerTimerGroup);
410   if (Verbose)
411     CompressTimer.startTimer();
412   llvm::compression::compress(P, BufferUint8, CompressedBuffer);
413   if (Verbose)
414     CompressTimer.stopTimer();
415 
416   uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
417   uint32_t UncompressedSize = Input.getBuffer().size();
418   uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +
419                            sizeof(Version) + sizeof(CompressionMethod) +
420                            sizeof(UncompressedSize) + sizeof(TruncatedHash) +
421                            CompressedBuffer.size();
422 
423   SmallVector<char, 0> FinalBuffer;
424   llvm::raw_svector_ostream OS(FinalBuffer);
425   OS << MagicNumber;
426   OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
427   OS.write(reinterpret_cast<const char *>(&CompressionMethod),
428            sizeof(CompressionMethod));
429   OS.write(reinterpret_cast<const char *>(&TotalFileSize),
430            sizeof(TotalFileSize));
431   OS.write(reinterpret_cast<const char *>(&UncompressedSize),
432            sizeof(UncompressedSize));
433   OS.write(reinterpret_cast<const char *>(&TruncatedHash),
434            sizeof(TruncatedHash));
435   OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
436            CompressedBuffer.size());
437 
438   if (Verbose) {
439     auto MethodUsed =
440         P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
441     double CompressionRate =
442         static_cast<double>(UncompressedSize) / CompressedBuffer.size();
443     double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
444     double CompressionSpeedMBs =
445         (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;
446 
447     llvm::errs() << "Compressed bundle format version: " << Version << "\n"
448                  << "Total file size (including headers): "
449                  << formatWithCommas(TotalFileSize) << " bytes\n"
450                  << "Compression method used: " << MethodUsed << "\n"
451                  << "Compression level: " << P.level << "\n"
452                  << "Binary size before compression: "
453                  << formatWithCommas(UncompressedSize) << " bytes\n"
454                  << "Binary size after compression: "
455                  << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
456                  << "Compression rate: "
457                  << llvm::format("%.2lf", CompressionRate) << "\n"
458                  << "Compression ratio: "
459                  << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
460                  << "Compression speed: "
461                  << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
462                  << "Truncated MD5 hash: "
463                  << llvm::format_hex(TruncatedHash, 16) << "\n";
464   }
465   return llvm::MemoryBuffer::getMemBufferCopy(
466       llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
467 }
468