xref: /freebsd/sys/contrib/zstd/doc/educational_decoder/zstd_decompress.c (revision 68d75eff68281c1b445e3010bb975eae07aac225)
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  */
9 
10 /// Zstandard educational decoder implementation
11 /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
12 
13 #include <stdint.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include "zstd_decompress.h"
18 
19 /******* UTILITY MACROS AND TYPES *********************************************/
20 // Max block size decompressed size is 128 KB and literal blocks can't be
21 // larger than their block
22 #define MAX_LITERALS_SIZE ((size_t)128 * 1024)
23 
24 #define MAX(a, b) ((a) > (b) ? (a) : (b))
25 #define MIN(a, b) ((a) < (b) ? (a) : (b))
26 
27 /// This decoder calls exit(1) when it encounters an error, however a production
28 /// library should propagate error codes
29 #define ERROR(s)                                                               \
30     do {                                                                       \
31         fprintf(stderr, "Error: %s\n", s);                                     \
32         exit(1);                                                               \
33     } while (0)
34 #define INP_SIZE()                                                             \
35     ERROR("Input buffer smaller than it should be or input is "                \
36           "corrupted")
37 #define OUT_SIZE() ERROR("Output buffer too small for output")
38 #define CORRUPTION() ERROR("Corruption detected while decompressing")
39 #define BAD_ALLOC() ERROR("Memory allocation error")
40 #define IMPOSSIBLE() ERROR("An impossibility has occurred")
41 
42 typedef uint8_t u8;
43 typedef uint16_t u16;
44 typedef uint32_t u32;
45 typedef uint64_t u64;
46 
47 typedef int8_t i8;
48 typedef int16_t i16;
49 typedef int32_t i32;
50 typedef int64_t i64;
51 /******* END UTILITY MACROS AND TYPES *****************************************/
52 
53 /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
54 /// The implementations for these functions can be found at the bottom of this
55 /// file.  They implement low-level functionality needed for the higher level
56 /// decompression functions.
57 
58 /*** IO STREAM OPERATIONS *************/
59 
60 /// ostream_t/istream_t are used to wrap the pointers/length data passed into
61 /// ZSTD_decompress, so that all IO operations are safely bounds checked
62 /// They are written/read forward, and reads are treated as little-endian
63 /// They should be used opaquely to ensure safety
64 typedef struct {
65     u8 *ptr;
66     size_t len;
67 } ostream_t;
68 
69 typedef struct {
70     const u8 *ptr;
71     size_t len;
72 
73     // Input often reads a few bits at a time, so maintain an internal offset
74     int bit_offset;
75 } istream_t;
76 
77 /// The following two functions are the only ones that allow the istream to be
78 /// non-byte aligned
79 
80 /// Reads `num` bits from a bitstream, and updates the internal offset
81 static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
82 /// Backs-up the stream by `num` bits so they can be read again
83 static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
84 /// If the remaining bits in a byte will be unused, advance to the end of the
85 /// byte
86 static inline void IO_align_stream(istream_t *const in);
87 
88 /// Write the given byte into the output stream
89 static inline void IO_write_byte(ostream_t *const out, u8 symb);
90 
91 /// Returns the number of bytes left to be read in this stream.  The stream must
92 /// be byte aligned.
93 static inline size_t IO_istream_len(const istream_t *const in);
94 
95 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
96 /// was skipped.  The stream must be byte aligned.
97 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
98 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
99 /// was skipped so it can be written to.
100 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
101 
102 /// Advance the inner state by `len` bytes.  The stream must be byte aligned.
103 static inline void IO_advance_input(istream_t *const in, size_t len);
104 
105 /// Returns an `ostream_t` constructed from the given pointer and length.
106 static inline ostream_t IO_make_ostream(u8 *out, size_t len);
107 /// Returns an `istream_t` constructed from the given pointer and length.
108 static inline istream_t IO_make_istream(const u8 *in, size_t len);
109 
110 /// Returns an `istream_t` with the same base as `in`, and length `len`.
111 /// Then, advance `in` to account for the consumed bytes.
112 /// `in` must be byte aligned.
113 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
114 /*** END IO STREAM OPERATIONS *********/
115 
116 /*** BITSTREAM OPERATIONS *************/
117 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
118 /// and return them interpreted as a little-endian unsigned integer.
119 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
120                                const size_t offset);
121 
122 /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
123 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
124 /// `src + offset`.  If the offset becomes negative, the extra bits at the
125 /// bottom are filled in with `0` bits instead of reading from before `src`.
126 static inline u64 STREAM_read_bits(const u8 *src, const int bits,
127                                    i64 *const offset);
128 /*** END BITSTREAM OPERATIONS *********/
129 
130 /*** BIT COUNTING OPERATIONS **********/
131 /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
132 static inline int highest_set_bit(const u64 num);
133 /*** END BIT COUNTING OPERATIONS ******/
134 
135 /*** HUFFMAN PRIMITIVES ***************/
136 // Table decode method uses exponential memory, so we need to limit depth
137 #define HUF_MAX_BITS (16)
138 
139 // Limit the maximum number of symbols to 256 so we can store a symbol in a byte
140 #define HUF_MAX_SYMBS (256)
141 
142 /// Structure containing all tables necessary for efficient Huffman decoding
143 typedef struct {
144     u8 *symbols;
145     u8 *num_bits;
146     int max_bits;
147 } HUF_dtable;
148 
149 /// Decode a single symbol and read in enough bits to refresh the state
150 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
151                                    u16 *const state, const u8 *const src,
152                                    i64 *const offset);
153 /// Read in a full state's worth of bits to initialize it
154 static inline void HUF_init_state(const HUF_dtable *const dtable,
155                                   u16 *const state, const u8 *const src,
156                                   i64 *const offset);
157 
158 /// Decompresses a single Huffman stream, returns the number of bytes decoded.
159 /// `src_len` must be the exact length of the Huffman-coded block.
160 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
161                                      ostream_t *const out, istream_t *const in);
162 /// Same as previous but decodes 4 streams, formatted as in the Zstandard
163 /// specification.
164 /// `src_len` must be the exact length of the Huffman-coded block.
165 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
166                                      ostream_t *const out, istream_t *const in);
167 
168 /// Initialize a Huffman decoding table using the table of bit counts provided
169 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
170                             const int num_symbs);
171 /// Initialize a Huffman decoding table using the table of weights provided
172 /// Weights follow the definition provided in the Zstandard specification
173 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
174                                          const u8 *const weights,
175                                          const int num_symbs);
176 
177 /// Free the malloc'ed parts of a decoding table
178 static void HUF_free_dtable(HUF_dtable *const dtable);
179 
180 /// Deep copy a decoding table, so that it can be used and free'd without
181 /// impacting the source table.
182 static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src);
183 /*** END HUFFMAN PRIMITIVES ***********/
184 
185 /*** FSE PRIMITIVES *******************/
186 /// For more description of FSE see
187 /// https://github.com/Cyan4973/FiniteStateEntropy/
188 
189 // FSE table decoding uses exponential memory, so limit the maximum accuracy
190 #define FSE_MAX_ACCURACY_LOG (15)
191 // Limit the maximum number of symbols so they can be stored in a single byte
192 #define FSE_MAX_SYMBS (256)
193 
194 /// The tables needed to decode FSE encoded streams
195 typedef struct {
196     u8 *symbols;
197     u8 *num_bits;
198     u16 *new_state_base;
199     int accuracy_log;
200 } FSE_dtable;
201 
202 /// Return the symbol for the current state
203 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
204                                  const u16 state);
205 /// Read the number of bits necessary to update state, update, and shift offset
206 /// back to reflect the bits read
207 static inline void FSE_update_state(const FSE_dtable *const dtable,
208                                     u16 *const state, const u8 *const src,
209                                     i64 *const offset);
210 
211 /// Combine peek and update: decode a symbol and update the state
212 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
213                                    u16 *const state, const u8 *const src,
214                                    i64 *const offset);
215 
216 /// Read bits from the stream to initialize the state and shift offset back
217 static inline void FSE_init_state(const FSE_dtable *const dtable,
218                                   u16 *const state, const u8 *const src,
219                                   i64 *const offset);
220 
221 /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
222 /// using an FSE decoding table.  `src_len` must be the exact length of the
223 /// block.
224 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
225                                           ostream_t *const out,
226                                           istream_t *const in);
227 
228 /// Initialize a decoding table using normalized frequencies.
229 static void FSE_init_dtable(FSE_dtable *const dtable,
230                             const i16 *const norm_freqs, const int num_symbs,
231                             const int accuracy_log);
232 
233 /// Decode an FSE header as defined in the Zstandard format specification and
234 /// use the decoded frequencies to initialize a decoding table.
235 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
236                                 const int max_accuracy_log);
237 
238 /// Initialize an FSE table that will always return the same symbol and consume
239 /// 0 bits per symbol, to be used for RLE mode in sequence commands
240 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
241 
242 /// Free the malloc'ed parts of a decoding table
243 static void FSE_free_dtable(FSE_dtable *const dtable);
244 
245 /// Deep copy a decoding table, so that it can be used and free'd without
246 /// impacting the source table.
247 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src);
248 /*** END FSE PRIMITIVES ***************/
249 
250 /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
251 
252 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
253 
254 /// A small structure that can be reused in various places that need to access
255 /// frame header information
256 typedef struct {
257     // The size of window that we need to be able to contiguously store for
258     // references
259     size_t window_size;
260     // The total output size of this compressed frame
261     size_t frame_content_size;
262 
263     // The dictionary id if this frame uses one
264     u32 dictionary_id;
265 
266     // Whether or not the content of this frame has a checksum
267     int content_checksum_flag;
268     // Whether or not the output for this frame is in a single segment
269     int single_segment_flag;
270 } frame_header_t;
271 
272 /// The context needed to decode blocks in a frame
273 typedef struct {
274     frame_header_t header;
275 
276     // The total amount of data available for backreferences, to determine if an
277     // offset too large to be correct
278     size_t current_total_output;
279 
280     const u8 *dict_content;
281     size_t dict_content_len;
282 
283     // Entropy encoding tables so they can be repeated by future blocks instead
284     // of retransmitting
285     HUF_dtable literals_dtable;
286     FSE_dtable ll_dtable;
287     FSE_dtable ml_dtable;
288     FSE_dtable of_dtable;
289 
290     // The last 3 offsets for the special "repeat offsets".
291     u64 previous_offsets[3];
292 } frame_context_t;
293 
294 /// The decoded contents of a dictionary so that it doesn't have to be repeated
295 /// for each frame that uses it
296 struct dictionary_s {
297     // Entropy tables
298     HUF_dtable literals_dtable;
299     FSE_dtable ll_dtable;
300     FSE_dtable ml_dtable;
301     FSE_dtable of_dtable;
302 
303     // Raw content for backreferences
304     u8 *content;
305     size_t content_size;
306 
307     // Offset history to prepopulate the frame's history
308     u64 previous_offsets[3];
309 
310     u32 dictionary_id;
311 };
312 
313 /// A tuple containing the parts necessary to decode and execute a ZSTD sequence
314 /// command
315 typedef struct {
316     u32 literal_length;
317     u32 match_length;
318     u32 offset;
319 } sequence_command_t;
320 
321 /// The decoder works top-down, starting at the high level like Zstd frames, and
322 /// working down to lower more technical levels such as blocks, literals, and
323 /// sequences.  The high-level functions roughly follow the outline of the
324 /// format specification:
325 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
326 
327 /// Before the implementation of each high-level function declared here, the
328 /// prototypes for their helper functions are defined and explained
329 
330 /// Decode a single Zstd frame, or error if the input is not a valid frame.
331 /// Accepts a dict argument, which may be NULL indicating no dictionary.
332 /// See
333 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
334 static void decode_frame(ostream_t *const out, istream_t *const in,
335                          const dictionary_t *const dict);
336 
337 // Decode data in a compressed block
338 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
339                              istream_t *const in);
340 
341 // Decode the literals section of a block
342 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
343                               u8 **const literals);
344 
345 // Decode the sequences part of a block
346 static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
347                                sequence_command_t **const sequences);
348 
349 // Execute the decoded sequences on the literals block
350 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
351                               const u8 *const literals,
352                               const size_t literals_len,
353                               const sequence_command_t *const sequences,
354                               const size_t num_sequences);
355 
356 // Copies literals and returns the total literal length that was copied
357 static u32 copy_literals(const size_t seq, istream_t *litstream,
358                          ostream_t *const out);
359 
360 // Given an offset code from a sequence command (either an actual offset value
361 // or an index for previous offset), computes the correct offset and updates
362 // the offset history
363 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
364 
365 // Given an offset, match length, and total output, as well as the frame
366 // context for the dictionary, determines if the dictionary is used and
367 // executes the copy operation
368 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
369                               size_t match_length, size_t total_output,
370                               ostream_t *const out);
371 
372 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
373 
374 size_t ZSTD_decompress(void *const dst, const size_t dst_len,
375                        const void *const src, const size_t src_len) {
376     dictionary_t* uninit_dict = create_dictionary();
377     size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
378                                                          src_len, uninit_dict);
379     free_dictionary(uninit_dict);
380     return decomp_size;
381 }
382 
383 size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
384                                  const void *const src, const size_t src_len,
385                                  dictionary_t* parsed_dict) {
386 
387     istream_t in = IO_make_istream(src, src_len);
388     ostream_t out = IO_make_ostream(dst, dst_len);
389 
390     // "A content compressed by Zstandard is transformed into a Zstandard frame.
391     // Multiple frames can be appended into a single file or stream. A frame is
392     // totally independent, has a defined beginning and end, and a set of
393     // parameters which tells the decoder how to decompress it."
394 
395     /* this decoder assumes decompression of a single frame */
396     decode_frame(&out, &in, parsed_dict);
397 
398     return (size_t)(out.ptr - (u8 *)dst);
399 }
400 
401 /******* FRAME DECODING ******************************************************/
402 
403 static void decode_data_frame(ostream_t *const out, istream_t *const in,
404                               const dictionary_t *const dict);
405 static void init_frame_context(frame_context_t *const context,
406                                istream_t *const in,
407                                const dictionary_t *const dict);
408 static void free_frame_context(frame_context_t *const context);
409 static void parse_frame_header(frame_header_t *const header,
410                                istream_t *const in);
411 static void frame_context_apply_dict(frame_context_t *const ctx,
412                                      const dictionary_t *const dict);
413 
414 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
415                             istream_t *const in);
416 
417 static void decode_frame(ostream_t *const out, istream_t *const in,
418                          const dictionary_t *const dict) {
419     const u32 magic_number = (u32)IO_read_bits(in, 32);
420     // Zstandard frame
421     //
422     // "Magic_Number
423     //
424     // 4 Bytes, little-endian format. Value : 0xFD2FB528"
425     if (magic_number == 0xFD2FB528U) {
426         // ZSTD frame
427         decode_data_frame(out, in, dict);
428 
429         return;
430     }
431 
432     // not a real frame or a skippable frame
433     ERROR("Tried to decode non-ZSTD frame");
434 }
435 
436 /// Decode a frame that contains compressed data.  Not all frames do as there
437 /// are skippable frames.
438 /// See
439 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
440 static void decode_data_frame(ostream_t *const out, istream_t *const in,
441                               const dictionary_t *const dict) {
442     frame_context_t ctx;
443 
444     // Initialize the context that needs to be carried from block to block
445     init_frame_context(&ctx, in, dict);
446 
447     if (ctx.header.frame_content_size != 0 &&
448         ctx.header.frame_content_size > out->len) {
449         OUT_SIZE();
450     }
451 
452     decompress_data(&ctx, out, in);
453 
454     free_frame_context(&ctx);
455 }
456 
457 /// Takes the information provided in the header and dictionary, and initializes
458 /// the context for this frame
459 static void init_frame_context(frame_context_t *const context,
460                                istream_t *const in,
461                                const dictionary_t *const dict) {
462     // Most fields in context are correct when initialized to 0
463     memset(context, 0, sizeof(frame_context_t));
464 
465     // Parse data from the frame header
466     parse_frame_header(&context->header, in);
467 
468     // Set up the offset history for the repeat offset commands
469     context->previous_offsets[0] = 1;
470     context->previous_offsets[1] = 4;
471     context->previous_offsets[2] = 8;
472 
473     // Apply details from the dict if it exists
474     frame_context_apply_dict(context, dict);
475 }
476 
477 static void free_frame_context(frame_context_t *const context) {
478     HUF_free_dtable(&context->literals_dtable);
479 
480     FSE_free_dtable(&context->ll_dtable);
481     FSE_free_dtable(&context->ml_dtable);
482     FSE_free_dtable(&context->of_dtable);
483 
484     memset(context, 0, sizeof(frame_context_t));
485 }
486 
487 static void parse_frame_header(frame_header_t *const header,
488                                istream_t *const in) {
489     // "The first header's byte is called the Frame_Header_Descriptor. It tells
490     // which other fields are present. Decoding this byte is enough to tell the
491     // size of Frame_Header.
492     //
493     // Bit number   Field name
494     // 7-6  Frame_Content_Size_flag
495     // 5    Single_Segment_flag
496     // 4    Unused_bit
497     // 3    Reserved_bit
498     // 2    Content_Checksum_flag
499     // 1-0  Dictionary_ID_flag"
500     const u8 descriptor = (u8)IO_read_bits(in, 8);
501 
502     // decode frame header descriptor into flags
503     const u8 frame_content_size_flag = descriptor >> 6;
504     const u8 single_segment_flag = (descriptor >> 5) & 1;
505     const u8 reserved_bit = (descriptor >> 3) & 1;
506     const u8 content_checksum_flag = (descriptor >> 2) & 1;
507     const u8 dictionary_id_flag = descriptor & 3;
508 
509     if (reserved_bit != 0) {
510         CORRUPTION();
511     }
512 
513     header->single_segment_flag = single_segment_flag;
514     header->content_checksum_flag = content_checksum_flag;
515 
516     // decode window size
517     if (!single_segment_flag) {
518         // "Provides guarantees on maximum back-reference distance that will be
519         // used within compressed data. This information is important for
520         // decoders to allocate enough memory.
521         //
522         // Bit numbers  7-3         2-0
523         // Field name   Exponent    Mantissa"
524         u8 window_descriptor = (u8)IO_read_bits(in, 8);
525         u8 exponent = window_descriptor >> 3;
526         u8 mantissa = window_descriptor & 7;
527 
528         // Use the algorithm from the specification to compute window size
529         // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
530         size_t window_base = (size_t)1 << (10 + exponent);
531         size_t window_add = (window_base / 8) * mantissa;
532         header->window_size = window_base + window_add;
533     }
534 
535     // decode dictionary id if it exists
536     if (dictionary_id_flag) {
537         // "This is a variable size field, which contains the ID of the
538         // dictionary required to properly decode the frame. Note that this
539         // field is optional. When it's not present, it's up to the caller to
540         // make sure it uses the correct dictionary. Format is little-endian."
541         const int bytes_array[] = {0, 1, 2, 4};
542         const int bytes = bytes_array[dictionary_id_flag];
543 
544         header->dictionary_id = (u32)IO_read_bits(in, bytes * 8);
545     } else {
546         header->dictionary_id = 0;
547     }
548 
549     // decode frame content size if it exists
550     if (single_segment_flag || frame_content_size_flag) {
551         // "This is the original (uncompressed) size. This information is
552         // optional. The Field_Size is provided according to value of
553         // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
554         // present), 1, 2, 4 or 8 bytes. Format is little-endian."
555         //
556         // if frame_content_size_flag == 0 but single_segment_flag is set, we
557         // still have a 1 byte field
558         const int bytes_array[] = {1, 2, 4, 8};
559         const int bytes = bytes_array[frame_content_size_flag];
560 
561         header->frame_content_size = IO_read_bits(in, bytes * 8);
562         if (bytes == 2) {
563             // "When Field_Size is 2, the offset of 256 is added."
564             header->frame_content_size += 256;
565         }
566     } else {
567         header->frame_content_size = 0;
568     }
569 
570     if (single_segment_flag) {
571         // "The Window_Descriptor byte is optional. It is absent when
572         // Single_Segment_flag is set. In this case, the maximum back-reference
573         // distance is the content size itself, which can be any value from 1 to
574         // 2^64-1 bytes (16 EB)."
575         header->window_size = header->frame_content_size;
576     }
577 }
578 
579 /// A dictionary acts as initializing values for the frame context before
580 /// decompression, so we implement it by applying it's predetermined
581 /// tables and content to the context before beginning decompression
582 static void frame_context_apply_dict(frame_context_t *const ctx,
583                                      const dictionary_t *const dict) {
584     // If the content pointer is NULL then it must be an empty dict
585     if (!dict || !dict->content)
586         return;
587 
588     // If the requested dictionary_id is non-zero, the correct dictionary must
589     // be present
590     if (ctx->header.dictionary_id != 0 &&
591         ctx->header.dictionary_id != dict->dictionary_id) {
592         ERROR("Wrong dictionary provided");
593     }
594 
595     // Copy the dict content to the context for references during sequence
596     // execution
597     ctx->dict_content = dict->content;
598     ctx->dict_content_len = dict->content_size;
599 
600     // If it's a formatted dict copy the precomputed tables in so they can
601     // be used in the table repeat modes
602     if (dict->dictionary_id != 0) {
603         // Deep copy the entropy tables so they can be freed independently of
604         // the dictionary struct
605         HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
606         FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
607         FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
608         FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
609 
610         // Copy the repeated offsets
611         memcpy(ctx->previous_offsets, dict->previous_offsets,
612                sizeof(ctx->previous_offsets));
613     }
614 }
615 
616 /// Decompress the data from a frame block by block
617 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
618                             istream_t *const in) {
619     // "A frame encapsulates one or multiple blocks. Each block can be
620     // compressed or not, and has a guaranteed maximum content size, which
621     // depends on frame parameters. Unlike frames, each block depends on
622     // previous blocks for proper decoding. However, each block can be
623     // decompressed without waiting for its successor, allowing streaming
624     // operations."
625     int last_block = 0;
626     do {
627         // "Last_Block
628         //
629         // The lowest bit signals if this block is the last one. Frame ends
630         // right after this block.
631         //
632         // Block_Type and Block_Size
633         //
634         // The next 2 bits represent the Block_Type, while the remaining 21 bits
635         // represent the Block_Size. Format is little-endian."
636         last_block = (int)IO_read_bits(in, 1);
637         const int block_type = (int)IO_read_bits(in, 2);
638         const size_t block_len = IO_read_bits(in, 21);
639 
640         switch (block_type) {
641         case 0: {
642             // "Raw_Block - this is an uncompressed block. Block_Size is the
643             // number of bytes to read and copy."
644             const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
645             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
646 
647             // Copy the raw data into the output
648             memcpy(write_ptr, read_ptr, block_len);
649 
650             ctx->current_total_output += block_len;
651             break;
652         }
653         case 1: {
654             // "RLE_Block - this is a single byte, repeated N times. In which
655             // case, Block_Size is the size to regenerate, while the
656             // "compressed" block is just 1 byte (the byte to repeat)."
657             const u8 *const read_ptr = IO_get_read_ptr(in, 1);
658             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
659 
660             // Copy `block_len` copies of `read_ptr[0]` to the output
661             memset(write_ptr, read_ptr[0], block_len);
662 
663             ctx->current_total_output += block_len;
664             break;
665         }
666         case 2: {
667             // "Compressed_Block - this is a Zstandard compressed block,
668             // detailed in another section of this specification. Block_Size is
669             // the compressed size.
670 
671             // Create a sub-stream for the block
672             istream_t block_stream = IO_make_sub_istream(in, block_len);
673             decompress_block(ctx, out, &block_stream);
674             break;
675         }
676         case 3:
677             // "Reserved - this is not a block. This value cannot be used with
678             // current version of this specification."
679             CORRUPTION();
680             break;
681         default:
682             IMPOSSIBLE();
683         }
684     } while (!last_block);
685 
686     if (ctx->header.content_checksum_flag) {
687         // This program does not support checking the checksum, so skip over it
688         // if it's present
689         IO_advance_input(in, 4);
690     }
691 }
692 /******* END FRAME DECODING ***************************************************/
693 
694 /******* BLOCK DECOMPRESSION **************************************************/
695 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
696                              istream_t *const in) {
697     // "A compressed block consists of 2 sections :
698     //
699     // Literals_Section
700     // Sequences_Section"
701 
702 
703     // Part 1: decode the literals block
704     u8 *literals = NULL;
705     const size_t literals_size = decode_literals(ctx, in, &literals);
706 
707     // Part 2: decode the sequences block
708     sequence_command_t *sequences = NULL;
709     const size_t num_sequences =
710         decode_sequences(ctx, in, &sequences);
711 
712     // Part 3: combine literals and sequence commands to generate output
713     execute_sequences(ctx, out, literals, literals_size, sequences,
714                       num_sequences);
715     free(literals);
716     free(sequences);
717 }
718 /******* END BLOCK DECOMPRESSION **********************************************/
719 
720 /******* LITERALS DECODING ****************************************************/
721 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
722                                      const int block_type,
723                                      const int size_format);
724 static size_t decode_literals_compressed(frame_context_t *const ctx,
725                                          istream_t *const in,
726                                          u8 **const literals,
727                                          const int block_type,
728                                          const int size_format);
729 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
730 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
731                                     int *const num_symbs);
732 
733 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
734                               u8 **const literals) {
735     // "Literals can be stored uncompressed or compressed using Huffman prefix
736     // codes. When compressed, an optional tree description can be present,
737     // followed by 1 or 4 streams."
738     //
739     // "Literals_Section_Header
740     //
741     // Header is in charge of describing how literals are packed. It's a
742     // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
743     // little-endian convention."
744     //
745     // "Literals_Block_Type
746     //
747     // This field uses 2 lowest bits of first byte, describing 4 different block
748     // types"
749     //
750     // size_format takes between 1 and 2 bits
751     int block_type = (int)IO_read_bits(in, 2);
752     int size_format = (int)IO_read_bits(in, 2);
753 
754     if (block_type <= 1) {
755         // Raw or RLE literals block
756         return decode_literals_simple(in, literals, block_type,
757                                       size_format);
758     } else {
759         // Huffman compressed literals
760         return decode_literals_compressed(ctx, in, literals, block_type,
761                                           size_format);
762     }
763 }
764 
765 /// Decodes literals blocks in raw or RLE form
766 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
767                                      const int block_type,
768                                      const int size_format) {
769     size_t size;
770     switch (size_format) {
771     // These cases are in the form ?0
772     // In this case, the ? bit is actually part of the size field
773     case 0:
774     case 2:
775         // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
776         IO_rewind_bits(in, 1);
777         size = IO_read_bits(in, 5);
778         break;
779     case 1:
780         // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
781         size = IO_read_bits(in, 12);
782         break;
783     case 3:
784         // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
785         size = IO_read_bits(in, 20);
786         break;
787     default:
788         // Size format is in range 0-3
789         IMPOSSIBLE();
790     }
791 
792     if (size > MAX_LITERALS_SIZE) {
793         CORRUPTION();
794     }
795 
796     *literals = malloc(size);
797     if (!*literals) {
798         BAD_ALLOC();
799     }
800 
801     switch (block_type) {
802     case 0: {
803         // "Raw_Literals_Block - Literals are stored uncompressed."
804         const u8 *const read_ptr = IO_get_read_ptr(in, size);
805         memcpy(*literals, read_ptr, size);
806         break;
807     }
808     case 1: {
809         // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
810         const u8 *const read_ptr = IO_get_read_ptr(in, 1);
811         memset(*literals, read_ptr[0], size);
812         break;
813     }
814     default:
815         IMPOSSIBLE();
816     }
817 
818     return size;
819 }
820 
821 /// Decodes Huffman compressed literals
822 static size_t decode_literals_compressed(frame_context_t *const ctx,
823                                          istream_t *const in,
824                                          u8 **const literals,
825                                          const int block_type,
826                                          const int size_format) {
827     size_t regenerated_size, compressed_size;
828     // Only size_format=0 has 1 stream, so default to 4
829     int num_streams = 4;
830     switch (size_format) {
831     case 0:
832         // "A single stream. Both Compressed_Size and Regenerated_Size use 10
833         // bits (0-1023)."
834         num_streams = 1;
835     // Fall through as it has the same size format
836         /* fallthrough */
837     case 1:
838         // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
839         // (0-1023)."
840         regenerated_size = IO_read_bits(in, 10);
841         compressed_size = IO_read_bits(in, 10);
842         break;
843     case 2:
844         // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
845         // (0-16383)."
846         regenerated_size = IO_read_bits(in, 14);
847         compressed_size = IO_read_bits(in, 14);
848         break;
849     case 3:
850         // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
851         // (0-262143)."
852         regenerated_size = IO_read_bits(in, 18);
853         compressed_size = IO_read_bits(in, 18);
854         break;
855     default:
856         // Impossible
857         IMPOSSIBLE();
858     }
859     if (regenerated_size > MAX_LITERALS_SIZE) {
860         CORRUPTION();
861     }
862 
863     *literals = malloc(regenerated_size);
864     if (!*literals) {
865         BAD_ALLOC();
866     }
867 
868     ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
869     istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
870 
871     if (block_type == 2) {
872         // Decode the provided Huffman table
873         // "This section is only present when Literals_Block_Type type is
874         // Compressed_Literals_Block (2)."
875 
876         HUF_free_dtable(&ctx->literals_dtable);
877         decode_huf_table(&ctx->literals_dtable, &huf_stream);
878     } else {
879         // If the previous Huffman table is being repeated, ensure it exists
880         if (!ctx->literals_dtable.symbols) {
881             CORRUPTION();
882         }
883     }
884 
885     size_t symbols_decoded;
886     if (num_streams == 1) {
887         symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
888     } else {
889         symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
890     }
891 
892     if (symbols_decoded != regenerated_size) {
893         CORRUPTION();
894     }
895 
896     return regenerated_size;
897 }
898 
899 // Decode the Huffman table description
900 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
901     // "All literal values from zero (included) to last present one (excluded)
902     // are represented by Weight with values from 0 to Max_Number_of_Bits."
903 
904     // "This is a single byte value (0-255), which describes how to decode the list of weights."
905     const u8 header = IO_read_bits(in, 8);
906 
907     u8 weights[HUF_MAX_SYMBS];
908     memset(weights, 0, sizeof(weights));
909 
910     int num_symbs;
911 
912     if (header >= 128) {
913         // "This is a direct representation, where each Weight is written
914         // directly as a 4 bits field (0-15). The full representation occupies
915         // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
916         // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
917         // 127"
918         num_symbs = header - 127;
919         const size_t bytes = (num_symbs + 1) / 2;
920 
921         const u8 *const weight_src = IO_get_read_ptr(in, bytes);
922 
923         for (int i = 0; i < num_symbs; i++) {
924             // "They are encoded forward, 2
925             // weights to a byte with the first weight taking the top four bits
926             // and the second taking the bottom four (e.g. the following
927             // operations could be used to read the weights: Weight[0] =
928             // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
929             if (i % 2 == 0) {
930                 weights[i] = weight_src[i / 2] >> 4;
931             } else {
932                 weights[i] = weight_src[i / 2] & 0xf;
933             }
934         }
935     } else {
936         // The weights are FSE encoded, decode them before we can construct the
937         // table
938         istream_t fse_stream = IO_make_sub_istream(in, header);
939         ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
940         fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
941     }
942 
943     // Construct the table using the decoded weights
944     HUF_init_dtable_usingweights(dtable, weights, num_symbs);
945 }
946 
947 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
948                                     int *const num_symbs) {
949     const int MAX_ACCURACY_LOG = 7;
950 
951     FSE_dtable dtable;
952 
953     // "An FSE bitstream starts by a header, describing probabilities
954     // distribution. It will create a Decoding Table. For a list of Huffman
955     // weights, maximum accuracy is 7 bits."
956     FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
957 
958     // Decode the weights
959     *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
960 
961     FSE_free_dtable(&dtable);
962 }
963 /******* END LITERALS DECODING ************************************************/
964 
965 /******* SEQUENCE DECODING ****************************************************/
966 /// The combination of FSE states needed to decode sequences
967 typedef struct {
968     FSE_dtable ll_table;
969     FSE_dtable of_table;
970     FSE_dtable ml_table;
971 
972     u16 ll_state;
973     u16 of_state;
974     u16 ml_state;
975 } sequence_states_t;
976 
977 /// Different modes to signal to decode_seq_tables what to do
978 typedef enum {
979     seq_literal_length = 0,
980     seq_offset = 1,
981     seq_match_length = 2,
982 } seq_part_t;
983 
984 typedef enum {
985     seq_predefined = 0,
986     seq_rle = 1,
987     seq_fse = 2,
988     seq_repeat = 3,
989 } seq_mode_t;
990 
991 /// The predefined FSE distribution tables for `seq_predefined` mode
992 static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
993     4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,  1,  2,  2,
994     2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
995 static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
996     1, 1, 1, 1, 1, 1, 2, 2, 2, 1,  1,  1,  1,  1, 1,
997     1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
998 static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
999     1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1,  1,  1,  1,  1,  1,  1, 1,
1000     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 1,
1001     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
1002 
1003 /// The sequence decoding baseline and number of additional bits to read/add
1004 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
1005 static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
1006     0,  1,  2,   3,   4,   5,    6,    7,    8,    9,     10,    11,
1007     12, 13, 14,  15,  16,  18,   20,   22,   24,   28,    32,    40,
1008     48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
1009 static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
1010     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  1,  1,
1011     1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
1012 
1013 static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
1014     3,  4,   5,   6,   7,    8,    9,    10,   11,    12,    13,   14, 15, 16,
1015     17, 18,  19,  20,  21,   22,   23,   24,   25,    26,    27,   28, 29, 30,
1016     31, 32,  33,  34,  35,   37,   39,   41,   43,    47,    51,   59, 67, 83,
1017     99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
1018 static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
1019     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,
1020     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  1,  1,  1, 1,
1021     2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
1022 
1023 /// Offset decoding is simpler so we just need a maximum code value
1024 static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52};
1025 
1026 static void decompress_sequences(frame_context_t *const ctx,
1027                                  istream_t *const in,
1028                                  sequence_command_t *const sequences,
1029                                  const size_t num_sequences);
1030 static sequence_command_t decode_sequence(sequence_states_t *const state,
1031                                           const u8 *const src,
1032                                           i64 *const offset);
1033 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1034                                const seq_part_t type, const seq_mode_t mode);
1035 
1036 static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
1037                                sequence_command_t **const sequences) {
1038     // "A compressed block is a succession of sequences . A sequence is a
1039     // literal copy command, followed by a match copy command. A literal copy
1040     // command specifies a length. It is the number of bytes to be copied (or
1041     // extracted) from the literal section. A match copy command specifies an
1042     // offset and a length. The offset gives the position to copy from, which
1043     // can be within a previous block."
1044 
1045     size_t num_sequences;
1046 
1047     // "Number_of_Sequences
1048     //
1049     // This is a variable size field using between 1 and 3 bytes. Let's call its
1050     // first byte byte0."
1051     u8 header = IO_read_bits(in, 8);
1052     if (header == 0) {
1053         // "There are no sequences. The sequence section stops there.
1054         // Regenerated content is defined entirely by literals section."
1055         *sequences = NULL;
1056         return 0;
1057     } else if (header < 128) {
1058         // "Number_of_Sequences = byte0 . Uses 1 byte."
1059         num_sequences = header;
1060     } else if (header < 255) {
1061         // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
1062         num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
1063     } else {
1064         // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
1065         num_sequences = IO_read_bits(in, 16) + 0x7F00;
1066     }
1067 
1068     *sequences = malloc(num_sequences * sizeof(sequence_command_t));
1069     if (!*sequences) {
1070         BAD_ALLOC();
1071     }
1072 
1073     decompress_sequences(ctx, in, *sequences, num_sequences);
1074     return num_sequences;
1075 }
1076 
1077 /// Decompress the FSE encoded sequence commands
1078 static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
1079                                  sequence_command_t *const sequences,
1080                                  const size_t num_sequences) {
1081     // "The Sequences_Section regroup all symbols required to decode commands.
1082     // There are 3 symbol types : literals lengths, offsets and match lengths.
1083     // They are encoded together, interleaved, in a single bitstream."
1084 
1085     // "Symbol compression modes
1086     //
1087     // This is a single byte, defining the compression mode of each symbol
1088     // type."
1089     //
1090     // Bit number : Field name
1091     // 7-6        : Literals_Lengths_Mode
1092     // 5-4        : Offsets_Mode
1093     // 3-2        : Match_Lengths_Mode
1094     // 1-0        : Reserved
1095     u8 compression_modes = IO_read_bits(in, 8);
1096 
1097     if ((compression_modes & 3) != 0) {
1098         // Reserved bits set
1099         CORRUPTION();
1100     }
1101 
1102     // "Following the header, up to 3 distribution tables can be described. When
1103     // present, they are in this order :
1104     //
1105     // Literals lengths
1106     // Offsets
1107     // Match Lengths"
1108     // Update the tables we have stored in the context
1109     decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
1110                      (compression_modes >> 6) & 3);
1111 
1112     decode_seq_table(&ctx->of_dtable, in, seq_offset,
1113                      (compression_modes >> 4) & 3);
1114 
1115     decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
1116                      (compression_modes >> 2) & 3);
1117 
1118 
1119     sequence_states_t states;
1120 
1121     // Initialize the decoding tables
1122     {
1123         states.ll_table = ctx->ll_dtable;
1124         states.of_table = ctx->of_dtable;
1125         states.ml_table = ctx->ml_dtable;
1126     }
1127 
1128     const size_t len = IO_istream_len(in);
1129     const u8 *const src = IO_get_read_ptr(in, len);
1130 
1131     // "After writing the last bit containing information, the compressor writes
1132     // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
1133     const int padding = 8 - highest_set_bit(src[len - 1]);
1134     // The offset starts at the end because FSE streams are read backwards
1135     i64 bit_offset = (i64)(len * 8 - (size_t)padding);
1136 
1137     // "The bitstream starts with initial state values, each using the required
1138     // number of bits in their respective accuracy, decoded previously from
1139     // their normalized distribution.
1140     //
1141     // It starts by Literals_Length_State, followed by Offset_State, and finally
1142     // Match_Length_State."
1143     FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
1144     FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
1145     FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
1146 
1147     for (size_t i = 0; i < num_sequences; i++) {
1148         // Decode sequences one by one
1149         sequences[i] = decode_sequence(&states, src, &bit_offset);
1150     }
1151 
1152     if (bit_offset != 0) {
1153         CORRUPTION();
1154     }
1155 }
1156 
1157 // Decode a single sequence and update the state
1158 static sequence_command_t decode_sequence(sequence_states_t *const states,
1159                                           const u8 *const src,
1160                                           i64 *const offset) {
1161     // "Each symbol is a code in its own context, which specifies Baseline and
1162     // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
1163     // additional bits in the same bitstream."
1164 
1165     // Decode symbols, but don't update states
1166     const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
1167     const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
1168     const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
1169 
1170     // Offset doesn't need a max value as it's not decoded using a table
1171     if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
1172         ml_code > SEQ_MAX_CODES[seq_match_length]) {
1173         CORRUPTION();
1174     }
1175 
1176     // Read the interleaved bits
1177     sequence_command_t seq;
1178     // "Decoding starts by reading the Number_of_Bits required to decode Offset.
1179     // It then does the same for Match_Length, and then for Literals_Length."
1180     seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
1181 
1182     seq.match_length =
1183         SEQ_MATCH_LENGTH_BASELINES[ml_code] +
1184         STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
1185 
1186     seq.literal_length =
1187         SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
1188         STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
1189 
1190     // "If it is not the last sequence in the block, the next operation is to
1191     // update states. Using the rules pre-calculated in the decoding tables,
1192     // Literals_Length_State is updated, followed by Match_Length_State, and
1193     // then Offset_State."
1194     // If the stream is complete don't read bits to update state
1195     if (*offset != 0) {
1196         FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
1197         FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
1198         FSE_update_state(&states->of_table, &states->of_state, src, offset);
1199     }
1200 
1201     return seq;
1202 }
1203 
1204 /// Given a sequence part and table mode, decode the FSE distribution
1205 /// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
1206 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1207                              const seq_part_t type, const seq_mode_t mode) {
1208     // Constant arrays indexed by seq_part_t
1209     const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
1210                                                 SEQ_OFFSET_DEFAULT_DIST,
1211                                                 SEQ_MATCH_LENGTH_DEFAULT_DIST};
1212     const size_t default_distribution_lengths[] = {36, 29, 53};
1213     const size_t default_distribution_accuracies[] = {6, 5, 6};
1214 
1215     const size_t max_accuracies[] = {9, 8, 9};
1216 
1217     if (mode != seq_repeat) {
1218         // Free old one before overwriting
1219         FSE_free_dtable(table);
1220     }
1221 
1222     switch (mode) {
1223     case seq_predefined: {
1224         // "Predefined_Mode : uses a predefined distribution table."
1225         const i16 *distribution = default_distributions[type];
1226         const size_t symbs = default_distribution_lengths[type];
1227         const size_t accuracy_log = default_distribution_accuracies[type];
1228 
1229         FSE_init_dtable(table, distribution, symbs, accuracy_log);
1230         break;
1231     }
1232     case seq_rle: {
1233         // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
1234         const u8 symb = IO_get_read_ptr(in, 1)[0];
1235         FSE_init_dtable_rle(table, symb);
1236         break;
1237     }
1238     case seq_fse: {
1239         // "FSE_Compressed_Mode : standard FSE compression. A distribution table
1240         // will be present "
1241         FSE_decode_header(table, in, max_accuracies[type]);
1242         break;
1243     }
1244     case seq_repeat:
1245         // "Repeat_Mode : re-use distribution table from previous compressed
1246         // block."
1247         // Nothing to do here, table will be unchanged
1248         if (!table->symbols) {
1249             // This mode is invalid if we don't already have a table
1250             CORRUPTION();
1251         }
1252         break;
1253     default:
1254         // Impossible, as mode is from 0-3
1255         IMPOSSIBLE();
1256         break;
1257     }
1258 
1259 }
1260 /******* END SEQUENCE DECODING ************************************************/
1261 
1262 /******* SEQUENCE EXECUTION ***************************************************/
1263 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
1264                               const u8 *const literals,
1265                               const size_t literals_len,
1266                               const sequence_command_t *const sequences,
1267                               const size_t num_sequences) {
1268     istream_t litstream = IO_make_istream(literals, literals_len);
1269 
1270     u64 *const offset_hist = ctx->previous_offsets;
1271     size_t total_output = ctx->current_total_output;
1272 
1273     for (size_t i = 0; i < num_sequences; i++) {
1274         const sequence_command_t seq = sequences[i];
1275         {
1276             const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
1277             total_output += literals_size;
1278         }
1279 
1280         size_t const offset = compute_offset(seq, offset_hist);
1281 
1282         size_t const match_length = seq.match_length;
1283 
1284         execute_match_copy(ctx, offset, match_length, total_output, out);
1285 
1286         total_output += match_length;
1287     }
1288 
1289     // Copy any leftover literals
1290     {
1291         size_t len = IO_istream_len(&litstream);
1292         copy_literals(len, &litstream, out);
1293         total_output += len;
1294     }
1295 
1296     ctx->current_total_output = total_output;
1297 }
1298 
1299 static u32 copy_literals(const size_t literal_length, istream_t *litstream,
1300                          ostream_t *const out) {
1301     // If the sequence asks for more literals than are left, the
1302     // sequence must be corrupted
1303     if (literal_length > IO_istream_len(litstream)) {
1304         CORRUPTION();
1305     }
1306 
1307     u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
1308     const u8 *const read_ptr =
1309          IO_get_read_ptr(litstream, literal_length);
1310     // Copy literals to output
1311     memcpy(write_ptr, read_ptr, literal_length);
1312 
1313     return literal_length;
1314 }
1315 
1316 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
1317     size_t offset;
1318     // Offsets are special, we need to handle the repeat offsets
1319     if (seq.offset <= 3) {
1320         // "The first 3 values define a repeated offset and we will call
1321         // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
1322         // They are sorted in recency order, with Repeated_Offset1 meaning
1323         // 'most recent one'".
1324 
1325         // Use 0 indexing for the array
1326         u32 idx = seq.offset - 1;
1327         if (seq.literal_length == 0) {
1328             // "There is an exception though, when current sequence's
1329             // literals length is 0. In this case, repeated offsets are
1330             // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
1331             // Repeated_Offset2 becomes Repeated_Offset3, and
1332             // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
1333             idx++;
1334         }
1335 
1336         if (idx == 0) {
1337             offset = offset_hist[0];
1338         } else {
1339             // If idx == 3 then literal length was 0 and the offset was 3,
1340             // as per the exception listed above
1341             offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
1342 
1343             // If idx == 1 we don't need to modify offset_hist[2], since
1344             // we're using the second-most recent code
1345             if (idx > 1) {
1346                 offset_hist[2] = offset_hist[1];
1347             }
1348             offset_hist[1] = offset_hist[0];
1349             offset_hist[0] = offset;
1350         }
1351     } else {
1352         // When it's not a repeat offset:
1353         // "if (Offset_Value > 3) offset = Offset_Value - 3;"
1354         offset = seq.offset - 3;
1355 
1356         // Shift back history
1357         offset_hist[2] = offset_hist[1];
1358         offset_hist[1] = offset_hist[0];
1359         offset_hist[0] = offset;
1360     }
1361     return offset;
1362 }
1363 
1364 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
1365                               size_t match_length, size_t total_output,
1366                               ostream_t *const out) {
1367     u8 *write_ptr = IO_get_write_ptr(out, match_length);
1368     if (total_output <= ctx->header.window_size) {
1369         // In this case offset might go back into the dictionary
1370         if (offset > total_output + ctx->dict_content_len) {
1371             // The offset goes beyond even the dictionary
1372             CORRUPTION();
1373         }
1374 
1375         if (offset > total_output) {
1376             // "The rest of the dictionary is its content. The content act
1377             // as a "past" in front of data to compress or decompress, so it
1378             // can be referenced in sequence commands."
1379             const size_t dict_copy =
1380                 MIN(offset - total_output, match_length);
1381             const size_t dict_offset =
1382                 ctx->dict_content_len - (offset - total_output);
1383 
1384             memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
1385             write_ptr += dict_copy;
1386             match_length -= dict_copy;
1387         }
1388     } else if (offset > ctx->header.window_size) {
1389         CORRUPTION();
1390     }
1391 
1392     // We must copy byte by byte because the match length might be larger
1393     // than the offset
1394     // ex: if the output so far was "abc", a command with offset=3 and
1395     // match_length=6 would produce "abcabcabc" as the new output
1396     for (size_t j = 0; j < match_length; j++) {
1397         *write_ptr = *(write_ptr - offset);
1398         write_ptr++;
1399     }
1400 }
1401 /******* END SEQUENCE EXECUTION ***********************************************/
1402 
1403 /******* OUTPUT SIZE COUNTING *************************************************/
1404 /// Get the decompressed size of an input stream so memory can be allocated in
1405 /// advance.
1406 /// This implementation assumes `src` points to a single ZSTD-compressed frame
1407 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
1408     istream_t in = IO_make_istream(src, src_len);
1409 
1410     // get decompressed size from ZSTD frame header
1411     {
1412         const u32 magic_number = (u32)IO_read_bits(&in, 32);
1413 
1414         if (magic_number == 0xFD2FB528U) {
1415             // ZSTD frame
1416             frame_header_t header;
1417             parse_frame_header(&header, &in);
1418 
1419             if (header.frame_content_size == 0 && !header.single_segment_flag) {
1420                 // Content size not provided, we can't tell
1421                 return (size_t)-1;
1422             }
1423 
1424             return header.frame_content_size;
1425         } else {
1426             // not a real frame or skippable frame
1427             ERROR("ZSTD frame magic number did not match");
1428         }
1429     }
1430 }
1431 /******* END OUTPUT SIZE COUNTING *********************************************/
1432 
1433 /******* DICTIONARY PARSING ***************************************************/
1434 #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
1435 #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
1436 
1437 dictionary_t* create_dictionary() {
1438     dictionary_t* dict = calloc(1, sizeof(dictionary_t));
1439     if (!dict) {
1440         BAD_ALLOC();
1441     }
1442     return dict;
1443 }
1444 
1445 static void init_dictionary_content(dictionary_t *const dict,
1446                                     istream_t *const in);
1447 
1448 void parse_dictionary(dictionary_t *const dict, const void *src,
1449                              size_t src_len) {
1450     const u8 *byte_src = (const u8 *)src;
1451     memset(dict, 0, sizeof(dictionary_t));
1452     if (src == NULL) { /* cannot initialize dictionary with null src */
1453         NULL_SRC();
1454     }
1455     if (src_len < 8) {
1456         DICT_SIZE_ERROR();
1457     }
1458 
1459     istream_t in = IO_make_istream(byte_src, src_len);
1460 
1461     const u32 magic_number = IO_read_bits(&in, 32);
1462     if (magic_number != 0xEC30A437) {
1463         // raw content dict
1464         IO_rewind_bits(&in, 32);
1465         init_dictionary_content(dict, &in);
1466         return;
1467     }
1468 
1469     dict->dictionary_id = IO_read_bits(&in, 32);
1470 
1471     // "Entropy_Tables : following the same format as the tables in compressed
1472     // blocks. They are stored in following order : Huffman tables for literals,
1473     // FSE table for offsets, FSE table for match lengths, and FSE table for
1474     // literals lengths. It's finally followed by 3 offset values, populating
1475     // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
1476     // little-endian each, for a total of 12 bytes. Each recent offset must have
1477     // a value < dictionary size."
1478     decode_huf_table(&dict->literals_dtable, &in);
1479     decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
1480     decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
1481     decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
1482 
1483     // Read in the previous offset history
1484     dict->previous_offsets[0] = IO_read_bits(&in, 32);
1485     dict->previous_offsets[1] = IO_read_bits(&in, 32);
1486     dict->previous_offsets[2] = IO_read_bits(&in, 32);
1487 
1488     // Ensure the provided offsets aren't too large
1489     // "Each recent offset must have a value < dictionary size."
1490     for (int i = 0; i < 3; i++) {
1491         if (dict->previous_offsets[i] > src_len) {
1492             ERROR("Dictionary corrupted");
1493         }
1494     }
1495 
1496     // "Content : The rest of the dictionary is its content. The content act as
1497     // a "past" in front of data to compress or decompress, so it can be
1498     // referenced in sequence commands."
1499     init_dictionary_content(dict, &in);
1500 }
1501 
1502 static void init_dictionary_content(dictionary_t *const dict,
1503                                     istream_t *const in) {
1504     // Copy in the content
1505     dict->content_size = IO_istream_len(in);
1506     dict->content = malloc(dict->content_size);
1507     if (!dict->content) {
1508         BAD_ALLOC();
1509     }
1510 
1511     const u8 *const content = IO_get_read_ptr(in, dict->content_size);
1512 
1513     memcpy(dict->content, content, dict->content_size);
1514 }
1515 
1516 /// Free an allocated dictionary
1517 void free_dictionary(dictionary_t *const dict) {
1518     HUF_free_dtable(&dict->literals_dtable);
1519     FSE_free_dtable(&dict->ll_dtable);
1520     FSE_free_dtable(&dict->of_dtable);
1521     FSE_free_dtable(&dict->ml_dtable);
1522 
1523     free(dict->content);
1524 
1525     memset(dict, 0, sizeof(dictionary_t));
1526 
1527     free(dict);
1528 }
1529 /******* END DICTIONARY PARSING ***********************************************/
1530 
1531 /******* IO STREAM OPERATIONS *************************************************/
1532 
1533 /// Reads `num` bits from a bitstream, and updates the internal offset
1534 static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
1535     if (num_bits > 64 || num_bits <= 0) {
1536         ERROR("Attempt to read an invalid number of bits");
1537     }
1538 
1539     const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
1540     const size_t full_bytes = (num_bits + in->bit_offset) / 8;
1541     if (bytes > in->len) {
1542         INP_SIZE();
1543     }
1544 
1545     const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
1546 
1547     in->bit_offset = (num_bits + in->bit_offset) % 8;
1548     in->ptr += full_bytes;
1549     in->len -= full_bytes;
1550 
1551     return result;
1552 }
1553 
1554 /// If a non-zero number of bits have been read from the current byte, advance
1555 /// the offset to the next byte
1556 static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
1557     if (num_bits < 0) {
1558         ERROR("Attempting to rewind stream by a negative number of bits");
1559     }
1560 
1561     // move the offset back by `num_bits` bits
1562     const int new_offset = in->bit_offset - num_bits;
1563     // determine the number of whole bytes we have to rewind, rounding up to an
1564     // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
1565     const i64 bytes = -(new_offset - 7) / 8;
1566 
1567     in->ptr -= bytes;
1568     in->len += bytes;
1569     // make sure the resulting `bit_offset` is positive, as mod in C does not
1570     // convert numbers from negative to positive (e.g. -22 % 8 == -6)
1571     in->bit_offset = ((new_offset % 8) + 8) % 8;
1572 }
1573 
1574 /// If the remaining bits in a byte will be unused, advance to the end of the
1575 /// byte
1576 static inline void IO_align_stream(istream_t *const in) {
1577     if (in->bit_offset != 0) {
1578         if (in->len == 0) {
1579             INP_SIZE();
1580         }
1581         in->ptr++;
1582         in->len--;
1583         in->bit_offset = 0;
1584     }
1585 }
1586 
1587 /// Write the given byte into the output stream
1588 static inline void IO_write_byte(ostream_t *const out, u8 symb) {
1589     if (out->len == 0) {
1590         OUT_SIZE();
1591     }
1592 
1593     out->ptr[0] = symb;
1594     out->ptr++;
1595     out->len--;
1596 }
1597 
1598 /// Returns the number of bytes left to be read in this stream.  The stream must
1599 /// be byte aligned.
1600 static inline size_t IO_istream_len(const istream_t *const in) {
1601     return in->len;
1602 }
1603 
1604 /// Returns a pointer where `len` bytes can be read, and advances the internal
1605 /// state.  The stream must be byte aligned.
1606 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
1607     if (len > in->len) {
1608         INP_SIZE();
1609     }
1610     if (in->bit_offset != 0) {
1611         ERROR("Attempting to operate on a non-byte aligned stream");
1612     }
1613     const u8 *const ptr = in->ptr;
1614     in->ptr += len;
1615     in->len -= len;
1616 
1617     return ptr;
1618 }
1619 /// Returns a pointer to write `len` bytes to, and advances the internal state
1620 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
1621     if (len > out->len) {
1622         OUT_SIZE();
1623     }
1624     u8 *const ptr = out->ptr;
1625     out->ptr += len;
1626     out->len -= len;
1627 
1628     return ptr;
1629 }
1630 
1631 /// Advance the inner state by `len` bytes
1632 static inline void IO_advance_input(istream_t *const in, size_t len) {
1633     if (len > in->len) {
1634          INP_SIZE();
1635     }
1636     if (in->bit_offset != 0) {
1637         ERROR("Attempting to operate on a non-byte aligned stream");
1638     }
1639 
1640     in->ptr += len;
1641     in->len -= len;
1642 }
1643 
1644 /// Returns an `ostream_t` constructed from the given pointer and length
1645 static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
1646     return (ostream_t) { out, len };
1647 }
1648 
1649 /// Returns an `istream_t` constructed from the given pointer and length
1650 static inline istream_t IO_make_istream(const u8 *in, size_t len) {
1651     return (istream_t) { in, len, 0 };
1652 }
1653 
1654 /// Returns an `istream_t` with the same base as `in`, and length `len`
1655 /// Then, advance `in` to account for the consumed bytes
1656 /// `in` must be byte aligned
1657 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
1658     // Consume `len` bytes of the parent stream
1659     const u8 *const ptr = IO_get_read_ptr(in, len);
1660 
1661     // Make a substream using the pointer to those `len` bytes
1662     return IO_make_istream(ptr, len);
1663 }
1664 /******* END IO STREAM OPERATIONS *********************************************/
1665 
1666 /******* BITSTREAM OPERATIONS *************************************************/
1667 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
1668 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1669                                const size_t offset) {
1670     if (num_bits > 64) {
1671         ERROR("Attempt to read an invalid number of bits");
1672     }
1673 
1674     // Skip over bytes that aren't in range
1675     src += offset / 8;
1676     size_t bit_offset = offset % 8;
1677     u64 res = 0;
1678 
1679     int shift = 0;
1680     int left = num_bits;
1681     while (left > 0) {
1682         u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
1683         // Read the next byte, shift it to account for the offset, and then mask
1684         // out the top part if we don't need all the bits
1685         res += (((u64)*src++ >> bit_offset) & mask) << shift;
1686         shift += 8 - bit_offset;
1687         left -= 8 - bit_offset;
1688         bit_offset = 0;
1689     }
1690 
1691     return res;
1692 }
1693 
1694 /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
1695 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1696 /// `src + offset`.  If the offset becomes negative, the extra bits at the
1697 /// bottom are filled in with `0` bits instead of reading from before `src`.
1698 static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
1699                                    i64 *const offset) {
1700     *offset = *offset - bits;
1701     size_t actual_off = *offset;
1702     size_t actual_bits = bits;
1703     // Don't actually read bits from before the start of src, so if `*offset <
1704     // 0` fix actual_off and actual_bits to reflect the quantity to read
1705     if (*offset < 0) {
1706         actual_bits += *offset;
1707         actual_off = 0;
1708     }
1709     u64 res = read_bits_LE(src, actual_bits, actual_off);
1710 
1711     if (*offset < 0) {
1712         // Fill in the bottom "overflowed" bits with 0's
1713         res = -*offset >= 64 ? 0 : (res << -*offset);
1714     }
1715     return res;
1716 }
1717 /******* END BITSTREAM OPERATIONS *********************************************/
1718 
1719 /******* BIT COUNTING OPERATIONS **********************************************/
1720 /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
1721 /// `num`, or `-1` if `num == 0`.
1722 static inline int highest_set_bit(const u64 num) {
1723     for (int i = 63; i >= 0; i--) {
1724         if (((u64)1 << i) <= num) {
1725             return i;
1726         }
1727     }
1728     return -1;
1729 }
1730 /******* END BIT COUNTING OPERATIONS ******************************************/
1731 
1732 /******* HUFFMAN PRIMITIVES ***************************************************/
1733 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1734                                    u16 *const state, const u8 *const src,
1735                                    i64 *const offset) {
1736     // Look up the symbol and number of bits to read
1737     const u8 symb = dtable->symbols[*state];
1738     const u8 bits = dtable->num_bits[*state];
1739     const u16 rest = STREAM_read_bits(src, bits, offset);
1740     // Shift `bits` bits out of the state, keeping the low order bits that
1741     // weren't necessary to determine this symbol.  Then add in the new bits
1742     // read from the stream.
1743     *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
1744 
1745     return symb;
1746 }
1747 
1748 static inline void HUF_init_state(const HUF_dtable *const dtable,
1749                                   u16 *const state, const u8 *const src,
1750                                   i64 *const offset) {
1751     // Read in a full `dtable->max_bits` bits to initialize the state
1752     const u8 bits = dtable->max_bits;
1753     *state = STREAM_read_bits(src, bits, offset);
1754 }
1755 
1756 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1757                                      ostream_t *const out,
1758                                      istream_t *const in) {
1759     const size_t len = IO_istream_len(in);
1760     if (len == 0) {
1761         INP_SIZE();
1762     }
1763     const u8 *const src = IO_get_read_ptr(in, len);
1764 
1765     // "Each bitstream must be read backward, that is starting from the end down
1766     // to the beginning. Therefore it's necessary to know the size of each
1767     // bitstream.
1768     //
1769     // It's also necessary to know exactly which bit is the latest. This is
1770     // detected by a final bit flag : the highest bit of latest byte is a
1771     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
1772     // final-bit-flag itself is not part of the useful bitstream. Hence, the
1773     // last byte contains between 0 and 7 useful bits."
1774     const int padding = 8 - highest_set_bit(src[len - 1]);
1775 
1776     // Offset starts at the end because HUF streams are read backwards
1777     i64 bit_offset = len * 8 - padding;
1778     u16 state;
1779 
1780     HUF_init_state(dtable, &state, src, &bit_offset);
1781 
1782     size_t symbols_written = 0;
1783     while (bit_offset > -dtable->max_bits) {
1784         // Iterate over the stream, decoding one symbol at a time
1785         IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
1786         symbols_written++;
1787     }
1788     // "The process continues up to reading the required number of symbols per
1789     // stream. If a bitstream is not entirely and exactly consumed, hence
1790     // reaching exactly its beginning position with all bits consumed, the
1791     // decoding process is considered faulty."
1792 
1793     // When all symbols have been decoded, the final state value shouldn't have
1794     // any data from the stream, so it should have "read" dtable->max_bits from
1795     // before the start of `src`
1796     // Therefore `offset`, the edge to start reading new bits at, should be
1797     // dtable->max_bits before the start of the stream
1798     if (bit_offset != -dtable->max_bits) {
1799         CORRUPTION();
1800     }
1801 
1802     return symbols_written;
1803 }
1804 
1805 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1806                                      ostream_t *const out, istream_t *const in) {
1807     // "Compressed size is provided explicitly : in the 4-streams variant,
1808     // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
1809     // value represents the compressed size of one stream, in order. The last
1810     // stream size is deducted from total compressed size and from previously
1811     // decoded stream sizes"
1812     const size_t csize1 = IO_read_bits(in, 16);
1813     const size_t csize2 = IO_read_bits(in, 16);
1814     const size_t csize3 = IO_read_bits(in, 16);
1815 
1816     istream_t in1 = IO_make_sub_istream(in, csize1);
1817     istream_t in2 = IO_make_sub_istream(in, csize2);
1818     istream_t in3 = IO_make_sub_istream(in, csize3);
1819     istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
1820 
1821     size_t total_output = 0;
1822     // Decode each stream independently for simplicity
1823     // If we wanted to we could decode all 4 at the same time for speed,
1824     // utilizing more execution units
1825     total_output += HUF_decompress_1stream(dtable, out, &in1);
1826     total_output += HUF_decompress_1stream(dtable, out, &in2);
1827     total_output += HUF_decompress_1stream(dtable, out, &in3);
1828     total_output += HUF_decompress_1stream(dtable, out, &in4);
1829 
1830     return total_output;
1831 }
1832 
1833 /// Initializes a Huffman table using canonical Huffman codes
1834 /// For more explanation on canonical Huffman codes see
1835 /// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
1836 /// Codes within a level are allocated in symbol order (i.e. smaller symbols get
1837 /// earlier codes)
1838 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1839                             const int num_symbs) {
1840     memset(table, 0, sizeof(HUF_dtable));
1841     if (num_symbs > HUF_MAX_SYMBS) {
1842         ERROR("Too many symbols for Huffman");
1843     }
1844 
1845     u8 max_bits = 0;
1846     u16 rank_count[HUF_MAX_BITS + 1];
1847     memset(rank_count, 0, sizeof(rank_count));
1848 
1849     // Count the number of symbols for each number of bits, and determine the
1850     // depth of the tree
1851     for (int i = 0; i < num_symbs; i++) {
1852         if (bits[i] > HUF_MAX_BITS) {
1853             ERROR("Huffman table depth too large");
1854         }
1855         max_bits = MAX(max_bits, bits[i]);
1856         rank_count[bits[i]]++;
1857     }
1858 
1859     const size_t table_size = 1 << max_bits;
1860     table->max_bits = max_bits;
1861     table->symbols = malloc(table_size);
1862     table->num_bits = malloc(table_size);
1863 
1864     if (!table->symbols || !table->num_bits) {
1865         free(table->symbols);
1866         free(table->num_bits);
1867         BAD_ALLOC();
1868     }
1869 
1870     // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
1871     // order. Symbols with a Weight of zero are removed. Then, starting from
1872     // lowest weight, prefix codes are distributed in order."
1873 
1874     u32 rank_idx[HUF_MAX_BITS + 1];
1875     // Initialize the starting codes for each rank (number of bits)
1876     rank_idx[max_bits] = 0;
1877     for (int i = max_bits; i >= 1; i--) {
1878         rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
1879         // The entire range takes the same number of bits so we can memset it
1880         memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
1881     }
1882 
1883     if (rank_idx[0] != table_size) {
1884         CORRUPTION();
1885     }
1886 
1887     // Allocate codes and fill in the table
1888     for (int i = 0; i < num_symbs; i++) {
1889         if (bits[i] != 0) {
1890             // Allocate a code for this symbol and set its range in the table
1891             const u16 code = rank_idx[bits[i]];
1892             // Since the code doesn't care about the bottom `max_bits - bits[i]`
1893             // bits of state, it gets a range that spans all possible values of
1894             // the lower bits
1895             const u16 len = 1 << (max_bits - bits[i]);
1896             memset(&table->symbols[code], i, len);
1897             rank_idx[bits[i]] += len;
1898         }
1899     }
1900 }
1901 
1902 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1903                                          const u8 *const weights,
1904                                          const int num_symbs) {
1905     // +1 because the last weight is not transmitted in the header
1906     if (num_symbs + 1 > HUF_MAX_SYMBS) {
1907         ERROR("Too many symbols for Huffman");
1908     }
1909 
1910     u8 bits[HUF_MAX_SYMBS];
1911 
1912     u64 weight_sum = 0;
1913     for (int i = 0; i < num_symbs; i++) {
1914         // Weights are in the same range as bit count
1915         if (weights[i] > HUF_MAX_BITS) {
1916             CORRUPTION();
1917         }
1918         weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
1919     }
1920 
1921     // Find the first power of 2 larger than the sum
1922     const int max_bits = highest_set_bit(weight_sum) + 1;
1923     const u64 left_over = ((u64)1 << max_bits) - weight_sum;
1924     // If the left over isn't a power of 2, the weights are invalid
1925     if (left_over & (left_over - 1)) {
1926         CORRUPTION();
1927     }
1928 
1929     // left_over is used to find the last weight as it's not transmitted
1930     // by inverting 2^(weight - 1) we can determine the value of last_weight
1931     const int last_weight = highest_set_bit(left_over) + 1;
1932 
1933     for (int i = 0; i < num_symbs; i++) {
1934         // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
1935         bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
1936     }
1937     bits[num_symbs] =
1938         max_bits + 1 - last_weight; // Last weight is always non-zero
1939 
1940     HUF_init_dtable(table, bits, num_symbs + 1);
1941 }
1942 
1943 static void HUF_free_dtable(HUF_dtable *const dtable) {
1944     free(dtable->symbols);
1945     free(dtable->num_bits);
1946     memset(dtable, 0, sizeof(HUF_dtable));
1947 }
1948 
1949 static void HUF_copy_dtable(HUF_dtable *const dst,
1950                             const HUF_dtable *const src) {
1951     if (src->max_bits == 0) {
1952         memset(dst, 0, sizeof(HUF_dtable));
1953         return;
1954     }
1955 
1956     const size_t size = (size_t)1 << src->max_bits;
1957     dst->max_bits = src->max_bits;
1958 
1959     dst->symbols = malloc(size);
1960     dst->num_bits = malloc(size);
1961     if (!dst->symbols || !dst->num_bits) {
1962         BAD_ALLOC();
1963     }
1964 
1965     memcpy(dst->symbols, src->symbols, size);
1966     memcpy(dst->num_bits, src->num_bits, size);
1967 }
1968 /******* END HUFFMAN PRIMITIVES ***********************************************/
1969 
1970 /******* FSE PRIMITIVES *******************************************************/
1971 /// For more description of FSE see
1972 /// https://github.com/Cyan4973/FiniteStateEntropy/
1973 
1974 /// Allow a symbol to be decoded without updating state
1975 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
1976                                  const u16 state) {
1977     return dtable->symbols[state];
1978 }
1979 
1980 /// Consumes bits from the input and uses the current state to determine the
1981 /// next state
1982 static inline void FSE_update_state(const FSE_dtable *const dtable,
1983                                     u16 *const state, const u8 *const src,
1984                                     i64 *const offset) {
1985     const u8 bits = dtable->num_bits[*state];
1986     const u16 rest = STREAM_read_bits(src, bits, offset);
1987     *state = dtable->new_state_base[*state] + rest;
1988 }
1989 
1990 /// Decodes a single FSE symbol and updates the offset
1991 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
1992                                    u16 *const state, const u8 *const src,
1993                                    i64 *const offset) {
1994     const u8 symb = FSE_peek_symbol(dtable, *state);
1995     FSE_update_state(dtable, state, src, offset);
1996     return symb;
1997 }
1998 
1999 static inline void FSE_init_state(const FSE_dtable *const dtable,
2000                                   u16 *const state, const u8 *const src,
2001                                   i64 *const offset) {
2002     // Read in a full `accuracy_log` bits to initialize the state
2003     const u8 bits = dtable->accuracy_log;
2004     *state = STREAM_read_bits(src, bits, offset);
2005 }
2006 
2007 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2008                                           ostream_t *const out,
2009                                           istream_t *const in) {
2010     const size_t len = IO_istream_len(in);
2011     if (len == 0) {
2012         INP_SIZE();
2013     }
2014     const u8 *const src = IO_get_read_ptr(in, len);
2015 
2016     // "Each bitstream must be read backward, that is starting from the end down
2017     // to the beginning. Therefore it's necessary to know the size of each
2018     // bitstream.
2019     //
2020     // It's also necessary to know exactly which bit is the latest. This is
2021     // detected by a final bit flag : the highest bit of latest byte is a
2022     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
2023     // final-bit-flag itself is not part of the useful bitstream. Hence, the
2024     // last byte contains between 0 and 7 useful bits."
2025     const int padding = 8 - highest_set_bit(src[len - 1]);
2026     i64 offset = len * 8 - padding;
2027 
2028     u16 state1, state2;
2029     // "The first state (State1) encodes the even indexed symbols, and the
2030     // second (State2) encodes the odd indexes. State1 is initialized first, and
2031     // then State2, and they take turns decoding a single symbol and updating
2032     // their state."
2033     FSE_init_state(dtable, &state1, src, &offset);
2034     FSE_init_state(dtable, &state2, src, &offset);
2035 
2036     // Decode until we overflow the stream
2037     // Since we decode in reverse order, overflowing the stream is offset going
2038     // negative
2039     size_t symbols_written = 0;
2040     while (1) {
2041         // "The number of symbols to decode is determined by tracking bitStream
2042         // overflow condition: If updating state after decoding a symbol would
2043         // require more bits than remain in the stream, it is assumed the extra
2044         // bits are 0. Then, the symbols for each of the final states are
2045         // decoded and the process is complete."
2046         IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
2047         symbols_written++;
2048         if (offset < 0) {
2049             // There's still a symbol to decode in state2
2050             IO_write_byte(out, FSE_peek_symbol(dtable, state2));
2051             symbols_written++;
2052             break;
2053         }
2054 
2055         IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
2056         symbols_written++;
2057         if (offset < 0) {
2058             // There's still a symbol to decode in state1
2059             IO_write_byte(out, FSE_peek_symbol(dtable, state1));
2060             symbols_written++;
2061             break;
2062         }
2063     }
2064 
2065     return symbols_written;
2066 }
2067 
2068 static void FSE_init_dtable(FSE_dtable *const dtable,
2069                             const i16 *const norm_freqs, const int num_symbs,
2070                             const int accuracy_log) {
2071     if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
2072         ERROR("FSE accuracy too large");
2073     }
2074     if (num_symbs > FSE_MAX_SYMBS) {
2075         ERROR("Too many symbols for FSE");
2076     }
2077 
2078     dtable->accuracy_log = accuracy_log;
2079 
2080     const size_t size = (size_t)1 << accuracy_log;
2081     dtable->symbols = malloc(size * sizeof(u8));
2082     dtable->num_bits = malloc(size * sizeof(u8));
2083     dtable->new_state_base = malloc(size * sizeof(u16));
2084 
2085     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2086         BAD_ALLOC();
2087     }
2088 
2089     // Used to determine how many bits need to be read for each state,
2090     // and where the destination range should start
2091     // Needs to be u16 because max value is 2 * max number of symbols,
2092     // which can be larger than a byte can store
2093     u16 state_desc[FSE_MAX_SYMBS];
2094 
2095     // "Symbols are scanned in their natural order for "less than 1"
2096     // probabilities. Symbols with this probability are being attributed a
2097     // single cell, starting from the end of the table. These symbols define a
2098     // full state reset, reading Accuracy_Log bits."
2099     int high_threshold = size;
2100     for (int s = 0; s < num_symbs; s++) {
2101         // Scan for low probability symbols to put at the top
2102         if (norm_freqs[s] == -1) {
2103             dtable->symbols[--high_threshold] = s;
2104             state_desc[s] = 1;
2105         }
2106     }
2107 
2108     // "All remaining symbols are sorted in their natural order. Starting from
2109     // symbol 0 and table position 0, each symbol gets attributed as many cells
2110     // as its probability. Cell allocation is spreaded, not linear."
2111     // Place the rest in the table
2112     const u16 step = (size >> 1) + (size >> 3) + 3;
2113     const u16 mask = size - 1;
2114     u16 pos = 0;
2115     for (int s = 0; s < num_symbs; s++) {
2116         if (norm_freqs[s] <= 0) {
2117             continue;
2118         }
2119 
2120         state_desc[s] = norm_freqs[s];
2121 
2122         for (int i = 0; i < norm_freqs[s]; i++) {
2123             // Give `norm_freqs[s]` states to symbol s
2124             dtable->symbols[pos] = s;
2125             // "A position is skipped if already occupied, typically by a "less
2126             // than 1" probability symbol."
2127             do {
2128                 pos = (pos + step) & mask;
2129             } while (pos >=
2130                      high_threshold);
2131             // Note: no other collision checking is necessary as `step` is
2132             // coprime to `size`, so the cycle will visit each position exactly
2133             // once
2134         }
2135     }
2136     if (pos != 0) {
2137         CORRUPTION();
2138     }
2139 
2140     // Now we can fill baseline and num bits
2141     for (size_t i = 0; i < size; i++) {
2142         u8 symbol = dtable->symbols[i];
2143         u16 next_state_desc = state_desc[symbol]++;
2144         // Fills in the table appropriately, next_state_desc increases by symbol
2145         // over time, decreasing number of bits
2146         dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
2147         // Baseline increases until the bit threshold is passed, at which point
2148         // it resets to 0
2149         dtable->new_state_base[i] =
2150             ((u16)next_state_desc << dtable->num_bits[i]) - size;
2151     }
2152 }
2153 
2154 /// Decode an FSE header as defined in the Zstandard format specification and
2155 /// use the decoded frequencies to initialize a decoding table.
2156 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2157                                 const int max_accuracy_log) {
2158     // "An FSE distribution table describes the probabilities of all symbols
2159     // from 0 to the last present one (included) on a normalized scale of 1 <<
2160     // Accuracy_Log .
2161     //
2162     // It's a bitstream which is read forward, in little-endian fashion. It's
2163     // not necessary to know its exact size, since it will be discovered and
2164     // reported by the decoding process.
2165     if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
2166         ERROR("FSE accuracy too large");
2167     }
2168 
2169     // The bitstream starts by reporting on which scale it operates.
2170     // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
2171     // and match lengths is 9, and for offsets is 8. Higher values are
2172     // considered errors."
2173     const int accuracy_log = 5 + IO_read_bits(in, 4);
2174     if (accuracy_log > max_accuracy_log) {
2175         ERROR("FSE accuracy too large");
2176     }
2177 
2178     // "Then follows each symbol value, from 0 to last present one. The number
2179     // of bits used by each field is variable. It depends on :
2180     //
2181     // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
2182     // and presuming 100 probabilities points have already been distributed, the
2183     // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
2184     // Therefore, it must read log2sup(156) == 8 bits.
2185     //
2186     // Value decoded : small values use 1 less bit : example : Presuming values
2187     // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
2188     // in an 8-bits field. They are used this way : first 99 values (hence from
2189     // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
2190 
2191     i32 remaining = 1 << accuracy_log;
2192     i16 frequencies[FSE_MAX_SYMBS];
2193 
2194     int symb = 0;
2195     while (remaining > 0 && symb < FSE_MAX_SYMBS) {
2196         // Log of the number of possible values we could read
2197         int bits = highest_set_bit(remaining + 1) + 1;
2198 
2199         u16 val = IO_read_bits(in, bits);
2200 
2201         // Try to mask out the lower bits to see if it qualifies for the "small
2202         // value" threshold
2203         const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
2204         const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
2205 
2206         if ((val & lower_mask) < threshold) {
2207             IO_rewind_bits(in, 1);
2208             val = val & lower_mask;
2209         } else if (val > lower_mask) {
2210             val = val - threshold;
2211         }
2212 
2213         // "Probability is obtained from Value decoded by following formula :
2214         // Proba = value - 1"
2215         const i16 proba = (i16)val - 1;
2216 
2217         // "It means value 0 becomes negative probability -1. -1 is a special
2218         // probability, which means "less than 1". Its effect on distribution
2219         // table is described in next paragraph. For the purpose of calculating
2220         // cumulated distribution, it counts as one."
2221         remaining -= proba < 0 ? -proba : proba;
2222 
2223         frequencies[symb] = proba;
2224         symb++;
2225 
2226         // "When a symbol has a probability of zero, it is followed by a 2-bits
2227         // repeat flag. This repeat flag tells how many probabilities of zeroes
2228         // follow the current one. It provides a number ranging from 0 to 3. If
2229         // it is a 3, another 2-bits repeat flag follows, and so on."
2230         if (proba == 0) {
2231             // Read the next two bits to see how many more 0s
2232             int repeat = IO_read_bits(in, 2);
2233 
2234             while (1) {
2235                 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
2236                     frequencies[symb++] = 0;
2237                 }
2238                 if (repeat == 3) {
2239                     repeat = IO_read_bits(in, 2);
2240                 } else {
2241                     break;
2242                 }
2243             }
2244         }
2245     }
2246     IO_align_stream(in);
2247 
2248     // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
2249     // is complete. If the last symbol makes cumulated total go above 1 <<
2250     // Accuracy_Log, distribution is considered corrupted."
2251     if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
2252         CORRUPTION();
2253     }
2254 
2255     // Initialize the decoding table using the determined weights
2256     FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
2257 }
2258 
2259 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
2260     dtable->symbols = malloc(sizeof(u8));
2261     dtable->num_bits = malloc(sizeof(u8));
2262     dtable->new_state_base = malloc(sizeof(u16));
2263 
2264     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2265         BAD_ALLOC();
2266     }
2267 
2268     // This setup will always have a state of 0, always return symbol `symb`,
2269     // and never consume any bits
2270     dtable->symbols[0] = symb;
2271     dtable->num_bits[0] = 0;
2272     dtable->new_state_base[0] = 0;
2273     dtable->accuracy_log = 0;
2274 }
2275 
2276 static void FSE_free_dtable(FSE_dtable *const dtable) {
2277     free(dtable->symbols);
2278     free(dtable->num_bits);
2279     free(dtable->new_state_base);
2280     memset(dtable, 0, sizeof(FSE_dtable));
2281 }
2282 
2283 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
2284     if (src->accuracy_log == 0) {
2285         memset(dst, 0, sizeof(FSE_dtable));
2286         return;
2287     }
2288 
2289     size_t size = (size_t)1 << src->accuracy_log;
2290     dst->accuracy_log = src->accuracy_log;
2291 
2292     dst->symbols = malloc(size);
2293     dst->num_bits = malloc(size);
2294     dst->new_state_base = malloc(size * sizeof(u16));
2295     if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
2296         BAD_ALLOC();
2297     }
2298 
2299     memcpy(dst->symbols, src->symbols, size);
2300     memcpy(dst->num_bits, src->num_bits, size);
2301     memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
2302 }
2303 /******* END FSE PRIMITIVES ***************************************************/
2304