1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Cryptographic API.
4 *
5 * Copyright (c) 2017-present, Facebook, Inc.
6 */
7 #include <linux/crypto.h>
8 #include <linux/init.h>
9 #include <linux/interrupt.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/overflow.h>
14 #include <linux/vmalloc.h>
15 #include <linux/zstd.h>
16 #include <crypto/internal/acompress.h>
17 #include <crypto/scatterwalk.h>
18
19
20 #define ZSTD_DEF_LEVEL 3
21 #define ZSTD_MAX_WINDOWLOG 18
22 #define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG)
23
24 struct zstd_ctx {
25 zstd_cctx *cctx;
26 zstd_dctx *dctx;
27 size_t wksp_size;
28 zstd_parameters params;
29 u8 wksp[] __aligned(8) __counted_by(wksp_size);
30 };
31
32 static DEFINE_MUTEX(zstd_stream_lock);
33
zstd_alloc_stream(void)34 static void *zstd_alloc_stream(void)
35 {
36 zstd_parameters params;
37 struct zstd_ctx *ctx;
38 size_t wksp_size;
39
40 params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
41
42 wksp_size = max(zstd_cstream_workspace_bound(¶ms.cParams),
43 zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
44 if (!wksp_size)
45 return ERR_PTR(-EINVAL);
46
47 ctx = kvmalloc(struct_size(ctx, wksp, wksp_size), GFP_KERNEL);
48 if (!ctx)
49 return ERR_PTR(-ENOMEM);
50
51 ctx->params = params;
52 ctx->wksp_size = wksp_size;
53
54 return ctx;
55 }
56
zstd_free_stream(void * ctx)57 static void zstd_free_stream(void *ctx)
58 {
59 kvfree(ctx);
60 }
61
62 static struct crypto_acomp_streams zstd_streams = {
63 .alloc_ctx = zstd_alloc_stream,
64 .free_ctx = zstd_free_stream,
65 };
66
zstd_init(struct crypto_acomp * acomp_tfm)67 static int zstd_init(struct crypto_acomp *acomp_tfm)
68 {
69 int ret = 0;
70
71 mutex_lock(&zstd_stream_lock);
72 ret = crypto_acomp_alloc_streams(&zstd_streams);
73 mutex_unlock(&zstd_stream_lock);
74
75 return ret;
76 }
77
zstd_compress_one(struct acomp_req * req,struct zstd_ctx * ctx,const void * src,void * dst,unsigned int * dlen)78 static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
79 const void *src, void *dst, unsigned int *dlen)
80 {
81 size_t out_len;
82
83 ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
84 if (!ctx->cctx)
85 return -EINVAL;
86
87 out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
88 &ctx->params);
89 if (zstd_is_error(out_len))
90 return -EINVAL;
91
92 *dlen = out_len;
93
94 return 0;
95 }
96
zstd_compress(struct acomp_req * req)97 static int zstd_compress(struct acomp_req *req)
98 {
99 struct crypto_acomp_stream *s;
100 unsigned int pos, scur, dcur;
101 unsigned int total_out = 0;
102 bool data_available = true;
103 zstd_out_buffer outbuf;
104 struct acomp_walk walk;
105 zstd_in_buffer inbuf;
106 struct zstd_ctx *ctx;
107 size_t pending_bytes;
108 size_t num_bytes;
109 int ret;
110
111 s = crypto_acomp_lock_stream_bh(&zstd_streams);
112 ctx = s->ctx;
113
114 ret = acomp_walk_virt(&walk, req, true);
115 if (ret)
116 goto out;
117
118 ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
119 if (!ctx->cctx) {
120 ret = -EINVAL;
121 goto out;
122 }
123
124 do {
125 dcur = acomp_walk_next_dst(&walk);
126 if (!dcur) {
127 ret = -ENOSPC;
128 goto out;
129 }
130
131 outbuf.pos = 0;
132 outbuf.dst = (u8 *)walk.dst.virt.addr;
133 outbuf.size = dcur;
134
135 do {
136 scur = acomp_walk_next_src(&walk);
137 if (dcur == req->dlen && scur == req->slen) {
138 ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
139 walk.dst.virt.addr, &total_out);
140 acomp_walk_done_src(&walk, scur);
141 acomp_walk_done_dst(&walk, dcur);
142 goto out;
143 }
144
145 if (scur) {
146 inbuf.pos = 0;
147 inbuf.src = walk.src.virt.addr;
148 inbuf.size = scur;
149 } else {
150 data_available = false;
151 break;
152 }
153
154 num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
155 if (ZSTD_isError(num_bytes)) {
156 ret = -EIO;
157 goto out;
158 }
159
160 pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
161 if (ZSTD_isError(pending_bytes)) {
162 ret = -EIO;
163 goto out;
164 }
165 acomp_walk_done_src(&walk, inbuf.pos);
166 } while (dcur != outbuf.pos);
167
168 total_out += outbuf.pos;
169 acomp_walk_done_dst(&walk, dcur);
170 } while (data_available);
171
172 pos = outbuf.pos;
173 num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
174 if (ZSTD_isError(num_bytes))
175 ret = -EIO;
176 else
177 total_out += (outbuf.pos - pos);
178
179 out:
180 if (ret)
181 req->dlen = 0;
182 else
183 req->dlen = total_out;
184
185 crypto_acomp_unlock_stream_bh(s);
186
187 return ret;
188 }
189
zstd_decompress_one(struct acomp_req * req,struct zstd_ctx * ctx,const void * src,void * dst,unsigned int * dlen)190 static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
191 const void *src, void *dst, unsigned int *dlen)
192 {
193 size_t out_len;
194
195 ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
196 if (!ctx->dctx)
197 return -EINVAL;
198
199 out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
200 if (zstd_is_error(out_len))
201 return -EINVAL;
202
203 *dlen = out_len;
204
205 return 0;
206 }
207
zstd_decompress(struct acomp_req * req)208 static int zstd_decompress(struct acomp_req *req)
209 {
210 struct crypto_acomp_stream *s;
211 unsigned int total_out = 0;
212 unsigned int scur, dcur;
213 zstd_out_buffer outbuf;
214 struct acomp_walk walk;
215 zstd_in_buffer inbuf;
216 struct zstd_ctx *ctx;
217 size_t pending_bytes;
218 int ret;
219
220 s = crypto_acomp_lock_stream_bh(&zstd_streams);
221 ctx = s->ctx;
222
223 ret = acomp_walk_virt(&walk, req, true);
224 if (ret)
225 goto out;
226
227 ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
228 if (!ctx->dctx) {
229 ret = -EINVAL;
230 goto out;
231 }
232
233 do {
234 scur = acomp_walk_next_src(&walk);
235 if (scur) {
236 inbuf.pos = 0;
237 inbuf.size = scur;
238 inbuf.src = walk.src.virt.addr;
239 } else {
240 break;
241 }
242
243 do {
244 dcur = acomp_walk_next_dst(&walk);
245 if (dcur == req->dlen && scur == req->slen) {
246 ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
247 walk.dst.virt.addr, &total_out);
248 acomp_walk_done_dst(&walk, dcur);
249 acomp_walk_done_src(&walk, scur);
250 goto out;
251 }
252
253 if (!dcur) {
254 ret = -ENOSPC;
255 goto out;
256 }
257
258 outbuf.pos = 0;
259 outbuf.dst = (u8 *)walk.dst.virt.addr;
260 outbuf.size = dcur;
261
262 pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
263 if (ZSTD_isError(pending_bytes)) {
264 ret = -EIO;
265 goto out;
266 }
267
268 total_out += outbuf.pos;
269
270 acomp_walk_done_dst(&walk, outbuf.pos);
271 } while (inbuf.pos != scur);
272
273 acomp_walk_done_src(&walk, scur);
274 } while (ret == 0);
275
276 out:
277 if (ret)
278 req->dlen = 0;
279 else
280 req->dlen = total_out;
281
282 crypto_acomp_unlock_stream_bh(s);
283
284 return ret;
285 }
286
287 static struct acomp_alg zstd_acomp = {
288 .base = {
289 .cra_name = "zstd",
290 .cra_driver_name = "zstd-generic",
291 .cra_flags = CRYPTO_ALG_REQ_VIRT,
292 .cra_module = THIS_MODULE,
293 },
294 .init = zstd_init,
295 .compress = zstd_compress,
296 .decompress = zstd_decompress,
297 };
298
zstd_mod_init(void)299 static int __init zstd_mod_init(void)
300 {
301 return crypto_register_acomp(&zstd_acomp);
302 }
303
zstd_mod_fini(void)304 static void __exit zstd_mod_fini(void)
305 {
306 crypto_unregister_acomp(&zstd_acomp);
307 crypto_acomp_free_streams(&zstd_streams);
308 }
309
310 module_init(zstd_mod_init);
311 module_exit(zstd_mod_fini);
312
313 MODULE_LICENSE("GPL");
314 MODULE_DESCRIPTION("Zstd Compression Algorithm");
315 MODULE_ALIAS_CRYPTO("zstd");
316