1*0c16b537SWarner Losh /* 2*0c16b537SWarner Losh * Copyright (c) 2017-present, Facebook, Inc. 3*0c16b537SWarner Losh * All rights reserved. 4*0c16b537SWarner Losh * 5*0c16b537SWarner Losh * This source code is licensed under both the BSD-style license (found in the 6*0c16b537SWarner Losh * LICENSE file in the root directory of this source tree) and the GPLv2 (found 7*0c16b537SWarner Losh * in the COPYING file in the root directory of this source tree). 8*0c16b537SWarner Losh */ 9*0c16b537SWarner Losh 10*0c16b537SWarner Losh #include <stdio.h> 11*0c16b537SWarner Losh #include <stdlib.h> 12*0c16b537SWarner Losh 13*0c16b537SWarner Losh #include "zstd_decompress.h" 14*0c16b537SWarner Losh 15*0c16b537SWarner Losh typedef unsigned char u8; 16*0c16b537SWarner Losh 17*0c16b537SWarner Losh // If the data doesn't have decompressed size with it, fallback on assuming the 18*0c16b537SWarner Losh // compression ratio is at most 16 19*0c16b537SWarner Losh #define MAX_COMPRESSION_RATIO (16) 20*0c16b537SWarner Losh 21*0c16b537SWarner Losh // Protect against allocating too much memory for output 22*0c16b537SWarner Losh #define MAX_OUTPUT_SIZE ((size_t)1024 * 1024 * 1024) 23*0c16b537SWarner Losh 24*0c16b537SWarner Losh u8 *input; 25*0c16b537SWarner Losh u8 *output; 26*0c16b537SWarner Losh u8 *dict; 27*0c16b537SWarner Losh 28*0c16b537SWarner Losh size_t read_file(const char *path, u8 **ptr) { 29*0c16b537SWarner Losh FILE *f = fopen(path, "rb"); 30*0c16b537SWarner Losh if (!f) { 31*0c16b537SWarner Losh fprintf(stderr, "failed to open file %s\n", path); 32*0c16b537SWarner Losh exit(1); 33*0c16b537SWarner Losh } 34*0c16b537SWarner Losh 35*0c16b537SWarner Losh fseek(f, 0L, SEEK_END); 36*0c16b537SWarner Losh size_t size = ftell(f); 37*0c16b537SWarner Losh rewind(f); 38*0c16b537SWarner Losh 39*0c16b537SWarner Losh *ptr = malloc(size); 40*0c16b537SWarner Losh if (!ptr) { 41*0c16b537SWarner Losh fprintf(stderr, "failed to allocate memory to hold %s\n", path); 42*0c16b537SWarner Losh exit(1); 43*0c16b537SWarner Losh } 44*0c16b537SWarner Losh 45*0c16b537SWarner Losh size_t pos = 0; 46*0c16b537SWarner Losh while (!feof(f)) { 47*0c16b537SWarner Losh size_t read = fread(&(*ptr)[pos], 1, size, f); 48*0c16b537SWarner Losh if (ferror(f)) { 49*0c16b537SWarner Losh fprintf(stderr, "error while reading file %s\n", path); 50*0c16b537SWarner Losh exit(1); 51*0c16b537SWarner Losh } 52*0c16b537SWarner Losh pos += read; 53*0c16b537SWarner Losh } 54*0c16b537SWarner Losh 55*0c16b537SWarner Losh fclose(f); 56*0c16b537SWarner Losh 57*0c16b537SWarner Losh return pos; 58*0c16b537SWarner Losh } 59*0c16b537SWarner Losh 60*0c16b537SWarner Losh void write_file(const char *path, const u8 *ptr, size_t size) { 61*0c16b537SWarner Losh FILE *f = fopen(path, "wb"); 62*0c16b537SWarner Losh 63*0c16b537SWarner Losh size_t written = 0; 64*0c16b537SWarner Losh while (written < size) { 65*0c16b537SWarner Losh written += fwrite(&ptr[written], 1, size, f); 66*0c16b537SWarner Losh if (ferror(f)) { 67*0c16b537SWarner Losh fprintf(stderr, "error while writing file %s\n", path); 68*0c16b537SWarner Losh exit(1); 69*0c16b537SWarner Losh } 70*0c16b537SWarner Losh } 71*0c16b537SWarner Losh 72*0c16b537SWarner Losh fclose(f); 73*0c16b537SWarner Losh } 74*0c16b537SWarner Losh 75*0c16b537SWarner Losh int main(int argc, char **argv) { 76*0c16b537SWarner Losh if (argc < 3) { 77*0c16b537SWarner Losh fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n", 78*0c16b537SWarner Losh argv[0]); 79*0c16b537SWarner Losh 80*0c16b537SWarner Losh return 1; 81*0c16b537SWarner Losh } 82*0c16b537SWarner Losh 83*0c16b537SWarner Losh size_t input_size = read_file(argv[1], &input); 84*0c16b537SWarner Losh size_t dict_size = 0; 85*0c16b537SWarner Losh if (argc >= 4) { 86*0c16b537SWarner Losh dict_size = read_file(argv[3], &dict); 87*0c16b537SWarner Losh } 88*0c16b537SWarner Losh 89*0c16b537SWarner Losh size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size); 90*0c16b537SWarner Losh if (decompressed_size == (size_t)-1) { 91*0c16b537SWarner Losh decompressed_size = MAX_COMPRESSION_RATIO * input_size; 92*0c16b537SWarner Losh fprintf(stderr, "WARNING: Compressed data does not contain " 93*0c16b537SWarner Losh "decompressed size, going to assume the compression " 94*0c16b537SWarner Losh "ratio is at most %d (decompressed size of at most " 95*0c16b537SWarner Losh "%zu)\n", 96*0c16b537SWarner Losh MAX_COMPRESSION_RATIO, decompressed_size); 97*0c16b537SWarner Losh } 98*0c16b537SWarner Losh if (decompressed_size > MAX_OUTPUT_SIZE) { 99*0c16b537SWarner Losh fprintf(stderr, 100*0c16b537SWarner Losh "Required output size too large for this implementation\n"); 101*0c16b537SWarner Losh return 1; 102*0c16b537SWarner Losh } 103*0c16b537SWarner Losh output = malloc(decompressed_size); 104*0c16b537SWarner Losh if (!output) { 105*0c16b537SWarner Losh fprintf(stderr, "failed to allocate memory\n"); 106*0c16b537SWarner Losh return 1; 107*0c16b537SWarner Losh } 108*0c16b537SWarner Losh 109*0c16b537SWarner Losh dictionary_t* const parsed_dict = create_dictionary(); 110*0c16b537SWarner Losh if (dict) { 111*0c16b537SWarner Losh parse_dictionary(parsed_dict, dict, dict_size); 112*0c16b537SWarner Losh } 113*0c16b537SWarner Losh size_t decompressed = 114*0c16b537SWarner Losh ZSTD_decompress_with_dict(output, decompressed_size, 115*0c16b537SWarner Losh input, input_size, parsed_dict); 116*0c16b537SWarner Losh 117*0c16b537SWarner Losh free_dictionary(parsed_dict); 118*0c16b537SWarner Losh 119*0c16b537SWarner Losh write_file(argv[2], output, decompressed); 120*0c16b537SWarner Losh 121*0c16b537SWarner Losh free(input); 122*0c16b537SWarner Losh free(output); 123*0c16b537SWarner Losh free(dict); 124*0c16b537SWarner Losh input = output = dict = NULL; 125*0c16b537SWarner Losh } 126