xref: /linux/crypto/zstd.c (revision 22c55fb9eb92395d999b8404d73e58540d11bdd8)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Cryptographic API.
4  *
5  * Copyright (c) 2017-present, Facebook, Inc.
6  */
7 #include <linux/crypto.h>
8 #include <linux/init.h>
9 #include <linux/interrupt.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/vmalloc.h>
14 #include <linux/zstd.h>
15 #include <crypto/internal/acompress.h>
16 #include <crypto/scatterwalk.h>
17 
18 
19 #define ZSTD_DEF_LEVEL		3
20 #define ZSTD_MAX_WINDOWLOG	18
21 #define ZSTD_MAX_SIZE		BIT(ZSTD_MAX_WINDOWLOG)
22 
23 struct zstd_ctx {
24 	zstd_cctx *cctx;
25 	zstd_dctx *dctx;
26 	size_t wksp_size;
27 	zstd_parameters params;
28 	u8 wksp[] __aligned(8);
29 };
30 
31 static DEFINE_MUTEX(zstd_stream_lock);
32 
33 static void *zstd_alloc_stream(void)
34 {
35 	zstd_parameters params;
36 	struct zstd_ctx *ctx;
37 	size_t wksp_size;
38 
39 	params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
40 
41 	wksp_size = max_t(size_t,
42 			  zstd_cstream_workspace_bound(&params.cParams),
43 			  zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
44 	if (!wksp_size)
45 		return ERR_PTR(-EINVAL);
46 
47 	ctx = kvmalloc(sizeof(*ctx) + wksp_size, GFP_KERNEL);
48 	if (!ctx)
49 		return ERR_PTR(-ENOMEM);
50 
51 	ctx->params = params;
52 	ctx->wksp_size = wksp_size;
53 
54 	return ctx;
55 }
56 
57 static void zstd_free_stream(void *ctx)
58 {
59 	kvfree(ctx);
60 }
61 
62 static struct crypto_acomp_streams zstd_streams = {
63 	.alloc_ctx = zstd_alloc_stream,
64 	.free_ctx = zstd_free_stream,
65 };
66 
67 static int zstd_init(struct crypto_acomp *acomp_tfm)
68 {
69 	int ret = 0;
70 
71 	mutex_lock(&zstd_stream_lock);
72 	ret = crypto_acomp_alloc_streams(&zstd_streams);
73 	mutex_unlock(&zstd_stream_lock);
74 
75 	return ret;
76 }
77 
78 static void zstd_exit(struct crypto_acomp *acomp_tfm)
79 {
80 	crypto_acomp_free_streams(&zstd_streams);
81 }
82 
83 static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
84 			     const void *src, void *dst, unsigned int *dlen)
85 {
86 	unsigned int out_len;
87 
88 	ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
89 	if (!ctx->cctx)
90 		return -EINVAL;
91 
92 	out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
93 				     &ctx->params);
94 	if (zstd_is_error(out_len))
95 		return -EINVAL;
96 
97 	*dlen = out_len;
98 
99 	return 0;
100 }
101 
102 static int zstd_compress(struct acomp_req *req)
103 {
104 	struct crypto_acomp_stream *s;
105 	unsigned int pos, scur, dcur;
106 	unsigned int total_out = 0;
107 	bool data_available = true;
108 	zstd_out_buffer outbuf;
109 	struct acomp_walk walk;
110 	zstd_in_buffer inbuf;
111 	struct zstd_ctx *ctx;
112 	size_t pending_bytes;
113 	size_t num_bytes;
114 	int ret;
115 
116 	s = crypto_acomp_lock_stream_bh(&zstd_streams);
117 	ctx = s->ctx;
118 
119 	ret = acomp_walk_virt(&walk, req, true);
120 	if (ret)
121 		goto out;
122 
123 	ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
124 	if (!ctx->cctx) {
125 		ret = -EINVAL;
126 		goto out;
127 	}
128 
129 	do {
130 		dcur = acomp_walk_next_dst(&walk);
131 		if (!dcur) {
132 			ret = -ENOSPC;
133 			goto out;
134 		}
135 
136 		outbuf.pos = 0;
137 		outbuf.dst = (u8 *)walk.dst.virt.addr;
138 		outbuf.size = dcur;
139 
140 		do {
141 			scur = acomp_walk_next_src(&walk);
142 			if (dcur == req->dlen && scur == req->slen) {
143 				ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
144 							walk.dst.virt.addr, &total_out);
145 				acomp_walk_done_src(&walk, scur);
146 				acomp_walk_done_dst(&walk, dcur);
147 				goto out;
148 			}
149 
150 			if (scur) {
151 				inbuf.pos = 0;
152 				inbuf.src = walk.src.virt.addr;
153 				inbuf.size = scur;
154 			} else {
155 				data_available = false;
156 				break;
157 			}
158 
159 			num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
160 			if (ZSTD_isError(num_bytes)) {
161 				ret = -EIO;
162 				goto out;
163 			}
164 
165 			pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
166 			if (ZSTD_isError(pending_bytes)) {
167 				ret = -EIO;
168 				goto out;
169 			}
170 			acomp_walk_done_src(&walk, inbuf.pos);
171 		} while (dcur != outbuf.pos);
172 
173 		total_out += outbuf.pos;
174 		acomp_walk_done_dst(&walk, dcur);
175 	} while (data_available);
176 
177 	pos = outbuf.pos;
178 	num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
179 	if (ZSTD_isError(num_bytes))
180 		ret = -EIO;
181 	else
182 		total_out += (outbuf.pos - pos);
183 
184 out:
185 	if (ret)
186 		req->dlen = 0;
187 	else
188 		req->dlen = total_out;
189 
190 	crypto_acomp_unlock_stream_bh(s);
191 
192 	return ret;
193 }
194 
195 static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
196 			       const void *src, void *dst, unsigned int *dlen)
197 {
198 	size_t out_len;
199 
200 	ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
201 	if (!ctx->dctx)
202 		return -EINVAL;
203 
204 	out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
205 	if (zstd_is_error(out_len))
206 		return -EINVAL;
207 
208 	*dlen = out_len;
209 
210 	return 0;
211 }
212 
213 static int zstd_decompress(struct acomp_req *req)
214 {
215 	struct crypto_acomp_stream *s;
216 	unsigned int total_out = 0;
217 	unsigned int scur, dcur;
218 	zstd_out_buffer outbuf;
219 	struct acomp_walk walk;
220 	zstd_in_buffer inbuf;
221 	struct zstd_ctx *ctx;
222 	size_t pending_bytes;
223 	int ret;
224 
225 	s = crypto_acomp_lock_stream_bh(&zstd_streams);
226 	ctx = s->ctx;
227 
228 	ret = acomp_walk_virt(&walk, req, true);
229 	if (ret)
230 		goto out;
231 
232 	ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
233 	if (!ctx->dctx) {
234 		ret = -EINVAL;
235 		goto out;
236 	}
237 
238 	do {
239 		scur = acomp_walk_next_src(&walk);
240 		if (scur) {
241 			inbuf.pos = 0;
242 			inbuf.size = scur;
243 			inbuf.src = walk.src.virt.addr;
244 		} else {
245 			break;
246 		}
247 
248 		do {
249 			dcur = acomp_walk_next_dst(&walk);
250 			if (dcur == req->dlen && scur == req->slen) {
251 				ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
252 							  walk.dst.virt.addr, &total_out);
253 				acomp_walk_done_dst(&walk, dcur);
254 				acomp_walk_done_src(&walk, scur);
255 				goto out;
256 			}
257 
258 			if (!dcur) {
259 				ret = -ENOSPC;
260 				goto out;
261 			}
262 
263 			outbuf.pos = 0;
264 			outbuf.dst = (u8 *)walk.dst.virt.addr;
265 			outbuf.size = dcur;
266 
267 			pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
268 			if (ZSTD_isError(pending_bytes)) {
269 				ret = -EIO;
270 				goto out;
271 			}
272 
273 			total_out += outbuf.pos;
274 
275 			acomp_walk_done_dst(&walk, outbuf.pos);
276 		} while (inbuf.pos != scur);
277 
278 		acomp_walk_done_src(&walk, scur);
279 	} while (ret == 0);
280 
281 out:
282 	if (ret)
283 		req->dlen = 0;
284 	else
285 		req->dlen = total_out;
286 
287 	crypto_acomp_unlock_stream_bh(s);
288 
289 	return ret;
290 }
291 
292 static struct acomp_alg zstd_acomp = {
293 	.base = {
294 		.cra_name = "zstd",
295 		.cra_driver_name = "zstd-generic",
296 		.cra_flags = CRYPTO_ALG_REQ_VIRT,
297 		.cra_module = THIS_MODULE,
298 	},
299 	.init = zstd_init,
300 	.exit = zstd_exit,
301 	.compress = zstd_compress,
302 	.decompress = zstd_decompress,
303 };
304 
305 static int __init zstd_mod_init(void)
306 {
307 	return crypto_register_acomp(&zstd_acomp);
308 }
309 
310 static void __exit zstd_mod_fini(void)
311 {
312 	crypto_unregister_acomp(&zstd_acomp);
313 }
314 
315 module_init(zstd_mod_init);
316 module_exit(zstd_mod_fini);
317 
318 MODULE_LICENSE("GPL");
319 MODULE_DESCRIPTION("Zstd Compression Algorithm");
320 MODULE_ALIAS_CRYPTO("zstd");
321