xref: /linux/crypto/zstd.c (revision 962fad301c33dec69324dc2d9320fd84a119a24c)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Cryptographic API.
4  *
5  * Copyright (c) 2017-present, Facebook, Inc.
6  */
7 #include <linux/crypto.h>
8 #include <linux/init.h>
9 #include <linux/interrupt.h>
10 #include <linux/mm.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/vmalloc.h>
14 #include <linux/zstd.h>
15 #include <crypto/internal/scompress.h>
16 
17 
18 #define ZSTD_DEF_LEVEL	3
19 
20 struct zstd_ctx {
21 	ZSTD_CCtx *cctx;
22 	ZSTD_DCtx *dctx;
23 	void *cwksp;
24 	void *dwksp;
25 };
26 
27 static ZSTD_parameters zstd_params(void)
28 {
29 	return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0);
30 }
31 
32 static int zstd_comp_init(struct zstd_ctx *ctx)
33 {
34 	int ret = 0;
35 	const ZSTD_parameters params = zstd_params();
36 	const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams);
37 
38 	ctx->cwksp = vzalloc(wksp_size);
39 	if (!ctx->cwksp) {
40 		ret = -ENOMEM;
41 		goto out;
42 	}
43 
44 	ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size);
45 	if (!ctx->cctx) {
46 		ret = -EINVAL;
47 		goto out_free;
48 	}
49 out:
50 	return ret;
51 out_free:
52 	vfree(ctx->cwksp);
53 	goto out;
54 }
55 
56 static int zstd_decomp_init(struct zstd_ctx *ctx)
57 {
58 	int ret = 0;
59 	const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
60 
61 	ctx->dwksp = vzalloc(wksp_size);
62 	if (!ctx->dwksp) {
63 		ret = -ENOMEM;
64 		goto out;
65 	}
66 
67 	ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size);
68 	if (!ctx->dctx) {
69 		ret = -EINVAL;
70 		goto out_free;
71 	}
72 out:
73 	return ret;
74 out_free:
75 	vfree(ctx->dwksp);
76 	goto out;
77 }
78 
79 static void zstd_comp_exit(struct zstd_ctx *ctx)
80 {
81 	vfree(ctx->cwksp);
82 	ctx->cwksp = NULL;
83 	ctx->cctx = NULL;
84 }
85 
86 static void zstd_decomp_exit(struct zstd_ctx *ctx)
87 {
88 	vfree(ctx->dwksp);
89 	ctx->dwksp = NULL;
90 	ctx->dctx = NULL;
91 }
92 
93 static int __zstd_init(void *ctx)
94 {
95 	int ret;
96 
97 	ret = zstd_comp_init(ctx);
98 	if (ret)
99 		return ret;
100 	ret = zstd_decomp_init(ctx);
101 	if (ret)
102 		zstd_comp_exit(ctx);
103 	return ret;
104 }
105 
106 static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
107 {
108 	int ret;
109 	struct zstd_ctx *ctx;
110 
111 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
112 	if (!ctx)
113 		return ERR_PTR(-ENOMEM);
114 
115 	ret = __zstd_init(ctx);
116 	if (ret) {
117 		kfree(ctx);
118 		return ERR_PTR(ret);
119 	}
120 
121 	return ctx;
122 }
123 
124 static int zstd_init(struct crypto_tfm *tfm)
125 {
126 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
127 
128 	return __zstd_init(ctx);
129 }
130 
131 static void __zstd_exit(void *ctx)
132 {
133 	zstd_comp_exit(ctx);
134 	zstd_decomp_exit(ctx);
135 }
136 
137 static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
138 {
139 	__zstd_exit(ctx);
140 	kfree_sensitive(ctx);
141 }
142 
143 static void zstd_exit(struct crypto_tfm *tfm)
144 {
145 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
146 
147 	__zstd_exit(ctx);
148 }
149 
150 static int __zstd_compress(const u8 *src, unsigned int slen,
151 			   u8 *dst, unsigned int *dlen, void *ctx)
152 {
153 	size_t out_len;
154 	struct zstd_ctx *zctx = ctx;
155 	const ZSTD_parameters params = zstd_params();
156 
157 	out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params);
158 	if (ZSTD_isError(out_len))
159 		return -EINVAL;
160 	*dlen = out_len;
161 	return 0;
162 }
163 
164 static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
165 			 unsigned int slen, u8 *dst, unsigned int *dlen)
166 {
167 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
168 
169 	return __zstd_compress(src, slen, dst, dlen, ctx);
170 }
171 
172 static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
173 			  unsigned int slen, u8 *dst, unsigned int *dlen,
174 			  void *ctx)
175 {
176 	return __zstd_compress(src, slen, dst, dlen, ctx);
177 }
178 
179 static int __zstd_decompress(const u8 *src, unsigned int slen,
180 			     u8 *dst, unsigned int *dlen, void *ctx)
181 {
182 	size_t out_len;
183 	struct zstd_ctx *zctx = ctx;
184 
185 	out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen);
186 	if (ZSTD_isError(out_len))
187 		return -EINVAL;
188 	*dlen = out_len;
189 	return 0;
190 }
191 
192 static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
193 			   unsigned int slen, u8 *dst, unsigned int *dlen)
194 {
195 	struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
196 
197 	return __zstd_decompress(src, slen, dst, dlen, ctx);
198 }
199 
200 static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
201 			    unsigned int slen, u8 *dst, unsigned int *dlen,
202 			    void *ctx)
203 {
204 	return __zstd_decompress(src, slen, dst, dlen, ctx);
205 }
206 
207 static struct crypto_alg alg = {
208 	.cra_name		= "zstd",
209 	.cra_driver_name	= "zstd-generic",
210 	.cra_flags		= CRYPTO_ALG_TYPE_COMPRESS,
211 	.cra_ctxsize		= sizeof(struct zstd_ctx),
212 	.cra_module		= THIS_MODULE,
213 	.cra_init		= zstd_init,
214 	.cra_exit		= zstd_exit,
215 	.cra_u			= { .compress = {
216 	.coa_compress		= zstd_compress,
217 	.coa_decompress		= zstd_decompress } }
218 };
219 
220 static struct scomp_alg scomp = {
221 	.alloc_ctx		= zstd_alloc_ctx,
222 	.free_ctx		= zstd_free_ctx,
223 	.compress		= zstd_scompress,
224 	.decompress		= zstd_sdecompress,
225 	.base			= {
226 		.cra_name	= "zstd",
227 		.cra_driver_name = "zstd-scomp",
228 		.cra_module	 = THIS_MODULE,
229 	}
230 };
231 
232 static int __init zstd_mod_init(void)
233 {
234 	int ret;
235 
236 	ret = crypto_register_alg(&alg);
237 	if (ret)
238 		return ret;
239 
240 	ret = crypto_register_scomp(&scomp);
241 	if (ret)
242 		crypto_unregister_alg(&alg);
243 
244 	return ret;
245 }
246 
247 static void __exit zstd_mod_fini(void)
248 {
249 	crypto_unregister_alg(&alg);
250 	crypto_unregister_scomp(&scomp);
251 }
252 
253 subsys_initcall(zstd_mod_init);
254 module_exit(zstd_mod_fini);
255 
256 MODULE_LICENSE("GPL");
257 MODULE_DESCRIPTION("Zstd Compression Algorithm");
258 MODULE_ALIAS_CRYPTO("zstd");
259