1 // SPDX-License-Identifier: GPL-2.0 2 3 #include <string.h> 4 5 #include "util/compress.h" 6 #include "util/debug.h" 7 8 int zstd_init(struct zstd_data *data, int level) 9 { 10 size_t ret; 11 12 data->dstream = ZSTD_createDStream(); 13 if (data->dstream == NULL) { 14 pr_err("Couldn't create decompression stream.\n"); 15 return -1; 16 } 17 18 ret = ZSTD_initDStream(data->dstream); 19 if (ZSTD_isError(ret)) { 20 pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret)); 21 return -1; 22 } 23 24 if (!level) 25 return 0; 26 27 data->cstream = ZSTD_createCStream(); 28 if (data->cstream == NULL) { 29 pr_err("Couldn't create compression stream.\n"); 30 return -1; 31 } 32 33 ret = ZSTD_initCStream(data->cstream, level); 34 if (ZSTD_isError(ret)) { 35 pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret)); 36 return -1; 37 } 38 39 return 0; 40 } 41 42 int zstd_fini(struct zstd_data *data) 43 { 44 if (data->dstream) { 45 ZSTD_freeDStream(data->dstream); 46 data->dstream = NULL; 47 } 48 49 if (data->cstream) { 50 ZSTD_freeCStream(data->cstream); 51 data->cstream = NULL; 52 } 53 54 return 0; 55 } 56 57 size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size, 58 void *src, size_t src_size, size_t max_record_size, 59 size_t process_header(void *record, size_t increment)) 60 { 61 size_t ret, size, compressed = 0; 62 ZSTD_inBuffer input = { src, src_size, 0 }; 63 ZSTD_outBuffer output; 64 void *record; 65 66 while (input.pos < input.size) { 67 record = dst; 68 size = process_header(record, 0); 69 compressed += size; 70 dst += size; 71 dst_size -= size; 72 output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ? 73 max_record_size : dst_size, 0 }; 74 ret = ZSTD_compressStream(data->cstream, &output, &input); 75 ZSTD_flushStream(data->cstream, &output); 76 if (ZSTD_isError(ret)) { 77 pr_err("failed to compress %ld bytes: %s\n", 78 (long)src_size, ZSTD_getErrorName(ret)); 79 memcpy(dst, src, src_size); 80 return src_size; 81 } 82 size = output.pos; 83 size = process_header(record, size); 84 compressed += size; 85 dst += size; 86 dst_size -= size; 87 } 88 89 return compressed; 90 } 91 92 size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size, 93 void *dst, size_t dst_size) 94 { 95 size_t ret; 96 ZSTD_inBuffer input = { src, src_size, 0 }; 97 ZSTD_outBuffer output = { dst, dst_size, 0 }; 98 99 while (input.pos < input.size) { 100 ret = ZSTD_decompressStream(data->dstream, &output, &input); 101 if (ZSTD_isError(ret)) { 102 pr_err("failed to decompress (B): %ld -> %ld, dst_size %ld : %s\n", 103 src_size, output.size, dst_size, ZSTD_getErrorName(ret)); 104 break; 105 } 106 output.dst = dst + output.pos; 107 output.size = dst_size - output.pos; 108 } 109 110 return output.pos; 111 } 112