xref: /linux/crypto/zstd.c (revision a619fe35ab41fded440d3762d4fbad84ff86a4d4)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Cryptographic API.
4  *
5  * Copyright (c) 2017-present, Facebook, Inc.
6  */
7 #include <linux/crypto.h>
8 #include <linux/init.h>
9 #include <linux/interrupt.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/overflow.h>
14 #include <linux/vmalloc.h>
15 #include <linux/zstd.h>
16 #include <crypto/internal/acompress.h>
17 #include <crypto/scatterwalk.h>
18 
19 
20 #define ZSTD_DEF_LEVEL		3
21 #define ZSTD_MAX_WINDOWLOG	18
22 #define ZSTD_MAX_SIZE		BIT(ZSTD_MAX_WINDOWLOG)
23 
24 struct zstd_ctx {
25 	zstd_cctx *cctx;
26 	zstd_dctx *dctx;
27 	size_t wksp_size;
28 	zstd_parameters params;
29 	u8 wksp[] __aligned(8) __counted_by(wksp_size);
30 };
31 
32 static DEFINE_MUTEX(zstd_stream_lock);
33 
zstd_alloc_stream(void)34 static void *zstd_alloc_stream(void)
35 {
36 	zstd_parameters params;
37 	struct zstd_ctx *ctx;
38 	size_t wksp_size;
39 
40 	params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
41 
42 	wksp_size = max(zstd_cstream_workspace_bound(&params.cParams),
43 			zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
44 	if (!wksp_size)
45 		return ERR_PTR(-EINVAL);
46 
47 	ctx = kvmalloc(struct_size(ctx, wksp, wksp_size), GFP_KERNEL);
48 	if (!ctx)
49 		return ERR_PTR(-ENOMEM);
50 
51 	ctx->params = params;
52 	ctx->wksp_size = wksp_size;
53 
54 	return ctx;
55 }
56 
zstd_free_stream(void * ctx)57 static void zstd_free_stream(void *ctx)
58 {
59 	kvfree(ctx);
60 }
61 
62 static struct crypto_acomp_streams zstd_streams = {
63 	.alloc_ctx = zstd_alloc_stream,
64 	.free_ctx = zstd_free_stream,
65 };
66 
zstd_init(struct crypto_acomp * acomp_tfm)67 static int zstd_init(struct crypto_acomp *acomp_tfm)
68 {
69 	int ret = 0;
70 
71 	mutex_lock(&zstd_stream_lock);
72 	ret = crypto_acomp_alloc_streams(&zstd_streams);
73 	mutex_unlock(&zstd_stream_lock);
74 
75 	return ret;
76 }
77 
zstd_compress_one(struct acomp_req * req,struct zstd_ctx * ctx,const void * src,void * dst,unsigned int * dlen)78 static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
79 			     const void *src, void *dst, unsigned int *dlen)
80 {
81 	size_t out_len;
82 
83 	ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
84 	if (!ctx->cctx)
85 		return -EINVAL;
86 
87 	out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
88 				     &ctx->params);
89 	if (zstd_is_error(out_len))
90 		return -EINVAL;
91 
92 	*dlen = out_len;
93 
94 	return 0;
95 }
96 
zstd_compress(struct acomp_req * req)97 static int zstd_compress(struct acomp_req *req)
98 {
99 	struct crypto_acomp_stream *s;
100 	unsigned int pos, scur, dcur;
101 	unsigned int total_out = 0;
102 	bool data_available = true;
103 	zstd_out_buffer outbuf;
104 	struct acomp_walk walk;
105 	zstd_in_buffer inbuf;
106 	struct zstd_ctx *ctx;
107 	size_t pending_bytes;
108 	size_t num_bytes;
109 	int ret;
110 
111 	s = crypto_acomp_lock_stream_bh(&zstd_streams);
112 	ctx = s->ctx;
113 
114 	ret = acomp_walk_virt(&walk, req, true);
115 	if (ret)
116 		goto out;
117 
118 	ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
119 	if (!ctx->cctx) {
120 		ret = -EINVAL;
121 		goto out;
122 	}
123 
124 	do {
125 		dcur = acomp_walk_next_dst(&walk);
126 		if (!dcur) {
127 			ret = -ENOSPC;
128 			goto out;
129 		}
130 
131 		outbuf.pos = 0;
132 		outbuf.dst = (u8 *)walk.dst.virt.addr;
133 		outbuf.size = dcur;
134 
135 		do {
136 			scur = acomp_walk_next_src(&walk);
137 			if (dcur == req->dlen && scur == req->slen) {
138 				ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
139 							walk.dst.virt.addr, &total_out);
140 				acomp_walk_done_src(&walk, scur);
141 				acomp_walk_done_dst(&walk, dcur);
142 				goto out;
143 			}
144 
145 			if (scur) {
146 				inbuf.pos = 0;
147 				inbuf.src = walk.src.virt.addr;
148 				inbuf.size = scur;
149 			} else {
150 				data_available = false;
151 				break;
152 			}
153 
154 			num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
155 			if (ZSTD_isError(num_bytes)) {
156 				ret = -EIO;
157 				goto out;
158 			}
159 
160 			pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
161 			if (ZSTD_isError(pending_bytes)) {
162 				ret = -EIO;
163 				goto out;
164 			}
165 			acomp_walk_done_src(&walk, inbuf.pos);
166 		} while (dcur != outbuf.pos);
167 
168 		total_out += outbuf.pos;
169 		acomp_walk_done_dst(&walk, dcur);
170 	} while (data_available);
171 
172 	pos = outbuf.pos;
173 	num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
174 	if (ZSTD_isError(num_bytes))
175 		ret = -EIO;
176 	else
177 		total_out += (outbuf.pos - pos);
178 
179 out:
180 	if (ret)
181 		req->dlen = 0;
182 	else
183 		req->dlen = total_out;
184 
185 	crypto_acomp_unlock_stream_bh(s);
186 
187 	return ret;
188 }
189 
zstd_decompress_one(struct acomp_req * req,struct zstd_ctx * ctx,const void * src,void * dst,unsigned int * dlen)190 static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
191 			       const void *src, void *dst, unsigned int *dlen)
192 {
193 	size_t out_len;
194 
195 	ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
196 	if (!ctx->dctx)
197 		return -EINVAL;
198 
199 	out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
200 	if (zstd_is_error(out_len))
201 		return -EINVAL;
202 
203 	*dlen = out_len;
204 
205 	return 0;
206 }
207 
zstd_decompress(struct acomp_req * req)208 static int zstd_decompress(struct acomp_req *req)
209 {
210 	struct crypto_acomp_stream *s;
211 	unsigned int total_out = 0;
212 	unsigned int scur, dcur;
213 	zstd_out_buffer outbuf;
214 	struct acomp_walk walk;
215 	zstd_in_buffer inbuf;
216 	struct zstd_ctx *ctx;
217 	size_t pending_bytes;
218 	int ret;
219 
220 	s = crypto_acomp_lock_stream_bh(&zstd_streams);
221 	ctx = s->ctx;
222 
223 	ret = acomp_walk_virt(&walk, req, true);
224 	if (ret)
225 		goto out;
226 
227 	ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
228 	if (!ctx->dctx) {
229 		ret = -EINVAL;
230 		goto out;
231 	}
232 
233 	do {
234 		scur = acomp_walk_next_src(&walk);
235 		if (scur) {
236 			inbuf.pos = 0;
237 			inbuf.size = scur;
238 			inbuf.src = walk.src.virt.addr;
239 		} else {
240 			break;
241 		}
242 
243 		do {
244 			dcur = acomp_walk_next_dst(&walk);
245 			if (dcur == req->dlen && scur == req->slen) {
246 				ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
247 							  walk.dst.virt.addr, &total_out);
248 				acomp_walk_done_dst(&walk, dcur);
249 				acomp_walk_done_src(&walk, scur);
250 				goto out;
251 			}
252 
253 			if (!dcur) {
254 				ret = -ENOSPC;
255 				goto out;
256 			}
257 
258 			outbuf.pos = 0;
259 			outbuf.dst = (u8 *)walk.dst.virt.addr;
260 			outbuf.size = dcur;
261 
262 			pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
263 			if (ZSTD_isError(pending_bytes)) {
264 				ret = -EIO;
265 				goto out;
266 			}
267 
268 			total_out += outbuf.pos;
269 
270 			acomp_walk_done_dst(&walk, outbuf.pos);
271 		} while (inbuf.pos != scur);
272 
273 		acomp_walk_done_src(&walk, scur);
274 	} while (ret == 0);
275 
276 out:
277 	if (ret)
278 		req->dlen = 0;
279 	else
280 		req->dlen = total_out;
281 
282 	crypto_acomp_unlock_stream_bh(s);
283 
284 	return ret;
285 }
286 
287 static struct acomp_alg zstd_acomp = {
288 	.base = {
289 		.cra_name = "zstd",
290 		.cra_driver_name = "zstd-generic",
291 		.cra_flags = CRYPTO_ALG_REQ_VIRT,
292 		.cra_module = THIS_MODULE,
293 	},
294 	.init = zstd_init,
295 	.compress = zstd_compress,
296 	.decompress = zstd_decompress,
297 };
298 
zstd_mod_init(void)299 static int __init zstd_mod_init(void)
300 {
301 	return crypto_register_acomp(&zstd_acomp);
302 }
303 
zstd_mod_fini(void)304 static void __exit zstd_mod_fini(void)
305 {
306 	crypto_unregister_acomp(&zstd_acomp);
307 	crypto_acomp_free_streams(&zstd_streams);
308 }
309 
310 module_init(zstd_mod_init);
311 module_exit(zstd_mod_fini);
312 
313 MODULE_LICENSE("GPL");
314 MODULE_DESCRIPTION("Zstd Compression Algorithm");
315 MODULE_ALIAS_CRYPTO("zstd");
316