]> git.proxmox.com Git - ceph.git/blob - ceph/src/zstd/doc/educational_decoder/zstd_decompress.c
update sources to ceph Nautilus 14.2.1
[ceph.git] / ceph / src / zstd / doc / educational_decoder / zstd_decompress.c
1 /*
2 * Copyright (c) 2017-present, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 */
9
10 /// Zstandard educational decoder implementation
11 /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
12
13 #include <stdint.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include "zstd_decompress.h"
18
19 /******* UTILITY MACROS AND TYPES *********************************************/
20 // Max block size decompressed size is 128 KB and literal blocks can't be
21 // larger than their block
22 #define MAX_LITERALS_SIZE ((size_t)128 * 1024)
23
24 #define MAX(a, b) ((a) > (b) ? (a) : (b))
25 #define MIN(a, b) ((a) < (b) ? (a) : (b))
26
27 /// This decoder calls exit(1) when it encounters an error, however a production
28 /// library should propagate error codes
29 #define ERROR(s) \
30 do { \
31 fprintf(stderr, "Error: %s\n", s); \
32 exit(1); \
33 } while (0)
34 #define INP_SIZE() \
35 ERROR("Input buffer smaller than it should be or input is " \
36 "corrupted")
37 #define OUT_SIZE() ERROR("Output buffer too small for output")
38 #define CORRUPTION() ERROR("Corruption detected while decompressing")
39 #define BAD_ALLOC() ERROR("Memory allocation error")
40 #define IMPOSSIBLE() ERROR("An impossibility has occurred")
41
42 typedef uint8_t u8;
43 typedef uint16_t u16;
44 typedef uint32_t u32;
45 typedef uint64_t u64;
46
47 typedef int8_t i8;
48 typedef int16_t i16;
49 typedef int32_t i32;
50 typedef int64_t i64;
51 /******* END UTILITY MACROS AND TYPES *****************************************/
52
53 /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
54 /// The implementations for these functions can be found at the bottom of this
55 /// file. They implement low-level functionality needed for the higher level
56 /// decompression functions.
57
58 /*** IO STREAM OPERATIONS *************/
59
60 /// ostream_t/istream_t are used to wrap the pointers/length data passed into
61 /// ZSTD_decompress, so that all IO operations are safely bounds checked
62 /// They are written/read forward, and reads are treated as little-endian
63 /// They should be used opaquely to ensure safety
64 typedef struct {
65 u8 *ptr;
66 size_t len;
67 } ostream_t;
68
69 typedef struct {
70 const u8 *ptr;
71 size_t len;
72
73 // Input often reads a few bits at a time, so maintain an internal offset
74 int bit_offset;
75 } istream_t;
76
77 /// The following two functions are the only ones that allow the istream to be
78 /// non-byte aligned
79
80 /// Reads `num` bits from a bitstream, and updates the internal offset
81 static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
82 /// Backs-up the stream by `num` bits so they can be read again
83 static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
84 /// If the remaining bits in a byte will be unused, advance to the end of the
85 /// byte
86 static inline void IO_align_stream(istream_t *const in);
87
88 /// Write the given byte into the output stream
89 static inline void IO_write_byte(ostream_t *const out, u8 symb);
90
91 /// Returns the number of bytes left to be read in this stream. The stream must
92 /// be byte aligned.
93 static inline size_t IO_istream_len(const istream_t *const in);
94
95 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
96 /// was skipped. The stream must be byte aligned.
97 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
98 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
99 /// was skipped so it can be written to.
100 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
101
102 /// Advance the inner state by `len` bytes. The stream must be byte aligned.
103 static inline void IO_advance_input(istream_t *const in, size_t len);
104
105 /// Returns an `ostream_t` constructed from the given pointer and length.
106 static inline ostream_t IO_make_ostream(u8 *out, size_t len);
107 /// Returns an `istream_t` constructed from the given pointer and length.
108 static inline istream_t IO_make_istream(const u8 *in, size_t len);
109
110 /// Returns an `istream_t` with the same base as `in`, and length `len`.
111 /// Then, advance `in` to account for the consumed bytes.
112 /// `in` must be byte aligned.
113 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
114 /*** END IO STREAM OPERATIONS *********/
115
116 /*** BITSTREAM OPERATIONS *************/
117 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
118 /// and return them interpreted as a little-endian unsigned integer.
119 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
120 const size_t offset);
121
122 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
123 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
124 /// `src + offset`. If the offset becomes negative, the extra bits at the
125 /// bottom are filled in with `0` bits instead of reading from before `src`.
126 static inline u64 STREAM_read_bits(const u8 *src, const int bits,
127 i64 *const offset);
128 /*** END BITSTREAM OPERATIONS *********/
129
130 /*** BIT COUNTING OPERATIONS **********/
131 /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
132 static inline int highest_set_bit(const u64 num);
133 /*** END BIT COUNTING OPERATIONS ******/
134
135 /*** HUFFMAN PRIMITIVES ***************/
136 // Table decode method uses exponential memory, so we need to limit depth
137 #define HUF_MAX_BITS (16)
138
139 // Limit the maximum number of symbols to 256 so we can store a symbol in a byte
140 #define HUF_MAX_SYMBS (256)
141
142 /// Structure containing all tables necessary for efficient Huffman decoding
143 typedef struct {
144 u8 *symbols;
145 u8 *num_bits;
146 int max_bits;
147 } HUF_dtable;
148
149 /// Decode a single symbol and read in enough bits to refresh the state
150 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
151 u16 *const state, const u8 *const src,
152 i64 *const offset);
153 /// Read in a full state's worth of bits to initialize it
154 static inline void HUF_init_state(const HUF_dtable *const dtable,
155 u16 *const state, const u8 *const src,
156 i64 *const offset);
157
158 /// Decompresses a single Huffman stream, returns the number of bytes decoded.
159 /// `src_len` must be the exact length of the Huffman-coded block.
160 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
161 ostream_t *const out, istream_t *const in);
162 /// Same as previous but decodes 4 streams, formatted as in the Zstandard
163 /// specification.
164 /// `src_len` must be the exact length of the Huffman-coded block.
165 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
166 ostream_t *const out, istream_t *const in);
167
168 /// Initialize a Huffman decoding table using the table of bit counts provided
169 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
170 const int num_symbs);
171 /// Initialize a Huffman decoding table using the table of weights provided
172 /// Weights follow the definition provided in the Zstandard specification
173 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
174 const u8 *const weights,
175 const int num_symbs);
176
177 /// Free the malloc'ed parts of a decoding table
178 static void HUF_free_dtable(HUF_dtable *const dtable);
179
180 /// Deep copy a decoding table, so that it can be used and free'd without
181 /// impacting the source table.
182 static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src);
183 /*** END HUFFMAN PRIMITIVES ***********/
184
185 /*** FSE PRIMITIVES *******************/
186 /// For more description of FSE see
187 /// https://github.com/Cyan4973/FiniteStateEntropy/
188
189 // FSE table decoding uses exponential memory, so limit the maximum accuracy
190 #define FSE_MAX_ACCURACY_LOG (15)
191 // Limit the maximum number of symbols so they can be stored in a single byte
192 #define FSE_MAX_SYMBS (256)
193
194 /// The tables needed to decode FSE encoded streams
195 typedef struct {
196 u8 *symbols;
197 u8 *num_bits;
198 u16 *new_state_base;
199 int accuracy_log;
200 } FSE_dtable;
201
202 /// Return the symbol for the current state
203 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
204 const u16 state);
205 /// Read the number of bits necessary to update state, update, and shift offset
206 /// back to reflect the bits read
207 static inline void FSE_update_state(const FSE_dtable *const dtable,
208 u16 *const state, const u8 *const src,
209 i64 *const offset);
210
211 /// Combine peek and update: decode a symbol and update the state
212 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
213 u16 *const state, const u8 *const src,
214 i64 *const offset);
215
216 /// Read bits from the stream to initialize the state and shift offset back
217 static inline void FSE_init_state(const FSE_dtable *const dtable,
218 u16 *const state, const u8 *const src,
219 i64 *const offset);
220
221 /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
222 /// using an FSE decoding table. `src_len` must be the exact length of the
223 /// block.
224 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
225 ostream_t *const out,
226 istream_t *const in);
227
228 /// Initialize a decoding table using normalized frequencies.
229 static void FSE_init_dtable(FSE_dtable *const dtable,
230 const i16 *const norm_freqs, const int num_symbs,
231 const int accuracy_log);
232
233 /// Decode an FSE header as defined in the Zstandard format specification and
234 /// use the decoded frequencies to initialize a decoding table.
235 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
236 const int max_accuracy_log);
237
238 /// Initialize an FSE table that will always return the same symbol and consume
239 /// 0 bits per symbol, to be used for RLE mode in sequence commands
240 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
241
242 /// Free the malloc'ed parts of a decoding table
243 static void FSE_free_dtable(FSE_dtable *const dtable);
244
245 /// Deep copy a decoding table, so that it can be used and free'd without
246 /// impacting the source table.
247 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src);
248 /*** END FSE PRIMITIVES ***************/
249
250 /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
251
252 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
253
254 /// A small structure that can be reused in various places that need to access
255 /// frame header information
256 typedef struct {
257 // The size of window that we need to be able to contiguously store for
258 // references
259 size_t window_size;
260 // The total output size of this compressed frame
261 size_t frame_content_size;
262
263 // The dictionary id if this frame uses one
264 u32 dictionary_id;
265
266 // Whether or not the content of this frame has a checksum
267 int content_checksum_flag;
268 // Whether or not the output for this frame is in a single segment
269 int single_segment_flag;
270 } frame_header_t;
271
272 /// The context needed to decode blocks in a frame
273 typedef struct {
274 frame_header_t header;
275
276 // The total amount of data available for backreferences, to determine if an
277 // offset too large to be correct
278 size_t current_total_output;
279
280 const u8 *dict_content;
281 size_t dict_content_len;
282
283 // Entropy encoding tables so they can be repeated by future blocks instead
284 // of retransmitting
285 HUF_dtable literals_dtable;
286 FSE_dtable ll_dtable;
287 FSE_dtable ml_dtable;
288 FSE_dtable of_dtable;
289
290 // The last 3 offsets for the special "repeat offsets".
291 u64 previous_offsets[3];
292 } frame_context_t;
293
294 /// The decoded contents of a dictionary so that it doesn't have to be repeated
295 /// for each frame that uses it
296 struct dictionary_s {
297 // Entropy tables
298 HUF_dtable literals_dtable;
299 FSE_dtable ll_dtable;
300 FSE_dtable ml_dtable;
301 FSE_dtable of_dtable;
302
303 // Raw content for backreferences
304 u8 *content;
305 size_t content_size;
306
307 // Offset history to prepopulate the frame's history
308 u64 previous_offsets[3];
309
310 u32 dictionary_id;
311 };
312
313 /// A tuple containing the parts necessary to decode and execute a ZSTD sequence
314 /// command
315 typedef struct {
316 u32 literal_length;
317 u32 match_length;
318 u32 offset;
319 } sequence_command_t;
320
321 /// The decoder works top-down, starting at the high level like Zstd frames, and
322 /// working down to lower more technical levels such as blocks, literals, and
323 /// sequences. The high-level functions roughly follow the outline of the
324 /// format specification:
325 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
326
327 /// Before the implementation of each high-level function declared here, the
328 /// prototypes for their helper functions are defined and explained
329
330 /// Decode a single Zstd frame, or error if the input is not a valid frame.
331 /// Accepts a dict argument, which may be NULL indicating no dictionary.
332 /// See
333 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
334 static void decode_frame(ostream_t *const out, istream_t *const in,
335 const dictionary_t *const dict);
336
337 // Decode data in a compressed block
338 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
339 istream_t *const in);
340
341 // Decode the literals section of a block
342 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
343 u8 **const literals);
344
345 // Decode the sequences part of a block
346 static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
347 sequence_command_t **const sequences);
348
349 // Execute the decoded sequences on the literals block
350 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
351 const u8 *const literals,
352 const size_t literals_len,
353 const sequence_command_t *const sequences,
354 const size_t num_sequences);
355
356 // Copies literals and returns the total literal length that was copied
357 static u32 copy_literals(const size_t seq, istream_t *litstream,
358 ostream_t *const out);
359
360 // Given an offset code from a sequence command (either an actual offset value
361 // or an index for previous offset), computes the correct offset and udpates
362 // the offset history
363 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
364
365 // Given an offset, match length, and total output, as well as the frame
366 // context for the dictionary, determines if the dictionary is used and
367 // executes the copy operation
368 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
369 size_t match_length, size_t total_output,
370 ostream_t *const out);
371
372 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
373
374 size_t ZSTD_decompress(void *const dst, const size_t dst_len,
375 const void *const src, const size_t src_len) {
376 dictionary_t* uninit_dict = create_dictionary();
377 size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
378 src_len, uninit_dict);
379 free_dictionary(uninit_dict);
380 return decomp_size;
381 }
382
383 size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
384 const void *const src, const size_t src_len,
385 dictionary_t* parsed_dict) {
386
387 istream_t in = IO_make_istream(src, src_len);
388 ostream_t out = IO_make_ostream(dst, dst_len);
389
390 // "A content compressed by Zstandard is transformed into a Zstandard frame.
391 // Multiple frames can be appended into a single file or stream. A frame is
392 // totally independent, has a defined beginning and end, and a set of
393 // parameters which tells the decoder how to decompress it."
394
395 /* this decoder assumes decompression of a single frame */
396 decode_frame(&out, &in, parsed_dict);
397
398 return out.ptr - (u8 *)dst;
399 }
400
401 /******* FRAME DECODING ******************************************************/
402
403 static void decode_data_frame(ostream_t *const out, istream_t *const in,
404 const dictionary_t *const dict);
405 static void init_frame_context(frame_context_t *const context,
406 istream_t *const in,
407 const dictionary_t *const dict);
408 static void free_frame_context(frame_context_t *const context);
409 static void parse_frame_header(frame_header_t *const header,
410 istream_t *const in);
411 static void frame_context_apply_dict(frame_context_t *const ctx,
412 const dictionary_t *const dict);
413
414 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
415 istream_t *const in);
416
417 static void decode_frame(ostream_t *const out, istream_t *const in,
418 const dictionary_t *const dict) {
419 const u32 magic_number = IO_read_bits(in, 32);
420 // Zstandard frame
421 //
422 // "Magic_Number
423 //
424 // 4 Bytes, little-endian format. Value : 0xFD2FB528"
425 if (magic_number == 0xFD2FB528U) {
426 // ZSTD frame
427 decode_data_frame(out, in, dict);
428
429 return;
430 }
431
432 // not a real frame or a skippable frame
433 ERROR("Tried to decode non-ZSTD frame");
434 }
435
436 /// Decode a frame that contains compressed data. Not all frames do as there
437 /// are skippable frames.
438 /// See
439 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
440 static void decode_data_frame(ostream_t *const out, istream_t *const in,
441 const dictionary_t *const dict) {
442 frame_context_t ctx;
443
444 // Initialize the context that needs to be carried from block to block
445 init_frame_context(&ctx, in, dict);
446
447 if (ctx.header.frame_content_size != 0 &&
448 ctx.header.frame_content_size > out->len) {
449 OUT_SIZE();
450 }
451
452 decompress_data(&ctx, out, in);
453
454 free_frame_context(&ctx);
455 }
456
457 /// Takes the information provided in the header and dictionary, and initializes
458 /// the context for this frame
459 static void init_frame_context(frame_context_t *const context,
460 istream_t *const in,
461 const dictionary_t *const dict) {
462 // Most fields in context are correct when initialized to 0
463 memset(context, 0, sizeof(frame_context_t));
464
465 // Parse data from the frame header
466 parse_frame_header(&context->header, in);
467
468 // Set up the offset history for the repeat offset commands
469 context->previous_offsets[0] = 1;
470 context->previous_offsets[1] = 4;
471 context->previous_offsets[2] = 8;
472
473 // Apply details from the dict if it exists
474 frame_context_apply_dict(context, dict);
475 }
476
477 static void free_frame_context(frame_context_t *const context) {
478 HUF_free_dtable(&context->literals_dtable);
479
480 FSE_free_dtable(&context->ll_dtable);
481 FSE_free_dtable(&context->ml_dtable);
482 FSE_free_dtable(&context->of_dtable);
483
484 memset(context, 0, sizeof(frame_context_t));
485 }
486
487 static void parse_frame_header(frame_header_t *const header,
488 istream_t *const in) {
489 // "The first header's byte is called the Frame_Header_Descriptor. It tells
490 // which other fields are present. Decoding this byte is enough to tell the
491 // size of Frame_Header.
492 //
493 // Bit number Field name
494 // 7-6 Frame_Content_Size_flag
495 // 5 Single_Segment_flag
496 // 4 Unused_bit
497 // 3 Reserved_bit
498 // 2 Content_Checksum_flag
499 // 1-0 Dictionary_ID_flag"
500 const u8 descriptor = IO_read_bits(in, 8);
501
502 // decode frame header descriptor into flags
503 const u8 frame_content_size_flag = descriptor >> 6;
504 const u8 single_segment_flag = (descriptor >> 5) & 1;
505 const u8 reserved_bit = (descriptor >> 3) & 1;
506 const u8 content_checksum_flag = (descriptor >> 2) & 1;
507 const u8 dictionary_id_flag = descriptor & 3;
508
509 if (reserved_bit != 0) {
510 CORRUPTION();
511 }
512
513 header->single_segment_flag = single_segment_flag;
514 header->content_checksum_flag = content_checksum_flag;
515
516 // decode window size
517 if (!single_segment_flag) {
518 // "Provides guarantees on maximum back-reference distance that will be
519 // used within compressed data. This information is important for
520 // decoders to allocate enough memory.
521 //
522 // Bit numbers 7-3 2-0
523 // Field name Exponent Mantissa"
524 u8 window_descriptor = IO_read_bits(in, 8);
525 u8 exponent = window_descriptor >> 3;
526 u8 mantissa = window_descriptor & 7;
527
528 // Use the algorithm from the specification to compute window size
529 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
530 size_t window_base = (size_t)1 << (10 + exponent);
531 size_t window_add = (window_base / 8) * mantissa;
532 header->window_size = window_base + window_add;
533 }
534
535 // decode dictionary id if it exists
536 if (dictionary_id_flag) {
537 // "This is a variable size field, which contains the ID of the
538 // dictionary required to properly decode the frame. Note that this
539 // field is optional. When it's not present, it's up to the caller to
540 // make sure it uses the correct dictionary. Format is little-endian."
541 const int bytes_array[] = {0, 1, 2, 4};
542 const int bytes = bytes_array[dictionary_id_flag];
543
544 header->dictionary_id = IO_read_bits(in, bytes * 8);
545 } else {
546 header->dictionary_id = 0;
547 }
548
549 // decode frame content size if it exists
550 if (single_segment_flag || frame_content_size_flag) {
551 // "This is the original (uncompressed) size. This information is
552 // optional. The Field_Size is provided according to value of
553 // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
554 // present), 1, 2, 4 or 8 bytes. Format is little-endian."
555 //
556 // if frame_content_size_flag == 0 but single_segment_flag is set, we
557 // still have a 1 byte field
558 const int bytes_array[] = {1, 2, 4, 8};
559 const int bytes = bytes_array[frame_content_size_flag];
560
561 header->frame_content_size = IO_read_bits(in, bytes * 8);
562 if (bytes == 2) {
563 // "When Field_Size is 2, the offset of 256 is added."
564 header->frame_content_size += 256;
565 }
566 } else {
567 header->frame_content_size = 0;
568 }
569
570 if (single_segment_flag) {
571 // "The Window_Descriptor byte is optional. It is absent when
572 // Single_Segment_flag is set. In this case, the maximum back-reference
573 // distance is the content size itself, which can be any value from 1 to
574 // 2^64-1 bytes (16 EB)."
575 header->window_size = header->frame_content_size;
576 }
577 }
578
579 /// A dictionary acts as initializing values for the frame context before
580 /// decompression, so we implement it by applying it's predetermined
581 /// tables and content to the context before beginning decompression
582 static void frame_context_apply_dict(frame_context_t *const ctx,
583 const dictionary_t *const dict) {
584 // If the content pointer is NULL then it must be an empty dict
585 if (!dict || !dict->content)
586 return;
587
588 // If the requested dictionary_id is non-zero, the correct dictionary must
589 // be present
590 if (ctx->header.dictionary_id != 0 &&
591 ctx->header.dictionary_id != dict->dictionary_id) {
592 ERROR("Wrong dictionary provided");
593 }
594
595 // Copy the dict content to the context for references during sequence
596 // execution
597 ctx->dict_content = dict->content;
598 ctx->dict_content_len = dict->content_size;
599
600 // If it's a formatted dict copy the precomputed tables in so they can
601 // be used in the table repeat modes
602 if (dict->dictionary_id != 0) {
603 // Deep copy the entropy tables so they can be freed independently of
604 // the dictionary struct
605 HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
606 FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
607 FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
608 FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
609
610 // Copy the repeated offsets
611 memcpy(ctx->previous_offsets, dict->previous_offsets,
612 sizeof(ctx->previous_offsets));
613 }
614 }
615
616 /// Decompress the data from a frame block by block
617 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
618 istream_t *const in) {
619 // "A frame encapsulates one or multiple blocks. Each block can be
620 // compressed or not, and has a guaranteed maximum content size, which
621 // depends on frame parameters. Unlike frames, each block depends on
622 // previous blocks for proper decoding. However, each block can be
623 // decompressed without waiting for its successor, allowing streaming
624 // operations."
625 int last_block = 0;
626 do {
627 // "Last_Block
628 //
629 // The lowest bit signals if this block is the last one. Frame ends
630 // right after this block.
631 //
632 // Block_Type and Block_Size
633 //
634 // The next 2 bits represent the Block_Type, while the remaining 21 bits
635 // represent the Block_Size. Format is little-endian."
636 last_block = IO_read_bits(in, 1);
637 const int block_type = IO_read_bits(in, 2);
638 const size_t block_len = IO_read_bits(in, 21);
639
640 switch (block_type) {
641 case 0: {
642 // "Raw_Block - this is an uncompressed block. Block_Size is the
643 // number of bytes to read and copy."
644 const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
645 u8 *const write_ptr = IO_get_write_ptr(out, block_len);
646
647 // Copy the raw data into the output
648 memcpy(write_ptr, read_ptr, block_len);
649
650 ctx->current_total_output += block_len;
651 break;
652 }
653 case 1: {
654 // "RLE_Block - this is a single byte, repeated N times. In which
655 // case, Block_Size is the size to regenerate, while the
656 // "compressed" block is just 1 byte (the byte to repeat)."
657 const u8 *const read_ptr = IO_get_read_ptr(in, 1);
658 u8 *const write_ptr = IO_get_write_ptr(out, block_len);
659
660 // Copy `block_len` copies of `read_ptr[0]` to the output
661 memset(write_ptr, read_ptr[0], block_len);
662
663 ctx->current_total_output += block_len;
664 break;
665 }
666 case 2: {
667 // "Compressed_Block - this is a Zstandard compressed block,
668 // detailed in another section of this specification. Block_Size is
669 // the compressed size.
670
671 // Create a sub-stream for the block
672 istream_t block_stream = IO_make_sub_istream(in, block_len);
673 decompress_block(ctx, out, &block_stream);
674 break;
675 }
676 case 3:
677 // "Reserved - this is not a block. This value cannot be used with
678 // current version of this specification."
679 CORRUPTION();
680 break;
681 default:
682 IMPOSSIBLE();
683 }
684 } while (!last_block);
685
686 if (ctx->header.content_checksum_flag) {
687 // This program does not support checking the checksum, so skip over it
688 // if it's present
689 IO_advance_input(in, 4);
690 }
691 }
692 /******* END FRAME DECODING ***************************************************/
693
694 /******* BLOCK DECOMPRESSION **************************************************/
695 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
696 istream_t *const in) {
697 // "A compressed block consists of 2 sections :
698 //
699 // Literals_Section
700 // Sequences_Section"
701
702
703 // Part 1: decode the literals block
704 u8 *literals = NULL;
705 const size_t literals_size = decode_literals(ctx, in, &literals);
706
707 // Part 2: decode the sequences block
708 sequence_command_t *sequences = NULL;
709 const size_t num_sequences =
710 decode_sequences(ctx, in, &sequences);
711
712 // Part 3: combine literals and sequence commands to generate output
713 execute_sequences(ctx, out, literals, literals_size, sequences,
714 num_sequences);
715 free(literals);
716 free(sequences);
717 }
718 /******* END BLOCK DECOMPRESSION **********************************************/
719
720 /******* LITERALS DECODING ****************************************************/
721 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
722 const int block_type,
723 const int size_format);
724 static size_t decode_literals_compressed(frame_context_t *const ctx,
725 istream_t *const in,
726 u8 **const literals,
727 const int block_type,
728 const int size_format);
729 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
730 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
731 int *const num_symbs);
732
733 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
734 u8 **const literals) {
735 // "Literals can be stored uncompressed or compressed using Huffman prefix
736 // codes. When compressed, an optional tree description can be present,
737 // followed by 1 or 4 streams."
738 //
739 // "Literals_Section_Header
740 //
741 // Header is in charge of describing how literals are packed. It's a
742 // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
743 // little-endian convention."
744 //
745 // "Literals_Block_Type
746 //
747 // This field uses 2 lowest bits of first byte, describing 4 different block
748 // types"
749 //
750 // size_format takes between 1 and 2 bits
751 int block_type = IO_read_bits(in, 2);
752 int size_format = IO_read_bits(in, 2);
753
754 if (block_type <= 1) {
755 // Raw or RLE literals block
756 return decode_literals_simple(in, literals, block_type,
757 size_format);
758 } else {
759 // Huffman compressed literals
760 return decode_literals_compressed(ctx, in, literals, block_type,
761 size_format);
762 }
763 }
764
765 /// Decodes literals blocks in raw or RLE form
766 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
767 const int block_type,
768 const int size_format) {
769 size_t size;
770 switch (size_format) {
771 // These cases are in the form ?0
772 // In this case, the ? bit is actually part of the size field
773 case 0:
774 case 2:
775 // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
776 IO_rewind_bits(in, 1);
777 size = IO_read_bits(in, 5);
778 break;
779 case 1:
780 // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
781 size = IO_read_bits(in, 12);
782 break;
783 case 3:
784 // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
785 size = IO_read_bits(in, 20);
786 break;
787 default:
788 // Size format is in range 0-3
789 IMPOSSIBLE();
790 }
791
792 if (size > MAX_LITERALS_SIZE) {
793 CORRUPTION();
794 }
795
796 *literals = malloc(size);
797 if (!*literals) {
798 BAD_ALLOC();
799 }
800
801 switch (block_type) {
802 case 0: {
803 // "Raw_Literals_Block - Literals are stored uncompressed."
804 const u8 *const read_ptr = IO_get_read_ptr(in, size);
805 memcpy(*literals, read_ptr, size);
806 break;
807 }
808 case 1: {
809 // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
810 const u8 *const read_ptr = IO_get_read_ptr(in, 1);
811 memset(*literals, read_ptr[0], size);
812 break;
813 }
814 default:
815 IMPOSSIBLE();
816 }
817
818 return size;
819 }
820
821 /// Decodes Huffman compressed literals
822 static size_t decode_literals_compressed(frame_context_t *const ctx,
823 istream_t *const in,
824 u8 **const literals,
825 const int block_type,
826 const int size_format) {
827 size_t regenerated_size, compressed_size;
828 // Only size_format=0 has 1 stream, so default to 4
829 int num_streams = 4;
830 switch (size_format) {
831 case 0:
832 // "A single stream. Both Compressed_Size and Regenerated_Size use 10
833 // bits (0-1023)."
834 num_streams = 1;
835 // Fall through as it has the same size format
836 case 1:
837 // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
838 // (0-1023)."
839 regenerated_size = IO_read_bits(in, 10);
840 compressed_size = IO_read_bits(in, 10);
841 break;
842 case 2:
843 // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
844 // (0-16383)."
845 regenerated_size = IO_read_bits(in, 14);
846 compressed_size = IO_read_bits(in, 14);
847 break;
848 case 3:
849 // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
850 // (0-262143)."
851 regenerated_size = IO_read_bits(in, 18);
852 compressed_size = IO_read_bits(in, 18);
853 break;
854 default:
855 // Impossible
856 IMPOSSIBLE();
857 }
858 if (regenerated_size > MAX_LITERALS_SIZE ||
859 compressed_size >= regenerated_size) {
860 CORRUPTION();
861 }
862
863 *literals = malloc(regenerated_size);
864 if (!*literals) {
865 BAD_ALLOC();
866 }
867
868 ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
869 istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
870
871 if (block_type == 2) {
872 // Decode the provided Huffman table
873 // "This section is only present when Literals_Block_Type type is
874 // Compressed_Literals_Block (2)."
875
876 HUF_free_dtable(&ctx->literals_dtable);
877 decode_huf_table(&ctx->literals_dtable, &huf_stream);
878 } else {
879 // If the previous Huffman table is being repeated, ensure it exists
880 if (!ctx->literals_dtable.symbols) {
881 CORRUPTION();
882 }
883 }
884
885 size_t symbols_decoded;
886 if (num_streams == 1) {
887 symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
888 } else {
889 symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
890 }
891
892 if (symbols_decoded != regenerated_size) {
893 CORRUPTION();
894 }
895
896 return regenerated_size;
897 }
898
899 // Decode the Huffman table description
900 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
901 // "All literal values from zero (included) to last present one (excluded)
902 // are represented by Weight with values from 0 to Max_Number_of_Bits."
903
904 // "This is a single byte value (0-255), which describes how to decode the list of weights."
905 const u8 header = IO_read_bits(in, 8);
906
907 u8 weights[HUF_MAX_SYMBS];
908 memset(weights, 0, sizeof(weights));
909
910 int num_symbs;
911
912 if (header >= 128) {
913 // "This is a direct representation, where each Weight is written
914 // directly as a 4 bits field (0-15). The full representation occupies
915 // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
916 // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
917 // 127"
918 num_symbs = header - 127;
919 const size_t bytes = (num_symbs + 1) / 2;
920
921 const u8 *const weight_src = IO_get_read_ptr(in, bytes);
922
923 for (int i = 0; i < num_symbs; i++) {
924 // "They are encoded forward, 2
925 // weights to a byte with the first weight taking the top four bits
926 // and the second taking the bottom four (e.g. the following
927 // operations could be used to read the weights: Weight[0] =
928 // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
929 if (i % 2 == 0) {
930 weights[i] = weight_src[i / 2] >> 4;
931 } else {
932 weights[i] = weight_src[i / 2] & 0xf;
933 }
934 }
935 } else {
936 // The weights are FSE encoded, decode them before we can construct the
937 // table
938 istream_t fse_stream = IO_make_sub_istream(in, header);
939 ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
940 fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
941 }
942
943 // Construct the table using the decoded weights
944 HUF_init_dtable_usingweights(dtable, weights, num_symbs);
945 }
946
947 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
948 int *const num_symbs) {
949 const int MAX_ACCURACY_LOG = 7;
950
951 FSE_dtable dtable;
952
953 // "An FSE bitstream starts by a header, describing probabilities
954 // distribution. It will create a Decoding Table. For a list of Huffman
955 // weights, maximum accuracy is 7 bits."
956 FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
957
958 // Decode the weights
959 *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
960
961 FSE_free_dtable(&dtable);
962 }
963 /******* END LITERALS DECODING ************************************************/
964
965 /******* SEQUENCE DECODING ****************************************************/
966 /// The combination of FSE states needed to decode sequences
967 typedef struct {
968 FSE_dtable ll_table;
969 FSE_dtable of_table;
970 FSE_dtable ml_table;
971
972 u16 ll_state;
973 u16 of_state;
974 u16 ml_state;
975 } sequence_states_t;
976
977 /// Different modes to signal to decode_seq_tables what to do
978 typedef enum {
979 seq_literal_length = 0,
980 seq_offset = 1,
981 seq_match_length = 2,
982 } seq_part_t;
983
984 typedef enum {
985 seq_predefined = 0,
986 seq_rle = 1,
987 seq_fse = 2,
988 seq_repeat = 3,
989 } seq_mode_t;
990
991 /// The predefined FSE distribution tables for `seq_predefined` mode
992 static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
993 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2,
994 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
995 static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
996 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
997 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
998 static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
999 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1000 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1001 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
1002
1003 /// The sequence decoding baseline and number of additional bits to read/add
1004 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
1005 static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
1006 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
1007 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40,
1008 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65538};
1009 static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
1010 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1011 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
1012
1013 static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
1014 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1015 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
1016 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83,
1017 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
1018 static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
1019 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1020 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
1021 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
1022
1023 /// Offset decoding is simpler so we just need a maximum code value
1024 static const u8 SEQ_MAX_CODES[3] = {35, -1, 52};
1025
1026 static void decompress_sequences(frame_context_t *const ctx,
1027 istream_t *const in,
1028 sequence_command_t *const sequences,
1029 const size_t num_sequences);
1030 static sequence_command_t decode_sequence(sequence_states_t *const state,
1031 const u8 *const src,
1032 i64 *const offset);
1033 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1034 const seq_part_t type, const seq_mode_t mode);
1035
1036 static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
1037 sequence_command_t **const sequences) {
1038 // "A compressed block is a succession of sequences . A sequence is a
1039 // literal copy command, followed by a match copy command. A literal copy
1040 // command specifies a length. It is the number of bytes to be copied (or
1041 // extracted) from the literal section. A match copy command specifies an
1042 // offset and a length. The offset gives the position to copy from, which
1043 // can be within a previous block."
1044
1045 size_t num_sequences;
1046
1047 // "Number_of_Sequences
1048 //
1049 // This is a variable size field using between 1 and 3 bytes. Let's call its
1050 // first byte byte0."
1051 u8 header = IO_read_bits(in, 8);
1052 if (header == 0) {
1053 // "There are no sequences. The sequence section stops there.
1054 // Regenerated content is defined entirely by literals section."
1055 *sequences = NULL;
1056 return 0;
1057 } else if (header < 128) {
1058 // "Number_of_Sequences = byte0 . Uses 1 byte."
1059 num_sequences = header;
1060 } else if (header < 255) {
1061 // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
1062 num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
1063 } else {
1064 // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
1065 num_sequences = IO_read_bits(in, 16) + 0x7F00;
1066 }
1067
1068 *sequences = malloc(num_sequences * sizeof(sequence_command_t));
1069 if (!*sequences) {
1070 BAD_ALLOC();
1071 }
1072
1073 decompress_sequences(ctx, in, *sequences, num_sequences);
1074 return num_sequences;
1075 }
1076
1077 /// Decompress the FSE encoded sequence commands
1078 static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
1079 sequence_command_t *const sequences,
1080 const size_t num_sequences) {
1081 // "The Sequences_Section regroup all symbols required to decode commands.
1082 // There are 3 symbol types : literals lengths, offsets and match lengths.
1083 // They are encoded together, interleaved, in a single bitstream."
1084
1085 // "Symbol compression modes
1086 //
1087 // This is a single byte, defining the compression mode of each symbol
1088 // type."
1089 //
1090 // Bit number : Field name
1091 // 7-6 : Literals_Lengths_Mode
1092 // 5-4 : Offsets_Mode
1093 // 3-2 : Match_Lengths_Mode
1094 // 1-0 : Reserved
1095 u8 compression_modes = IO_read_bits(in, 8);
1096
1097 if ((compression_modes & 3) != 0) {
1098 // Reserved bits set
1099 CORRUPTION();
1100 }
1101
1102 // "Following the header, up to 3 distribution tables can be described. When
1103 // present, they are in this order :
1104 //
1105 // Literals lengths
1106 // Offsets
1107 // Match Lengths"
1108 // Update the tables we have stored in the context
1109 decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
1110 (compression_modes >> 6) & 3);
1111
1112 decode_seq_table(&ctx->of_dtable, in, seq_offset,
1113 (compression_modes >> 4) & 3);
1114
1115 decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
1116 (compression_modes >> 2) & 3);
1117
1118
1119 sequence_states_t states;
1120
1121 // Initialize the decoding tables
1122 {
1123 states.ll_table = ctx->ll_dtable;
1124 states.of_table = ctx->of_dtable;
1125 states.ml_table = ctx->ml_dtable;
1126 }
1127
1128 const size_t len = IO_istream_len(in);
1129 const u8 *const src = IO_get_read_ptr(in, len);
1130
1131 // "After writing the last bit containing information, the compressor writes
1132 // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
1133 const int padding = 8 - highest_set_bit(src[len - 1]);
1134 // The offset starts at the end because FSE streams are read backwards
1135 i64 bit_offset = len * 8 - padding;
1136
1137 // "The bitstream starts with initial state values, each using the required
1138 // number of bits in their respective accuracy, decoded previously from
1139 // their normalized distribution.
1140 //
1141 // It starts by Literals_Length_State, followed by Offset_State, and finally
1142 // Match_Length_State."
1143 FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
1144 FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
1145 FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
1146
1147 for (size_t i = 0; i < num_sequences; i++) {
1148 // Decode sequences one by one
1149 sequences[i] = decode_sequence(&states, src, &bit_offset);
1150 }
1151
1152 if (bit_offset != 0) {
1153 CORRUPTION();
1154 }
1155 }
1156
1157 // Decode a single sequence and update the state
1158 static sequence_command_t decode_sequence(sequence_states_t *const states,
1159 const u8 *const src,
1160 i64 *const offset) {
1161 // "Each symbol is a code in its own context, which specifies Baseline and
1162 // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
1163 // additional bits in the same bitstream."
1164
1165 // Decode symbols, but don't update states
1166 const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
1167 const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
1168 const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
1169
1170 // Offset doesn't need a max value as it's not decoded using a table
1171 if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
1172 ml_code > SEQ_MAX_CODES[seq_match_length]) {
1173 CORRUPTION();
1174 }
1175
1176 // Read the interleaved bits
1177 sequence_command_t seq;
1178 // "Decoding starts by reading the Number_of_Bits required to decode Offset.
1179 // It then does the same for Match_Length, and then for Literals_Length."
1180 seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
1181
1182 seq.match_length =
1183 SEQ_MATCH_LENGTH_BASELINES[ml_code] +
1184 STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
1185
1186 seq.literal_length =
1187 SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
1188 STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
1189
1190 // "If it is not the last sequence in the block, the next operation is to
1191 // update states. Using the rules pre-calculated in the decoding tables,
1192 // Literals_Length_State is updated, followed by Match_Length_State, and
1193 // then Offset_State."
1194 // If the stream is complete don't read bits to update state
1195 if (*offset != 0) {
1196 FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
1197 FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
1198 FSE_update_state(&states->of_table, &states->of_state, src, offset);
1199 }
1200
1201 return seq;
1202 }
1203
1204 /// Given a sequence part and table mode, decode the FSE distribution
1205 /// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
1206 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1207 const seq_part_t type, const seq_mode_t mode) {
1208 // Constant arrays indexed by seq_part_t
1209 const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
1210 SEQ_OFFSET_DEFAULT_DIST,
1211 SEQ_MATCH_LENGTH_DEFAULT_DIST};
1212 const size_t default_distribution_lengths[] = {36, 29, 53};
1213 const size_t default_distribution_accuracies[] = {6, 5, 6};
1214
1215 const size_t max_accuracies[] = {9, 8, 9};
1216
1217 if (mode != seq_repeat) {
1218 // Free old one before overwriting
1219 FSE_free_dtable(table);
1220 }
1221
1222 switch (mode) {
1223 case seq_predefined: {
1224 // "Predefined_Mode : uses a predefined distribution table."
1225 const i16 *distribution = default_distributions[type];
1226 const size_t symbs = default_distribution_lengths[type];
1227 const size_t accuracy_log = default_distribution_accuracies[type];
1228
1229 FSE_init_dtable(table, distribution, symbs, accuracy_log);
1230 break;
1231 }
1232 case seq_rle: {
1233 // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
1234 const u8 symb = IO_get_read_ptr(in, 1)[0];
1235 FSE_init_dtable_rle(table, symb);
1236 break;
1237 }
1238 case seq_fse: {
1239 // "FSE_Compressed_Mode : standard FSE compression. A distribution table
1240 // will be present "
1241 FSE_decode_header(table, in, max_accuracies[type]);
1242 break;
1243 }
1244 case seq_repeat:
1245 // "Repeat_Mode : re-use distribution table from previous compressed
1246 // block."
1247 // Nothing to do here, table will be unchanged
1248 if (!table->symbols) {
1249 // This mode is invalid if we don't already have a table
1250 CORRUPTION();
1251 }
1252 break;
1253 default:
1254 // Impossible, as mode is from 0-3
1255 IMPOSSIBLE();
1256 break;
1257 }
1258
1259 }
1260 /******* END SEQUENCE DECODING ************************************************/
1261
1262 /******* SEQUENCE EXECUTION ***************************************************/
1263 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
1264 const u8 *const literals,
1265 const size_t literals_len,
1266 const sequence_command_t *const sequences,
1267 const size_t num_sequences) {
1268 istream_t litstream = IO_make_istream(literals, literals_len);
1269
1270 u64 *const offset_hist = ctx->previous_offsets;
1271 size_t total_output = ctx->current_total_output;
1272
1273 for (size_t i = 0; i < num_sequences; i++) {
1274 const sequence_command_t seq = sequences[i];
1275 {
1276 const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
1277 total_output += literals_size;
1278 }
1279
1280 size_t const offset = compute_offset(seq, offset_hist);
1281
1282 size_t const match_length = seq.match_length;
1283
1284 execute_match_copy(ctx, offset, match_length, total_output, out);
1285
1286 total_output += match_length;
1287 }
1288
1289 // Copy any leftover literals
1290 {
1291 size_t len = IO_istream_len(&litstream);
1292 copy_literals(len, &litstream, out);
1293 total_output += len;
1294 }
1295
1296 ctx->current_total_output = total_output;
1297 }
1298
1299 static u32 copy_literals(const size_t literal_length, istream_t *litstream,
1300 ostream_t *const out) {
1301 // If the sequence asks for more literals than are left, the
1302 // sequence must be corrupted
1303 if (literal_length > IO_istream_len(litstream)) {
1304 CORRUPTION();
1305 }
1306
1307 u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
1308 const u8 *const read_ptr =
1309 IO_get_read_ptr(litstream, literal_length);
1310 // Copy literals to output
1311 memcpy(write_ptr, read_ptr, literal_length);
1312
1313 return literal_length;
1314 }
1315
1316 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
1317 size_t offset;
1318 // Offsets are special, we need to handle the repeat offsets
1319 if (seq.offset <= 3) {
1320 // "The first 3 values define a repeated offset and we will call
1321 // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
1322 // They are sorted in recency order, with Repeated_Offset1 meaning
1323 // 'most recent one'".
1324
1325 // Use 0 indexing for the array
1326 u32 idx = seq.offset - 1;
1327 if (seq.literal_length == 0) {
1328 // "There is an exception though, when current sequence's
1329 // literals length is 0. In this case, repeated offsets are
1330 // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
1331 // Repeated_Offset2 becomes Repeated_Offset3, and
1332 // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
1333 idx++;
1334 }
1335
1336 if (idx == 0) {
1337 offset = offset_hist[0];
1338 } else {
1339 // If idx == 3 then literal length was 0 and the offset was 3,
1340 // as per the exception listed above
1341 offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
1342
1343 // If idx == 1 we don't need to modify offset_hist[2], since
1344 // we're using the second-most recent code
1345 if (idx > 1) {
1346 offset_hist[2] = offset_hist[1];
1347 }
1348 offset_hist[1] = offset_hist[0];
1349 offset_hist[0] = offset;
1350 }
1351 } else {
1352 // When it's not a repeat offset:
1353 // "if (Offset_Value > 3) offset = Offset_Value - 3;"
1354 offset = seq.offset - 3;
1355
1356 // Shift back history
1357 offset_hist[2] = offset_hist[1];
1358 offset_hist[1] = offset_hist[0];
1359 offset_hist[0] = offset;
1360 }
1361 return offset;
1362 }
1363
1364 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
1365 size_t match_length, size_t total_output,
1366 ostream_t *const out) {
1367 u8 *write_ptr = IO_get_write_ptr(out, match_length);
1368 if (total_output <= ctx->header.window_size) {
1369 // In this case offset might go back into the dictionary
1370 if (offset > total_output + ctx->dict_content_len) {
1371 // The offset goes beyond even the dictionary
1372 CORRUPTION();
1373 }
1374
1375 if (offset > total_output) {
1376 // "The rest of the dictionary is its content. The content act
1377 // as a "past" in front of data to compress or decompress, so it
1378 // can be referenced in sequence commands."
1379 const size_t dict_copy =
1380 MIN(offset - total_output, match_length);
1381 const size_t dict_offset =
1382 ctx->dict_content_len - (offset - total_output);
1383
1384 memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
1385 write_ptr += dict_copy;
1386 match_length -= dict_copy;
1387 }
1388 } else if (offset > ctx->header.window_size) {
1389 CORRUPTION();
1390 }
1391
1392 // We must copy byte by byte because the match length might be larger
1393 // than the offset
1394 // ex: if the output so far was "abc", a command with offset=3 and
1395 // match_length=6 would produce "abcabcabc" as the new output
1396 for (size_t j = 0; j < match_length; j++) {
1397 *write_ptr = *(write_ptr - offset);
1398 write_ptr++;
1399 }
1400 }
1401 /******* END SEQUENCE EXECUTION ***********************************************/
1402
1403 /******* OUTPUT SIZE COUNTING *************************************************/
1404 /// Get the decompressed size of an input stream so memory can be allocated in
1405 /// advance.
1406 /// This implementation assumes `src` points to a single ZSTD-compressed frame
1407 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
1408 istream_t in = IO_make_istream(src, src_len);
1409
1410 // get decompressed size from ZSTD frame header
1411 {
1412 const u32 magic_number = IO_read_bits(&in, 32);
1413
1414 if (magic_number == 0xFD2FB528U) {
1415 // ZSTD frame
1416 frame_header_t header;
1417 parse_frame_header(&header, &in);
1418
1419 if (header.frame_content_size == 0 && !header.single_segment_flag) {
1420 // Content size not provided, we can't tell
1421 return -1;
1422 }
1423
1424 return header.frame_content_size;
1425 } else {
1426 // not a real frame or skippable frame
1427 ERROR("ZSTD frame magic number did not match");
1428 }
1429 }
1430 }
1431 /******* END OUTPUT SIZE COUNTING *********************************************/
1432
1433 /******* DICTIONARY PARSING ***************************************************/
1434 #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
1435 #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
1436
1437 dictionary_t* create_dictionary() {
1438 dictionary_t* dict = calloc(1, sizeof(dictionary_t));
1439 if (!dict) {
1440 BAD_ALLOC();
1441 }
1442 return dict;
1443 }
1444
1445 static void init_dictionary_content(dictionary_t *const dict,
1446 istream_t *const in);
1447
1448 void parse_dictionary(dictionary_t *const dict, const void *src,
1449 size_t src_len) {
1450 const u8 *byte_src = (const u8 *)src;
1451 memset(dict, 0, sizeof(dictionary_t));
1452 if (src == NULL) { /* cannot initialize dictionary with null src */
1453 NULL_SRC();
1454 }
1455 if (src_len < 8) {
1456 DICT_SIZE_ERROR();
1457 }
1458
1459 istream_t in = IO_make_istream(byte_src, src_len);
1460
1461 const u32 magic_number = IO_read_bits(&in, 32);
1462 if (magic_number != 0xEC30A437) {
1463 // raw content dict
1464 IO_rewind_bits(&in, 32);
1465 init_dictionary_content(dict, &in);
1466 return;
1467 }
1468
1469 dict->dictionary_id = IO_read_bits(&in, 32);
1470
1471 // "Entropy_Tables : following the same format as the tables in compressed
1472 // blocks. They are stored in following order : Huffman tables for literals,
1473 // FSE table for offsets, FSE table for match lengths, and FSE table for
1474 // literals lengths. It's finally followed by 3 offset values, populating
1475 // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
1476 // little-endian each, for a total of 12 bytes. Each recent offset must have
1477 // a value < dictionary size."
1478 decode_huf_table(&dict->literals_dtable, &in);
1479 decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
1480 decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
1481 decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
1482
1483 // Read in the previous offset history
1484 dict->previous_offsets[0] = IO_read_bits(&in, 32);
1485 dict->previous_offsets[1] = IO_read_bits(&in, 32);
1486 dict->previous_offsets[2] = IO_read_bits(&in, 32);
1487
1488 // Ensure the provided offsets aren't too large
1489 // "Each recent offset must have a value < dictionary size."
1490 for (int i = 0; i < 3; i++) {
1491 if (dict->previous_offsets[i] > src_len) {
1492 ERROR("Dictionary corrupted");
1493 }
1494 }
1495
1496 // "Content : The rest of the dictionary is its content. The content act as
1497 // a "past" in front of data to compress or decompress, so it can be
1498 // referenced in sequence commands."
1499 init_dictionary_content(dict, &in);
1500 }
1501
1502 static void init_dictionary_content(dictionary_t *const dict,
1503 istream_t *const in) {
1504 // Copy in the content
1505 dict->content_size = IO_istream_len(in);
1506 dict->content = malloc(dict->content_size);
1507 if (!dict->content) {
1508 BAD_ALLOC();
1509 }
1510
1511 const u8 *const content = IO_get_read_ptr(in, dict->content_size);
1512
1513 memcpy(dict->content, content, dict->content_size);
1514 }
1515
1516 /// Free an allocated dictionary
1517 void free_dictionary(dictionary_t *const dict) {
1518 HUF_free_dtable(&dict->literals_dtable);
1519 FSE_free_dtable(&dict->ll_dtable);
1520 FSE_free_dtable(&dict->of_dtable);
1521 FSE_free_dtable(&dict->ml_dtable);
1522
1523 free(dict->content);
1524
1525 memset(dict, 0, sizeof(dictionary_t));
1526
1527 free(dict);
1528 }
1529 /******* END DICTIONARY PARSING ***********************************************/
1530
1531 /******* IO STREAM OPERATIONS *************************************************/
1532 #define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream")
1533 /// Reads `num` bits from a bitstream, and updates the internal offset
1534 static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
1535 if (num_bits > 64 || num_bits <= 0) {
1536 ERROR("Attempt to read an invalid number of bits");
1537 }
1538
1539 const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
1540 const size_t full_bytes = (num_bits + in->bit_offset) / 8;
1541 if (bytes > in->len) {
1542 INP_SIZE();
1543 }
1544
1545 const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
1546
1547 in->bit_offset = (num_bits + in->bit_offset) % 8;
1548 in->ptr += full_bytes;
1549 in->len -= full_bytes;
1550
1551 return result;
1552 }
1553
1554 /// If a non-zero number of bits have been read from the current byte, advance
1555 /// the offset to the next byte
1556 static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
1557 if (num_bits < 0) {
1558 ERROR("Attempting to rewind stream by a negative number of bits");
1559 }
1560
1561 // move the offset back by `num_bits` bits
1562 const int new_offset = in->bit_offset - num_bits;
1563 // determine the number of whole bytes we have to rewind, rounding up to an
1564 // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
1565 const i64 bytes = -(new_offset - 7) / 8;
1566
1567 in->ptr -= bytes;
1568 in->len += bytes;
1569 // make sure the resulting `bit_offset` is positive, as mod in C does not
1570 // convert numbers from negative to positive (e.g. -22 % 8 == -6)
1571 in->bit_offset = ((new_offset % 8) + 8) % 8;
1572 }
1573
1574 /// If the remaining bits in a byte will be unused, advance to the end of the
1575 /// byte
1576 static inline void IO_align_stream(istream_t *const in) {
1577 if (in->bit_offset != 0) {
1578 if (in->len == 0) {
1579 INP_SIZE();
1580 }
1581 in->ptr++;
1582 in->len--;
1583 in->bit_offset = 0;
1584 }
1585 }
1586
1587 /// Write the given byte into the output stream
1588 static inline void IO_write_byte(ostream_t *const out, u8 symb) {
1589 if (out->len == 0) {
1590 OUT_SIZE();
1591 }
1592
1593 out->ptr[0] = symb;
1594 out->ptr++;
1595 out->len--;
1596 }
1597
1598 /// Returns the number of bytes left to be read in this stream. The stream must
1599 /// be byte aligned.
1600 static inline size_t IO_istream_len(const istream_t *const in) {
1601 return in->len;
1602 }
1603
1604 /// Returns a pointer where `len` bytes can be read, and advances the internal
1605 /// state. The stream must be byte aligned.
1606 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
1607 if (len > in->len) {
1608 INP_SIZE();
1609 }
1610 if (in->bit_offset != 0) {
1611 UNALIGNED();
1612 }
1613 const u8 *const ptr = in->ptr;
1614 in->ptr += len;
1615 in->len -= len;
1616
1617 return ptr;
1618 }
1619 /// Returns a pointer to write `len` bytes to, and advances the internal state
1620 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
1621 if (len > out->len) {
1622 OUT_SIZE();
1623 }
1624 u8 *const ptr = out->ptr;
1625 out->ptr += len;
1626 out->len -= len;
1627
1628 return ptr;
1629 }
1630
1631 /// Advance the inner state by `len` bytes
1632 static inline void IO_advance_input(istream_t *const in, size_t len) {
1633 if (len > in->len) {
1634 INP_SIZE();
1635 }
1636 if (in->bit_offset != 0) {
1637 UNALIGNED();
1638 }
1639
1640 in->ptr += len;
1641 in->len -= len;
1642 }
1643
1644 /// Returns an `ostream_t` constructed from the given pointer and length
1645 static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
1646 return (ostream_t) { out, len };
1647 }
1648
1649 /// Returns an `istream_t` constructed from the given pointer and length
1650 static inline istream_t IO_make_istream(const u8 *in, size_t len) {
1651 return (istream_t) { in, len, 0 };
1652 }
1653
1654 /// Returns an `istream_t` with the same base as `in`, and length `len`
1655 /// Then, advance `in` to account for the consumed bytes
1656 /// `in` must be byte aligned
1657 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
1658 // Consume `len` bytes of the parent stream
1659 const u8 *const ptr = IO_get_read_ptr(in, len);
1660
1661 // Make a substream using the pointer to those `len` bytes
1662 return IO_make_istream(ptr, len);
1663 }
1664 /******* END IO STREAM OPERATIONS *********************************************/
1665
1666 /******* BITSTREAM OPERATIONS *************************************************/
1667 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
1668 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1669 const size_t offset) {
1670 if (num_bits > 64) {
1671 ERROR("Attempt to read an invalid number of bits");
1672 }
1673
1674 // Skip over bytes that aren't in range
1675 src += offset / 8;
1676 size_t bit_offset = offset % 8;
1677 u64 res = 0;
1678
1679 int shift = 0;
1680 int left = num_bits;
1681 while (left > 0) {
1682 u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
1683 // Read the next byte, shift it to account for the offset, and then mask
1684 // out the top part if we don't need all the bits
1685 res += (((u64)*src++ >> bit_offset) & mask) << shift;
1686 shift += 8 - bit_offset;
1687 left -= 8 - bit_offset;
1688 bit_offset = 0;
1689 }
1690
1691 return res;
1692 }
1693
1694 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
1695 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1696 /// `src + offset`. If the offset becomes negative, the extra bits at the
1697 /// bottom are filled in with `0` bits instead of reading from before `src`.
1698 static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
1699 i64 *const offset) {
1700 *offset = *offset - bits;
1701 size_t actual_off = *offset;
1702 size_t actual_bits = bits;
1703 // Don't actually read bits from before the start of src, so if `*offset <
1704 // 0` fix actual_off and actual_bits to reflect the quantity to read
1705 if (*offset < 0) {
1706 actual_bits += *offset;
1707 actual_off = 0;
1708 }
1709 u64 res = read_bits_LE(src, actual_bits, actual_off);
1710
1711 if (*offset < 0) {
1712 // Fill in the bottom "overflowed" bits with 0's
1713 res = -*offset >= 64 ? 0 : (res << -*offset);
1714 }
1715 return res;
1716 }
1717 /******* END BITSTREAM OPERATIONS *********************************************/
1718
1719 /******* BIT COUNTING OPERATIONS **********************************************/
1720 /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
1721 /// `num`, or `-1` if `num == 0`.
1722 static inline int highest_set_bit(const u64 num) {
1723 for (int i = 63; i >= 0; i--) {
1724 if (((u64)1 << i) <= num) {
1725 return i;
1726 }
1727 }
1728 return -1;
1729 }
1730 /******* END BIT COUNTING OPERATIONS ******************************************/
1731
1732 /******* HUFFMAN PRIMITIVES ***************************************************/
1733 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1734 u16 *const state, const u8 *const src,
1735 i64 *const offset) {
1736 // Look up the symbol and number of bits to read
1737 const u8 symb = dtable->symbols[*state];
1738 const u8 bits = dtable->num_bits[*state];
1739 const u16 rest = STREAM_read_bits(src, bits, offset);
1740 // Shift `bits` bits out of the state, keeping the low order bits that
1741 // weren't necessary to determine this symbol. Then add in the new bits
1742 // read from the stream.
1743 *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
1744
1745 return symb;
1746 }
1747
1748 static inline void HUF_init_state(const HUF_dtable *const dtable,
1749 u16 *const state, const u8 *const src,
1750 i64 *const offset) {
1751 // Read in a full `dtable->max_bits` bits to initialize the state
1752 const u8 bits = dtable->max_bits;
1753 *state = STREAM_read_bits(src, bits, offset);
1754 }
1755
1756 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1757 ostream_t *const out,
1758 istream_t *const in) {
1759 const size_t len = IO_istream_len(in);
1760 if (len == 0) {
1761 INP_SIZE();
1762 }
1763 const u8 *const src = IO_get_read_ptr(in, len);
1764
1765 // "Each bitstream must be read backward, that is starting from the end down
1766 // to the beginning. Therefore it's necessary to know the size of each
1767 // bitstream.
1768 //
1769 // It's also necessary to know exactly which bit is the latest. This is
1770 // detected by a final bit flag : the highest bit of latest byte is a
1771 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
1772 // final-bit-flag itself is not part of the useful bitstream. Hence, the
1773 // last byte contains between 0 and 7 useful bits."
1774 const int padding = 8 - highest_set_bit(src[len - 1]);
1775
1776 // Offset starts at the end because HUF streams are read backwards
1777 i64 bit_offset = len * 8 - padding;
1778 u16 state;
1779
1780 HUF_init_state(dtable, &state, src, &bit_offset);
1781
1782 size_t symbols_written = 0;
1783 while (bit_offset > -dtable->max_bits) {
1784 // Iterate over the stream, decoding one symbol at a time
1785 IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
1786 symbols_written++;
1787 }
1788 // "The process continues up to reading the required number of symbols per
1789 // stream. If a bitstream is not entirely and exactly consumed, hence
1790 // reaching exactly its beginning position with all bits consumed, the
1791 // decoding process is considered faulty."
1792
1793 // When all symbols have been decoded, the final state value shouldn't have
1794 // any data from the stream, so it should have "read" dtable->max_bits from
1795 // before the start of `src`
1796 // Therefore `offset`, the edge to start reading new bits at, should be
1797 // dtable->max_bits before the start of the stream
1798 if (bit_offset != -dtable->max_bits) {
1799 CORRUPTION();
1800 }
1801
1802 return symbols_written;
1803 }
1804
1805 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1806 ostream_t *const out, istream_t *const in) {
1807 // "Compressed size is provided explicitly : in the 4-streams variant,
1808 // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
1809 // value represents the compressed size of one stream, in order. The last
1810 // stream size is deducted from total compressed size and from previously
1811 // decoded stream sizes"
1812 const size_t csize1 = IO_read_bits(in, 16);
1813 const size_t csize2 = IO_read_bits(in, 16);
1814 const size_t csize3 = IO_read_bits(in, 16);
1815
1816 istream_t in1 = IO_make_sub_istream(in, csize1);
1817 istream_t in2 = IO_make_sub_istream(in, csize2);
1818 istream_t in3 = IO_make_sub_istream(in, csize3);
1819 istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
1820
1821 size_t total_output = 0;
1822 // Decode each stream independently for simplicity
1823 // If we wanted to we could decode all 4 at the same time for speed,
1824 // utilizing more execution units
1825 total_output += HUF_decompress_1stream(dtable, out, &in1);
1826 total_output += HUF_decompress_1stream(dtable, out, &in2);
1827 total_output += HUF_decompress_1stream(dtable, out, &in3);
1828 total_output += HUF_decompress_1stream(dtable, out, &in4);
1829
1830 return total_output;
1831 }
1832
1833 /// Initializes a Huffman table using canonical Huffman codes
1834 /// For more explanation on canonical Huffman codes see
1835 /// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
1836 /// Codes within a level are allocated in symbol order (i.e. smaller symbols get
1837 /// earlier codes)
1838 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1839 const int num_symbs) {
1840 memset(table, 0, sizeof(HUF_dtable));
1841 if (num_symbs > HUF_MAX_SYMBS) {
1842 ERROR("Too many symbols for Huffman");
1843 }
1844
1845 u8 max_bits = 0;
1846 u16 rank_count[HUF_MAX_BITS + 1];
1847 memset(rank_count, 0, sizeof(rank_count));
1848
1849 // Count the number of symbols for each number of bits, and determine the
1850 // depth of the tree
1851 for (int i = 0; i < num_symbs; i++) {
1852 if (bits[i] > HUF_MAX_BITS) {
1853 ERROR("Huffman table depth too large");
1854 }
1855 max_bits = MAX(max_bits, bits[i]);
1856 rank_count[bits[i]]++;
1857 }
1858
1859 const size_t table_size = 1 << max_bits;
1860 table->max_bits = max_bits;
1861 table->symbols = malloc(table_size);
1862 table->num_bits = malloc(table_size);
1863
1864 if (!table->symbols || !table->num_bits) {
1865 free(table->symbols);
1866 free(table->num_bits);
1867 BAD_ALLOC();
1868 }
1869
1870 // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
1871 // order. Symbols with a Weight of zero are removed. Then, starting from
1872 // lowest weight, prefix codes are distributed in order."
1873
1874 u32 rank_idx[HUF_MAX_BITS + 1];
1875 // Initialize the starting codes for each rank (number of bits)
1876 rank_idx[max_bits] = 0;
1877 for (int i = max_bits; i >= 1; i--) {
1878 rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
1879 // The entire range takes the same number of bits so we can memset it
1880 memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
1881 }
1882
1883 if (rank_idx[0] != table_size) {
1884 CORRUPTION();
1885 }
1886
1887 // Allocate codes and fill in the table
1888 for (int i = 0; i < num_symbs; i++) {
1889 if (bits[i] != 0) {
1890 // Allocate a code for this symbol and set its range in the table
1891 const u16 code = rank_idx[bits[i]];
1892 // Since the code doesn't care about the bottom `max_bits - bits[i]`
1893 // bits of state, it gets a range that spans all possible values of
1894 // the lower bits
1895 const u16 len = 1 << (max_bits - bits[i]);
1896 memset(&table->symbols[code], i, len);
1897 rank_idx[bits[i]] += len;
1898 }
1899 }
1900 }
1901
1902 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1903 const u8 *const weights,
1904 const int num_symbs) {
1905 // +1 because the last weight is not transmitted in the header
1906 if (num_symbs + 1 > HUF_MAX_SYMBS) {
1907 ERROR("Too many symbols for Huffman");
1908 }
1909
1910 u8 bits[HUF_MAX_SYMBS];
1911
1912 u64 weight_sum = 0;
1913 for (int i = 0; i < num_symbs; i++) {
1914 // Weights are in the same range as bit count
1915 if (weights[i] > HUF_MAX_BITS) {
1916 CORRUPTION();
1917 }
1918 weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
1919 }
1920
1921 // Find the first power of 2 larger than the sum
1922 const int max_bits = highest_set_bit(weight_sum) + 1;
1923 const u64 left_over = ((u64)1 << max_bits) - weight_sum;
1924 // If the left over isn't a power of 2, the weights are invalid
1925 if (left_over & (left_over - 1)) {
1926 CORRUPTION();
1927 }
1928
1929 // left_over is used to find the last weight as it's not transmitted
1930 // by inverting 2^(weight - 1) we can determine the value of last_weight
1931 const int last_weight = highest_set_bit(left_over) + 1;
1932
1933 for (int i = 0; i < num_symbs; i++) {
1934 // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
1935 bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
1936 }
1937 bits[num_symbs] =
1938 max_bits + 1 - last_weight; // Last weight is always non-zero
1939
1940 HUF_init_dtable(table, bits, num_symbs + 1);
1941 }
1942
1943 static void HUF_free_dtable(HUF_dtable *const dtable) {
1944 free(dtable->symbols);
1945 free(dtable->num_bits);
1946 memset(dtable, 0, sizeof(HUF_dtable));
1947 }
1948
1949 static void HUF_copy_dtable(HUF_dtable *const dst,
1950 const HUF_dtable *const src) {
1951 if (src->max_bits == 0) {
1952 memset(dst, 0, sizeof(HUF_dtable));
1953 return;
1954 }
1955
1956 const size_t size = (size_t)1 << src->max_bits;
1957 dst->max_bits = src->max_bits;
1958
1959 dst->symbols = malloc(size);
1960 dst->num_bits = malloc(size);
1961 if (!dst->symbols || !dst->num_bits) {
1962 BAD_ALLOC();
1963 }
1964
1965 memcpy(dst->symbols, src->symbols, size);
1966 memcpy(dst->num_bits, src->num_bits, size);
1967 }
1968 /******* END HUFFMAN PRIMITIVES ***********************************************/
1969
1970 /******* FSE PRIMITIVES *******************************************************/
1971 /// For more description of FSE see
1972 /// https://github.com/Cyan4973/FiniteStateEntropy/
1973
1974 /// Allow a symbol to be decoded without updating state
1975 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
1976 const u16 state) {
1977 return dtable->symbols[state];
1978 }
1979
1980 /// Consumes bits from the input and uses the current state to determine the
1981 /// next state
1982 static inline void FSE_update_state(const FSE_dtable *const dtable,
1983 u16 *const state, const u8 *const src,
1984 i64 *const offset) {
1985 const u8 bits = dtable->num_bits[*state];
1986 const u16 rest = STREAM_read_bits(src, bits, offset);
1987 *state = dtable->new_state_base[*state] + rest;
1988 }
1989
1990 /// Decodes a single FSE symbol and updates the offset
1991 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
1992 u16 *const state, const u8 *const src,
1993 i64 *const offset) {
1994 const u8 symb = FSE_peek_symbol(dtable, *state);
1995 FSE_update_state(dtable, state, src, offset);
1996 return symb;
1997 }
1998
1999 static inline void FSE_init_state(const FSE_dtable *const dtable,
2000 u16 *const state, const u8 *const src,
2001 i64 *const offset) {
2002 // Read in a full `accuracy_log` bits to initialize the state
2003 const u8 bits = dtable->accuracy_log;
2004 *state = STREAM_read_bits(src, bits, offset);
2005 }
2006
2007 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2008 ostream_t *const out,
2009 istream_t *const in) {
2010 const size_t len = IO_istream_len(in);
2011 if (len == 0) {
2012 INP_SIZE();
2013 }
2014 const u8 *const src = IO_get_read_ptr(in, len);
2015
2016 // "Each bitstream must be read backward, that is starting from the end down
2017 // to the beginning. Therefore it's necessary to know the size of each
2018 // bitstream.
2019 //
2020 // It's also necessary to know exactly which bit is the latest. This is
2021 // detected by a final bit flag : the highest bit of latest byte is a
2022 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
2023 // final-bit-flag itself is not part of the useful bitstream. Hence, the
2024 // last byte contains between 0 and 7 useful bits."
2025 const int padding = 8 - highest_set_bit(src[len - 1]);
2026 i64 offset = len * 8 - padding;
2027
2028 u16 state1, state2;
2029 // "The first state (State1) encodes the even indexed symbols, and the
2030 // second (State2) encodes the odd indexes. State1 is initialized first, and
2031 // then State2, and they take turns decoding a single symbol and updating
2032 // their state."
2033 FSE_init_state(dtable, &state1, src, &offset);
2034 FSE_init_state(dtable, &state2, src, &offset);
2035
2036 // Decode until we overflow the stream
2037 // Since we decode in reverse order, overflowing the stream is offset going
2038 // negative
2039 size_t symbols_written = 0;
2040 while (1) {
2041 // "The number of symbols to decode is determined by tracking bitStream
2042 // overflow condition: If updating state after decoding a symbol would
2043 // require more bits than remain in the stream, it is assumed the extra
2044 // bits are 0. Then, the symbols for each of the final states are
2045 // decoded and the process is complete."
2046 IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
2047 symbols_written++;
2048 if (offset < 0) {
2049 // There's still a symbol to decode in state2
2050 IO_write_byte(out, FSE_peek_symbol(dtable, state2));
2051 symbols_written++;
2052 break;
2053 }
2054
2055 IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
2056 symbols_written++;
2057 if (offset < 0) {
2058 // There's still a symbol to decode in state1
2059 IO_write_byte(out, FSE_peek_symbol(dtable, state1));
2060 symbols_written++;
2061 break;
2062 }
2063 }
2064
2065 return symbols_written;
2066 }
2067
2068 static void FSE_init_dtable(FSE_dtable *const dtable,
2069 const i16 *const norm_freqs, const int num_symbs,
2070 const int accuracy_log) {
2071 if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
2072 ERROR("FSE accuracy too large");
2073 }
2074 if (num_symbs > FSE_MAX_SYMBS) {
2075 ERROR("Too many symbols for FSE");
2076 }
2077
2078 dtable->accuracy_log = accuracy_log;
2079
2080 const size_t size = (size_t)1 << accuracy_log;
2081 dtable->symbols = malloc(size * sizeof(u8));
2082 dtable->num_bits = malloc(size * sizeof(u8));
2083 dtable->new_state_base = malloc(size * sizeof(u16));
2084
2085 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2086 BAD_ALLOC();
2087 }
2088
2089 // Used to determine how many bits need to be read for each state,
2090 // and where the destination range should start
2091 // Needs to be u16 because max value is 2 * max number of symbols,
2092 // which can be larger than a byte can store
2093 u16 state_desc[FSE_MAX_SYMBS];
2094
2095 // "Symbols are scanned in their natural order for "less than 1"
2096 // probabilities. Symbols with this probability are being attributed a
2097 // single cell, starting from the end of the table. These symbols define a
2098 // full state reset, reading Accuracy_Log bits."
2099 int high_threshold = size;
2100 for (int s = 0; s < num_symbs; s++) {
2101 // Scan for low probability symbols to put at the top
2102 if (norm_freqs[s] == -1) {
2103 dtable->symbols[--high_threshold] = s;
2104 state_desc[s] = 1;
2105 }
2106 }
2107
2108 // "All remaining symbols are sorted in their natural order. Starting from
2109 // symbol 0 and table position 0, each symbol gets attributed as many cells
2110 // as its probability. Cell allocation is spreaded, not linear."
2111 // Place the rest in the table
2112 const u16 step = (size >> 1) + (size >> 3) + 3;
2113 const u16 mask = size - 1;
2114 u16 pos = 0;
2115 for (int s = 0; s < num_symbs; s++) {
2116 if (norm_freqs[s] <= 0) {
2117 continue;
2118 }
2119
2120 state_desc[s] = norm_freqs[s];
2121
2122 for (int i = 0; i < norm_freqs[s]; i++) {
2123 // Give `norm_freqs[s]` states to symbol s
2124 dtable->symbols[pos] = s;
2125 // "A position is skipped if already occupied, typically by a "less
2126 // than 1" probability symbol."
2127 do {
2128 pos = (pos + step) & mask;
2129 } while (pos >=
2130 high_threshold);
2131 // Note: no other collision checking is necessary as `step` is
2132 // coprime to `size`, so the cycle will visit each position exactly
2133 // once
2134 }
2135 }
2136 if (pos != 0) {
2137 CORRUPTION();
2138 }
2139
2140 // Now we can fill baseline and num bits
2141 for (size_t i = 0; i < size; i++) {
2142 u8 symbol = dtable->symbols[i];
2143 u16 next_state_desc = state_desc[symbol]++;
2144 // Fills in the table appropriately, next_state_desc increases by symbol
2145 // over time, decreasing number of bits
2146 dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
2147 // Baseline increases until the bit threshold is passed, at which point
2148 // it resets to 0
2149 dtable->new_state_base[i] =
2150 ((u16)next_state_desc << dtable->num_bits[i]) - size;
2151 }
2152 }
2153
2154 /// Decode an FSE header as defined in the Zstandard format specification and
2155 /// use the decoded frequencies to initialize a decoding table.
2156 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2157 const int max_accuracy_log) {
2158 // "An FSE distribution table describes the probabilities of all symbols
2159 // from 0 to the last present one (included) on a normalized scale of 1 <<
2160 // Accuracy_Log .
2161 //
2162 // It's a bitstream which is read forward, in little-endian fashion. It's
2163 // not necessary to know its exact size, since it will be discovered and
2164 // reported by the decoding process.
2165 if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
2166 ERROR("FSE accuracy too large");
2167 }
2168
2169 // The bitstream starts by reporting on which scale it operates.
2170 // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
2171 // and match lengths is 9, and for offsets is 8. Higher values are
2172 // considered errors."
2173 const int accuracy_log = 5 + IO_read_bits(in, 4);
2174 if (accuracy_log > max_accuracy_log) {
2175 ERROR("FSE accuracy too large");
2176 }
2177
2178 // "Then follows each symbol value, from 0 to last present one. The number
2179 // of bits used by each field is variable. It depends on :
2180 //
2181 // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
2182 // and presuming 100 probabilities points have already been distributed, the
2183 // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
2184 // Therefore, it must read log2sup(156) == 8 bits.
2185 //
2186 // Value decoded : small values use 1 less bit : example : Presuming values
2187 // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
2188 // in an 8-bits field. They are used this way : first 99 values (hence from
2189 // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
2190
2191 i32 remaining = 1 << accuracy_log;
2192 i16 frequencies[FSE_MAX_SYMBS];
2193
2194 int symb = 0;
2195 while (remaining > 0 && symb < FSE_MAX_SYMBS) {
2196 // Log of the number of possible values we could read
2197 int bits = highest_set_bit(remaining + 1) + 1;
2198
2199 u16 val = IO_read_bits(in, bits);
2200
2201 // Try to mask out the lower bits to see if it qualifies for the "small
2202 // value" threshold
2203 const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
2204 const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
2205
2206 if ((val & lower_mask) < threshold) {
2207 IO_rewind_bits(in, 1);
2208 val = val & lower_mask;
2209 } else if (val > lower_mask) {
2210 val = val - threshold;
2211 }
2212
2213 // "Probability is obtained from Value decoded by following formula :
2214 // Proba = value - 1"
2215 const i16 proba = (i16)val - 1;
2216
2217 // "It means value 0 becomes negative probability -1. -1 is a special
2218 // probability, which means "less than 1". Its effect on distribution
2219 // table is described in next paragraph. For the purpose of calculating
2220 // cumulated distribution, it counts as one."
2221 remaining -= proba < 0 ? -proba : proba;
2222
2223 frequencies[symb] = proba;
2224 symb++;
2225
2226 // "When a symbol has a probability of zero, it is followed by a 2-bits
2227 // repeat flag. This repeat flag tells how many probabilities of zeroes
2228 // follow the current one. It provides a number ranging from 0 to 3. If
2229 // it is a 3, another 2-bits repeat flag follows, and so on."
2230 if (proba == 0) {
2231 // Read the next two bits to see how many more 0s
2232 int repeat = IO_read_bits(in, 2);
2233
2234 while (1) {
2235 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
2236 frequencies[symb++] = 0;
2237 }
2238 if (repeat == 3) {
2239 repeat = IO_read_bits(in, 2);
2240 } else {
2241 break;
2242 }
2243 }
2244 }
2245 }
2246 IO_align_stream(in);
2247
2248 // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
2249 // is complete. If the last symbol makes cumulated total go above 1 <<
2250 // Accuracy_Log, distribution is considered corrupted."
2251 if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
2252 CORRUPTION();
2253 }
2254
2255 // Initialize the decoding table using the determined weights
2256 FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
2257 }
2258
2259 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
2260 dtable->symbols = malloc(sizeof(u8));
2261 dtable->num_bits = malloc(sizeof(u8));
2262 dtable->new_state_base = malloc(sizeof(u16));
2263
2264 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2265 BAD_ALLOC();
2266 }
2267
2268 // This setup will always have a state of 0, always return symbol `symb`,
2269 // and never consume any bits
2270 dtable->symbols[0] = symb;
2271 dtable->num_bits[0] = 0;
2272 dtable->new_state_base[0] = 0;
2273 dtable->accuracy_log = 0;
2274 }
2275
2276 static void FSE_free_dtable(FSE_dtable *const dtable) {
2277 free(dtable->symbols);
2278 free(dtable->num_bits);
2279 free(dtable->new_state_base);
2280 memset(dtable, 0, sizeof(FSE_dtable));
2281 }
2282
2283 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
2284 if (src->accuracy_log == 0) {
2285 memset(dst, 0, sizeof(FSE_dtable));
2286 return;
2287 }
2288
2289 size_t size = (size_t)1 << src->accuracy_log;
2290 dst->accuracy_log = src->accuracy_log;
2291
2292 dst->symbols = malloc(size);
2293 dst->num_bits = malloc(size);
2294 dst->new_state_base = malloc(size * sizeof(u16));
2295 if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
2296 BAD_ALLOC();
2297 }
2298
2299 memcpy(dst->symbols, src->symbols, size);
2300 memcpy(dst->num_bits, src->num_bits, size);
2301 memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
2302 }
2303 /******* END FSE PRIMITIVES ***************************************************/