1 /* 2 * Copyright (c) 2017-present, Facebook, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under both the BSD-style license (found in the 6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7 * in the COPYING file in the root directory of this source tree). 8 */ 9 10 /// Zstandard educational decoder implementation 11 /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md 12 13 #include <stdint.h> 14 #include <stdio.h> 15 #include <stdlib.h> 16 #include <string.h> 17 #include "zstd_decompress.h" 18 19 /******* UTILITY MACROS AND TYPES *********************************************/ 20 // Max block size decompressed size is 128 KB and literal blocks can't be 21 // larger than their block 22 #define MAX_LITERALS_SIZE ((size_t)128 * 1024) 23 24 #define MAX(a, b) ((a) > (b) ? (a) : (b)) 25 #define MIN(a, b) ((a) < (b) ? (a) : (b)) 26 27 /// This decoder calls exit(1) when it encounters an error, however a production 28 /// library should propagate error codes 29 #define ERROR(s) \ 30 do { \ 31 fprintf(stderr, "Error: %s\n", s); \ 32 exit(1); \ 33 } while (0) 34 #define INP_SIZE() \ 35 ERROR("Input buffer smaller than it should be or input is " \ 36 "corrupted") 37 #define OUT_SIZE() ERROR("Output buffer too small for output") 38 #define CORRUPTION() ERROR("Corruption detected while decompressing") 39 #define BAD_ALLOC() ERROR("Memory allocation error") 40 #define IMPOSSIBLE() ERROR("An impossibility has occurred") 41 42 typedef uint8_t u8; 43 typedef uint16_t u16; 44 typedef uint32_t u32; 45 typedef uint64_t u64; 46 47 typedef int8_t i8; 48 typedef int16_t i16; 49 typedef int32_t i32; 50 typedef int64_t i64; 51 /******* END UTILITY MACROS AND TYPES *****************************************/ 52 53 /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/ 54 /// The implementations for these functions can be found at the bottom of this 55 /// file. They implement low-level functionality needed for the higher level 56 /// decompression functions. 57 58 /*** IO STREAM OPERATIONS *************/ 59 60 /// ostream_t/istream_t are used to wrap the pointers/length data passed into 61 /// ZSTD_decompress, so that all IO operations are safely bounds checked 62 /// They are written/read forward, and reads are treated as little-endian 63 /// They should be used opaquely to ensure safety 64 typedef struct { 65 u8 *ptr; 66 size_t len; 67 } ostream_t; 68 69 typedef struct { 70 const u8 *ptr; 71 size_t len; 72 73 // Input often reads a few bits at a time, so maintain an internal offset 74 int bit_offset; 75 } istream_t; 76 77 /// The following two functions are the only ones that allow the istream to be 78 /// non-byte aligned 79 80 /// Reads `num` bits from a bitstream, and updates the internal offset 81 static inline u64 IO_read_bits(istream_t *const in, const int num_bits); 82 /// Backs-up the stream by `num` bits so they can be read again 83 static inline void IO_rewind_bits(istream_t *const in, const int num_bits); 84 /// If the remaining bits in a byte will be unused, advance to the end of the 85 /// byte 86 static inline void IO_align_stream(istream_t *const in); 87 88 /// Write the given byte into the output stream 89 static inline void IO_write_byte(ostream_t *const out, u8 symb); 90 91 /// Returns the number of bytes left to be read in this stream. The stream must 92 /// be byte aligned. 93 static inline size_t IO_istream_len(const istream_t *const in); 94 95 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that 96 /// was skipped. The stream must be byte aligned. 97 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len); 98 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that 99 /// was skipped so it can be written to. 100 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len); 101 102 /// Advance the inner state by `len` bytes. The stream must be byte aligned. 103 static inline void IO_advance_input(istream_t *const in, size_t len); 104 105 /// Returns an `ostream_t` constructed from the given pointer and length. 106 static inline ostream_t IO_make_ostream(u8 *out, size_t len); 107 /// Returns an `istream_t` constructed from the given pointer and length. 108 static inline istream_t IO_make_istream(const u8 *in, size_t len); 109 110 /// Returns an `istream_t` with the same base as `in`, and length `len`. 111 /// Then, advance `in` to account for the consumed bytes. 112 /// `in` must be byte aligned. 113 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len); 114 /*** END IO STREAM OPERATIONS *********/ 115 116 /*** BITSTREAM OPERATIONS *************/ 117 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits, 118 /// and return them interpreted as a little-endian unsigned integer. 119 static inline u64 read_bits_LE(const u8 *src, const int num_bits, 120 const size_t offset); 121 122 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so 123 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from 124 /// `src + offset`. If the offset becomes negative, the extra bits at the 125 /// bottom are filled in with `0` bits instead of reading from before `src`. 126 static inline u64 STREAM_read_bits(const u8 *src, const int bits, 127 i64 *const offset); 128 /*** END BITSTREAM OPERATIONS *********/ 129 130 /*** BIT COUNTING OPERATIONS **********/ 131 /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0` 132 static inline int highest_set_bit(const u64 num); 133 /*** END BIT COUNTING OPERATIONS ******/ 134 135 /*** HUFFMAN PRIMITIVES ***************/ 136 // Table decode method uses exponential memory, so we need to limit depth 137 #define HUF_MAX_BITS (16) 138 139 // Limit the maximum number of symbols to 256 so we can store a symbol in a byte 140 #define HUF_MAX_SYMBS (256) 141 142 /// Structure containing all tables necessary for efficient Huffman decoding 143 typedef struct { 144 u8 *symbols; 145 u8 *num_bits; 146 int max_bits; 147 } HUF_dtable; 148 149 /// Decode a single symbol and read in enough bits to refresh the state 150 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, 151 u16 *const state, const u8 *const src, 152 i64 *const offset); 153 /// Read in a full state's worth of bits to initialize it 154 static inline void HUF_init_state(const HUF_dtable *const dtable, 155 u16 *const state, const u8 *const src, 156 i64 *const offset); 157 158 /// Decompresses a single Huffman stream, returns the number of bytes decoded. 159 /// `src_len` must be the exact length of the Huffman-coded block. 160 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, 161 ostream_t *const out, istream_t *const in); 162 /// Same as previous but decodes 4 streams, formatted as in the Zstandard 163 /// specification. 164 /// `src_len` must be the exact length of the Huffman-coded block. 165 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, 166 ostream_t *const out, istream_t *const in); 167 168 /// Initialize a Huffman decoding table using the table of bit counts provided 169 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, 170 const int num_symbs); 171 /// Initialize a Huffman decoding table using the table of weights provided 172 /// Weights follow the definition provided in the Zstandard specification 173 static void HUF_init_dtable_usingweights(HUF_dtable *const table, 174 const u8 *const weights, 175 const int num_symbs); 176 177 /// Free the malloc'ed parts of a decoding table 178 static void HUF_free_dtable(HUF_dtable *const dtable); 179 180 /// Deep copy a decoding table, so that it can be used and free'd without 181 /// impacting the source table. 182 static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src); 183 /*** END HUFFMAN PRIMITIVES ***********/ 184 185 /*** FSE PRIMITIVES *******************/ 186 /// For more description of FSE see 187 /// https://github.com/Cyan4973/FiniteStateEntropy/ 188 189 // FSE table decoding uses exponential memory, so limit the maximum accuracy 190 #define FSE_MAX_ACCURACY_LOG (15) 191 // Limit the maximum number of symbols so they can be stored in a single byte 192 #define FSE_MAX_SYMBS (256) 193 194 /// The tables needed to decode FSE encoded streams 195 typedef struct { 196 u8 *symbols; 197 u8 *num_bits; 198 u16 *new_state_base; 199 int accuracy_log; 200 } FSE_dtable; 201 202 /// Return the symbol for the current state 203 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, 204 const u16 state); 205 /// Read the number of bits necessary to update state, update, and shift offset 206 /// back to reflect the bits read 207 static inline void FSE_update_state(const FSE_dtable *const dtable, 208 u16 *const state, const u8 *const src, 209 i64 *const offset); 210 211 /// Combine peek and update: decode a symbol and update the state 212 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, 213 u16 *const state, const u8 *const src, 214 i64 *const offset); 215 216 /// Read bits from the stream to initialize the state and shift offset back 217 static inline void FSE_init_state(const FSE_dtable *const dtable, 218 u16 *const state, const u8 *const src, 219 i64 *const offset); 220 221 /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights) 222 /// using an FSE decoding table. `src_len` must be the exact length of the 223 /// block. 224 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, 225 ostream_t *const out, 226 istream_t *const in); 227 228 /// Initialize a decoding table using normalized frequencies. 229 static void FSE_init_dtable(FSE_dtable *const dtable, 230 const i16 *const norm_freqs, const int num_symbs, 231 const int accuracy_log); 232 233 /// Decode an FSE header as defined in the Zstandard format specification and 234 /// use the decoded frequencies to initialize a decoding table. 235 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, 236 const int max_accuracy_log); 237 238 /// Initialize an FSE table that will always return the same symbol and consume 239 /// 0 bits per symbol, to be used for RLE mode in sequence commands 240 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb); 241 242 /// Free the malloc'ed parts of a decoding table 243 static void FSE_free_dtable(FSE_dtable *const dtable); 244 245 /// Deep copy a decoding table, so that it can be used and free'd without 246 /// impacting the source table. 247 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src); 248 /*** END FSE PRIMITIVES ***************/ 249 250 /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/ 251 252 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/ 253 254 /// A small structure that can be reused in various places that need to access 255 /// frame header information 256 typedef struct { 257 // The size of window that we need to be able to contiguously store for 258 // references 259 size_t window_size; 260 // The total output size of this compressed frame 261 size_t frame_content_size; 262 263 // The dictionary id if this frame uses one 264 u32 dictionary_id; 265 266 // Whether or not the content of this frame has a checksum 267 int content_checksum_flag; 268 // Whether or not the output for this frame is in a single segment 269 int single_segment_flag; 270 } frame_header_t; 271 272 /// The context needed to decode blocks in a frame 273 typedef struct { 274 frame_header_t header; 275 276 // The total amount of data available for backreferences, to determine if an 277 // offset too large to be correct 278 size_t current_total_output; 279 280 const u8 *dict_content; 281 size_t dict_content_len; 282 283 // Entropy encoding tables so they can be repeated by future blocks instead 284 // of retransmitting 285 HUF_dtable literals_dtable; 286 FSE_dtable ll_dtable; 287 FSE_dtable ml_dtable; 288 FSE_dtable of_dtable; 289 290 // The last 3 offsets for the special "repeat offsets". 291 u64 previous_offsets[3]; 292 } frame_context_t; 293 294 /// The decoded contents of a dictionary so that it doesn't have to be repeated 295 /// for each frame that uses it 296 struct dictionary_s { 297 // Entropy tables 298 HUF_dtable literals_dtable; 299 FSE_dtable ll_dtable; 300 FSE_dtable ml_dtable; 301 FSE_dtable of_dtable; 302 303 // Raw content for backreferences 304 u8 *content; 305 size_t content_size; 306 307 // Offset history to prepopulate the frame's history 308 u64 previous_offsets[3]; 309 310 u32 dictionary_id; 311 }; 312 313 /// A tuple containing the parts necessary to decode and execute a ZSTD sequence 314 /// command 315 typedef struct { 316 u32 literal_length; 317 u32 match_length; 318 u32 offset; 319 } sequence_command_t; 320 321 /// The decoder works top-down, starting at the high level like Zstd frames, and 322 /// working down to lower more technical levels such as blocks, literals, and 323 /// sequences. The high-level functions roughly follow the outline of the 324 /// format specification: 325 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md 326 327 /// Before the implementation of each high-level function declared here, the 328 /// prototypes for their helper functions are defined and explained 329 330 /// Decode a single Zstd frame, or error if the input is not a valid frame. 331 /// Accepts a dict argument, which may be NULL indicating no dictionary. 332 /// See 333 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation 334 static void decode_frame(ostream_t *const out, istream_t *const in, 335 const dictionary_t *const dict); 336 337 // Decode data in a compressed block 338 static void decompress_block(frame_context_t *const ctx, ostream_t *const out, 339 istream_t *const in); 340 341 // Decode the literals section of a block 342 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, 343 u8 **const literals); 344 345 // Decode the sequences part of a block 346 static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in, 347 sequence_command_t **const sequences); 348 349 // Execute the decoded sequences on the literals block 350 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, 351 const u8 *const literals, 352 const size_t literals_len, 353 const sequence_command_t *const sequences, 354 const size_t num_sequences); 355 356 // Copies literals and returns the total literal length that was copied 357 static u32 copy_literals(const size_t seq, istream_t *litstream, 358 ostream_t *const out); 359 360 // Given an offset code from a sequence command (either an actual offset value 361 // or an index for previous offset), computes the correct offset and updates 362 // the offset history 363 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist); 364 365 // Given an offset, match length, and total output, as well as the frame 366 // context for the dictionary, determines if the dictionary is used and 367 // executes the copy operation 368 static void execute_match_copy(frame_context_t *const ctx, size_t offset, 369 size_t match_length, size_t total_output, 370 ostream_t *const out); 371 372 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ 373 374 size_t ZSTD_decompress(void *const dst, const size_t dst_len, 375 const void *const src, const size_t src_len) { 376 dictionary_t* uninit_dict = create_dictionary(); 377 size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src, 378 src_len, uninit_dict); 379 free_dictionary(uninit_dict); 380 return decomp_size; 381 } 382 383 size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, 384 const void *const src, const size_t src_len, 385 dictionary_t* parsed_dict) { 386 387 istream_t in = IO_make_istream(src, src_len); 388 ostream_t out = IO_make_ostream(dst, dst_len); 389 390 // "A content compressed by Zstandard is transformed into a Zstandard frame. 391 // Multiple frames can be appended into a single file or stream. A frame is 392 // totally independent, has a defined beginning and end, and a set of 393 // parameters which tells the decoder how to decompress it." 394 395 /* this decoder assumes decompression of a single frame */ 396 decode_frame(&out, &in, parsed_dict); 397 398 return out.ptr - (u8 *)dst; 399 } 400 401 /******* FRAME DECODING ******************************************************/ 402 403 static void decode_data_frame(ostream_t *const out, istream_t *const in, 404 const dictionary_t *const dict); 405 static void init_frame_context(frame_context_t *const context, 406 istream_t *const in, 407 const dictionary_t *const dict); 408 static void free_frame_context(frame_context_t *const context); 409 static void parse_frame_header(frame_header_t *const header, 410 istream_t *const in); 411 static void frame_context_apply_dict(frame_context_t *const ctx, 412 const dictionary_t *const dict); 413 414 static void decompress_data(frame_context_t *const ctx, ostream_t *const out, 415 istream_t *const in); 416 417 static void decode_frame(ostream_t *const out, istream_t *const in, 418 const dictionary_t *const dict) { 419 const u32 magic_number = IO_read_bits(in, 32); 420 // Zstandard frame 421 // 422 // "Magic_Number 423 // 424 // 4 Bytes, little-endian format. Value : 0xFD2FB528" 425 if (magic_number == 0xFD2FB528U) { 426 // ZSTD frame 427 decode_data_frame(out, in, dict); 428 429 return; 430 } 431 432 // not a real frame or a skippable frame 433 ERROR("Tried to decode non-ZSTD frame"); 434 } 435 436 /// Decode a frame that contains compressed data. Not all frames do as there 437 /// are skippable frames. 438 /// See 439 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format 440 static void decode_data_frame(ostream_t *const out, istream_t *const in, 441 const dictionary_t *const dict) { 442 frame_context_t ctx; 443 444 // Initialize the context that needs to be carried from block to block 445 init_frame_context(&ctx, in, dict); 446 447 if (ctx.header.frame_content_size != 0 && 448 ctx.header.frame_content_size > out->len) { 449 OUT_SIZE(); 450 } 451 452 decompress_data(&ctx, out, in); 453 454 free_frame_context(&ctx); 455 } 456 457 /// Takes the information provided in the header and dictionary, and initializes 458 /// the context for this frame 459 static void init_frame_context(frame_context_t *const context, 460 istream_t *const in, 461 const dictionary_t *const dict) { 462 // Most fields in context are correct when initialized to 0 463 memset(context, 0, sizeof(frame_context_t)); 464 465 // Parse data from the frame header 466 parse_frame_header(&context->header, in); 467 468 // Set up the offset history for the repeat offset commands 469 context->previous_offsets[0] = 1; 470 context->previous_offsets[1] = 4; 471 context->previous_offsets[2] = 8; 472 473 // Apply details from the dict if it exists 474 frame_context_apply_dict(context, dict); 475 } 476 477 static void free_frame_context(frame_context_t *const context) { 478 HUF_free_dtable(&context->literals_dtable); 479 480 FSE_free_dtable(&context->ll_dtable); 481 FSE_free_dtable(&context->ml_dtable); 482 FSE_free_dtable(&context->of_dtable); 483 484 memset(context, 0, sizeof(frame_context_t)); 485 } 486 487 static void parse_frame_header(frame_header_t *const header, 488 istream_t *const in) { 489 // "The first header's byte is called the Frame_Header_Descriptor. It tells 490 // which other fields are present. Decoding this byte is enough to tell the 491 // size of Frame_Header. 492 // 493 // Bit number Field name 494 // 7-6 Frame_Content_Size_flag 495 // 5 Single_Segment_flag 496 // 4 Unused_bit 497 // 3 Reserved_bit 498 // 2 Content_Checksum_flag 499 // 1-0 Dictionary_ID_flag" 500 const u8 descriptor = IO_read_bits(in, 8); 501 502 // decode frame header descriptor into flags 503 const u8 frame_content_size_flag = descriptor >> 6; 504 const u8 single_segment_flag = (descriptor >> 5) & 1; 505 const u8 reserved_bit = (descriptor >> 3) & 1; 506 const u8 content_checksum_flag = (descriptor >> 2) & 1; 507 const u8 dictionary_id_flag = descriptor & 3; 508 509 if (reserved_bit != 0) { 510 CORRUPTION(); 511 } 512 513 header->single_segment_flag = single_segment_flag; 514 header->content_checksum_flag = content_checksum_flag; 515 516 // decode window size 517 if (!single_segment_flag) { 518 // "Provides guarantees on maximum back-reference distance that will be 519 // used within compressed data. This information is important for 520 // decoders to allocate enough memory. 521 // 522 // Bit numbers 7-3 2-0 523 // Field name Exponent Mantissa" 524 u8 window_descriptor = IO_read_bits(in, 8); 525 u8 exponent = window_descriptor >> 3; 526 u8 mantissa = window_descriptor & 7; 527 528 // Use the algorithm from the specification to compute window size 529 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor 530 size_t window_base = (size_t)1 << (10 + exponent); 531 size_t window_add = (window_base / 8) * mantissa; 532 header->window_size = window_base + window_add; 533 } 534 535 // decode dictionary id if it exists 536 if (dictionary_id_flag) { 537 // "This is a variable size field, which contains the ID of the 538 // dictionary required to properly decode the frame. Note that this 539 // field is optional. When it's not present, it's up to the caller to 540 // make sure it uses the correct dictionary. Format is little-endian." 541 const int bytes_array[] = {0, 1, 2, 4}; 542 const int bytes = bytes_array[dictionary_id_flag]; 543 544 header->dictionary_id = IO_read_bits(in, bytes * 8); 545 } else { 546 header->dictionary_id = 0; 547 } 548 549 // decode frame content size if it exists 550 if (single_segment_flag || frame_content_size_flag) { 551 // "This is the original (uncompressed) size. This information is 552 // optional. The Field_Size is provided according to value of 553 // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not 554 // present), 1, 2, 4 or 8 bytes. Format is little-endian." 555 // 556 // if frame_content_size_flag == 0 but single_segment_flag is set, we 557 // still have a 1 byte field 558 const int bytes_array[] = {1, 2, 4, 8}; 559 const int bytes = bytes_array[frame_content_size_flag]; 560 561 header->frame_content_size = IO_read_bits(in, bytes * 8); 562 if (bytes == 2) { 563 // "When Field_Size is 2, the offset of 256 is added." 564 header->frame_content_size += 256; 565 } 566 } else { 567 header->frame_content_size = 0; 568 } 569 570 if (single_segment_flag) { 571 // "The Window_Descriptor byte is optional. It is absent when 572 // Single_Segment_flag is set. In this case, the maximum back-reference 573 // distance is the content size itself, which can be any value from 1 to 574 // 2^64-1 bytes (16 EB)." 575 header->window_size = header->frame_content_size; 576 } 577 } 578 579 /// A dictionary acts as initializing values for the frame context before 580 /// decompression, so we implement it by applying it's predetermined 581 /// tables and content to the context before beginning decompression 582 static void frame_context_apply_dict(frame_context_t *const ctx, 583 const dictionary_t *const dict) { 584 // If the content pointer is NULL then it must be an empty dict 585 if (!dict || !dict->content) 586 return; 587 588 // If the requested dictionary_id is non-zero, the correct dictionary must 589 // be present 590 if (ctx->header.dictionary_id != 0 && 591 ctx->header.dictionary_id != dict->dictionary_id) { 592 ERROR("Wrong dictionary provided"); 593 } 594 595 // Copy the dict content to the context for references during sequence 596 // execution 597 ctx->dict_content = dict->content; 598 ctx->dict_content_len = dict->content_size; 599 600 // If it's a formatted dict copy the precomputed tables in so they can 601 // be used in the table repeat modes 602 if (dict->dictionary_id != 0) { 603 // Deep copy the entropy tables so they can be freed independently of 604 // the dictionary struct 605 HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable); 606 FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable); 607 FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable); 608 FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable); 609 610 // Copy the repeated offsets 611 memcpy(ctx->previous_offsets, dict->previous_offsets, 612 sizeof(ctx->previous_offsets)); 613 } 614 } 615 616 /// Decompress the data from a frame block by block 617 static void decompress_data(frame_context_t *const ctx, ostream_t *const out, 618 istream_t *const in) { 619 // "A frame encapsulates one or multiple blocks. Each block can be 620 // compressed or not, and has a guaranteed maximum content size, which 621 // depends on frame parameters. Unlike frames, each block depends on 622 // previous blocks for proper decoding. However, each block can be 623 // decompressed without waiting for its successor, allowing streaming 624 // operations." 625 int last_block = 0; 626 do { 627 // "Last_Block 628 // 629 // The lowest bit signals if this block is the last one. Frame ends 630 // right after this block. 631 // 632 // Block_Type and Block_Size 633 // 634 // The next 2 bits represent the Block_Type, while the remaining 21 bits 635 // represent the Block_Size. Format is little-endian." 636 last_block = IO_read_bits(in, 1); 637 const int block_type = IO_read_bits(in, 2); 638 const size_t block_len = IO_read_bits(in, 21); 639 640 switch (block_type) { 641 case 0: { 642 // "Raw_Block - this is an uncompressed block. Block_Size is the 643 // number of bytes to read and copy." 644 const u8 *const read_ptr = IO_get_read_ptr(in, block_len); 645 u8 *const write_ptr = IO_get_write_ptr(out, block_len); 646 647 // Copy the raw data into the output 648 memcpy(write_ptr, read_ptr, block_len); 649 650 ctx->current_total_output += block_len; 651 break; 652 } 653 case 1: { 654 // "RLE_Block - this is a single byte, repeated N times. In which 655 // case, Block_Size is the size to regenerate, while the 656 // "compressed" block is just 1 byte (the byte to repeat)." 657 const u8 *const read_ptr = IO_get_read_ptr(in, 1); 658 u8 *const write_ptr = IO_get_write_ptr(out, block_len); 659 660 // Copy `block_len` copies of `read_ptr[0]` to the output 661 memset(write_ptr, read_ptr[0], block_len); 662 663 ctx->current_total_output += block_len; 664 break; 665 } 666 case 2: { 667 // "Compressed_Block - this is a Zstandard compressed block, 668 // detailed in another section of this specification. Block_Size is 669 // the compressed size. 670 671 // Create a sub-stream for the block 672 istream_t block_stream = IO_make_sub_istream(in, block_len); 673 decompress_block(ctx, out, &block_stream); 674 break; 675 } 676 case 3: 677 // "Reserved - this is not a block. This value cannot be used with 678 // current version of this specification." 679 CORRUPTION(); 680 break; 681 default: 682 IMPOSSIBLE(); 683 } 684 } while (!last_block); 685 686 if (ctx->header.content_checksum_flag) { 687 // This program does not support checking the checksum, so skip over it 688 // if it's present 689 IO_advance_input(in, 4); 690 } 691 } 692 /******* END FRAME DECODING ***************************************************/ 693 694 /******* BLOCK DECOMPRESSION **************************************************/ 695 static void decompress_block(frame_context_t *const ctx, ostream_t *const out, 696 istream_t *const in) { 697 // "A compressed block consists of 2 sections : 698 // 699 // Literals_Section 700 // Sequences_Section" 701 702 703 // Part 1: decode the literals block 704 u8 *literals = NULL; 705 const size_t literals_size = decode_literals(ctx, in, &literals); 706 707 // Part 2: decode the sequences block 708 sequence_command_t *sequences = NULL; 709 const size_t num_sequences = 710 decode_sequences(ctx, in, &sequences); 711 712 // Part 3: combine literals and sequence commands to generate output 713 execute_sequences(ctx, out, literals, literals_size, sequences, 714 num_sequences); 715 free(literals); 716 free(sequences); 717 } 718 /******* END BLOCK DECOMPRESSION **********************************************/ 719 720 /******* LITERALS DECODING ****************************************************/ 721 static size_t decode_literals_simple(istream_t *const in, u8 **const literals, 722 const int block_type, 723 const int size_format); 724 static size_t decode_literals_compressed(frame_context_t *const ctx, 725 istream_t *const in, 726 u8 **const literals, 727 const int block_type, 728 const int size_format); 729 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in); 730 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, 731 int *const num_symbs); 732 733 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, 734 u8 **const literals) { 735 // "Literals can be stored uncompressed or compressed using Huffman prefix 736 // codes. When compressed, an optional tree description can be present, 737 // followed by 1 or 4 streams." 738 // 739 // "Literals_Section_Header 740 // 741 // Header is in charge of describing how literals are packed. It's a 742 // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using 743 // little-endian convention." 744 // 745 // "Literals_Block_Type 746 // 747 // This field uses 2 lowest bits of first byte, describing 4 different block 748 // types" 749 // 750 // size_format takes between 1 and 2 bits 751 int block_type = IO_read_bits(in, 2); 752 int size_format = IO_read_bits(in, 2); 753 754 if (block_type <= 1) { 755 // Raw or RLE literals block 756 return decode_literals_simple(in, literals, block_type, 757 size_format); 758 } else { 759 // Huffman compressed literals 760 return decode_literals_compressed(ctx, in, literals, block_type, 761 size_format); 762 } 763 } 764 765 /// Decodes literals blocks in raw or RLE form 766 static size_t decode_literals_simple(istream_t *const in, u8 **const literals, 767 const int block_type, 768 const int size_format) { 769 size_t size; 770 switch (size_format) { 771 // These cases are in the form ?0 772 // In this case, the ? bit is actually part of the size field 773 case 0: 774 case 2: 775 // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)." 776 IO_rewind_bits(in, 1); 777 size = IO_read_bits(in, 5); 778 break; 779 case 1: 780 // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)." 781 size = IO_read_bits(in, 12); 782 break; 783 case 3: 784 // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)." 785 size = IO_read_bits(in, 20); 786 break; 787 default: 788 // Size format is in range 0-3 789 IMPOSSIBLE(); 790 } 791 792 if (size > MAX_LITERALS_SIZE) { 793 CORRUPTION(); 794 } 795 796 *literals = malloc(size); 797 if (!*literals) { 798 BAD_ALLOC(); 799 } 800 801 switch (block_type) { 802 case 0: { 803 // "Raw_Literals_Block - Literals are stored uncompressed." 804 const u8 *const read_ptr = IO_get_read_ptr(in, size); 805 memcpy(*literals, read_ptr, size); 806 break; 807 } 808 case 1: { 809 // "RLE_Literals_Block - Literals consist of a single byte value repeated N times." 810 const u8 *const read_ptr = IO_get_read_ptr(in, 1); 811 memset(*literals, read_ptr[0], size); 812 break; 813 } 814 default: 815 IMPOSSIBLE(); 816 } 817 818 return size; 819 } 820 821 /// Decodes Huffman compressed literals 822 static size_t decode_literals_compressed(frame_context_t *const ctx, 823 istream_t *const in, 824 u8 **const literals, 825 const int block_type, 826 const int size_format) { 827 size_t regenerated_size, compressed_size; 828 // Only size_format=0 has 1 stream, so default to 4 829 int num_streams = 4; 830 switch (size_format) { 831 case 0: 832 // "A single stream. Both Compressed_Size and Regenerated_Size use 10 833 // bits (0-1023)." 834 num_streams = 1; 835 // Fall through as it has the same size format 836 case 1: 837 // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits 838 // (0-1023)." 839 regenerated_size = IO_read_bits(in, 10); 840 compressed_size = IO_read_bits(in, 10); 841 break; 842 case 2: 843 // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits 844 // (0-16383)." 845 regenerated_size = IO_read_bits(in, 14); 846 compressed_size = IO_read_bits(in, 14); 847 break; 848 case 3: 849 // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits 850 // (0-262143)." 851 regenerated_size = IO_read_bits(in, 18); 852 compressed_size = IO_read_bits(in, 18); 853 break; 854 default: 855 // Impossible 856 IMPOSSIBLE(); 857 } 858 if (regenerated_size > MAX_LITERALS_SIZE || 859 compressed_size >= regenerated_size) { 860 CORRUPTION(); 861 } 862 863 *literals = malloc(regenerated_size); 864 if (!*literals) { 865 BAD_ALLOC(); 866 } 867 868 ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size); 869 istream_t huf_stream = IO_make_sub_istream(in, compressed_size); 870 871 if (block_type == 2) { 872 // Decode the provided Huffman table 873 // "This section is only present when Literals_Block_Type type is 874 // Compressed_Literals_Block (2)." 875 876 HUF_free_dtable(&ctx->literals_dtable); 877 decode_huf_table(&ctx->literals_dtable, &huf_stream); 878 } else { 879 // If the previous Huffman table is being repeated, ensure it exists 880 if (!ctx->literals_dtable.symbols) { 881 CORRUPTION(); 882 } 883 } 884 885 size_t symbols_decoded; 886 if (num_streams == 1) { 887 symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream); 888 } else { 889 symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream); 890 } 891 892 if (symbols_decoded != regenerated_size) { 893 CORRUPTION(); 894 } 895 896 return regenerated_size; 897 } 898 899 // Decode the Huffman table description 900 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) { 901 // "All literal values from zero (included) to last present one (excluded) 902 // are represented by Weight with values from 0 to Max_Number_of_Bits." 903 904 // "This is a single byte value (0-255), which describes how to decode the list of weights." 905 const u8 header = IO_read_bits(in, 8); 906 907 u8 weights[HUF_MAX_SYMBS]; 908 memset(weights, 0, sizeof(weights)); 909 910 int num_symbs; 911 912 if (header >= 128) { 913 // "This is a direct representation, where each Weight is written 914 // directly as a 4 bits field (0-15). The full representation occupies 915 // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte 916 // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte - 917 // 127" 918 num_symbs = header - 127; 919 const size_t bytes = (num_symbs + 1) / 2; 920 921 const u8 *const weight_src = IO_get_read_ptr(in, bytes); 922 923 for (int i = 0; i < num_symbs; i++) { 924 // "They are encoded forward, 2 925 // weights to a byte with the first weight taking the top four bits 926 // and the second taking the bottom four (e.g. the following 927 // operations could be used to read the weights: Weight[0] = 928 // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)." 929 if (i % 2 == 0) { 930 weights[i] = weight_src[i / 2] >> 4; 931 } else { 932 weights[i] = weight_src[i / 2] & 0xf; 933 } 934 } 935 } else { 936 // The weights are FSE encoded, decode them before we can construct the 937 // table 938 istream_t fse_stream = IO_make_sub_istream(in, header); 939 ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS); 940 fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs); 941 } 942 943 // Construct the table using the decoded weights 944 HUF_init_dtable_usingweights(dtable, weights, num_symbs); 945 } 946 947 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, 948 int *const num_symbs) { 949 const int MAX_ACCURACY_LOG = 7; 950 951 FSE_dtable dtable; 952 953 // "An FSE bitstream starts by a header, describing probabilities 954 // distribution. It will create a Decoding Table. For a list of Huffman 955 // weights, maximum accuracy is 7 bits." 956 FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG); 957 958 // Decode the weights 959 *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in); 960 961 FSE_free_dtable(&dtable); 962 } 963 /******* END LITERALS DECODING ************************************************/ 964 965 /******* SEQUENCE DECODING ****************************************************/ 966 /// The combination of FSE states needed to decode sequences 967 typedef struct { 968 FSE_dtable ll_table; 969 FSE_dtable of_table; 970 FSE_dtable ml_table; 971 972 u16 ll_state; 973 u16 of_state; 974 u16 ml_state; 975 } sequence_states_t; 976 977 /// Different modes to signal to decode_seq_tables what to do 978 typedef enum { 979 seq_literal_length = 0, 980 seq_offset = 1, 981 seq_match_length = 2, 982 } seq_part_t; 983 984 typedef enum { 985 seq_predefined = 0, 986 seq_rle = 1, 987 seq_fse = 2, 988 seq_repeat = 3, 989 } seq_mode_t; 990 991 /// The predefined FSE distribution tables for `seq_predefined` mode 992 static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = { 993 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 994 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1}; 995 static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = { 996 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 997 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1}; 998 static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = { 999 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1000 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1001 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1}; 1002 1003 /// The sequence decoding baseline and number of additional bits to read/add 1004 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets 1005 static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = { 1006 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1007 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, 1008 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65538}; 1009 static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = { 1010 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1011 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 1012 1013 static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = { 1014 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1015 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 1016 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 1017 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539}; 1018 static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = { 1019 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1020 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1021 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; 1022 1023 /// Offset decoding is simpler so we just need a maximum code value 1024 static const u8 SEQ_MAX_CODES[3] = {35, -1, 52}; 1025 1026 static void decompress_sequences(frame_context_t *const ctx, 1027 istream_t *const in, 1028 sequence_command_t *const sequences, 1029 const size_t num_sequences); 1030 static sequence_command_t decode_sequence(sequence_states_t *const state, 1031 const u8 *const src, 1032 i64 *const offset); 1033 static void decode_seq_table(FSE_dtable *const table, istream_t *const in, 1034 const seq_part_t type, const seq_mode_t mode); 1035 1036 static size_t decode_sequences(frame_context_t *const ctx, istream_t *in, 1037 sequence_command_t **const sequences) { 1038 // "A compressed block is a succession of sequences . A sequence is a 1039 // literal copy command, followed by a match copy command. A literal copy 1040 // command specifies a length. It is the number of bytes to be copied (or 1041 // extracted) from the literal section. A match copy command specifies an 1042 // offset and a length. The offset gives the position to copy from, which 1043 // can be within a previous block." 1044 1045 size_t num_sequences; 1046 1047 // "Number_of_Sequences 1048 // 1049 // This is a variable size field using between 1 and 3 bytes. Let's call its 1050 // first byte byte0." 1051 u8 header = IO_read_bits(in, 8); 1052 if (header == 0) { 1053 // "There are no sequences. The sequence section stops there. 1054 // Regenerated content is defined entirely by literals section." 1055 *sequences = NULL; 1056 return 0; 1057 } else if (header < 128) { 1058 // "Number_of_Sequences = byte0 . Uses 1 byte." 1059 num_sequences = header; 1060 } else if (header < 255) { 1061 // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes." 1062 num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8); 1063 } else { 1064 // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes." 1065 num_sequences = IO_read_bits(in, 16) + 0x7F00; 1066 } 1067 1068 *sequences = malloc(num_sequences * sizeof(sequence_command_t)); 1069 if (!*sequences) { 1070 BAD_ALLOC(); 1071 } 1072 1073 decompress_sequences(ctx, in, *sequences, num_sequences); 1074 return num_sequences; 1075 } 1076 1077 /// Decompress the FSE encoded sequence commands 1078 static void decompress_sequences(frame_context_t *const ctx, istream_t *in, 1079 sequence_command_t *const sequences, 1080 const size_t num_sequences) { 1081 // "The Sequences_Section regroup all symbols required to decode commands. 1082 // There are 3 symbol types : literals lengths, offsets and match lengths. 1083 // They are encoded together, interleaved, in a single bitstream." 1084 1085 // "Symbol compression modes 1086 // 1087 // This is a single byte, defining the compression mode of each symbol 1088 // type." 1089 // 1090 // Bit number : Field name 1091 // 7-6 : Literals_Lengths_Mode 1092 // 5-4 : Offsets_Mode 1093 // 3-2 : Match_Lengths_Mode 1094 // 1-0 : Reserved 1095 u8 compression_modes = IO_read_bits(in, 8); 1096 1097 if ((compression_modes & 3) != 0) { 1098 // Reserved bits set 1099 CORRUPTION(); 1100 } 1101 1102 // "Following the header, up to 3 distribution tables can be described. When 1103 // present, they are in this order : 1104 // 1105 // Literals lengths 1106 // Offsets 1107 // Match Lengths" 1108 // Update the tables we have stored in the context 1109 decode_seq_table(&ctx->ll_dtable, in, seq_literal_length, 1110 (compression_modes >> 6) & 3); 1111 1112 decode_seq_table(&ctx->of_dtable, in, seq_offset, 1113 (compression_modes >> 4) & 3); 1114 1115 decode_seq_table(&ctx->ml_dtable, in, seq_match_length, 1116 (compression_modes >> 2) & 3); 1117 1118 1119 sequence_states_t states; 1120 1121 // Initialize the decoding tables 1122 { 1123 states.ll_table = ctx->ll_dtable; 1124 states.of_table = ctx->of_dtable; 1125 states.ml_table = ctx->ml_dtable; 1126 } 1127 1128 const size_t len = IO_istream_len(in); 1129 const u8 *const src = IO_get_read_ptr(in, len); 1130 1131 // "After writing the last bit containing information, the compressor writes 1132 // a single 1-bit and then fills the byte with 0-7 0 bits of padding." 1133 const int padding = 8 - highest_set_bit(src[len - 1]); 1134 // The offset starts at the end because FSE streams are read backwards 1135 i64 bit_offset = len * 8 - padding; 1136 1137 // "The bitstream starts with initial state values, each using the required 1138 // number of bits in their respective accuracy, decoded previously from 1139 // their normalized distribution. 1140 // 1141 // It starts by Literals_Length_State, followed by Offset_State, and finally 1142 // Match_Length_State." 1143 FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset); 1144 FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset); 1145 FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset); 1146 1147 for (size_t i = 0; i < num_sequences; i++) { 1148 // Decode sequences one by one 1149 sequences[i] = decode_sequence(&states, src, &bit_offset); 1150 } 1151 1152 if (bit_offset != 0) { 1153 CORRUPTION(); 1154 } 1155 } 1156 1157 // Decode a single sequence and update the state 1158 static sequence_command_t decode_sequence(sequence_states_t *const states, 1159 const u8 *const src, 1160 i64 *const offset) { 1161 // "Each symbol is a code in its own context, which specifies Baseline and 1162 // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw 1163 // additional bits in the same bitstream." 1164 1165 // Decode symbols, but don't update states 1166 const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state); 1167 const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state); 1168 const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state); 1169 1170 // Offset doesn't need a max value as it's not decoded using a table 1171 if (ll_code > SEQ_MAX_CODES[seq_literal_length] || 1172 ml_code > SEQ_MAX_CODES[seq_match_length]) { 1173 CORRUPTION(); 1174 } 1175 1176 // Read the interleaved bits 1177 sequence_command_t seq; 1178 // "Decoding starts by reading the Number_of_Bits required to decode Offset. 1179 // It then does the same for Match_Length, and then for Literals_Length." 1180 seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset); 1181 1182 seq.match_length = 1183 SEQ_MATCH_LENGTH_BASELINES[ml_code] + 1184 STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset); 1185 1186 seq.literal_length = 1187 SEQ_LITERAL_LENGTH_BASELINES[ll_code] + 1188 STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset); 1189 1190 // "If it is not the last sequence in the block, the next operation is to 1191 // update states. Using the rules pre-calculated in the decoding tables, 1192 // Literals_Length_State is updated, followed by Match_Length_State, and 1193 // then Offset_State." 1194 // If the stream is complete don't read bits to update state 1195 if (*offset != 0) { 1196 FSE_update_state(&states->ll_table, &states->ll_state, src, offset); 1197 FSE_update_state(&states->ml_table, &states->ml_state, src, offset); 1198 FSE_update_state(&states->of_table, &states->of_state, src, offset); 1199 } 1200 1201 return seq; 1202 } 1203 1204 /// Given a sequence part and table mode, decode the FSE distribution 1205 /// Errors if the mode is `seq_repeat` without a pre-existing table in `table` 1206 static void decode_seq_table(FSE_dtable *const table, istream_t *const in, 1207 const seq_part_t type, const seq_mode_t mode) { 1208 // Constant arrays indexed by seq_part_t 1209 const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST, 1210 SEQ_OFFSET_DEFAULT_DIST, 1211 SEQ_MATCH_LENGTH_DEFAULT_DIST}; 1212 const size_t default_distribution_lengths[] = {36, 29, 53}; 1213 const size_t default_distribution_accuracies[] = {6, 5, 6}; 1214 1215 const size_t max_accuracies[] = {9, 8, 9}; 1216 1217 if (mode != seq_repeat) { 1218 // Free old one before overwriting 1219 FSE_free_dtable(table); 1220 } 1221 1222 switch (mode) { 1223 case seq_predefined: { 1224 // "Predefined_Mode : uses a predefined distribution table." 1225 const i16 *distribution = default_distributions[type]; 1226 const size_t symbs = default_distribution_lengths[type]; 1227 const size_t accuracy_log = default_distribution_accuracies[type]; 1228 1229 FSE_init_dtable(table, distribution, symbs, accuracy_log); 1230 break; 1231 } 1232 case seq_rle: { 1233 // "RLE_Mode : it's a single code, repeated Number_of_Sequences times." 1234 const u8 symb = IO_get_read_ptr(in, 1)[0]; 1235 FSE_init_dtable_rle(table, symb); 1236 break; 1237 } 1238 case seq_fse: { 1239 // "FSE_Compressed_Mode : standard FSE compression. A distribution table 1240 // will be present " 1241 FSE_decode_header(table, in, max_accuracies[type]); 1242 break; 1243 } 1244 case seq_repeat: 1245 // "Repeat_Mode : re-use distribution table from previous compressed 1246 // block." 1247 // Nothing to do here, table will be unchanged 1248 if (!table->symbols) { 1249 // This mode is invalid if we don't already have a table 1250 CORRUPTION(); 1251 } 1252 break; 1253 default: 1254 // Impossible, as mode is from 0-3 1255 IMPOSSIBLE(); 1256 break; 1257 } 1258 1259 } 1260 /******* END SEQUENCE DECODING ************************************************/ 1261 1262 /******* SEQUENCE EXECUTION ***************************************************/ 1263 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, 1264 const u8 *const literals, 1265 const size_t literals_len, 1266 const sequence_command_t *const sequences, 1267 const size_t num_sequences) { 1268 istream_t litstream = IO_make_istream(literals, literals_len); 1269 1270 u64 *const offset_hist = ctx->previous_offsets; 1271 size_t total_output = ctx->current_total_output; 1272 1273 for (size_t i = 0; i < num_sequences; i++) { 1274 const sequence_command_t seq = sequences[i]; 1275 { 1276 const u32 literals_size = copy_literals(seq.literal_length, &litstream, out); 1277 total_output += literals_size; 1278 } 1279 1280 size_t const offset = compute_offset(seq, offset_hist); 1281 1282 size_t const match_length = seq.match_length; 1283 1284 execute_match_copy(ctx, offset, match_length, total_output, out); 1285 1286 total_output += match_length; 1287 } 1288 1289 // Copy any leftover literals 1290 { 1291 size_t len = IO_istream_len(&litstream); 1292 copy_literals(len, &litstream, out); 1293 total_output += len; 1294 } 1295 1296 ctx->current_total_output = total_output; 1297 } 1298 1299 static u32 copy_literals(const size_t literal_length, istream_t *litstream, 1300 ostream_t *const out) { 1301 // If the sequence asks for more literals than are left, the 1302 // sequence must be corrupted 1303 if (literal_length > IO_istream_len(litstream)) { 1304 CORRUPTION(); 1305 } 1306 1307 u8 *const write_ptr = IO_get_write_ptr(out, literal_length); 1308 const u8 *const read_ptr = 1309 IO_get_read_ptr(litstream, literal_length); 1310 // Copy literals to output 1311 memcpy(write_ptr, read_ptr, literal_length); 1312 1313 return literal_length; 1314 } 1315 1316 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) { 1317 size_t offset; 1318 // Offsets are special, we need to handle the repeat offsets 1319 if (seq.offset <= 3) { 1320 // "The first 3 values define a repeated offset and we will call 1321 // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3. 1322 // They are sorted in recency order, with Repeated_Offset1 meaning 1323 // 'most recent one'". 1324 1325 // Use 0 indexing for the array 1326 u32 idx = seq.offset - 1; 1327 if (seq.literal_length == 0) { 1328 // "There is an exception though, when current sequence's 1329 // literals length is 0. In this case, repeated offsets are 1330 // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2, 1331 // Repeated_Offset2 becomes Repeated_Offset3, and 1332 // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte." 1333 idx++; 1334 } 1335 1336 if (idx == 0) { 1337 offset = offset_hist[0]; 1338 } else { 1339 // If idx == 3 then literal length was 0 and the offset was 3, 1340 // as per the exception listed above 1341 offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1; 1342 1343 // If idx == 1 we don't need to modify offset_hist[2], since 1344 // we're using the second-most recent code 1345 if (idx > 1) { 1346 offset_hist[2] = offset_hist[1]; 1347 } 1348 offset_hist[1] = offset_hist[0]; 1349 offset_hist[0] = offset; 1350 } 1351 } else { 1352 // When it's not a repeat offset: 1353 // "if (Offset_Value > 3) offset = Offset_Value - 3;" 1354 offset = seq.offset - 3; 1355 1356 // Shift back history 1357 offset_hist[2] = offset_hist[1]; 1358 offset_hist[1] = offset_hist[0]; 1359 offset_hist[0] = offset; 1360 } 1361 return offset; 1362 } 1363 1364 static void execute_match_copy(frame_context_t *const ctx, size_t offset, 1365 size_t match_length, size_t total_output, 1366 ostream_t *const out) { 1367 u8 *write_ptr = IO_get_write_ptr(out, match_length); 1368 if (total_output <= ctx->header.window_size) { 1369 // In this case offset might go back into the dictionary 1370 if (offset > total_output + ctx->dict_content_len) { 1371 // The offset goes beyond even the dictionary 1372 CORRUPTION(); 1373 } 1374 1375 if (offset > total_output) { 1376 // "The rest of the dictionary is its content. The content act 1377 // as a "past" in front of data to compress or decompress, so it 1378 // can be referenced in sequence commands." 1379 const size_t dict_copy = 1380 MIN(offset - total_output, match_length); 1381 const size_t dict_offset = 1382 ctx->dict_content_len - (offset - total_output); 1383 1384 memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy); 1385 write_ptr += dict_copy; 1386 match_length -= dict_copy; 1387 } 1388 } else if (offset > ctx->header.window_size) { 1389 CORRUPTION(); 1390 } 1391 1392 // We must copy byte by byte because the match length might be larger 1393 // than the offset 1394 // ex: if the output so far was "abc", a command with offset=3 and 1395 // match_length=6 would produce "abcabcabc" as the new output 1396 for (size_t j = 0; j < match_length; j++) { 1397 *write_ptr = *(write_ptr - offset); 1398 write_ptr++; 1399 } 1400 } 1401 /******* END SEQUENCE EXECUTION ***********************************************/ 1402 1403 /******* OUTPUT SIZE COUNTING *************************************************/ 1404 /// Get the decompressed size of an input stream so memory can be allocated in 1405 /// advance. 1406 /// This implementation assumes `src` points to a single ZSTD-compressed frame 1407 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { 1408 istream_t in = IO_make_istream(src, src_len); 1409 1410 // get decompressed size from ZSTD frame header 1411 { 1412 const u32 magic_number = IO_read_bits(&in, 32); 1413 1414 if (magic_number == 0xFD2FB528U) { 1415 // ZSTD frame 1416 frame_header_t header; 1417 parse_frame_header(&header, &in); 1418 1419 if (header.frame_content_size == 0 && !header.single_segment_flag) { 1420 // Content size not provided, we can't tell 1421 return -1; 1422 } 1423 1424 return header.frame_content_size; 1425 } else { 1426 // not a real frame or skippable frame 1427 ERROR("ZSTD frame magic number did not match"); 1428 } 1429 } 1430 } 1431 /******* END OUTPUT SIZE COUNTING *********************************************/ 1432 1433 /******* DICTIONARY PARSING ***************************************************/ 1434 #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes") 1435 #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src"); 1436 1437 dictionary_t* create_dictionary() { 1438 dictionary_t* dict = calloc(1, sizeof(dictionary_t)); 1439 if (!dict) { 1440 BAD_ALLOC(); 1441 } 1442 return dict; 1443 } 1444 1445 static void init_dictionary_content(dictionary_t *const dict, 1446 istream_t *const in); 1447 1448 void parse_dictionary(dictionary_t *const dict, const void *src, 1449 size_t src_len) { 1450 const u8 *byte_src = (const u8 *)src; 1451 memset(dict, 0, sizeof(dictionary_t)); 1452 if (src == NULL) { /* cannot initialize dictionary with null src */ 1453 NULL_SRC(); 1454 } 1455 if (src_len < 8) { 1456 DICT_SIZE_ERROR(); 1457 } 1458 1459 istream_t in = IO_make_istream(byte_src, src_len); 1460 1461 const u32 magic_number = IO_read_bits(&in, 32); 1462 if (magic_number != 0xEC30A437) { 1463 // raw content dict 1464 IO_rewind_bits(&in, 32); 1465 init_dictionary_content(dict, &in); 1466 return; 1467 } 1468 1469 dict->dictionary_id = IO_read_bits(&in, 32); 1470 1471 // "Entropy_Tables : following the same format as the tables in compressed 1472 // blocks. They are stored in following order : Huffman tables for literals, 1473 // FSE table for offsets, FSE table for match lengths, and FSE table for 1474 // literals lengths. It's finally followed by 3 offset values, populating 1475 // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes 1476 // little-endian each, for a total of 12 bytes. Each recent offset must have 1477 // a value < dictionary size." 1478 decode_huf_table(&dict->literals_dtable, &in); 1479 decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse); 1480 decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse); 1481 decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse); 1482 1483 // Read in the previous offset history 1484 dict->previous_offsets[0] = IO_read_bits(&in, 32); 1485 dict->previous_offsets[1] = IO_read_bits(&in, 32); 1486 dict->previous_offsets[2] = IO_read_bits(&in, 32); 1487 1488 // Ensure the provided offsets aren't too large 1489 // "Each recent offset must have a value < dictionary size." 1490 for (int i = 0; i < 3; i++) { 1491 if (dict->previous_offsets[i] > src_len) { 1492 ERROR("Dictionary corrupted"); 1493 } 1494 } 1495 1496 // "Content : The rest of the dictionary is its content. The content act as 1497 // a "past" in front of data to compress or decompress, so it can be 1498 // referenced in sequence commands." 1499 init_dictionary_content(dict, &in); 1500 } 1501 1502 static void init_dictionary_content(dictionary_t *const dict, 1503 istream_t *const in) { 1504 // Copy in the content 1505 dict->content_size = IO_istream_len(in); 1506 dict->content = malloc(dict->content_size); 1507 if (!dict->content) { 1508 BAD_ALLOC(); 1509 } 1510 1511 const u8 *const content = IO_get_read_ptr(in, dict->content_size); 1512 1513 memcpy(dict->content, content, dict->content_size); 1514 } 1515 1516 /// Free an allocated dictionary 1517 void free_dictionary(dictionary_t *const dict) { 1518 HUF_free_dtable(&dict->literals_dtable); 1519 FSE_free_dtable(&dict->ll_dtable); 1520 FSE_free_dtable(&dict->of_dtable); 1521 FSE_free_dtable(&dict->ml_dtable); 1522 1523 free(dict->content); 1524 1525 memset(dict, 0, sizeof(dictionary_t)); 1526 1527 free(dict); 1528 } 1529 /******* END DICTIONARY PARSING ***********************************************/ 1530 1531 /******* IO STREAM OPERATIONS *************************************************/ 1532 #define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream") 1533 /// Reads `num` bits from a bitstream, and updates the internal offset 1534 static inline u64 IO_read_bits(istream_t *const in, const int num_bits) { 1535 if (num_bits > 64 || num_bits <= 0) { 1536 ERROR("Attempt to read an invalid number of bits"); 1537 } 1538 1539 const size_t bytes = (num_bits + in->bit_offset + 7) / 8; 1540 const size_t full_bytes = (num_bits + in->bit_offset) / 8; 1541 if (bytes > in->len) { 1542 INP_SIZE(); 1543 } 1544 1545 const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset); 1546 1547 in->bit_offset = (num_bits + in->bit_offset) % 8; 1548 in->ptr += full_bytes; 1549 in->len -= full_bytes; 1550 1551 return result; 1552 } 1553 1554 /// If a non-zero number of bits have been read from the current byte, advance 1555 /// the offset to the next byte 1556 static inline void IO_rewind_bits(istream_t *const in, int num_bits) { 1557 if (num_bits < 0) { 1558 ERROR("Attempting to rewind stream by a negative number of bits"); 1559 } 1560 1561 // move the offset back by `num_bits` bits 1562 const int new_offset = in->bit_offset - num_bits; 1563 // determine the number of whole bytes we have to rewind, rounding up to an 1564 // integer number (e.g. if `new_offset == -5`, `bytes == 1`) 1565 const i64 bytes = -(new_offset - 7) / 8; 1566 1567 in->ptr -= bytes; 1568 in->len += bytes; 1569 // make sure the resulting `bit_offset` is positive, as mod in C does not 1570 // convert numbers from negative to positive (e.g. -22 % 8 == -6) 1571 in->bit_offset = ((new_offset % 8) + 8) % 8; 1572 } 1573 1574 /// If the remaining bits in a byte will be unused, advance to the end of the 1575 /// byte 1576 static inline void IO_align_stream(istream_t *const in) { 1577 if (in->bit_offset != 0) { 1578 if (in->len == 0) { 1579 INP_SIZE(); 1580 } 1581 in->ptr++; 1582 in->len--; 1583 in->bit_offset = 0; 1584 } 1585 } 1586 1587 /// Write the given byte into the output stream 1588 static inline void IO_write_byte(ostream_t *const out, u8 symb) { 1589 if (out->len == 0) { 1590 OUT_SIZE(); 1591 } 1592 1593 out->ptr[0] = symb; 1594 out->ptr++; 1595 out->len--; 1596 } 1597 1598 /// Returns the number of bytes left to be read in this stream. The stream must 1599 /// be byte aligned. 1600 static inline size_t IO_istream_len(const istream_t *const in) { 1601 return in->len; 1602 } 1603 1604 /// Returns a pointer where `len` bytes can be read, and advances the internal 1605 /// state. The stream must be byte aligned. 1606 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) { 1607 if (len > in->len) { 1608 INP_SIZE(); 1609 } 1610 if (in->bit_offset != 0) { 1611 UNALIGNED(); 1612 } 1613 const u8 *const ptr = in->ptr; 1614 in->ptr += len; 1615 in->len -= len; 1616 1617 return ptr; 1618 } 1619 /// Returns a pointer to write `len` bytes to, and advances the internal state 1620 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) { 1621 if (len > out->len) { 1622 OUT_SIZE(); 1623 } 1624 u8 *const ptr = out->ptr; 1625 out->ptr += len; 1626 out->len -= len; 1627 1628 return ptr; 1629 } 1630 1631 /// Advance the inner state by `len` bytes 1632 static inline void IO_advance_input(istream_t *const in, size_t len) { 1633 if (len > in->len) { 1634 INP_SIZE(); 1635 } 1636 if (in->bit_offset != 0) { 1637 UNALIGNED(); 1638 } 1639 1640 in->ptr += len; 1641 in->len -= len; 1642 } 1643 1644 /// Returns an `ostream_t` constructed from the given pointer and length 1645 static inline ostream_t IO_make_ostream(u8 *out, size_t len) { 1646 return (ostream_t) { out, len }; 1647 } 1648 1649 /// Returns an `istream_t` constructed from the given pointer and length 1650 static inline istream_t IO_make_istream(const u8 *in, size_t len) { 1651 return (istream_t) { in, len, 0 }; 1652 } 1653 1654 /// Returns an `istream_t` with the same base as `in`, and length `len` 1655 /// Then, advance `in` to account for the consumed bytes 1656 /// `in` must be byte aligned 1657 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) { 1658 // Consume `len` bytes of the parent stream 1659 const u8 *const ptr = IO_get_read_ptr(in, len); 1660 1661 // Make a substream using the pointer to those `len` bytes 1662 return IO_make_istream(ptr, len); 1663 } 1664 /******* END IO STREAM OPERATIONS *********************************************/ 1665 1666 /******* BITSTREAM OPERATIONS *************************************************/ 1667 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits 1668 static inline u64 read_bits_LE(const u8 *src, const int num_bits, 1669 const size_t offset) { 1670 if (num_bits > 64) { 1671 ERROR("Attempt to read an invalid number of bits"); 1672 } 1673 1674 // Skip over bytes that aren't in range 1675 src += offset / 8; 1676 size_t bit_offset = offset % 8; 1677 u64 res = 0; 1678 1679 int shift = 0; 1680 int left = num_bits; 1681 while (left > 0) { 1682 u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1); 1683 // Read the next byte, shift it to account for the offset, and then mask 1684 // out the top part if we don't need all the bits 1685 res += (((u64)*src++ >> bit_offset) & mask) << shift; 1686 shift += 8 - bit_offset; 1687 left -= 8 - bit_offset; 1688 bit_offset = 0; 1689 } 1690 1691 return res; 1692 } 1693 1694 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so 1695 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from 1696 /// `src + offset`. If the offset becomes negative, the extra bits at the 1697 /// bottom are filled in with `0` bits instead of reading from before `src`. 1698 static inline u64 STREAM_read_bits(const u8 *const src, const int bits, 1699 i64 *const offset) { 1700 *offset = *offset - bits; 1701 size_t actual_off = *offset; 1702 size_t actual_bits = bits; 1703 // Don't actually read bits from before the start of src, so if `*offset < 1704 // 0` fix actual_off and actual_bits to reflect the quantity to read 1705 if (*offset < 0) { 1706 actual_bits += *offset; 1707 actual_off = 0; 1708 } 1709 u64 res = read_bits_LE(src, actual_bits, actual_off); 1710 1711 if (*offset < 0) { 1712 // Fill in the bottom "overflowed" bits with 0's 1713 res = -*offset >= 64 ? 0 : (res << -*offset); 1714 } 1715 return res; 1716 } 1717 /******* END BITSTREAM OPERATIONS *********************************************/ 1718 1719 /******* BIT COUNTING OPERATIONS **********************************************/ 1720 /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to 1721 /// `num`, or `-1` if `num == 0`. 1722 static inline int highest_set_bit(const u64 num) { 1723 for (int i = 63; i >= 0; i--) { 1724 if (((u64)1 << i) <= num) { 1725 return i; 1726 } 1727 } 1728 return -1; 1729 } 1730 /******* END BIT COUNTING OPERATIONS ******************************************/ 1731 1732 /******* HUFFMAN PRIMITIVES ***************************************************/ 1733 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, 1734 u16 *const state, const u8 *const src, 1735 i64 *const offset) { 1736 // Look up the symbol and number of bits to read 1737 const u8 symb = dtable->symbols[*state]; 1738 const u8 bits = dtable->num_bits[*state]; 1739 const u16 rest = STREAM_read_bits(src, bits, offset); 1740 // Shift `bits` bits out of the state, keeping the low order bits that 1741 // weren't necessary to determine this symbol. Then add in the new bits 1742 // read from the stream. 1743 *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1); 1744 1745 return symb; 1746 } 1747 1748 static inline void HUF_init_state(const HUF_dtable *const dtable, 1749 u16 *const state, const u8 *const src, 1750 i64 *const offset) { 1751 // Read in a full `dtable->max_bits` bits to initialize the state 1752 const u8 bits = dtable->max_bits; 1753 *state = STREAM_read_bits(src, bits, offset); 1754 } 1755 1756 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, 1757 ostream_t *const out, 1758 istream_t *const in) { 1759 const size_t len = IO_istream_len(in); 1760 if (len == 0) { 1761 INP_SIZE(); 1762 } 1763 const u8 *const src = IO_get_read_ptr(in, len); 1764 1765 // "Each bitstream must be read backward, that is starting from the end down 1766 // to the beginning. Therefore it's necessary to know the size of each 1767 // bitstream. 1768 // 1769 // It's also necessary to know exactly which bit is the latest. This is 1770 // detected by a final bit flag : the highest bit of latest byte is a 1771 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the 1772 // final-bit-flag itself is not part of the useful bitstream. Hence, the 1773 // last byte contains between 0 and 7 useful bits." 1774 const int padding = 8 - highest_set_bit(src[len - 1]); 1775 1776 // Offset starts at the end because HUF streams are read backwards 1777 i64 bit_offset = len * 8 - padding; 1778 u16 state; 1779 1780 HUF_init_state(dtable, &state, src, &bit_offset); 1781 1782 size_t symbols_written = 0; 1783 while (bit_offset > -dtable->max_bits) { 1784 // Iterate over the stream, decoding one symbol at a time 1785 IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset)); 1786 symbols_written++; 1787 } 1788 // "The process continues up to reading the required number of symbols per 1789 // stream. If a bitstream is not entirely and exactly consumed, hence 1790 // reaching exactly its beginning position with all bits consumed, the 1791 // decoding process is considered faulty." 1792 1793 // When all symbols have been decoded, the final state value shouldn't have 1794 // any data from the stream, so it should have "read" dtable->max_bits from 1795 // before the start of `src` 1796 // Therefore `offset`, the edge to start reading new bits at, should be 1797 // dtable->max_bits before the start of the stream 1798 if (bit_offset != -dtable->max_bits) { 1799 CORRUPTION(); 1800 } 1801 1802 return symbols_written; 1803 } 1804 1805 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, 1806 ostream_t *const out, istream_t *const in) { 1807 // "Compressed size is provided explicitly : in the 4-streams variant, 1808 // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each 1809 // value represents the compressed size of one stream, in order. The last 1810 // stream size is deducted from total compressed size and from previously 1811 // decoded stream sizes" 1812 const size_t csize1 = IO_read_bits(in, 16); 1813 const size_t csize2 = IO_read_bits(in, 16); 1814 const size_t csize3 = IO_read_bits(in, 16); 1815 1816 istream_t in1 = IO_make_sub_istream(in, csize1); 1817 istream_t in2 = IO_make_sub_istream(in, csize2); 1818 istream_t in3 = IO_make_sub_istream(in, csize3); 1819 istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in)); 1820 1821 size_t total_output = 0; 1822 // Decode each stream independently for simplicity 1823 // If we wanted to we could decode all 4 at the same time for speed, 1824 // utilizing more execution units 1825 total_output += HUF_decompress_1stream(dtable, out, &in1); 1826 total_output += HUF_decompress_1stream(dtable, out, &in2); 1827 total_output += HUF_decompress_1stream(dtable, out, &in3); 1828 total_output += HUF_decompress_1stream(dtable, out, &in4); 1829 1830 return total_output; 1831 } 1832 1833 /// Initializes a Huffman table using canonical Huffman codes 1834 /// For more explanation on canonical Huffman codes see 1835 /// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html 1836 /// Codes within a level are allocated in symbol order (i.e. smaller symbols get 1837 /// earlier codes) 1838 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, 1839 const int num_symbs) { 1840 memset(table, 0, sizeof(HUF_dtable)); 1841 if (num_symbs > HUF_MAX_SYMBS) { 1842 ERROR("Too many symbols for Huffman"); 1843 } 1844 1845 u8 max_bits = 0; 1846 u16 rank_count[HUF_MAX_BITS + 1]; 1847 memset(rank_count, 0, sizeof(rank_count)); 1848 1849 // Count the number of symbols for each number of bits, and determine the 1850 // depth of the tree 1851 for (int i = 0; i < num_symbs; i++) { 1852 if (bits[i] > HUF_MAX_BITS) { 1853 ERROR("Huffman table depth too large"); 1854 } 1855 max_bits = MAX(max_bits, bits[i]); 1856 rank_count[bits[i]]++; 1857 } 1858 1859 const size_t table_size = 1 << max_bits; 1860 table->max_bits = max_bits; 1861 table->symbols = malloc(table_size); 1862 table->num_bits = malloc(table_size); 1863 1864 if (!table->symbols || !table->num_bits) { 1865 free(table->symbols); 1866 free(table->num_bits); 1867 BAD_ALLOC(); 1868 } 1869 1870 // "Symbols are sorted by Weight. Within same Weight, symbols keep natural 1871 // order. Symbols with a Weight of zero are removed. Then, starting from 1872 // lowest weight, prefix codes are distributed in order." 1873 1874 u32 rank_idx[HUF_MAX_BITS + 1]; 1875 // Initialize the starting codes for each rank (number of bits) 1876 rank_idx[max_bits] = 0; 1877 for (int i = max_bits; i >= 1; i--) { 1878 rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i)); 1879 // The entire range takes the same number of bits so we can memset it 1880 memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]); 1881 } 1882 1883 if (rank_idx[0] != table_size) { 1884 CORRUPTION(); 1885 } 1886 1887 // Allocate codes and fill in the table 1888 for (int i = 0; i < num_symbs; i++) { 1889 if (bits[i] != 0) { 1890 // Allocate a code for this symbol and set its range in the table 1891 const u16 code = rank_idx[bits[i]]; 1892 // Since the code doesn't care about the bottom `max_bits - bits[i]` 1893 // bits of state, it gets a range that spans all possible values of 1894 // the lower bits 1895 const u16 len = 1 << (max_bits - bits[i]); 1896 memset(&table->symbols[code], i, len); 1897 rank_idx[bits[i]] += len; 1898 } 1899 } 1900 } 1901 1902 static void HUF_init_dtable_usingweights(HUF_dtable *const table, 1903 const u8 *const weights, 1904 const int num_symbs) { 1905 // +1 because the last weight is not transmitted in the header 1906 if (num_symbs + 1 > HUF_MAX_SYMBS) { 1907 ERROR("Too many symbols for Huffman"); 1908 } 1909 1910 u8 bits[HUF_MAX_SYMBS]; 1911 1912 u64 weight_sum = 0; 1913 for (int i = 0; i < num_symbs; i++) { 1914 // Weights are in the same range as bit count 1915 if (weights[i] > HUF_MAX_BITS) { 1916 CORRUPTION(); 1917 } 1918 weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0; 1919 } 1920 1921 // Find the first power of 2 larger than the sum 1922 const int max_bits = highest_set_bit(weight_sum) + 1; 1923 const u64 left_over = ((u64)1 << max_bits) - weight_sum; 1924 // If the left over isn't a power of 2, the weights are invalid 1925 if (left_over & (left_over - 1)) { 1926 CORRUPTION(); 1927 } 1928 1929 // left_over is used to find the last weight as it's not transmitted 1930 // by inverting 2^(weight - 1) we can determine the value of last_weight 1931 const int last_weight = highest_set_bit(left_over) + 1; 1932 1933 for (int i = 0; i < num_symbs; i++) { 1934 // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0" 1935 bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0; 1936 } 1937 bits[num_symbs] = 1938 max_bits + 1 - last_weight; // Last weight is always non-zero 1939 1940 HUF_init_dtable(table, bits, num_symbs + 1); 1941 } 1942 1943 static void HUF_free_dtable(HUF_dtable *const dtable) { 1944 free(dtable->symbols); 1945 free(dtable->num_bits); 1946 memset(dtable, 0, sizeof(HUF_dtable)); 1947 } 1948 1949 static void HUF_copy_dtable(HUF_dtable *const dst, 1950 const HUF_dtable *const src) { 1951 if (src->max_bits == 0) { 1952 memset(dst, 0, sizeof(HUF_dtable)); 1953 return; 1954 } 1955 1956 const size_t size = (size_t)1 << src->max_bits; 1957 dst->max_bits = src->max_bits; 1958 1959 dst->symbols = malloc(size); 1960 dst->num_bits = malloc(size); 1961 if (!dst->symbols || !dst->num_bits) { 1962 BAD_ALLOC(); 1963 } 1964 1965 memcpy(dst->symbols, src->symbols, size); 1966 memcpy(dst->num_bits, src->num_bits, size); 1967 } 1968 /******* END HUFFMAN PRIMITIVES ***********************************************/ 1969 1970 /******* FSE PRIMITIVES *******************************************************/ 1971 /// For more description of FSE see 1972 /// https://github.com/Cyan4973/FiniteStateEntropy/ 1973 1974 /// Allow a symbol to be decoded without updating state 1975 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, 1976 const u16 state) { 1977 return dtable->symbols[state]; 1978 } 1979 1980 /// Consumes bits from the input and uses the current state to determine the 1981 /// next state 1982 static inline void FSE_update_state(const FSE_dtable *const dtable, 1983 u16 *const state, const u8 *const src, 1984 i64 *const offset) { 1985 const u8 bits = dtable->num_bits[*state]; 1986 const u16 rest = STREAM_read_bits(src, bits, offset); 1987 *state = dtable->new_state_base[*state] + rest; 1988 } 1989 1990 /// Decodes a single FSE symbol and updates the offset 1991 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, 1992 u16 *const state, const u8 *const src, 1993 i64 *const offset) { 1994 const u8 symb = FSE_peek_symbol(dtable, *state); 1995 FSE_update_state(dtable, state, src, offset); 1996 return symb; 1997 } 1998 1999 static inline void FSE_init_state(const FSE_dtable *const dtable, 2000 u16 *const state, const u8 *const src, 2001 i64 *const offset) { 2002 // Read in a full `accuracy_log` bits to initialize the state 2003 const u8 bits = dtable->accuracy_log; 2004 *state = STREAM_read_bits(src, bits, offset); 2005 } 2006 2007 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, 2008 ostream_t *const out, 2009 istream_t *const in) { 2010 const size_t len = IO_istream_len(in); 2011 if (len == 0) { 2012 INP_SIZE(); 2013 } 2014 const u8 *const src = IO_get_read_ptr(in, len); 2015 2016 // "Each bitstream must be read backward, that is starting from the end down 2017 // to the beginning. Therefore it's necessary to know the size of each 2018 // bitstream. 2019 // 2020 // It's also necessary to know exactly which bit is the latest. This is 2021 // detected by a final bit flag : the highest bit of latest byte is a 2022 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the 2023 // final-bit-flag itself is not part of the useful bitstream. Hence, the 2024 // last byte contains between 0 and 7 useful bits." 2025 const int padding = 8 - highest_set_bit(src[len - 1]); 2026 i64 offset = len * 8 - padding; 2027 2028 u16 state1, state2; 2029 // "The first state (State1) encodes the even indexed symbols, and the 2030 // second (State2) encodes the odd indexes. State1 is initialized first, and 2031 // then State2, and they take turns decoding a single symbol and updating 2032 // their state." 2033 FSE_init_state(dtable, &state1, src, &offset); 2034 FSE_init_state(dtable, &state2, src, &offset); 2035 2036 // Decode until we overflow the stream 2037 // Since we decode in reverse order, overflowing the stream is offset going 2038 // negative 2039 size_t symbols_written = 0; 2040 while (1) { 2041 // "The number of symbols to decode is determined by tracking bitStream 2042 // overflow condition: If updating state after decoding a symbol would 2043 // require more bits than remain in the stream, it is assumed the extra 2044 // bits are 0. Then, the symbols for each of the final states are 2045 // decoded and the process is complete." 2046 IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset)); 2047 symbols_written++; 2048 if (offset < 0) { 2049 // There's still a symbol to decode in state2 2050 IO_write_byte(out, FSE_peek_symbol(dtable, state2)); 2051 symbols_written++; 2052 break; 2053 } 2054 2055 IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset)); 2056 symbols_written++; 2057 if (offset < 0) { 2058 // There's still a symbol to decode in state1 2059 IO_write_byte(out, FSE_peek_symbol(dtable, state1)); 2060 symbols_written++; 2061 break; 2062 } 2063 } 2064 2065 return symbols_written; 2066 } 2067 2068 static void FSE_init_dtable(FSE_dtable *const dtable, 2069 const i16 *const norm_freqs, const int num_symbs, 2070 const int accuracy_log) { 2071 if (accuracy_log > FSE_MAX_ACCURACY_LOG) { 2072 ERROR("FSE accuracy too large"); 2073 } 2074 if (num_symbs > FSE_MAX_SYMBS) { 2075 ERROR("Too many symbols for FSE"); 2076 } 2077 2078 dtable->accuracy_log = accuracy_log; 2079 2080 const size_t size = (size_t)1 << accuracy_log; 2081 dtable->symbols = malloc(size * sizeof(u8)); 2082 dtable->num_bits = malloc(size * sizeof(u8)); 2083 dtable->new_state_base = malloc(size * sizeof(u16)); 2084 2085 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { 2086 BAD_ALLOC(); 2087 } 2088 2089 // Used to determine how many bits need to be read for each state, 2090 // and where the destination range should start 2091 // Needs to be u16 because max value is 2 * max number of symbols, 2092 // which can be larger than a byte can store 2093 u16 state_desc[FSE_MAX_SYMBS]; 2094 2095 // "Symbols are scanned in their natural order for "less than 1" 2096 // probabilities. Symbols with this probability are being attributed a 2097 // single cell, starting from the end of the table. These symbols define a 2098 // full state reset, reading Accuracy_Log bits." 2099 int high_threshold = size; 2100 for (int s = 0; s < num_symbs; s++) { 2101 // Scan for low probability symbols to put at the top 2102 if (norm_freqs[s] == -1) { 2103 dtable->symbols[--high_threshold] = s; 2104 state_desc[s] = 1; 2105 } 2106 } 2107 2108 // "All remaining symbols are sorted in their natural order. Starting from 2109 // symbol 0 and table position 0, each symbol gets attributed as many cells 2110 // as its probability. Cell allocation is spreaded, not linear." 2111 // Place the rest in the table 2112 const u16 step = (size >> 1) + (size >> 3) + 3; 2113 const u16 mask = size - 1; 2114 u16 pos = 0; 2115 for (int s = 0; s < num_symbs; s++) { 2116 if (norm_freqs[s] <= 0) { 2117 continue; 2118 } 2119 2120 state_desc[s] = norm_freqs[s]; 2121 2122 for (int i = 0; i < norm_freqs[s]; i++) { 2123 // Give `norm_freqs[s]` states to symbol s 2124 dtable->symbols[pos] = s; 2125 // "A position is skipped if already occupied, typically by a "less 2126 // than 1" probability symbol." 2127 do { 2128 pos = (pos + step) & mask; 2129 } while (pos >= 2130 high_threshold); 2131 // Note: no other collision checking is necessary as `step` is 2132 // coprime to `size`, so the cycle will visit each position exactly 2133 // once 2134 } 2135 } 2136 if (pos != 0) { 2137 CORRUPTION(); 2138 } 2139 2140 // Now we can fill baseline and num bits 2141 for (size_t i = 0; i < size; i++) { 2142 u8 symbol = dtable->symbols[i]; 2143 u16 next_state_desc = state_desc[symbol]++; 2144 // Fills in the table appropriately, next_state_desc increases by symbol 2145 // over time, decreasing number of bits 2146 dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc)); 2147 // Baseline increases until the bit threshold is passed, at which point 2148 // it resets to 0 2149 dtable->new_state_base[i] = 2150 ((u16)next_state_desc << dtable->num_bits[i]) - size; 2151 } 2152 } 2153 2154 /// Decode an FSE header as defined in the Zstandard format specification and 2155 /// use the decoded frequencies to initialize a decoding table. 2156 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, 2157 const int max_accuracy_log) { 2158 // "An FSE distribution table describes the probabilities of all symbols 2159 // from 0 to the last present one (included) on a normalized scale of 1 << 2160 // Accuracy_Log . 2161 // 2162 // It's a bitstream which is read forward, in little-endian fashion. It's 2163 // not necessary to know its exact size, since it will be discovered and 2164 // reported by the decoding process. 2165 if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) { 2166 ERROR("FSE accuracy too large"); 2167 } 2168 2169 // The bitstream starts by reporting on which scale it operates. 2170 // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal 2171 // and match lengths is 9, and for offsets is 8. Higher values are 2172 // considered errors." 2173 const int accuracy_log = 5 + IO_read_bits(in, 4); 2174 if (accuracy_log > max_accuracy_log) { 2175 ERROR("FSE accuracy too large"); 2176 } 2177 2178 // "Then follows each symbol value, from 0 to last present one. The number 2179 // of bits used by each field is variable. It depends on : 2180 // 2181 // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8, 2182 // and presuming 100 probabilities points have already been distributed, the 2183 // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive). 2184 // Therefore, it must read log2sup(156) == 8 bits. 2185 // 2186 // Value decoded : small values use 1 less bit : example : Presuming values 2187 // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining 2188 // in an 8-bits field. They are used this way : first 99 values (hence from 2189 // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. " 2190 2191 i32 remaining = 1 << accuracy_log; 2192 i16 frequencies[FSE_MAX_SYMBS]; 2193 2194 int symb = 0; 2195 while (remaining > 0 && symb < FSE_MAX_SYMBS) { 2196 // Log of the number of possible values we could read 2197 int bits = highest_set_bit(remaining + 1) + 1; 2198 2199 u16 val = IO_read_bits(in, bits); 2200 2201 // Try to mask out the lower bits to see if it qualifies for the "small 2202 // value" threshold 2203 const u16 lower_mask = ((u16)1 << (bits - 1)) - 1; 2204 const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1); 2205 2206 if ((val & lower_mask) < threshold) { 2207 IO_rewind_bits(in, 1); 2208 val = val & lower_mask; 2209 } else if (val > lower_mask) { 2210 val = val - threshold; 2211 } 2212 2213 // "Probability is obtained from Value decoded by following formula : 2214 // Proba = value - 1" 2215 const i16 proba = (i16)val - 1; 2216 2217 // "It means value 0 becomes negative probability -1. -1 is a special 2218 // probability, which means "less than 1". Its effect on distribution 2219 // table is described in next paragraph. For the purpose of calculating 2220 // cumulated distribution, it counts as one." 2221 remaining -= proba < 0 ? -proba : proba; 2222 2223 frequencies[symb] = proba; 2224 symb++; 2225 2226 // "When a symbol has a probability of zero, it is followed by a 2-bits 2227 // repeat flag. This repeat flag tells how many probabilities of zeroes 2228 // follow the current one. It provides a number ranging from 0 to 3. If 2229 // it is a 3, another 2-bits repeat flag follows, and so on." 2230 if (proba == 0) { 2231 // Read the next two bits to see how many more 0s 2232 int repeat = IO_read_bits(in, 2); 2233 2234 while (1) { 2235 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) { 2236 frequencies[symb++] = 0; 2237 } 2238 if (repeat == 3) { 2239 repeat = IO_read_bits(in, 2); 2240 } else { 2241 break; 2242 } 2243 } 2244 } 2245 } 2246 IO_align_stream(in); 2247 2248 // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding 2249 // is complete. If the last symbol makes cumulated total go above 1 << 2250 // Accuracy_Log, distribution is considered corrupted." 2251 if (remaining != 0 || symb >= FSE_MAX_SYMBS) { 2252 CORRUPTION(); 2253 } 2254 2255 // Initialize the decoding table using the determined weights 2256 FSE_init_dtable(dtable, frequencies, symb, accuracy_log); 2257 } 2258 2259 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) { 2260 dtable->symbols = malloc(sizeof(u8)); 2261 dtable->num_bits = malloc(sizeof(u8)); 2262 dtable->new_state_base = malloc(sizeof(u16)); 2263 2264 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { 2265 BAD_ALLOC(); 2266 } 2267 2268 // This setup will always have a state of 0, always return symbol `symb`, 2269 // and never consume any bits 2270 dtable->symbols[0] = symb; 2271 dtable->num_bits[0] = 0; 2272 dtable->new_state_base[0] = 0; 2273 dtable->accuracy_log = 0; 2274 } 2275 2276 static void FSE_free_dtable(FSE_dtable *const dtable) { 2277 free(dtable->symbols); 2278 free(dtable->num_bits); 2279 free(dtable->new_state_base); 2280 memset(dtable, 0, sizeof(FSE_dtable)); 2281 } 2282 2283 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) { 2284 if (src->accuracy_log == 0) { 2285 memset(dst, 0, sizeof(FSE_dtable)); 2286 return; 2287 } 2288 2289 size_t size = (size_t)1 << src->accuracy_log; 2290 dst->accuracy_log = src->accuracy_log; 2291 2292 dst->symbols = malloc(size); 2293 dst->num_bits = malloc(size); 2294 dst->new_state_base = malloc(size * sizeof(u16)); 2295 if (!dst->symbols || !dst->num_bits || !dst->new_state_base) { 2296 BAD_ALLOC(); 2297 } 2298 2299 memcpy(dst->symbols, src->symbols, size); 2300 memcpy(dst->num_bits, src->num_bits, size); 2301 memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16)); 2302 } 2303 /******* END FSE PRIMITIVES ***************************************************/ 2304