1 /* 2 * Copyright (c) 2017-2020, Facebook, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under both the BSD-style license (found in the 6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7 * in the COPYING file in the root directory of this source tree). 8 * You may select, at your option, one of the above-listed licenses. 9 */ 10 11 /// Zstandard educational decoder implementation 12 /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md 13 14 #include <stdint.h> // uint8_t, etc. 15 #include <stdlib.h> // malloc, free, exit 16 #include <stdio.h> // fprintf 17 #include <string.h> // memset, memcpy 18 #include "zstd_decompress.h" 19 20 21 /******* IMPORTANT CONSTANTS *********************************************/ 22 23 // Zstandard frame 24 // "Magic_Number 25 // 4 Bytes, little-endian format. Value : 0xFD2FB528" 26 #define ZSTD_MAGIC_NUMBER 0xFD2FB528U 27 28 // The size of `Block_Content` is limited by `Block_Maximum_Size`, 29 #define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024) 30 31 // literal blocks can't be larger than their block 32 #define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX 33 34 35 /******* UTILITY MACROS AND TYPES *********************************************/ 36 #define MAX(a, b) ((a) > (b) ? (a) : (b)) 37 #define MIN(a, b) ((a) < (b) ? (a) : (b)) 38 39 #if defined(ZDEC_NO_MESSAGE) 40 #define MESSAGE(...) 41 #else 42 #define MESSAGE(...) fprintf(stderr, "" __VA_ARGS__) 43 #endif 44 45 /// This decoder calls exit(1) when it encounters an error, however a production 46 /// library should propagate error codes 47 #define ERROR(s) \ 48 do { \ 49 MESSAGE("Error: %s\n", s); \ 50 exit(1); \ 51 } while (0) 52 #define INP_SIZE() \ 53 ERROR("Input buffer smaller than it should be or input is " \ 54 "corrupted") 55 #define OUT_SIZE() ERROR("Output buffer too small for output") 56 #define CORRUPTION() ERROR("Corruption detected while decompressing") 57 #define BAD_ALLOC() ERROR("Memory allocation error") 58 #define IMPOSSIBLE() ERROR("An impossibility has occurred") 59 60 typedef uint8_t u8; 61 typedef uint16_t u16; 62 typedef uint32_t u32; 63 typedef uint64_t u64; 64 65 typedef int8_t i8; 66 typedef int16_t i16; 67 typedef int32_t i32; 68 typedef int64_t i64; 69 /******* END UTILITY MACROS AND TYPES *****************************************/ 70 71 /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/ 72 /// The implementations for these functions can be found at the bottom of this 73 /// file. They implement low-level functionality needed for the higher level 74 /// decompression functions. 75 76 /*** IO STREAM OPERATIONS *************/ 77 78 /// ostream_t/istream_t are used to wrap the pointers/length data passed into 79 /// ZSTD_decompress, so that all IO operations are safely bounds checked 80 /// They are written/read forward, and reads are treated as little-endian 81 /// They should be used opaquely to ensure safety 82 typedef struct { 83 u8 *ptr; 84 size_t len; 85 } ostream_t; 86 87 typedef struct { 88 const u8 *ptr; 89 size_t len; 90 91 // Input often reads a few bits at a time, so maintain an internal offset 92 int bit_offset; 93 } istream_t; 94 95 /// The following two functions are the only ones that allow the istream to be 96 /// non-byte aligned 97 98 /// Reads `num` bits from a bitstream, and updates the internal offset 99 static inline u64 IO_read_bits(istream_t *const in, const int num_bits); 100 /// Backs-up the stream by `num` bits so they can be read again 101 static inline void IO_rewind_bits(istream_t *const in, const int num_bits); 102 /// If the remaining bits in a byte will be unused, advance to the end of the 103 /// byte 104 static inline void IO_align_stream(istream_t *const in); 105 106 /// Write the given byte into the output stream 107 static inline void IO_write_byte(ostream_t *const out, u8 symb); 108 109 /// Returns the number of bytes left to be read in this stream. The stream must 110 /// be byte aligned. 111 static inline size_t IO_istream_len(const istream_t *const in); 112 113 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that 114 /// was skipped. The stream must be byte aligned. 115 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len); 116 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that 117 /// was skipped so it can be written to. 118 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len); 119 120 /// Advance the inner state by `len` bytes. The stream must be byte aligned. 121 static inline void IO_advance_input(istream_t *const in, size_t len); 122 123 /// Returns an `ostream_t` constructed from the given pointer and length. 124 static inline ostream_t IO_make_ostream(u8 *out, size_t len); 125 /// Returns an `istream_t` constructed from the given pointer and length. 126 static inline istream_t IO_make_istream(const u8 *in, size_t len); 127 128 /// Returns an `istream_t` with the same base as `in`, and length `len`. 129 /// Then, advance `in` to account for the consumed bytes. 130 /// `in` must be byte aligned. 131 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len); 132 /*** END IO STREAM OPERATIONS *********/ 133 134 /*** BITSTREAM OPERATIONS *************/ 135 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits, 136 /// and return them interpreted as a little-endian unsigned integer. 137 static inline u64 read_bits_LE(const u8 *src, const int num_bits, 138 const size_t offset); 139 140 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so 141 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from 142 /// `src + offset`. If the offset becomes negative, the extra bits at the 143 /// bottom are filled in with `0` bits instead of reading from before `src`. 144 static inline u64 STREAM_read_bits(const u8 *src, const int bits, 145 i64 *const offset); 146 /*** END BITSTREAM OPERATIONS *********/ 147 148 /*** BIT COUNTING OPERATIONS **********/ 149 /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0` 150 static inline int highest_set_bit(const u64 num); 151 /*** END BIT COUNTING OPERATIONS ******/ 152 153 /*** HUFFMAN PRIMITIVES ***************/ 154 // Table decode method uses exponential memory, so we need to limit depth 155 #define HUF_MAX_BITS (16) 156 157 // Limit the maximum number of symbols to 256 so we can store a symbol in a byte 158 #define HUF_MAX_SYMBS (256) 159 160 /// Structure containing all tables necessary for efficient Huffman decoding 161 typedef struct { 162 u8 *symbols; 163 u8 *num_bits; 164 int max_bits; 165 } HUF_dtable; 166 167 /// Decode a single symbol and read in enough bits to refresh the state 168 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, 169 u16 *const state, const u8 *const src, 170 i64 *const offset); 171 /// Read in a full state's worth of bits to initialize it 172 static inline void HUF_init_state(const HUF_dtable *const dtable, 173 u16 *const state, const u8 *const src, 174 i64 *const offset); 175 176 /// Decompresses a single Huffman stream, returns the number of bytes decoded. 177 /// `src_len` must be the exact length of the Huffman-coded block. 178 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, 179 ostream_t *const out, istream_t *const in); 180 /// Same as previous but decodes 4 streams, formatted as in the Zstandard 181 /// specification. 182 /// `src_len` must be the exact length of the Huffman-coded block. 183 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, 184 ostream_t *const out, istream_t *const in); 185 186 /// Initialize a Huffman decoding table using the table of bit counts provided 187 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, 188 const int num_symbs); 189 /// Initialize a Huffman decoding table using the table of weights provided 190 /// Weights follow the definition provided in the Zstandard specification 191 static void HUF_init_dtable_usingweights(HUF_dtable *const table, 192 const u8 *const weights, 193 const int num_symbs); 194 195 /// Free the malloc'ed parts of a decoding table 196 static void HUF_free_dtable(HUF_dtable *const dtable); 197 /*** END HUFFMAN PRIMITIVES ***********/ 198 199 /*** FSE PRIMITIVES *******************/ 200 /// For more description of FSE see 201 /// https://github.com/Cyan4973/FiniteStateEntropy/ 202 203 // FSE table decoding uses exponential memory, so limit the maximum accuracy 204 #define FSE_MAX_ACCURACY_LOG (15) 205 // Limit the maximum number of symbols so they can be stored in a single byte 206 #define FSE_MAX_SYMBS (256) 207 208 /// The tables needed to decode FSE encoded streams 209 typedef struct { 210 u8 *symbols; 211 u8 *num_bits; 212 u16 *new_state_base; 213 int accuracy_log; 214 } FSE_dtable; 215 216 /// Return the symbol for the current state 217 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, 218 const u16 state); 219 /// Read the number of bits necessary to update state, update, and shift offset 220 /// back to reflect the bits read 221 static inline void FSE_update_state(const FSE_dtable *const dtable, 222 u16 *const state, const u8 *const src, 223 i64 *const offset); 224 225 /// Combine peek and update: decode a symbol and update the state 226 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, 227 u16 *const state, const u8 *const src, 228 i64 *const offset); 229 230 /// Read bits from the stream to initialize the state and shift offset back 231 static inline void FSE_init_state(const FSE_dtable *const dtable, 232 u16 *const state, const u8 *const src, 233 i64 *const offset); 234 235 /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights) 236 /// using an FSE decoding table. `src_len` must be the exact length of the 237 /// block. 238 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, 239 ostream_t *const out, 240 istream_t *const in); 241 242 /// Initialize a decoding table using normalized frequencies. 243 static void FSE_init_dtable(FSE_dtable *const dtable, 244 const i16 *const norm_freqs, const int num_symbs, 245 const int accuracy_log); 246 247 /// Decode an FSE header as defined in the Zstandard format specification and 248 /// use the decoded frequencies to initialize a decoding table. 249 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, 250 const int max_accuracy_log); 251 252 /// Initialize an FSE table that will always return the same symbol and consume 253 /// 0 bits per symbol, to be used for RLE mode in sequence commands 254 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb); 255 256 /// Free the malloc'ed parts of a decoding table 257 static void FSE_free_dtable(FSE_dtable *const dtable); 258 /*** END FSE PRIMITIVES ***************/ 259 260 /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/ 261 262 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/ 263 264 /// A small structure that can be reused in various places that need to access 265 /// frame header information 266 typedef struct { 267 // The size of window that we need to be able to contiguously store for 268 // references 269 size_t window_size; 270 // The total output size of this compressed frame 271 size_t frame_content_size; 272 273 // The dictionary id if this frame uses one 274 u32 dictionary_id; 275 276 // Whether or not the content of this frame has a checksum 277 int content_checksum_flag; 278 // Whether or not the output for this frame is in a single segment 279 int single_segment_flag; 280 } frame_header_t; 281 282 /// The context needed to decode blocks in a frame 283 typedef struct { 284 frame_header_t header; 285 286 // The total amount of data available for backreferences, to determine if an 287 // offset too large to be correct 288 size_t current_total_output; 289 290 const u8 *dict_content; 291 size_t dict_content_len; 292 293 // Entropy encoding tables so they can be repeated by future blocks instead 294 // of retransmitting 295 HUF_dtable literals_dtable; 296 FSE_dtable ll_dtable; 297 FSE_dtable ml_dtable; 298 FSE_dtable of_dtable; 299 300 // The last 3 offsets for the special "repeat offsets". 301 u64 previous_offsets[3]; 302 } frame_context_t; 303 304 /// The decoded contents of a dictionary so that it doesn't have to be repeated 305 /// for each frame that uses it 306 struct dictionary_s { 307 // Entropy tables 308 HUF_dtable literals_dtable; 309 FSE_dtable ll_dtable; 310 FSE_dtable ml_dtable; 311 FSE_dtable of_dtable; 312 313 // Raw content for backreferences 314 u8 *content; 315 size_t content_size; 316 317 // Offset history to prepopulate the frame's history 318 u64 previous_offsets[3]; 319 320 u32 dictionary_id; 321 }; 322 323 /// A tuple containing the parts necessary to decode and execute a ZSTD sequence 324 /// command 325 typedef struct { 326 u32 literal_length; 327 u32 match_length; 328 u32 offset; 329 } sequence_command_t; 330 331 /// The decoder works top-down, starting at the high level like Zstd frames, and 332 /// working down to lower more technical levels such as blocks, literals, and 333 /// sequences. The high-level functions roughly follow the outline of the 334 /// format specification: 335 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md 336 337 /// Before the implementation of each high-level function declared here, the 338 /// prototypes for their helper functions are defined and explained 339 340 /// Decode a single Zstd frame, or error if the input is not a valid frame. 341 /// Accepts a dict argument, which may be NULL indicating no dictionary. 342 /// See 343 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation 344 static void decode_frame(ostream_t *const out, istream_t *const in, 345 const dictionary_t *const dict); 346 347 // Decode data in a compressed block 348 static void decompress_block(frame_context_t *const ctx, ostream_t *const out, 349 istream_t *const in); 350 351 // Decode the literals section of a block 352 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, 353 u8 **const literals); 354 355 // Decode the sequences part of a block 356 static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in, 357 sequence_command_t **const sequences); 358 359 // Execute the decoded sequences on the literals block 360 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, 361 const u8 *const literals, 362 const size_t literals_len, 363 const sequence_command_t *const sequences, 364 const size_t num_sequences); 365 366 // Copies literals and returns the total literal length that was copied 367 static u32 copy_literals(const size_t seq, istream_t *litstream, 368 ostream_t *const out); 369 370 // Given an offset code from a sequence command (either an actual offset value 371 // or an index for previous offset), computes the correct offset and updates 372 // the offset history 373 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist); 374 375 // Given an offset, match length, and total output, as well as the frame 376 // context for the dictionary, determines if the dictionary is used and 377 // executes the copy operation 378 static void execute_match_copy(frame_context_t *const ctx, size_t offset, 379 size_t match_length, size_t total_output, 380 ostream_t *const out); 381 382 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ 383 384 size_t ZSTD_decompress(void *const dst, const size_t dst_len, 385 const void *const src, const size_t src_len) { 386 dictionary_t* const uninit_dict = create_dictionary(); 387 size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src, 388 src_len, uninit_dict); 389 free_dictionary(uninit_dict); 390 return decomp_size; 391 } 392 393 size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, 394 const void *const src, const size_t src_len, 395 dictionary_t* parsed_dict) { 396 397 istream_t in = IO_make_istream(src, src_len); 398 ostream_t out = IO_make_ostream(dst, dst_len); 399 400 // "A content compressed by Zstandard is transformed into a Zstandard frame. 401 // Multiple frames can be appended into a single file or stream. A frame is 402 // totally independent, has a defined beginning and end, and a set of 403 // parameters which tells the decoder how to decompress it." 404 405 /* this decoder assumes decompression of a single frame */ 406 decode_frame(&out, &in, parsed_dict); 407 408 return (size_t)(out.ptr - (u8 *)dst); 409 } 410 411 /******* FRAME DECODING ******************************************************/ 412 413 static void decode_data_frame(ostream_t *const out, istream_t *const in, 414 const dictionary_t *const dict); 415 static void init_frame_context(frame_context_t *const context, 416 istream_t *const in, 417 const dictionary_t *const dict); 418 static void free_frame_context(frame_context_t *const context); 419 static void parse_frame_header(frame_header_t *const header, 420 istream_t *const in); 421 static void frame_context_apply_dict(frame_context_t *const ctx, 422 const dictionary_t *const dict); 423 424 static void decompress_data(frame_context_t *const ctx, ostream_t *const out, 425 istream_t *const in); 426 427 static void decode_frame(ostream_t *const out, istream_t *const in, 428 const dictionary_t *const dict) { 429 const u32 magic_number = (u32)IO_read_bits(in, 32); 430 if (magic_number == ZSTD_MAGIC_NUMBER) { 431 // ZSTD frame 432 decode_data_frame(out, in, dict); 433 434 return; 435 } 436 437 // not a real frame or a skippable frame 438 ERROR("Tried to decode non-ZSTD frame"); 439 } 440 441 /// Decode a frame that contains compressed data. Not all frames do as there 442 /// are skippable frames. 443 /// See 444 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format 445 static void decode_data_frame(ostream_t *const out, istream_t *const in, 446 const dictionary_t *const dict) { 447 frame_context_t ctx; 448 449 // Initialize the context that needs to be carried from block to block 450 init_frame_context(&ctx, in, dict); 451 452 if (ctx.header.frame_content_size != 0 && 453 ctx.header.frame_content_size > out->len) { 454 OUT_SIZE(); 455 } 456 457 decompress_data(&ctx, out, in); 458 459 free_frame_context(&ctx); 460 } 461 462 /// Takes the information provided in the header and dictionary, and initializes 463 /// the context for this frame 464 static void init_frame_context(frame_context_t *const context, 465 istream_t *const in, 466 const dictionary_t *const dict) { 467 // Most fields in context are correct when initialized to 0 468 memset(context, 0, sizeof(frame_context_t)); 469 470 // Parse data from the frame header 471 parse_frame_header(&context->header, in); 472 473 // Set up the offset history for the repeat offset commands 474 context->previous_offsets[0] = 1; 475 context->previous_offsets[1] = 4; 476 context->previous_offsets[2] = 8; 477 478 // Apply details from the dict if it exists 479 frame_context_apply_dict(context, dict); 480 } 481 482 static void free_frame_context(frame_context_t *const context) { 483 HUF_free_dtable(&context->literals_dtable); 484 485 FSE_free_dtable(&context->ll_dtable); 486 FSE_free_dtable(&context->ml_dtable); 487 FSE_free_dtable(&context->of_dtable); 488 489 memset(context, 0, sizeof(frame_context_t)); 490 } 491 492 static void parse_frame_header(frame_header_t *const header, 493 istream_t *const in) { 494 // "The first header's byte is called the Frame_Header_Descriptor. It tells 495 // which other fields are present. Decoding this byte is enough to tell the 496 // size of Frame_Header. 497 // 498 // Bit number Field name 499 // 7-6 Frame_Content_Size_flag 500 // 5 Single_Segment_flag 501 // 4 Unused_bit 502 // 3 Reserved_bit 503 // 2 Content_Checksum_flag 504 // 1-0 Dictionary_ID_flag" 505 const u8 descriptor = (u8)IO_read_bits(in, 8); 506 507 // decode frame header descriptor into flags 508 const u8 frame_content_size_flag = descriptor >> 6; 509 const u8 single_segment_flag = (descriptor >> 5) & 1; 510 const u8 reserved_bit = (descriptor >> 3) & 1; 511 const u8 content_checksum_flag = (descriptor >> 2) & 1; 512 const u8 dictionary_id_flag = descriptor & 3; 513 514 if (reserved_bit != 0) { 515 CORRUPTION(); 516 } 517 518 header->single_segment_flag = single_segment_flag; 519 header->content_checksum_flag = content_checksum_flag; 520 521 // decode window size 522 if (!single_segment_flag) { 523 // "Provides guarantees on maximum back-reference distance that will be 524 // used within compressed data. This information is important for 525 // decoders to allocate enough memory. 526 // 527 // Bit numbers 7-3 2-0 528 // Field name Exponent Mantissa" 529 u8 window_descriptor = (u8)IO_read_bits(in, 8); 530 u8 exponent = window_descriptor >> 3; 531 u8 mantissa = window_descriptor & 7; 532 533 // Use the algorithm from the specification to compute window size 534 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor 535 size_t window_base = (size_t)1 << (10 + exponent); 536 size_t window_add = (window_base / 8) * mantissa; 537 header->window_size = window_base + window_add; 538 } 539 540 // decode dictionary id if it exists 541 if (dictionary_id_flag) { 542 // "This is a variable size field, which contains the ID of the 543 // dictionary required to properly decode the frame. Note that this 544 // field is optional. When it's not present, it's up to the caller to 545 // make sure it uses the correct dictionary. Format is little-endian." 546 const int bytes_array[] = {0, 1, 2, 4}; 547 const int bytes = bytes_array[dictionary_id_flag]; 548 549 header->dictionary_id = (u32)IO_read_bits(in, bytes * 8); 550 } else { 551 header->dictionary_id = 0; 552 } 553 554 // decode frame content size if it exists 555 if (single_segment_flag || frame_content_size_flag) { 556 // "This is the original (uncompressed) size. This information is 557 // optional. The Field_Size is provided according to value of 558 // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not 559 // present), 1, 2, 4 or 8 bytes. Format is little-endian." 560 // 561 // if frame_content_size_flag == 0 but single_segment_flag is set, we 562 // still have a 1 byte field 563 const int bytes_array[] = {1, 2, 4, 8}; 564 const int bytes = bytes_array[frame_content_size_flag]; 565 566 header->frame_content_size = IO_read_bits(in, bytes * 8); 567 if (bytes == 2) { 568 // "When Field_Size is 2, the offset of 256 is added." 569 header->frame_content_size += 256; 570 } 571 } else { 572 header->frame_content_size = 0; 573 } 574 575 if (single_segment_flag) { 576 // "The Window_Descriptor byte is optional. It is absent when 577 // Single_Segment_flag is set. In this case, the maximum back-reference 578 // distance is the content size itself, which can be any value from 1 to 579 // 2^64-1 bytes (16 EB)." 580 header->window_size = header->frame_content_size; 581 } 582 } 583 584 /// Decompress the data from a frame block by block 585 static void decompress_data(frame_context_t *const ctx, ostream_t *const out, 586 istream_t *const in) { 587 // "A frame encapsulates one or multiple blocks. Each block can be 588 // compressed or not, and has a guaranteed maximum content size, which 589 // depends on frame parameters. Unlike frames, each block depends on 590 // previous blocks for proper decoding. However, each block can be 591 // decompressed without waiting for its successor, allowing streaming 592 // operations." 593 int last_block = 0; 594 do { 595 // "Last_Block 596 // 597 // The lowest bit signals if this block is the last one. Frame ends 598 // right after this block. 599 // 600 // Block_Type and Block_Size 601 // 602 // The next 2 bits represent the Block_Type, while the remaining 21 bits 603 // represent the Block_Size. Format is little-endian." 604 last_block = (int)IO_read_bits(in, 1); 605 const int block_type = (int)IO_read_bits(in, 2); 606 const size_t block_len = IO_read_bits(in, 21); 607 608 switch (block_type) { 609 case 0: { 610 // "Raw_Block - this is an uncompressed block. Block_Size is the 611 // number of bytes to read and copy." 612 const u8 *const read_ptr = IO_get_read_ptr(in, block_len); 613 u8 *const write_ptr = IO_get_write_ptr(out, block_len); 614 615 // Copy the raw data into the output 616 memcpy(write_ptr, read_ptr, block_len); 617 618 ctx->current_total_output += block_len; 619 break; 620 } 621 case 1: { 622 // "RLE_Block - this is a single byte, repeated N times. In which 623 // case, Block_Size is the size to regenerate, while the 624 // "compressed" block is just 1 byte (the byte to repeat)." 625 const u8 *const read_ptr = IO_get_read_ptr(in, 1); 626 u8 *const write_ptr = IO_get_write_ptr(out, block_len); 627 628 // Copy `block_len` copies of `read_ptr[0]` to the output 629 memset(write_ptr, read_ptr[0], block_len); 630 631 ctx->current_total_output += block_len; 632 break; 633 } 634 case 2: { 635 // "Compressed_Block - this is a Zstandard compressed block, 636 // detailed in another section of this specification. Block_Size is 637 // the compressed size. 638 639 // Create a sub-stream for the block 640 istream_t block_stream = IO_make_sub_istream(in, block_len); 641 decompress_block(ctx, out, &block_stream); 642 break; 643 } 644 case 3: 645 // "Reserved - this is not a block. This value cannot be used with 646 // current version of this specification." 647 CORRUPTION(); 648 break; 649 default: 650 IMPOSSIBLE(); 651 } 652 } while (!last_block); 653 654 if (ctx->header.content_checksum_flag) { 655 // This program does not support checking the checksum, so skip over it 656 // if it's present 657 IO_advance_input(in, 4); 658 } 659 } 660 /******* END FRAME DECODING ***************************************************/ 661 662 /******* BLOCK DECOMPRESSION **************************************************/ 663 static void decompress_block(frame_context_t *const ctx, ostream_t *const out, 664 istream_t *const in) { 665 // "A compressed block consists of 2 sections : 666 // 667 // Literals_Section 668 // Sequences_Section" 669 670 671 // Part 1: decode the literals block 672 u8 *literals = NULL; 673 const size_t literals_size = decode_literals(ctx, in, &literals); 674 675 // Part 2: decode the sequences block 676 sequence_command_t *sequences = NULL; 677 const size_t num_sequences = 678 decode_sequences(ctx, in, &sequences); 679 680 // Part 3: combine literals and sequence commands to generate output 681 execute_sequences(ctx, out, literals, literals_size, sequences, 682 num_sequences); 683 free(literals); 684 free(sequences); 685 } 686 /******* END BLOCK DECOMPRESSION **********************************************/ 687 688 /******* LITERALS DECODING ****************************************************/ 689 static size_t decode_literals_simple(istream_t *const in, u8 **const literals, 690 const int block_type, 691 const int size_format); 692 static size_t decode_literals_compressed(frame_context_t *const ctx, 693 istream_t *const in, 694 u8 **const literals, 695 const int block_type, 696 const int size_format); 697 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in); 698 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, 699 int *const num_symbs); 700 701 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, 702 u8 **const literals) { 703 // "Literals can be stored uncompressed or compressed using Huffman prefix 704 // codes. When compressed, an optional tree description can be present, 705 // followed by 1 or 4 streams." 706 // 707 // "Literals_Section_Header 708 // 709 // Header is in charge of describing how literals are packed. It's a 710 // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using 711 // little-endian convention." 712 // 713 // "Literals_Block_Type 714 // 715 // This field uses 2 lowest bits of first byte, describing 4 different block 716 // types" 717 // 718 // size_format takes between 1 and 2 bits 719 int block_type = (int)IO_read_bits(in, 2); 720 int size_format = (int)IO_read_bits(in, 2); 721 722 if (block_type <= 1) { 723 // Raw or RLE literals block 724 return decode_literals_simple(in, literals, block_type, 725 size_format); 726 } else { 727 // Huffman compressed literals 728 return decode_literals_compressed(ctx, in, literals, block_type, 729 size_format); 730 } 731 } 732 733 /// Decodes literals blocks in raw or RLE form 734 static size_t decode_literals_simple(istream_t *const in, u8 **const literals, 735 const int block_type, 736 const int size_format) { 737 size_t size; 738 switch (size_format) { 739 // These cases are in the form ?0 740 // In this case, the ? bit is actually part of the size field 741 case 0: 742 case 2: 743 // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)." 744 IO_rewind_bits(in, 1); 745 size = IO_read_bits(in, 5); 746 break; 747 case 1: 748 // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)." 749 size = IO_read_bits(in, 12); 750 break; 751 case 3: 752 // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)." 753 size = IO_read_bits(in, 20); 754 break; 755 default: 756 // Size format is in range 0-3 757 IMPOSSIBLE(); 758 } 759 760 if (size > MAX_LITERALS_SIZE) { 761 CORRUPTION(); 762 } 763 764 *literals = malloc(size); 765 if (!*literals) { 766 BAD_ALLOC(); 767 } 768 769 switch (block_type) { 770 case 0: { 771 // "Raw_Literals_Block - Literals are stored uncompressed." 772 const u8 *const read_ptr = IO_get_read_ptr(in, size); 773 memcpy(*literals, read_ptr, size); 774 break; 775 } 776 case 1: { 777 // "RLE_Literals_Block - Literals consist of a single byte value repeated N times." 778 const u8 *const read_ptr = IO_get_read_ptr(in, 1); 779 memset(*literals, read_ptr[0], size); 780 break; 781 } 782 default: 783 IMPOSSIBLE(); 784 } 785 786 return size; 787 } 788 789 /// Decodes Huffman compressed literals 790 static size_t decode_literals_compressed(frame_context_t *const ctx, 791 istream_t *const in, 792 u8 **const literals, 793 const int block_type, 794 const int size_format) { 795 size_t regenerated_size, compressed_size; 796 // Only size_format=0 has 1 stream, so default to 4 797 int num_streams = 4; 798 switch (size_format) { 799 case 0: 800 // "A single stream. Both Compressed_Size and Regenerated_Size use 10 801 // bits (0-1023)." 802 num_streams = 1; 803 // Fall through as it has the same size format 804 /* fallthrough */ 805 case 1: 806 // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits 807 // (0-1023)." 808 regenerated_size = IO_read_bits(in, 10); 809 compressed_size = IO_read_bits(in, 10); 810 break; 811 case 2: 812 // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits 813 // (0-16383)." 814 regenerated_size = IO_read_bits(in, 14); 815 compressed_size = IO_read_bits(in, 14); 816 break; 817 case 3: 818 // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits 819 // (0-262143)." 820 regenerated_size = IO_read_bits(in, 18); 821 compressed_size = IO_read_bits(in, 18); 822 break; 823 default: 824 // Impossible 825 IMPOSSIBLE(); 826 } 827 if (regenerated_size > MAX_LITERALS_SIZE) { 828 CORRUPTION(); 829 } 830 831 *literals = malloc(regenerated_size); 832 if (!*literals) { 833 BAD_ALLOC(); 834 } 835 836 ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size); 837 istream_t huf_stream = IO_make_sub_istream(in, compressed_size); 838 839 if (block_type == 2) { 840 // Decode the provided Huffman table 841 // "This section is only present when Literals_Block_Type type is 842 // Compressed_Literals_Block (2)." 843 844 HUF_free_dtable(&ctx->literals_dtable); 845 decode_huf_table(&ctx->literals_dtable, &huf_stream); 846 } else { 847 // If the previous Huffman table is being repeated, ensure it exists 848 if (!ctx->literals_dtable.symbols) { 849 CORRUPTION(); 850 } 851 } 852 853 size_t symbols_decoded; 854 if (num_streams == 1) { 855 symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream); 856 } else { 857 symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream); 858 } 859 860 if (symbols_decoded != regenerated_size) { 861 CORRUPTION(); 862 } 863 864 return regenerated_size; 865 } 866 867 // Decode the Huffman table description 868 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) { 869 // "All literal values from zero (included) to last present one (excluded) 870 // are represented by Weight with values from 0 to Max_Number_of_Bits." 871 872 // "This is a single byte value (0-255), which describes how to decode the list of weights." 873 const u8 header = IO_read_bits(in, 8); 874 875 u8 weights[HUF_MAX_SYMBS]; 876 memset(weights, 0, sizeof(weights)); 877 878 int num_symbs; 879 880 if (header >= 128) { 881 // "This is a direct representation, where each Weight is written 882 // directly as a 4 bits field (0-15). The full representation occupies 883 // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte 884 // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte - 885 // 127" 886 num_symbs = header - 127; 887 const size_t bytes = (num_symbs + 1) / 2; 888 889 const u8 *const weight_src = IO_get_read_ptr(in, bytes); 890 891 for (int i = 0; i < num_symbs; i++) { 892 // "They are encoded forward, 2 893 // weights to a byte with the first weight taking the top four bits 894 // and the second taking the bottom four (e.g. the following 895 // operations could be used to read the weights: Weight[0] = 896 // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)." 897 if (i % 2 == 0) { 898 weights[i] = weight_src[i / 2] >> 4; 899 } else { 900 weights[i] = weight_src[i / 2] & 0xf; 901 } 902 } 903 } else { 904 // The weights are FSE encoded, decode them before we can construct the 905 // table 906 istream_t fse_stream = IO_make_sub_istream(in, header); 907 ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS); 908 fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs); 909 } 910 911 // Construct the table using the decoded weights 912 HUF_init_dtable_usingweights(dtable, weights, num_symbs); 913 } 914 915 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, 916 int *const num_symbs) { 917 const int MAX_ACCURACY_LOG = 7; 918 919 FSE_dtable dtable; 920 921 // "An FSE bitstream starts by a header, describing probabilities 922 // distribution. It will create a Decoding Table. For a list of Huffman 923 // weights, maximum accuracy is 7 bits." 924 FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG); 925 926 // Decode the weights 927 *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in); 928 929 FSE_free_dtable(&dtable); 930 } 931 /******* END LITERALS DECODING ************************************************/ 932 933 /******* SEQUENCE DECODING ****************************************************/ 934 /// The combination of FSE states needed to decode sequences 935 typedef struct { 936 FSE_dtable ll_table; 937 FSE_dtable of_table; 938 FSE_dtable ml_table; 939 940 u16 ll_state; 941 u16 of_state; 942 u16 ml_state; 943 } sequence_states_t; 944 945 /// Different modes to signal to decode_seq_tables what to do 946 typedef enum { 947 seq_literal_length = 0, 948 seq_offset = 1, 949 seq_match_length = 2, 950 } seq_part_t; 951 952 typedef enum { 953 seq_predefined = 0, 954 seq_rle = 1, 955 seq_fse = 2, 956 seq_repeat = 3, 957 } seq_mode_t; 958 959 /// The predefined FSE distribution tables for `seq_predefined` mode 960 static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = { 961 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 962 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1}; 963 static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = { 964 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 965 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1}; 966 static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = { 967 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 968 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 969 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1}; 970 971 /// The sequence decoding baseline and number of additional bits to read/add 972 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets 973 static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = { 974 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 975 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, 976 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}; 977 static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = { 978 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 979 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 980 981 static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = { 982 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 983 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 984 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 985 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539}; 986 static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = { 987 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 988 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 989 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 990 991 /// Offset decoding is simpler so we just need a maximum code value 992 static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52}; 993 994 static void decompress_sequences(frame_context_t *const ctx, 995 istream_t *const in, 996 sequence_command_t *const sequences, 997 const size_t num_sequences); 998 static sequence_command_t decode_sequence(sequence_states_t *const state, 999 const u8 *const src, 1000 i64 *const offset); 1001 static void decode_seq_table(FSE_dtable *const table, istream_t *const in, 1002 const seq_part_t type, const seq_mode_t mode); 1003 1004 static size_t decode_sequences(frame_context_t *const ctx, istream_t *in, 1005 sequence_command_t **const sequences) { 1006 // "A compressed block is a succession of sequences . A sequence is a 1007 // literal copy command, followed by a match copy command. A literal copy 1008 // command specifies a length. It is the number of bytes to be copied (or 1009 // extracted) from the literal section. A match copy command specifies an 1010 // offset and a length. The offset gives the position to copy from, which 1011 // can be within a previous block." 1012 1013 size_t num_sequences; 1014 1015 // "Number_of_Sequences 1016 // 1017 // This is a variable size field using between 1 and 3 bytes. Let's call its 1018 // first byte byte0." 1019 u8 header = IO_read_bits(in, 8); 1020 if (header == 0) { 1021 // "There are no sequences. The sequence section stops there. 1022 // Regenerated content is defined entirely by literals section." 1023 *sequences = NULL; 1024 return 0; 1025 } else if (header < 128) { 1026 // "Number_of_Sequences = byte0 . Uses 1 byte." 1027 num_sequences = header; 1028 } else if (header < 255) { 1029 // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes." 1030 num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8); 1031 } else { 1032 // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes." 1033 num_sequences = IO_read_bits(in, 16) + 0x7F00; 1034 } 1035 1036 *sequences = malloc(num_sequences * sizeof(sequence_command_t)); 1037 if (!*sequences) { 1038 BAD_ALLOC(); 1039 } 1040 1041 decompress_sequences(ctx, in, *sequences, num_sequences); 1042 return num_sequences; 1043 } 1044 1045 /// Decompress the FSE encoded sequence commands 1046 static void decompress_sequences(frame_context_t *const ctx, istream_t *in, 1047 sequence_command_t *const sequences, 1048 const size_t num_sequences) { 1049 // "The Sequences_Section regroup all symbols required to decode commands. 1050 // There are 3 symbol types : literals lengths, offsets and match lengths. 1051 // They are encoded together, interleaved, in a single bitstream." 1052 1053 // "Symbol compression modes 1054 // 1055 // This is a single byte, defining the compression mode of each symbol 1056 // type." 1057 // 1058 // Bit number : Field name 1059 // 7-6 : Literals_Lengths_Mode 1060 // 5-4 : Offsets_Mode 1061 // 3-2 : Match_Lengths_Mode 1062 // 1-0 : Reserved 1063 u8 compression_modes = IO_read_bits(in, 8); 1064 1065 if ((compression_modes & 3) != 0) { 1066 // Reserved bits set 1067 CORRUPTION(); 1068 } 1069 1070 // "Following the header, up to 3 distribution tables can be described. When 1071 // present, they are in this order : 1072 // 1073 // Literals lengths 1074 // Offsets 1075 // Match Lengths" 1076 // Update the tables we have stored in the context 1077 decode_seq_table(&ctx->ll_dtable, in, seq_literal_length, 1078 (compression_modes >> 6) & 3); 1079 1080 decode_seq_table(&ctx->of_dtable, in, seq_offset, 1081 (compression_modes >> 4) & 3); 1082 1083 decode_seq_table(&ctx->ml_dtable, in, seq_match_length, 1084 (compression_modes >> 2) & 3); 1085 1086 1087 sequence_states_t states; 1088 1089 // Initialize the decoding tables 1090 { 1091 states.ll_table = ctx->ll_dtable; 1092 states.of_table = ctx->of_dtable; 1093 states.ml_table = ctx->ml_dtable; 1094 } 1095 1096 const size_t len = IO_istream_len(in); 1097 const u8 *const src = IO_get_read_ptr(in, len); 1098 1099 // "After writing the last bit containing information, the compressor writes 1100 // a single 1-bit and then fills the byte with 0-7 0 bits of padding." 1101 const int padding = 8 - highest_set_bit(src[len - 1]); 1102 // The offset starts at the end because FSE streams are read backwards 1103 i64 bit_offset = (i64)(len * 8 - (size_t)padding); 1104 1105 // "The bitstream starts with initial state values, each using the required 1106 // number of bits in their respective accuracy, decoded previously from 1107 // their normalized distribution. 1108 // 1109 // It starts by Literals_Length_State, followed by Offset_State, and finally 1110 // Match_Length_State." 1111 FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset); 1112 FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset); 1113 FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset); 1114 1115 for (size_t i = 0; i < num_sequences; i++) { 1116 // Decode sequences one by one 1117 sequences[i] = decode_sequence(&states, src, &bit_offset); 1118 } 1119 1120 if (bit_offset != 0) { 1121 CORRUPTION(); 1122 } 1123 } 1124 1125 // Decode a single sequence and update the state 1126 static sequence_command_t decode_sequence(sequence_states_t *const states, 1127 const u8 *const src, 1128 i64 *const offset) { 1129 // "Each symbol is a code in its own context, which specifies Baseline and 1130 // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw 1131 // additional bits in the same bitstream." 1132 1133 // Decode symbols, but don't update states 1134 const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state); 1135 const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state); 1136 const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state); 1137 1138 // Offset doesn't need a max value as it's not decoded using a table 1139 if (ll_code > SEQ_MAX_CODES[seq_literal_length] || 1140 ml_code > SEQ_MAX_CODES[seq_match_length]) { 1141 CORRUPTION(); 1142 } 1143 1144 // Read the interleaved bits 1145 sequence_command_t seq; 1146 // "Decoding starts by reading the Number_of_Bits required to decode Offset. 1147 // It then does the same for Match_Length, and then for Literals_Length." 1148 seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset); 1149 1150 seq.match_length = 1151 SEQ_MATCH_LENGTH_BASELINES[ml_code] + 1152 STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset); 1153 1154 seq.literal_length = 1155 SEQ_LITERAL_LENGTH_BASELINES[ll_code] + 1156 STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset); 1157 1158 // "If it is not the last sequence in the block, the next operation is to 1159 // update states. Using the rules pre-calculated in the decoding tables, 1160 // Literals_Length_State is updated, followed by Match_Length_State, and 1161 // then Offset_State." 1162 // If the stream is complete don't read bits to update state 1163 if (*offset != 0) { 1164 FSE_update_state(&states->ll_table, &states->ll_state, src, offset); 1165 FSE_update_state(&states->ml_table, &states->ml_state, src, offset); 1166 FSE_update_state(&states->of_table, &states->of_state, src, offset); 1167 } 1168 1169 return seq; 1170 } 1171 1172 /// Given a sequence part and table mode, decode the FSE distribution 1173 /// Errors if the mode is `seq_repeat` without a pre-existing table in `table` 1174 static void decode_seq_table(FSE_dtable *const table, istream_t *const in, 1175 const seq_part_t type, const seq_mode_t mode) { 1176 // Constant arrays indexed by seq_part_t 1177 const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST, 1178 SEQ_OFFSET_DEFAULT_DIST, 1179 SEQ_MATCH_LENGTH_DEFAULT_DIST}; 1180 const size_t default_distribution_lengths[] = {36, 29, 53}; 1181 const size_t default_distribution_accuracies[] = {6, 5, 6}; 1182 1183 const size_t max_accuracies[] = {9, 8, 9}; 1184 1185 if (mode != seq_repeat) { 1186 // Free old one before overwriting 1187 FSE_free_dtable(table); 1188 } 1189 1190 switch (mode) { 1191 case seq_predefined: { 1192 // "Predefined_Mode : uses a predefined distribution table." 1193 const i16 *distribution = default_distributions[type]; 1194 const size_t symbs = default_distribution_lengths[type]; 1195 const size_t accuracy_log = default_distribution_accuracies[type]; 1196 1197 FSE_init_dtable(table, distribution, symbs, accuracy_log); 1198 break; 1199 } 1200 case seq_rle: { 1201 // "RLE_Mode : it's a single code, repeated Number_of_Sequences times." 1202 const u8 symb = IO_get_read_ptr(in, 1)[0]; 1203 FSE_init_dtable_rle(table, symb); 1204 break; 1205 } 1206 case seq_fse: { 1207 // "FSE_Compressed_Mode : standard FSE compression. A distribution table 1208 // will be present " 1209 FSE_decode_header(table, in, max_accuracies[type]); 1210 break; 1211 } 1212 case seq_repeat: 1213 // "Repeat_Mode : re-use distribution table from previous compressed 1214 // block." 1215 // Nothing to do here, table will be unchanged 1216 if (!table->symbols) { 1217 // This mode is invalid if we don't already have a table 1218 CORRUPTION(); 1219 } 1220 break; 1221 default: 1222 // Impossible, as mode is from 0-3 1223 IMPOSSIBLE(); 1224 break; 1225 } 1226 1227 } 1228 /******* END SEQUENCE DECODING ************************************************/ 1229 1230 /******* SEQUENCE EXECUTION ***************************************************/ 1231 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, 1232 const u8 *const literals, 1233 const size_t literals_len, 1234 const sequence_command_t *const sequences, 1235 const size_t num_sequences) { 1236 istream_t litstream = IO_make_istream(literals, literals_len); 1237 1238 u64 *const offset_hist = ctx->previous_offsets; 1239 size_t total_output = ctx->current_total_output; 1240 1241 for (size_t i = 0; i < num_sequences; i++) { 1242 const sequence_command_t seq = sequences[i]; 1243 { 1244 const u32 literals_size = copy_literals(seq.literal_length, &litstream, out); 1245 total_output += literals_size; 1246 } 1247 1248 size_t const offset = compute_offset(seq, offset_hist); 1249 1250 size_t const match_length = seq.match_length; 1251 1252 execute_match_copy(ctx, offset, match_length, total_output, out); 1253 1254 total_output += match_length; 1255 } 1256 1257 // Copy any leftover literals 1258 { 1259 size_t len = IO_istream_len(&litstream); 1260 copy_literals(len, &litstream, out); 1261 total_output += len; 1262 } 1263 1264 ctx->current_total_output = total_output; 1265 } 1266 1267 static u32 copy_literals(const size_t literal_length, istream_t *litstream, 1268 ostream_t *const out) { 1269 // If the sequence asks for more literals than are left, the 1270 // sequence must be corrupted 1271 if (literal_length > IO_istream_len(litstream)) { 1272 CORRUPTION(); 1273 } 1274 1275 u8 *const write_ptr = IO_get_write_ptr(out, literal_length); 1276 const u8 *const read_ptr = 1277 IO_get_read_ptr(litstream, literal_length); 1278 // Copy literals to output 1279 memcpy(write_ptr, read_ptr, literal_length); 1280 1281 return literal_length; 1282 } 1283 1284 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) { 1285 size_t offset; 1286 // Offsets are special, we need to handle the repeat offsets 1287 if (seq.offset <= 3) { 1288 // "The first 3 values define a repeated offset and we will call 1289 // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3. 1290 // They are sorted in recency order, with Repeated_Offset1 meaning 1291 // 'most recent one'". 1292 1293 // Use 0 indexing for the array 1294 u32 idx = seq.offset - 1; 1295 if (seq.literal_length == 0) { 1296 // "There is an exception though, when current sequence's 1297 // literals length is 0. In this case, repeated offsets are 1298 // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2, 1299 // Repeated_Offset2 becomes Repeated_Offset3, and 1300 // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte." 1301 idx++; 1302 } 1303 1304 if (idx == 0) { 1305 offset = offset_hist[0]; 1306 } else { 1307 // If idx == 3 then literal length was 0 and the offset was 3, 1308 // as per the exception listed above 1309 offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1; 1310 1311 // If idx == 1 we don't need to modify offset_hist[2], since 1312 // we're using the second-most recent code 1313 if (idx > 1) { 1314 offset_hist[2] = offset_hist[1]; 1315 } 1316 offset_hist[1] = offset_hist[0]; 1317 offset_hist[0] = offset; 1318 } 1319 } else { 1320 // When it's not a repeat offset: 1321 // "if (Offset_Value > 3) offset = Offset_Value - 3;" 1322 offset = seq.offset - 3; 1323 1324 // Shift back history 1325 offset_hist[2] = offset_hist[1]; 1326 offset_hist[1] = offset_hist[0]; 1327 offset_hist[0] = offset; 1328 } 1329 return offset; 1330 } 1331 1332 static void execute_match_copy(frame_context_t *const ctx, size_t offset, 1333 size_t match_length, size_t total_output, 1334 ostream_t *const out) { 1335 u8 *write_ptr = IO_get_write_ptr(out, match_length); 1336 if (total_output <= ctx->header.window_size) { 1337 // In this case offset might go back into the dictionary 1338 if (offset > total_output + ctx->dict_content_len) { 1339 // The offset goes beyond even the dictionary 1340 CORRUPTION(); 1341 } 1342 1343 if (offset > total_output) { 1344 // "The rest of the dictionary is its content. The content act 1345 // as a "past" in front of data to compress or decompress, so it 1346 // can be referenced in sequence commands." 1347 const size_t dict_copy = 1348 MIN(offset - total_output, match_length); 1349 const size_t dict_offset = 1350 ctx->dict_content_len - (offset - total_output); 1351 1352 memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy); 1353 write_ptr += dict_copy; 1354 match_length -= dict_copy; 1355 } 1356 } else if (offset > ctx->header.window_size) { 1357 CORRUPTION(); 1358 } 1359 1360 // We must copy byte by byte because the match length might be larger 1361 // than the offset 1362 // ex: if the output so far was "abc", a command with offset=3 and 1363 // match_length=6 would produce "abcabcabc" as the new output 1364 for (size_t j = 0; j < match_length; j++) { 1365 *write_ptr = *(write_ptr - offset); 1366 write_ptr++; 1367 } 1368 } 1369 /******* END SEQUENCE EXECUTION ***********************************************/ 1370 1371 /******* OUTPUT SIZE COUNTING *************************************************/ 1372 /// Get the decompressed size of an input stream so memory can be allocated in 1373 /// advance. 1374 /// This implementation assumes `src` points to a single ZSTD-compressed frame 1375 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { 1376 istream_t in = IO_make_istream(src, src_len); 1377 1378 // get decompressed size from ZSTD frame header 1379 { 1380 const u32 magic_number = (u32)IO_read_bits(&in, 32); 1381 1382 if (magic_number == ZSTD_MAGIC_NUMBER) { 1383 // ZSTD frame 1384 frame_header_t header; 1385 parse_frame_header(&header, &in); 1386 1387 if (header.frame_content_size == 0 && !header.single_segment_flag) { 1388 // Content size not provided, we can't tell 1389 return (size_t)-1; 1390 } 1391 1392 return header.frame_content_size; 1393 } else { 1394 // not a real frame or skippable frame 1395 ERROR("ZSTD frame magic number did not match"); 1396 } 1397 } 1398 } 1399 /******* END OUTPUT SIZE COUNTING *********************************************/ 1400 1401 /******* DICTIONARY PARSING ***************************************************/ 1402 dictionary_t* create_dictionary() { 1403 dictionary_t* const dict = calloc(1, sizeof(dictionary_t)); 1404 if (!dict) { 1405 BAD_ALLOC(); 1406 } 1407 return dict; 1408 } 1409 1410 /// Free an allocated dictionary 1411 void free_dictionary(dictionary_t *const dict) { 1412 HUF_free_dtable(&dict->literals_dtable); 1413 FSE_free_dtable(&dict->ll_dtable); 1414 FSE_free_dtable(&dict->of_dtable); 1415 FSE_free_dtable(&dict->ml_dtable); 1416 1417 free(dict->content); 1418 1419 memset(dict, 0, sizeof(dictionary_t)); 1420 1421 free(dict); 1422 } 1423 1424 1425 #if !defined(ZDEC_NO_DICTIONARY) 1426 #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes") 1427 #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src"); 1428 1429 static void init_dictionary_content(dictionary_t *const dict, 1430 istream_t *const in); 1431 1432 void parse_dictionary(dictionary_t *const dict, const void *src, 1433 size_t src_len) { 1434 const u8 *byte_src = (const u8 *)src; 1435 memset(dict, 0, sizeof(dictionary_t)); 1436 if (src == NULL) { /* cannot initialize dictionary with null src */ 1437 NULL_SRC(); 1438 } 1439 if (src_len < 8) { 1440 DICT_SIZE_ERROR(); 1441 } 1442 1443 istream_t in = IO_make_istream(byte_src, src_len); 1444 1445 const u32 magic_number = IO_read_bits(&in, 32); 1446 if (magic_number != 0xEC30A437) { 1447 // raw content dict 1448 IO_rewind_bits(&in, 32); 1449 init_dictionary_content(dict, &in); 1450 return; 1451 } 1452 1453 dict->dictionary_id = IO_read_bits(&in, 32); 1454 1455 // "Entropy_Tables : following the same format as the tables in compressed 1456 // blocks. They are stored in following order : Huffman tables for literals, 1457 // FSE table for offsets, FSE table for match lengths, and FSE table for 1458 // literals lengths. It's finally followed by 3 offset values, populating 1459 // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes 1460 // little-endian each, for a total of 12 bytes. Each recent offset must have 1461 // a value < dictionary size." 1462 decode_huf_table(&dict->literals_dtable, &in); 1463 decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse); 1464 decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse); 1465 decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse); 1466 1467 // Read in the previous offset history 1468 dict->previous_offsets[0] = IO_read_bits(&in, 32); 1469 dict->previous_offsets[1] = IO_read_bits(&in, 32); 1470 dict->previous_offsets[2] = IO_read_bits(&in, 32); 1471 1472 // Ensure the provided offsets aren't too large 1473 // "Each recent offset must have a value < dictionary size." 1474 for (int i = 0; i < 3; i++) { 1475 if (dict->previous_offsets[i] > src_len) { 1476 ERROR("Dictionary corrupted"); 1477 } 1478 } 1479 1480 // "Content : The rest of the dictionary is its content. The content act as 1481 // a "past" in front of data to compress or decompress, so it can be 1482 // referenced in sequence commands." 1483 init_dictionary_content(dict, &in); 1484 } 1485 1486 static void init_dictionary_content(dictionary_t *const dict, 1487 istream_t *const in) { 1488 // Copy in the content 1489 dict->content_size = IO_istream_len(in); 1490 dict->content = malloc(dict->content_size); 1491 if (!dict->content) { 1492 BAD_ALLOC(); 1493 } 1494 1495 const u8 *const content = IO_get_read_ptr(in, dict->content_size); 1496 1497 memcpy(dict->content, content, dict->content_size); 1498 } 1499 1500 static void HUF_copy_dtable(HUF_dtable *const dst, 1501 const HUF_dtable *const src) { 1502 if (src->max_bits == 0) { 1503 memset(dst, 0, sizeof(HUF_dtable)); 1504 return; 1505 } 1506 1507 const size_t size = (size_t)1 << src->max_bits; 1508 dst->max_bits = src->max_bits; 1509 1510 dst->symbols = malloc(size); 1511 dst->num_bits = malloc(size); 1512 if (!dst->symbols || !dst->num_bits) { 1513 BAD_ALLOC(); 1514 } 1515 1516 memcpy(dst->symbols, src->symbols, size); 1517 memcpy(dst->num_bits, src->num_bits, size); 1518 } 1519 1520 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) { 1521 if (src->accuracy_log == 0) { 1522 memset(dst, 0, sizeof(FSE_dtable)); 1523 return; 1524 } 1525 1526 size_t size = (size_t)1 << src->accuracy_log; 1527 dst->accuracy_log = src->accuracy_log; 1528 1529 dst->symbols = malloc(size); 1530 dst->num_bits = malloc(size); 1531 dst->new_state_base = malloc(size * sizeof(u16)); 1532 if (!dst->symbols || !dst->num_bits || !dst->new_state_base) { 1533 BAD_ALLOC(); 1534 } 1535 1536 memcpy(dst->symbols, src->symbols, size); 1537 memcpy(dst->num_bits, src->num_bits, size); 1538 memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16)); 1539 } 1540 1541 /// A dictionary acts as initializing values for the frame context before 1542 /// decompression, so we implement it by applying it's predetermined 1543 /// tables and content to the context before beginning decompression 1544 static void frame_context_apply_dict(frame_context_t *const ctx, 1545 const dictionary_t *const dict) { 1546 // If the content pointer is NULL then it must be an empty dict 1547 if (!dict || !dict->content) 1548 return; 1549 1550 // If the requested dictionary_id is non-zero, the correct dictionary must 1551 // be present 1552 if (ctx->header.dictionary_id != 0 && 1553 ctx->header.dictionary_id != dict->dictionary_id) { 1554 ERROR("Wrong dictionary provided"); 1555 } 1556 1557 // Copy the dict content to the context for references during sequence 1558 // execution 1559 ctx->dict_content = dict->content; 1560 ctx->dict_content_len = dict->content_size; 1561 1562 // If it's a formatted dict copy the precomputed tables in so they can 1563 // be used in the table repeat modes 1564 if (dict->dictionary_id != 0) { 1565 // Deep copy the entropy tables so they can be freed independently of 1566 // the dictionary struct 1567 HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable); 1568 FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable); 1569 FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable); 1570 FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable); 1571 1572 // Copy the repeated offsets 1573 memcpy(ctx->previous_offsets, dict->previous_offsets, 1574 sizeof(ctx->previous_offsets)); 1575 } 1576 } 1577 1578 #else // ZDEC_NO_DICTIONARY is defined 1579 1580 static void frame_context_apply_dict(frame_context_t *const ctx, 1581 const dictionary_t *const dict) { 1582 (void)ctx; 1583 if (dict && dict->content) ERROR("dictionary not supported"); 1584 } 1585 1586 #endif 1587 /******* END DICTIONARY PARSING ***********************************************/ 1588 1589 /******* IO STREAM OPERATIONS *************************************************/ 1590 1591 /// Reads `num` bits from a bitstream, and updates the internal offset 1592 static inline u64 IO_read_bits(istream_t *const in, const int num_bits) { 1593 if (num_bits > 64 || num_bits <= 0) { 1594 ERROR("Attempt to read an invalid number of bits"); 1595 } 1596 1597 const size_t bytes = (num_bits + in->bit_offset + 7) / 8; 1598 const size_t full_bytes = (num_bits + in->bit_offset) / 8; 1599 if (bytes > in->len) { 1600 INP_SIZE(); 1601 } 1602 1603 const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset); 1604 1605 in->bit_offset = (num_bits + in->bit_offset) % 8; 1606 in->ptr += full_bytes; 1607 in->len -= full_bytes; 1608 1609 return result; 1610 } 1611 1612 /// If a non-zero number of bits have been read from the current byte, advance 1613 /// the offset to the next byte 1614 static inline void IO_rewind_bits(istream_t *const in, int num_bits) { 1615 if (num_bits < 0) { 1616 ERROR("Attempting to rewind stream by a negative number of bits"); 1617 } 1618 1619 // move the offset back by `num_bits` bits 1620 const int new_offset = in->bit_offset - num_bits; 1621 // determine the number of whole bytes we have to rewind, rounding up to an 1622 // integer number (e.g. if `new_offset == -5`, `bytes == 1`) 1623 const i64 bytes = -(new_offset - 7) / 8; 1624 1625 in->ptr -= bytes; 1626 in->len += bytes; 1627 // make sure the resulting `bit_offset` is positive, as mod in C does not 1628 // convert numbers from negative to positive (e.g. -22 % 8 == -6) 1629 in->bit_offset = ((new_offset % 8) + 8) % 8; 1630 } 1631 1632 /// If the remaining bits in a byte will be unused, advance to the end of the 1633 /// byte 1634 static inline void IO_align_stream(istream_t *const in) { 1635 if (in->bit_offset != 0) { 1636 if (in->len == 0) { 1637 INP_SIZE(); 1638 } 1639 in->ptr++; 1640 in->len--; 1641 in->bit_offset = 0; 1642 } 1643 } 1644 1645 /// Write the given byte into the output stream 1646 static inline void IO_write_byte(ostream_t *const out, u8 symb) { 1647 if (out->len == 0) { 1648 OUT_SIZE(); 1649 } 1650 1651 out->ptr[0] = symb; 1652 out->ptr++; 1653 out->len--; 1654 } 1655 1656 /// Returns the number of bytes left to be read in this stream. The stream must 1657 /// be byte aligned. 1658 static inline size_t IO_istream_len(const istream_t *const in) { 1659 return in->len; 1660 } 1661 1662 /// Returns a pointer where `len` bytes can be read, and advances the internal 1663 /// state. The stream must be byte aligned. 1664 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) { 1665 if (len > in->len) { 1666 INP_SIZE(); 1667 } 1668 if (in->bit_offset != 0) { 1669 ERROR("Attempting to operate on a non-byte aligned stream"); 1670 } 1671 const u8 *const ptr = in->ptr; 1672 in->ptr += len; 1673 in->len -= len; 1674 1675 return ptr; 1676 } 1677 /// Returns a pointer to write `len` bytes to, and advances the internal state 1678 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) { 1679 if (len > out->len) { 1680 OUT_SIZE(); 1681 } 1682 u8 *const ptr = out->ptr; 1683 out->ptr += len; 1684 out->len -= len; 1685 1686 return ptr; 1687 } 1688 1689 /// Advance the inner state by `len` bytes 1690 static inline void IO_advance_input(istream_t *const in, size_t len) { 1691 if (len > in->len) { 1692 INP_SIZE(); 1693 } 1694 if (in->bit_offset != 0) { 1695 ERROR("Attempting to operate on a non-byte aligned stream"); 1696 } 1697 1698 in->ptr += len; 1699 in->len -= len; 1700 } 1701 1702 /// Returns an `ostream_t` constructed from the given pointer and length 1703 static inline ostream_t IO_make_ostream(u8 *out, size_t len) { 1704 return (ostream_t) { out, len }; 1705 } 1706 1707 /// Returns an `istream_t` constructed from the given pointer and length 1708 static inline istream_t IO_make_istream(const u8 *in, size_t len) { 1709 return (istream_t) { in, len, 0 }; 1710 } 1711 1712 /// Returns an `istream_t` with the same base as `in`, and length `len` 1713 /// Then, advance `in` to account for the consumed bytes 1714 /// `in` must be byte aligned 1715 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) { 1716 // Consume `len` bytes of the parent stream 1717 const u8 *const ptr = IO_get_read_ptr(in, len); 1718 1719 // Make a substream using the pointer to those `len` bytes 1720 return IO_make_istream(ptr, len); 1721 } 1722 /******* END IO STREAM OPERATIONS *********************************************/ 1723 1724 /******* BITSTREAM OPERATIONS *************************************************/ 1725 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits 1726 static inline u64 read_bits_LE(const u8 *src, const int num_bits, 1727 const size_t offset) { 1728 if (num_bits > 64) { 1729 ERROR("Attempt to read an invalid number of bits"); 1730 } 1731 1732 // Skip over bytes that aren't in range 1733 src += offset / 8; 1734 size_t bit_offset = offset % 8; 1735 u64 res = 0; 1736 1737 int shift = 0; 1738 int left = num_bits; 1739 while (left > 0) { 1740 u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1); 1741 // Read the next byte, shift it to account for the offset, and then mask 1742 // out the top part if we don't need all the bits 1743 res += (((u64)*src++ >> bit_offset) & mask) << shift; 1744 shift += 8 - bit_offset; 1745 left -= 8 - bit_offset; 1746 bit_offset = 0; 1747 } 1748 1749 return res; 1750 } 1751 1752 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so 1753 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from 1754 /// `src + offset`. If the offset becomes negative, the extra bits at the 1755 /// bottom are filled in with `0` bits instead of reading from before `src`. 1756 static inline u64 STREAM_read_bits(const u8 *const src, const int bits, 1757 i64 *const offset) { 1758 *offset = *offset - bits; 1759 size_t actual_off = *offset; 1760 size_t actual_bits = bits; 1761 // Don't actually read bits from before the start of src, so if `*offset < 1762 // 0` fix actual_off and actual_bits to reflect the quantity to read 1763 if (*offset < 0) { 1764 actual_bits += *offset; 1765 actual_off = 0; 1766 } 1767 u64 res = read_bits_LE(src, actual_bits, actual_off); 1768 1769 if (*offset < 0) { 1770 // Fill in the bottom "overflowed" bits with 0's 1771 res = -*offset >= 64 ? 0 : (res << -*offset); 1772 } 1773 return res; 1774 } 1775 /******* END BITSTREAM OPERATIONS *********************************************/ 1776 1777 /******* BIT COUNTING OPERATIONS **********************************************/ 1778 /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to 1779 /// `num`, or `-1` if `num == 0`. 1780 static inline int highest_set_bit(const u64 num) { 1781 for (int i = 63; i >= 0; i--) { 1782 if (((u64)1 << i) <= num) { 1783 return i; 1784 } 1785 } 1786 return -1; 1787 } 1788 /******* END BIT COUNTING OPERATIONS ******************************************/ 1789 1790 /******* HUFFMAN PRIMITIVES ***************************************************/ 1791 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, 1792 u16 *const state, const u8 *const src, 1793 i64 *const offset) { 1794 // Look up the symbol and number of bits to read 1795 const u8 symb = dtable->symbols[*state]; 1796 const u8 bits = dtable->num_bits[*state]; 1797 const u16 rest = STREAM_read_bits(src, bits, offset); 1798 // Shift `bits` bits out of the state, keeping the low order bits that 1799 // weren't necessary to determine this symbol. Then add in the new bits 1800 // read from the stream. 1801 *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1); 1802 1803 return symb; 1804 } 1805 1806 static inline void HUF_init_state(const HUF_dtable *const dtable, 1807 u16 *const state, const u8 *const src, 1808 i64 *const offset) { 1809 // Read in a full `dtable->max_bits` bits to initialize the state 1810 const u8 bits = dtable->max_bits; 1811 *state = STREAM_read_bits(src, bits, offset); 1812 } 1813 1814 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, 1815 ostream_t *const out, 1816 istream_t *const in) { 1817 const size_t len = IO_istream_len(in); 1818 if (len == 0) { 1819 INP_SIZE(); 1820 } 1821 const u8 *const src = IO_get_read_ptr(in, len); 1822 1823 // "Each bitstream must be read backward, that is starting from the end down 1824 // to the beginning. Therefore it's necessary to know the size of each 1825 // bitstream. 1826 // 1827 // It's also necessary to know exactly which bit is the latest. This is 1828 // detected by a final bit flag : the highest bit of latest byte is a 1829 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the 1830 // final-bit-flag itself is not part of the useful bitstream. Hence, the 1831 // last byte contains between 0 and 7 useful bits." 1832 const int padding = 8 - highest_set_bit(src[len - 1]); 1833 1834 // Offset starts at the end because HUF streams are read backwards 1835 i64 bit_offset = len * 8 - padding; 1836 u16 state; 1837 1838 HUF_init_state(dtable, &state, src, &bit_offset); 1839 1840 size_t symbols_written = 0; 1841 while (bit_offset > -dtable->max_bits) { 1842 // Iterate over the stream, decoding one symbol at a time 1843 IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset)); 1844 symbols_written++; 1845 } 1846 // "The process continues up to reading the required number of symbols per 1847 // stream. If a bitstream is not entirely and exactly consumed, hence 1848 // reaching exactly its beginning position with all bits consumed, the 1849 // decoding process is considered faulty." 1850 1851 // When all symbols have been decoded, the final state value shouldn't have 1852 // any data from the stream, so it should have "read" dtable->max_bits from 1853 // before the start of `src` 1854 // Therefore `offset`, the edge to start reading new bits at, should be 1855 // dtable->max_bits before the start of the stream 1856 if (bit_offset != -dtable->max_bits) { 1857 CORRUPTION(); 1858 } 1859 1860 return symbols_written; 1861 } 1862 1863 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, 1864 ostream_t *const out, istream_t *const in) { 1865 // "Compressed size is provided explicitly : in the 4-streams variant, 1866 // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each 1867 // value represents the compressed size of one stream, in order. The last 1868 // stream size is deducted from total compressed size and from previously 1869 // decoded stream sizes" 1870 const size_t csize1 = IO_read_bits(in, 16); 1871 const size_t csize2 = IO_read_bits(in, 16); 1872 const size_t csize3 = IO_read_bits(in, 16); 1873 1874 istream_t in1 = IO_make_sub_istream(in, csize1); 1875 istream_t in2 = IO_make_sub_istream(in, csize2); 1876 istream_t in3 = IO_make_sub_istream(in, csize3); 1877 istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in)); 1878 1879 size_t total_output = 0; 1880 // Decode each stream independently for simplicity 1881 // If we wanted to we could decode all 4 at the same time for speed, 1882 // utilizing more execution units 1883 total_output += HUF_decompress_1stream(dtable, out, &in1); 1884 total_output += HUF_decompress_1stream(dtable, out, &in2); 1885 total_output += HUF_decompress_1stream(dtable, out, &in3); 1886 total_output += HUF_decompress_1stream(dtable, out, &in4); 1887 1888 return total_output; 1889 } 1890 1891 /// Initializes a Huffman table using canonical Huffman codes 1892 /// For more explanation on canonical Huffman codes see 1893 /// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html 1894 /// Codes within a level are allocated in symbol order (i.e. smaller symbols get 1895 /// earlier codes) 1896 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, 1897 const int num_symbs) { 1898 memset(table, 0, sizeof(HUF_dtable)); 1899 if (num_symbs > HUF_MAX_SYMBS) { 1900 ERROR("Too many symbols for Huffman"); 1901 } 1902 1903 u8 max_bits = 0; 1904 u16 rank_count[HUF_MAX_BITS + 1]; 1905 memset(rank_count, 0, sizeof(rank_count)); 1906 1907 // Count the number of symbols for each number of bits, and determine the 1908 // depth of the tree 1909 for (int i = 0; i < num_symbs; i++) { 1910 if (bits[i] > HUF_MAX_BITS) { 1911 ERROR("Huffman table depth too large"); 1912 } 1913 max_bits = MAX(max_bits, bits[i]); 1914 rank_count[bits[i]]++; 1915 } 1916 1917 const size_t table_size = 1 << max_bits; 1918 table->max_bits = max_bits; 1919 table->symbols = malloc(table_size); 1920 table->num_bits = malloc(table_size); 1921 1922 if (!table->symbols || !table->num_bits) { 1923 free(table->symbols); 1924 free(table->num_bits); 1925 BAD_ALLOC(); 1926 } 1927 1928 // "Symbols are sorted by Weight. Within same Weight, symbols keep natural 1929 // order. Symbols with a Weight of zero are removed. Then, starting from 1930 // lowest weight, prefix codes are distributed in order." 1931 1932 u32 rank_idx[HUF_MAX_BITS + 1]; 1933 // Initialize the starting codes for each rank (number of bits) 1934 rank_idx[max_bits] = 0; 1935 for (int i = max_bits; i >= 1; i--) { 1936 rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i)); 1937 // The entire range takes the same number of bits so we can memset it 1938 memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]); 1939 } 1940 1941 if (rank_idx[0] != table_size) { 1942 CORRUPTION(); 1943 } 1944 1945 // Allocate codes and fill in the table 1946 for (int i = 0; i < num_symbs; i++) { 1947 if (bits[i] != 0) { 1948 // Allocate a code for this symbol and set its range in the table 1949 const u16 code = rank_idx[bits[i]]; 1950 // Since the code doesn't care about the bottom `max_bits - bits[i]` 1951 // bits of state, it gets a range that spans all possible values of 1952 // the lower bits 1953 const u16 len = 1 << (max_bits - bits[i]); 1954 memset(&table->symbols[code], i, len); 1955 rank_idx[bits[i]] += len; 1956 } 1957 } 1958 } 1959 1960 static void HUF_init_dtable_usingweights(HUF_dtable *const table, 1961 const u8 *const weights, 1962 const int num_symbs) { 1963 // +1 because the last weight is not transmitted in the header 1964 if (num_symbs + 1 > HUF_MAX_SYMBS) { 1965 ERROR("Too many symbols for Huffman"); 1966 } 1967 1968 u8 bits[HUF_MAX_SYMBS]; 1969 1970 u64 weight_sum = 0; 1971 for (int i = 0; i < num_symbs; i++) { 1972 // Weights are in the same range as bit count 1973 if (weights[i] > HUF_MAX_BITS) { 1974 CORRUPTION(); 1975 } 1976 weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0; 1977 } 1978 1979 // Find the first power of 2 larger than the sum 1980 const int max_bits = highest_set_bit(weight_sum) + 1; 1981 const u64 left_over = ((u64)1 << max_bits) - weight_sum; 1982 // If the left over isn't a power of 2, the weights are invalid 1983 if (left_over & (left_over - 1)) { 1984 CORRUPTION(); 1985 } 1986 1987 // left_over is used to find the last weight as it's not transmitted 1988 // by inverting 2^(weight - 1) we can determine the value of last_weight 1989 const int last_weight = highest_set_bit(left_over) + 1; 1990 1991 for (int i = 0; i < num_symbs; i++) { 1992 // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0" 1993 bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0; 1994 } 1995 bits[num_symbs] = 1996 max_bits + 1 - last_weight; // Last weight is always non-zero 1997 1998 HUF_init_dtable(table, bits, num_symbs + 1); 1999 } 2000 2001 static void HUF_free_dtable(HUF_dtable *const dtable) { 2002 free(dtable->symbols); 2003 free(dtable->num_bits); 2004 memset(dtable, 0, sizeof(HUF_dtable)); 2005 } 2006 /******* END HUFFMAN PRIMITIVES ***********************************************/ 2007 2008 /******* FSE PRIMITIVES *******************************************************/ 2009 /// For more description of FSE see 2010 /// https://github.com/Cyan4973/FiniteStateEntropy/ 2011 2012 /// Allow a symbol to be decoded without updating state 2013 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, 2014 const u16 state) { 2015 return dtable->symbols[state]; 2016 } 2017 2018 /// Consumes bits from the input and uses the current state to determine the 2019 /// next state 2020 static inline void FSE_update_state(const FSE_dtable *const dtable, 2021 u16 *const state, const u8 *const src, 2022 i64 *const offset) { 2023 const u8 bits = dtable->num_bits[*state]; 2024 const u16 rest = STREAM_read_bits(src, bits, offset); 2025 *state = dtable->new_state_base[*state] + rest; 2026 } 2027 2028 /// Decodes a single FSE symbol and updates the offset 2029 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, 2030 u16 *const state, const u8 *const src, 2031 i64 *const offset) { 2032 const u8 symb = FSE_peek_symbol(dtable, *state); 2033 FSE_update_state(dtable, state, src, offset); 2034 return symb; 2035 } 2036 2037 static inline void FSE_init_state(const FSE_dtable *const dtable, 2038 u16 *const state, const u8 *const src, 2039 i64 *const offset) { 2040 // Read in a full `accuracy_log` bits to initialize the state 2041 const u8 bits = dtable->accuracy_log; 2042 *state = STREAM_read_bits(src, bits, offset); 2043 } 2044 2045 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, 2046 ostream_t *const out, 2047 istream_t *const in) { 2048 const size_t len = IO_istream_len(in); 2049 if (len == 0) { 2050 INP_SIZE(); 2051 } 2052 const u8 *const src = IO_get_read_ptr(in, len); 2053 2054 // "Each bitstream must be read backward, that is starting from the end down 2055 // to the beginning. Therefore it's necessary to know the size of each 2056 // bitstream. 2057 // 2058 // It's also necessary to know exactly which bit is the latest. This is 2059 // detected by a final bit flag : the highest bit of latest byte is a 2060 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the 2061 // final-bit-flag itself is not part of the useful bitstream. Hence, the 2062 // last byte contains between 0 and 7 useful bits." 2063 const int padding = 8 - highest_set_bit(src[len - 1]); 2064 i64 offset = len * 8 - padding; 2065 2066 u16 state1, state2; 2067 // "The first state (State1) encodes the even indexed symbols, and the 2068 // second (State2) encodes the odd indexes. State1 is initialized first, and 2069 // then State2, and they take turns decoding a single symbol and updating 2070 // their state." 2071 FSE_init_state(dtable, &state1, src, &offset); 2072 FSE_init_state(dtable, &state2, src, &offset); 2073 2074 // Decode until we overflow the stream 2075 // Since we decode in reverse order, overflowing the stream is offset going 2076 // negative 2077 size_t symbols_written = 0; 2078 while (1) { 2079 // "The number of symbols to decode is determined by tracking bitStream 2080 // overflow condition: If updating state after decoding a symbol would 2081 // require more bits than remain in the stream, it is assumed the extra 2082 // bits are 0. Then, the symbols for each of the final states are 2083 // decoded and the process is complete." 2084 IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset)); 2085 symbols_written++; 2086 if (offset < 0) { 2087 // There's still a symbol to decode in state2 2088 IO_write_byte(out, FSE_peek_symbol(dtable, state2)); 2089 symbols_written++; 2090 break; 2091 } 2092 2093 IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset)); 2094 symbols_written++; 2095 if (offset < 0) { 2096 // There's still a symbol to decode in state1 2097 IO_write_byte(out, FSE_peek_symbol(dtable, state1)); 2098 symbols_written++; 2099 break; 2100 } 2101 } 2102 2103 return symbols_written; 2104 } 2105 2106 static void FSE_init_dtable(FSE_dtable *const dtable, 2107 const i16 *const norm_freqs, const int num_symbs, 2108 const int accuracy_log) { 2109 if (accuracy_log > FSE_MAX_ACCURACY_LOG) { 2110 ERROR("FSE accuracy too large"); 2111 } 2112 if (num_symbs > FSE_MAX_SYMBS) { 2113 ERROR("Too many symbols for FSE"); 2114 } 2115 2116 dtable->accuracy_log = accuracy_log; 2117 2118 const size_t size = (size_t)1 << accuracy_log; 2119 dtable->symbols = malloc(size * sizeof(u8)); 2120 dtable->num_bits = malloc(size * sizeof(u8)); 2121 dtable->new_state_base = malloc(size * sizeof(u16)); 2122 2123 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { 2124 BAD_ALLOC(); 2125 } 2126 2127 // Used to determine how many bits need to be read for each state, 2128 // and where the destination range should start 2129 // Needs to be u16 because max value is 2 * max number of symbols, 2130 // which can be larger than a byte can store 2131 u16 state_desc[FSE_MAX_SYMBS]; 2132 2133 // "Symbols are scanned in their natural order for "less than 1" 2134 // probabilities. Symbols with this probability are being attributed a 2135 // single cell, starting from the end of the table. These symbols define a 2136 // full state reset, reading Accuracy_Log bits." 2137 int high_threshold = size; 2138 for (int s = 0; s < num_symbs; s++) { 2139 // Scan for low probability symbols to put at the top 2140 if (norm_freqs[s] == -1) { 2141 dtable->symbols[--high_threshold] = s; 2142 state_desc[s] = 1; 2143 } 2144 } 2145 2146 // "All remaining symbols are sorted in their natural order. Starting from 2147 // symbol 0 and table position 0, each symbol gets attributed as many cells 2148 // as its probability. Cell allocation is spreaded, not linear." 2149 // Place the rest in the table 2150 const u16 step = (size >> 1) + (size >> 3) + 3; 2151 const u16 mask = size - 1; 2152 u16 pos = 0; 2153 for (int s = 0; s < num_symbs; s++) { 2154 if (norm_freqs[s] <= 0) { 2155 continue; 2156 } 2157 2158 state_desc[s] = norm_freqs[s]; 2159 2160 for (int i = 0; i < norm_freqs[s]; i++) { 2161 // Give `norm_freqs[s]` states to symbol s 2162 dtable->symbols[pos] = s; 2163 // "A position is skipped if already occupied, typically by a "less 2164 // than 1" probability symbol." 2165 do { 2166 pos = (pos + step) & mask; 2167 } while (pos >= 2168 high_threshold); 2169 // Note: no other collision checking is necessary as `step` is 2170 // coprime to `size`, so the cycle will visit each position exactly 2171 // once 2172 } 2173 } 2174 if (pos != 0) { 2175 CORRUPTION(); 2176 } 2177 2178 // Now we can fill baseline and num bits 2179 for (size_t i = 0; i < size; i++) { 2180 u8 symbol = dtable->symbols[i]; 2181 u16 next_state_desc = state_desc[symbol]++; 2182 // Fills in the table appropriately, next_state_desc increases by symbol 2183 // over time, decreasing number of bits 2184 dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc)); 2185 // Baseline increases until the bit threshold is passed, at which point 2186 // it resets to 0 2187 dtable->new_state_base[i] = 2188 ((u16)next_state_desc << dtable->num_bits[i]) - size; 2189 } 2190 } 2191 2192 /// Decode an FSE header as defined in the Zstandard format specification and 2193 /// use the decoded frequencies to initialize a decoding table. 2194 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, 2195 const int max_accuracy_log) { 2196 // "An FSE distribution table describes the probabilities of all symbols 2197 // from 0 to the last present one (included) on a normalized scale of 1 << 2198 // Accuracy_Log . 2199 // 2200 // It's a bitstream which is read forward, in little-endian fashion. It's 2201 // not necessary to know its exact size, since it will be discovered and 2202 // reported by the decoding process. 2203 if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) { 2204 ERROR("FSE accuracy too large"); 2205 } 2206 2207 // The bitstream starts by reporting on which scale it operates. 2208 // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal 2209 // and match lengths is 9, and for offsets is 8. Higher values are 2210 // considered errors." 2211 const int accuracy_log = 5 + IO_read_bits(in, 4); 2212 if (accuracy_log > max_accuracy_log) { 2213 ERROR("FSE accuracy too large"); 2214 } 2215 2216 // "Then follows each symbol value, from 0 to last present one. The number 2217 // of bits used by each field is variable. It depends on : 2218 // 2219 // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8, 2220 // and presuming 100 probabilities points have already been distributed, the 2221 // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive). 2222 // Therefore, it must read log2sup(156) == 8 bits. 2223 // 2224 // Value decoded : small values use 1 less bit : example : Presuming values 2225 // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining 2226 // in an 8-bits field. They are used this way : first 99 values (hence from 2227 // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. " 2228 2229 i32 remaining = 1 << accuracy_log; 2230 i16 frequencies[FSE_MAX_SYMBS]; 2231 2232 int symb = 0; 2233 while (remaining > 0 && symb < FSE_MAX_SYMBS) { 2234 // Log of the number of possible values we could read 2235 int bits = highest_set_bit(remaining + 1) + 1; 2236 2237 u16 val = IO_read_bits(in, bits); 2238 2239 // Try to mask out the lower bits to see if it qualifies for the "small 2240 // value" threshold 2241 const u16 lower_mask = ((u16)1 << (bits - 1)) - 1; 2242 const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1); 2243 2244 if ((val & lower_mask) < threshold) { 2245 IO_rewind_bits(in, 1); 2246 val = val & lower_mask; 2247 } else if (val > lower_mask) { 2248 val = val - threshold; 2249 } 2250 2251 // "Probability is obtained from Value decoded by following formula : 2252 // Proba = value - 1" 2253 const i16 proba = (i16)val - 1; 2254 2255 // "It means value 0 becomes negative probability -1. -1 is a special 2256 // probability, which means "less than 1". Its effect on distribution 2257 // table is described in next paragraph. For the purpose of calculating 2258 // cumulated distribution, it counts as one." 2259 remaining -= proba < 0 ? -proba : proba; 2260 2261 frequencies[symb] = proba; 2262 symb++; 2263 2264 // "When a symbol has a probability of zero, it is followed by a 2-bits 2265 // repeat flag. This repeat flag tells how many probabilities of zeroes 2266 // follow the current one. It provides a number ranging from 0 to 3. If 2267 // it is a 3, another 2-bits repeat flag follows, and so on." 2268 if (proba == 0) { 2269 // Read the next two bits to see how many more 0s 2270 int repeat = IO_read_bits(in, 2); 2271 2272 while (1) { 2273 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) { 2274 frequencies[symb++] = 0; 2275 } 2276 if (repeat == 3) { 2277 repeat = IO_read_bits(in, 2); 2278 } else { 2279 break; 2280 } 2281 } 2282 } 2283 } 2284 IO_align_stream(in); 2285 2286 // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding 2287 // is complete. If the last symbol makes cumulated total go above 1 << 2288 // Accuracy_Log, distribution is considered corrupted." 2289 if (remaining != 0 || symb >= FSE_MAX_SYMBS) { 2290 CORRUPTION(); 2291 } 2292 2293 // Initialize the decoding table using the determined weights 2294 FSE_init_dtable(dtable, frequencies, symb, accuracy_log); 2295 } 2296 2297 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) { 2298 dtable->symbols = malloc(sizeof(u8)); 2299 dtable->num_bits = malloc(sizeof(u8)); 2300 dtable->new_state_base = malloc(sizeof(u16)); 2301 2302 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { 2303 BAD_ALLOC(); 2304 } 2305 2306 // This setup will always have a state of 0, always return symbol `symb`, 2307 // and never consume any bits 2308 dtable->symbols[0] = symb; 2309 dtable->num_bits[0] = 0; 2310 dtable->new_state_base[0] = 0; 2311 dtable->accuracy_log = 0; 2312 } 2313 2314 static void FSE_free_dtable(FSE_dtable *const dtable) { 2315 free(dtable->symbols); 2316 free(dtable->num_bits); 2317 free(dtable->new_state_base); 2318 memset(dtable, 0, sizeof(FSE_dtable)); 2319 } 2320 /******* END FSE PRIMITIVES ***************************************************/ 2321