xref: /linux/tools/perf/util/zstd.c (revision 06d07429858317ded2db7986113a9e0129cd599b)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <string.h>
4 
5 #include "util/compress.h"
6 #include "util/debug.h"
7 
zstd_init(struct zstd_data * data,int level)8 int zstd_init(struct zstd_data *data, int level)
9 {
10 	data->comp_level = level;
11 	data->dstream = NULL;
12 	data->cstream = NULL;
13 	return 0;
14 }
15 
zstd_fini(struct zstd_data * data)16 int zstd_fini(struct zstd_data *data)
17 {
18 	if (data->dstream) {
19 		ZSTD_freeDStream(data->dstream);
20 		data->dstream = NULL;
21 	}
22 
23 	if (data->cstream) {
24 		ZSTD_freeCStream(data->cstream);
25 		data->cstream = NULL;
26 	}
27 
28 	return 0;
29 }
30 
zstd_compress_stream_to_records(struct zstd_data * data,void * dst,size_t dst_size,void * src,size_t src_size,size_t max_record_size,size_t process_header (void * record,size_t increment))31 ssize_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
32 				       void *src, size_t src_size, size_t max_record_size,
33 				       size_t process_header(void *record, size_t increment))
34 {
35 	size_t ret, size, compressed = 0;
36 	ZSTD_inBuffer input = { src, src_size, 0 };
37 	ZSTD_outBuffer output;
38 	void *record;
39 
40 	if (!data->cstream) {
41 		data->cstream = ZSTD_createCStream();
42 		if (data->cstream == NULL) {
43 			pr_err("Couldn't create compression stream.\n");
44 			return -1;
45 		}
46 
47 		ret = ZSTD_initCStream(data->cstream, data->comp_level);
48 		if (ZSTD_isError(ret)) {
49 			pr_err("Failed to initialize compression stream: %s\n",
50 				ZSTD_getErrorName(ret));
51 			return -1;
52 		}
53 	}
54 
55 	while (input.pos < input.size) {
56 		record = dst;
57 		size = process_header(record, 0);
58 		compressed += size;
59 		dst += size;
60 		dst_size -= size;
61 		output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
62 						max_record_size : dst_size, 0 };
63 		ret = ZSTD_compressStream(data->cstream, &output, &input);
64 		ZSTD_flushStream(data->cstream, &output);
65 		if (ZSTD_isError(ret)) {
66 			pr_err("failed to compress %ld bytes: %s\n",
67 				(long)src_size, ZSTD_getErrorName(ret));
68 			memcpy(dst, src, src_size);
69 			return src_size;
70 		}
71 		size = output.pos;
72 		size = process_header(record, size);
73 		compressed += size;
74 		dst += size;
75 		dst_size -= size;
76 	}
77 
78 	return compressed;
79 }
80 
zstd_decompress_stream(struct zstd_data * data,void * src,size_t src_size,void * dst,size_t dst_size)81 size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
82 			      void *dst, size_t dst_size)
83 {
84 	size_t ret;
85 	ZSTD_inBuffer input = { src, src_size, 0 };
86 	ZSTD_outBuffer output = { dst, dst_size, 0 };
87 
88 	if (!data->dstream) {
89 		data->dstream = ZSTD_createDStream();
90 		if (data->dstream == NULL) {
91 			pr_err("Couldn't create decompression stream.\n");
92 			return 0;
93 		}
94 
95 		ret = ZSTD_initDStream(data->dstream);
96 		if (ZSTD_isError(ret)) {
97 			pr_err("Failed to initialize decompression stream: %s\n",
98 				ZSTD_getErrorName(ret));
99 			return 0;
100 		}
101 	}
102 	while (input.pos < input.size) {
103 		ret = ZSTD_decompressStream(data->dstream, &output, &input);
104 		if (ZSTD_isError(ret)) {
105 			pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
106 			       src_size, output.size, dst_size, ZSTD_getErrorName(ret));
107 			break;
108 		}
109 		output.dst  = dst + output.pos;
110 		output.size = dst_size - output.pos;
111 	}
112 
113 	return output.pos;
114 }
115