1 //===- OffloadBundle.cpp - Utilities for offload bundles---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------===// 8 9 #include "llvm/Object/OffloadBundle.h" 10 #include "llvm/BinaryFormat/Magic.h" 11 #include "llvm/IR/Module.h" 12 #include "llvm/IRReader/IRReader.h" 13 #include "llvm/MC/StringTableBuilder.h" 14 #include "llvm/Object/Archive.h" 15 #include "llvm/Object/Binary.h" 16 #include "llvm/Object/COFF.h" 17 #include "llvm/Object/ELFObjectFile.h" 18 #include "llvm/Object/Error.h" 19 #include "llvm/Object/IRObjectFile.h" 20 #include "llvm/Object/ObjectFile.h" 21 #include "llvm/Support/BinaryStreamReader.h" 22 #include "llvm/Support/SourceMgr.h" 23 #include "llvm/Support/Timer.h" 24 25 using namespace llvm; 26 using namespace llvm::object; 27 28 static llvm::TimerGroup 29 OffloadBundlerTimerGroup("Offload Bundler Timer Group", 30 "Timer group for offload bundler"); 31 32 // Extract an Offload bundle (usually a Offload Bundle) from a fat_bin 33 // section 34 Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset, 35 StringRef FileName, 36 SmallVectorImpl<OffloadBundleFatBin> &Bundles) { 37 38 size_t Offset = 0; 39 size_t NextbundleStart = 0; 40 41 // There could be multiple offloading bundles stored at this section. 42 while (NextbundleStart != StringRef::npos) { 43 std::unique_ptr<MemoryBuffer> Buffer = 44 MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "", 45 /*RequiresNullTerminator=*/false); 46 47 // Create the FatBinBindle object. This will also create the Bundle Entry 48 // list info. 49 auto FatBundleOrErr = 50 OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName); 51 if (!FatBundleOrErr) 52 return FatBundleOrErr.takeError(); 53 54 // Add current Bundle to list. 55 Bundles.emplace_back(std::move(**FatBundleOrErr)); 56 57 // Find the next bundle by searching for the magic string 58 StringRef Str = Buffer->getBuffer(); 59 NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24); 60 61 if (NextbundleStart != StringRef::npos) 62 Offset += NextbundleStart; 63 } 64 65 return Error::success(); 66 } 67 68 Error OffloadBundleFatBin::readEntries(StringRef Buffer, 69 uint64_t SectionOffset) { 70 uint64_t NumOfEntries = 0; 71 72 BinaryStreamReader Reader(Buffer, llvm::endianness::little); 73 74 // Read the Magic String first. 75 StringRef Magic; 76 if (auto EC = Reader.readFixedString(Magic, 24)) 77 return errorCodeToError(object_error::parse_failed); 78 79 // Read the number of Code Objects (Entries) in the current Bundle. 80 if (auto EC = Reader.readInteger(NumOfEntries)) 81 return errorCodeToError(object_error::parse_failed); 82 83 NumberOfEntries = NumOfEntries; 84 85 // For each Bundle Entry (code object) 86 for (uint64_t I = 0; I < NumOfEntries; I++) { 87 uint64_t EntrySize; 88 uint64_t EntryOffset; 89 uint64_t EntryIDSize; 90 StringRef EntryID; 91 92 if (auto EC = Reader.readInteger(EntryOffset)) 93 return errorCodeToError(object_error::parse_failed); 94 95 if (auto EC = Reader.readInteger(EntrySize)) 96 return errorCodeToError(object_error::parse_failed); 97 98 if (auto EC = Reader.readInteger(EntryIDSize)) 99 return errorCodeToError(object_error::parse_failed); 100 101 if (auto EC = Reader.readFixedString(EntryID, EntryIDSize)) 102 return errorCodeToError(object_error::parse_failed); 103 104 auto Entry = std::make_unique<OffloadBundleEntry>( 105 EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID); 106 107 Entries.push_back(*Entry); 108 } 109 110 return Error::success(); 111 } 112 113 Expected<std::unique_ptr<OffloadBundleFatBin>> 114 OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, 115 StringRef FileName) { 116 if (Buf.getBufferSize() < 24) 117 return errorCodeToError(object_error::parse_failed); 118 119 // Check for magic bytes. 120 if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) 121 return errorCodeToError(object_error::parse_failed); 122 123 OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName); 124 125 // Read the Bundle Entries 126 Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset); 127 if (Err) 128 return errorCodeToError(object_error::parse_failed); 129 130 return std::unique_ptr<OffloadBundleFatBin>(TheBundle); 131 } 132 133 Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) { 134 // This will extract all entries in the Bundle 135 for (OffloadBundleEntry &Entry : Entries) { 136 137 if (Entry.Size == 0) 138 continue; 139 140 // create output file name. Which should be 141 // <fileName>-offset<Offset>-size<Size>.co" 142 std::string Str = getFileName().str() + "-offset" + itostr(Entry.Offset) + 143 "-size" + itostr(Entry.Size) + ".co"; 144 if (Error Err = object::extractCodeObject(Source, Entry.Offset, Entry.Size, 145 StringRef(Str))) 146 return Err; 147 } 148 149 return Error::success(); 150 } 151 152 Error object::extractOffloadBundleFatBinary( 153 const ObjectFile &Obj, SmallVectorImpl<OffloadBundleFatBin> &Bundles) { 154 assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type"); 155 156 // Iterate through Sections until we find an offload_bundle section. 157 for (SectionRef Sec : Obj.sections()) { 158 Expected<StringRef> Buffer = Sec.getContents(); 159 if (!Buffer) 160 return Buffer.takeError(); 161 162 // If it does not start with the reserved suffix, just skip this section. 163 if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) || 164 (llvm::identify_magic(*Buffer) == 165 llvm::file_magic::offload_bundle_compressed)) { 166 167 uint64_t SectionOffset = 0; 168 if (Obj.isELF()) { 169 SectionOffset = ELFSectionRef(Sec).getOffset(); 170 } else if (Obj.isCOFF()) // TODO: add COFF Support 171 return createStringError(object_error::parse_failed, 172 "COFF object files not supported.\n"); 173 174 MemoryBufferRef Contents(*Buffer, Obj.getFileName()); 175 176 if (llvm::identify_magic(*Buffer) == 177 llvm::file_magic::offload_bundle_compressed) { 178 // Decompress the input if necessary. 179 Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr = 180 CompressedOffloadBundle::decompress(Contents, false); 181 182 if (!DecompressedBufferOrErr) 183 return createStringError( 184 inconvertibleErrorCode(), 185 "Failed to decompress input: " + 186 llvm::toString(DecompressedBufferOrErr.takeError())); 187 188 MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr; 189 if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset, 190 Obj.getFileName(), Bundles)) 191 return Err; 192 } else { 193 if (Error Err = extractOffloadBundle(Contents, SectionOffset, 194 Obj.getFileName(), Bundles)) 195 return Err; 196 } 197 } 198 } 199 return Error::success(); 200 } 201 202 Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset, 203 int64_t Size, StringRef OutputFileName) { 204 Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr = 205 FileOutputBuffer::create(OutputFileName, Size); 206 207 if (!BufferOrErr) 208 return BufferOrErr.takeError(); 209 210 Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef(); 211 if (Error Err = InputBuffOrErr.takeError()) 212 return Err; 213 214 std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr); 215 std::copy(InputBuffOrErr->getBufferStart() + Offset, 216 InputBuffOrErr->getBufferStart() + Offset + Size, 217 Buf->getBufferStart()); 218 if (Error E = Buf->commit()) 219 return E; 220 221 return Error::success(); 222 } 223 224 // given a file name, offset, and size, extract data into a code object file, 225 // into file <SourceFile>-offset<Offset>-size<Size>.co 226 Error object::extractOffloadBundleByURI(StringRef URIstr) { 227 // create a URI object 228 Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr( 229 OffloadBundleURI::createOffloadBundleURI(URIstr, FILE_URI)); 230 if (!UriOrErr) 231 return UriOrErr.takeError(); 232 233 OffloadBundleURI &Uri = **UriOrErr; 234 std::string OutputFile = Uri.FileName.str(); 235 OutputFile += 236 "-offset" + itostr(Uri.Offset) + "-size" + itostr(Uri.Size) + ".co"; 237 238 // Create an ObjectFile object from uri.file_uri 239 auto ObjOrErr = ObjectFile::createObjectFile(Uri.FileName); 240 if (!ObjOrErr) 241 return ObjOrErr.takeError(); 242 243 auto Obj = ObjOrErr->getBinary(); 244 if (Error Err = 245 object::extractCodeObject(*Obj, Uri.Offset, Uri.Size, OutputFile)) 246 return Err; 247 248 return Error::success(); 249 } 250 251 // Utility function to format numbers with commas 252 static std::string formatWithCommas(unsigned long long Value) { 253 std::string Num = std::to_string(Value); 254 int InsertPosition = Num.length() - 3; 255 while (InsertPosition > 0) { 256 Num.insert(InsertPosition, ","); 257 InsertPosition -= 3; 258 } 259 return Num; 260 } 261 262 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> 263 CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, 264 bool Verbose) { 265 StringRef Blob = Input.getBuffer(); 266 267 if (Blob.size() < V1HeaderSize) 268 return llvm::MemoryBuffer::getMemBufferCopy(Blob); 269 270 if (llvm::identify_magic(Blob) != 271 llvm::file_magic::offload_bundle_compressed) { 272 if (Verbose) 273 llvm::errs() << "Uncompressed bundle.\n"; 274 return llvm::MemoryBuffer::getMemBufferCopy(Blob); 275 } 276 277 size_t CurrentOffset = MagicSize; 278 279 uint16_t ThisVersion; 280 memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t)); 281 CurrentOffset += VersionFieldSize; 282 283 uint16_t CompressionMethod; 284 memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t)); 285 CurrentOffset += MethodFieldSize; 286 287 uint32_t TotalFileSize; 288 if (ThisVersion >= 2) { 289 if (Blob.size() < V2HeaderSize) 290 return createStringError(inconvertibleErrorCode(), 291 "Compressed bundle header size too small"); 292 memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); 293 CurrentOffset += FileSizeFieldSize; 294 } 295 296 uint32_t UncompressedSize; 297 memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); 298 CurrentOffset += UncompressedSizeFieldSize; 299 300 uint64_t StoredHash; 301 memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t)); 302 CurrentOffset += HashFieldSize; 303 304 llvm::compression::Format CompressionFormat; 305 if (CompressionMethod == 306 static_cast<uint16_t>(llvm::compression::Format::Zlib)) 307 CompressionFormat = llvm::compression::Format::Zlib; 308 else if (CompressionMethod == 309 static_cast<uint16_t>(llvm::compression::Format::Zstd)) 310 CompressionFormat = llvm::compression::Format::Zstd; 311 else 312 return createStringError(inconvertibleErrorCode(), 313 "Unknown compressing method"); 314 315 llvm::Timer DecompressTimer("Decompression Timer", "Decompression time", 316 OffloadBundlerTimerGroup); 317 if (Verbose) 318 DecompressTimer.startTimer(); 319 320 SmallVector<uint8_t, 0> DecompressedData; 321 StringRef CompressedData = Blob.substr(CurrentOffset); 322 if (llvm::Error DecompressionError = llvm::compression::decompress( 323 CompressionFormat, llvm::arrayRefFromStringRef(CompressedData), 324 DecompressedData, UncompressedSize)) 325 return createStringError(inconvertibleErrorCode(), 326 "Could not decompress embedded file contents: " + 327 llvm::toString(std::move(DecompressionError))); 328 329 if (Verbose) { 330 DecompressTimer.stopTimer(); 331 332 double DecompressionTimeSeconds = 333 DecompressTimer.getTotalTime().getWallTime(); 334 335 // Recalculate MD5 hash for integrity check. 336 llvm::Timer HashRecalcTimer("Hash Recalculation Timer", 337 "Hash recalculation time", 338 OffloadBundlerTimerGroup); 339 HashRecalcTimer.startTimer(); 340 llvm::MD5 Hash; 341 llvm::MD5::MD5Result Result; 342 Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData)); 343 Hash.final(Result); 344 uint64_t RecalculatedHash = Result.low(); 345 HashRecalcTimer.stopTimer(); 346 bool HashMatch = (StoredHash == RecalculatedHash); 347 348 double CompressionRate = 349 static_cast<double>(UncompressedSize) / CompressedData.size(); 350 double DecompressionSpeedMBs = 351 (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds; 352 353 llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n"; 354 if (ThisVersion >= 2) 355 llvm::errs() << "Total file size (from header): " 356 << formatWithCommas(TotalFileSize) << " bytes\n"; 357 llvm::errs() << "Decompression method: " 358 << (CompressionFormat == llvm::compression::Format::Zlib 359 ? "zlib" 360 : "zstd") 361 << "\n" 362 << "Size before decompression: " 363 << formatWithCommas(CompressedData.size()) << " bytes\n" 364 << "Size after decompression: " 365 << formatWithCommas(UncompressedSize) << " bytes\n" 366 << "Compression rate: " 367 << llvm::format("%.2lf", CompressionRate) << "\n" 368 << "Compression ratio: " 369 << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" 370 << "Decompression speed: " 371 << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n" 372 << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n" 373 << "Recalculated hash: " 374 << llvm::format_hex(RecalculatedHash, 16) << "\n" 375 << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n"; 376 } 377 378 return llvm::MemoryBuffer::getMemBufferCopy( 379 llvm::toStringRef(DecompressedData)); 380 } 381 382 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> 383 CompressedOffloadBundle::compress(llvm::compression::Params P, 384 const llvm::MemoryBuffer &Input, 385 bool Verbose) { 386 if (!llvm::compression::zstd::isAvailable() && 387 !llvm::compression::zlib::isAvailable()) 388 return createStringError(llvm::inconvertibleErrorCode(), 389 "Compression not supported"); 390 391 llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time", 392 OffloadBundlerTimerGroup); 393 if (Verbose) 394 HashTimer.startTimer(); 395 llvm::MD5 Hash; 396 llvm::MD5::MD5Result Result; 397 Hash.update(Input.getBuffer()); 398 Hash.final(Result); 399 uint64_t TruncatedHash = Result.low(); 400 if (Verbose) 401 HashTimer.stopTimer(); 402 403 SmallVector<uint8_t, 0> CompressedBuffer; 404 auto BufferUint8 = llvm::ArrayRef<uint8_t>( 405 reinterpret_cast<const uint8_t *>(Input.getBuffer().data()), 406 Input.getBuffer().size()); 407 408 llvm::Timer CompressTimer("Compression Timer", "Compression time", 409 OffloadBundlerTimerGroup); 410 if (Verbose) 411 CompressTimer.startTimer(); 412 llvm::compression::compress(P, BufferUint8, CompressedBuffer); 413 if (Verbose) 414 CompressTimer.stopTimer(); 415 416 uint16_t CompressionMethod = static_cast<uint16_t>(P.format); 417 uint32_t UncompressedSize = Input.getBuffer().size(); 418 uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) + 419 sizeof(Version) + sizeof(CompressionMethod) + 420 sizeof(UncompressedSize) + sizeof(TruncatedHash) + 421 CompressedBuffer.size(); 422 423 SmallVector<char, 0> FinalBuffer; 424 llvm::raw_svector_ostream OS(FinalBuffer); 425 OS << MagicNumber; 426 OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version)); 427 OS.write(reinterpret_cast<const char *>(&CompressionMethod), 428 sizeof(CompressionMethod)); 429 OS.write(reinterpret_cast<const char *>(&TotalFileSize), 430 sizeof(TotalFileSize)); 431 OS.write(reinterpret_cast<const char *>(&UncompressedSize), 432 sizeof(UncompressedSize)); 433 OS.write(reinterpret_cast<const char *>(&TruncatedHash), 434 sizeof(TruncatedHash)); 435 OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()), 436 CompressedBuffer.size()); 437 438 if (Verbose) { 439 auto MethodUsed = 440 P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib"; 441 double CompressionRate = 442 static_cast<double>(UncompressedSize) / CompressedBuffer.size(); 443 double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime(); 444 double CompressionSpeedMBs = 445 (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds; 446 447 llvm::errs() << "Compressed bundle format version: " << Version << "\n" 448 << "Total file size (including headers): " 449 << formatWithCommas(TotalFileSize) << " bytes\n" 450 << "Compression method used: " << MethodUsed << "\n" 451 << "Compression level: " << P.level << "\n" 452 << "Binary size before compression: " 453 << formatWithCommas(UncompressedSize) << " bytes\n" 454 << "Binary size after compression: " 455 << formatWithCommas(CompressedBuffer.size()) << " bytes\n" 456 << "Compression rate: " 457 << llvm::format("%.2lf", CompressionRate) << "\n" 458 << "Compression ratio: " 459 << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" 460 << "Compression speed: " 461 << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n" 462 << "Truncated MD5 hash: " 463 << llvm::format_hex(TruncatedHash, 16) << "\n"; 464 } 465 return llvm::MemoryBuffer::getMemBufferCopy( 466 llvm::StringRef(FinalBuffer.data(), FinalBuffer.size())); 467 } 468