xref: /freebsd/sys/contrib/zstd/doc/educational_decoder/harness.c (revision 0c16b53773565120a8f80a31a0af2ef56ccd368e)
1*0c16b537SWarner Losh /*
2*0c16b537SWarner Losh  * Copyright (c) 2017-present, Facebook, Inc.
3*0c16b537SWarner Losh  * All rights reserved.
4*0c16b537SWarner Losh  *
5*0c16b537SWarner Losh  * This source code is licensed under both the BSD-style license (found in the
6*0c16b537SWarner Losh  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7*0c16b537SWarner Losh  * in the COPYING file in the root directory of this source tree).
8*0c16b537SWarner Losh  */
9*0c16b537SWarner Losh 
10*0c16b537SWarner Losh #include <stdio.h>
11*0c16b537SWarner Losh #include <stdlib.h>
12*0c16b537SWarner Losh 
13*0c16b537SWarner Losh #include "zstd_decompress.h"
14*0c16b537SWarner Losh 
15*0c16b537SWarner Losh typedef unsigned char u8;
16*0c16b537SWarner Losh 
17*0c16b537SWarner Losh // If the data doesn't have decompressed size with it, fallback on assuming the
18*0c16b537SWarner Losh // compression ratio is at most 16
19*0c16b537SWarner Losh #define MAX_COMPRESSION_RATIO (16)
20*0c16b537SWarner Losh 
21*0c16b537SWarner Losh // Protect against allocating too much memory for output
22*0c16b537SWarner Losh #define MAX_OUTPUT_SIZE ((size_t)1024 * 1024 * 1024)
23*0c16b537SWarner Losh 
24*0c16b537SWarner Losh u8 *input;
25*0c16b537SWarner Losh u8 *output;
26*0c16b537SWarner Losh u8 *dict;
27*0c16b537SWarner Losh 
28*0c16b537SWarner Losh size_t read_file(const char *path, u8 **ptr) {
29*0c16b537SWarner Losh     FILE *f = fopen(path, "rb");
30*0c16b537SWarner Losh     if (!f) {
31*0c16b537SWarner Losh         fprintf(stderr, "failed to open file %s\n", path);
32*0c16b537SWarner Losh         exit(1);
33*0c16b537SWarner Losh     }
34*0c16b537SWarner Losh 
35*0c16b537SWarner Losh     fseek(f, 0L, SEEK_END);
36*0c16b537SWarner Losh     size_t size = ftell(f);
37*0c16b537SWarner Losh     rewind(f);
38*0c16b537SWarner Losh 
39*0c16b537SWarner Losh     *ptr = malloc(size);
40*0c16b537SWarner Losh     if (!ptr) {
41*0c16b537SWarner Losh         fprintf(stderr, "failed to allocate memory to hold %s\n", path);
42*0c16b537SWarner Losh         exit(1);
43*0c16b537SWarner Losh     }
44*0c16b537SWarner Losh 
45*0c16b537SWarner Losh     size_t pos = 0;
46*0c16b537SWarner Losh     while (!feof(f)) {
47*0c16b537SWarner Losh         size_t read = fread(&(*ptr)[pos], 1, size, f);
48*0c16b537SWarner Losh         if (ferror(f)) {
49*0c16b537SWarner Losh             fprintf(stderr, "error while reading file %s\n", path);
50*0c16b537SWarner Losh             exit(1);
51*0c16b537SWarner Losh         }
52*0c16b537SWarner Losh         pos += read;
53*0c16b537SWarner Losh     }
54*0c16b537SWarner Losh 
55*0c16b537SWarner Losh     fclose(f);
56*0c16b537SWarner Losh 
57*0c16b537SWarner Losh     return pos;
58*0c16b537SWarner Losh }
59*0c16b537SWarner Losh 
60*0c16b537SWarner Losh void write_file(const char *path, const u8 *ptr, size_t size) {
61*0c16b537SWarner Losh     FILE *f = fopen(path, "wb");
62*0c16b537SWarner Losh 
63*0c16b537SWarner Losh     size_t written = 0;
64*0c16b537SWarner Losh     while (written < size) {
65*0c16b537SWarner Losh         written += fwrite(&ptr[written], 1, size, f);
66*0c16b537SWarner Losh         if (ferror(f)) {
67*0c16b537SWarner Losh             fprintf(stderr, "error while writing file %s\n", path);
68*0c16b537SWarner Losh             exit(1);
69*0c16b537SWarner Losh         }
70*0c16b537SWarner Losh     }
71*0c16b537SWarner Losh 
72*0c16b537SWarner Losh     fclose(f);
73*0c16b537SWarner Losh }
74*0c16b537SWarner Losh 
75*0c16b537SWarner Losh int main(int argc, char **argv) {
76*0c16b537SWarner Losh     if (argc < 3) {
77*0c16b537SWarner Losh         fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n",
78*0c16b537SWarner Losh                 argv[0]);
79*0c16b537SWarner Losh 
80*0c16b537SWarner Losh         return 1;
81*0c16b537SWarner Losh     }
82*0c16b537SWarner Losh 
83*0c16b537SWarner Losh     size_t input_size = read_file(argv[1], &input);
84*0c16b537SWarner Losh     size_t dict_size = 0;
85*0c16b537SWarner Losh     if (argc >= 4) {
86*0c16b537SWarner Losh         dict_size = read_file(argv[3], &dict);
87*0c16b537SWarner Losh     }
88*0c16b537SWarner Losh 
89*0c16b537SWarner Losh     size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size);
90*0c16b537SWarner Losh     if (decompressed_size == (size_t)-1) {
91*0c16b537SWarner Losh         decompressed_size = MAX_COMPRESSION_RATIO * input_size;
92*0c16b537SWarner Losh         fprintf(stderr, "WARNING: Compressed data does not contain "
93*0c16b537SWarner Losh                         "decompressed size, going to assume the compression "
94*0c16b537SWarner Losh                         "ratio is at most %d (decompressed size of at most "
95*0c16b537SWarner Losh                         "%zu)\n",
96*0c16b537SWarner Losh                 MAX_COMPRESSION_RATIO, decompressed_size);
97*0c16b537SWarner Losh     }
98*0c16b537SWarner Losh     if (decompressed_size > MAX_OUTPUT_SIZE) {
99*0c16b537SWarner Losh         fprintf(stderr,
100*0c16b537SWarner Losh                 "Required output size too large for this implementation\n");
101*0c16b537SWarner Losh         return 1;
102*0c16b537SWarner Losh     }
103*0c16b537SWarner Losh     output = malloc(decompressed_size);
104*0c16b537SWarner Losh     if (!output) {
105*0c16b537SWarner Losh         fprintf(stderr, "failed to allocate memory\n");
106*0c16b537SWarner Losh         return 1;
107*0c16b537SWarner Losh     }
108*0c16b537SWarner Losh 
109*0c16b537SWarner Losh     dictionary_t* const parsed_dict = create_dictionary();
110*0c16b537SWarner Losh     if (dict) {
111*0c16b537SWarner Losh         parse_dictionary(parsed_dict, dict, dict_size);
112*0c16b537SWarner Losh     }
113*0c16b537SWarner Losh     size_t decompressed =
114*0c16b537SWarner Losh         ZSTD_decompress_with_dict(output, decompressed_size,
115*0c16b537SWarner Losh                                   input, input_size, parsed_dict);
116*0c16b537SWarner Losh 
117*0c16b537SWarner Losh     free_dictionary(parsed_dict);
118*0c16b537SWarner Losh 
119*0c16b537SWarner Losh     write_file(argv[2], output, decompressed);
120*0c16b537SWarner Losh 
121*0c16b537SWarner Losh     free(input);
122*0c16b537SWarner Losh     free(output);
123*0c16b537SWarner Losh     free(dict);
124*0c16b537SWarner Losh     input = output = dict = NULL;
125*0c16b537SWarner Losh }
126