xref: /linux/drivers/block/zram/backend_zstd.c (revision f2bac7ad187d77e2a053d3cd04b158a06b683d26)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 
3 #include <linux/kernel.h>
4 #include <linux/slab.h>
5 #include <linux/vmalloc.h>
6 #include <linux/zstd.h>
7 
8 #include "backend_zstd.h"
9 
10 struct zstd_ctx {
11 	zstd_cctx *cctx;
12 	zstd_dctx *dctx;
13 	void *cctx_mem;
14 	void *dctx_mem;
15 	s32 level;
16 };
17 
18 static void zstd_destroy(void *ctx)
19 {
20 	struct zstd_ctx *zctx = ctx;
21 
22 	vfree(zctx->cctx_mem);
23 	vfree(zctx->dctx_mem);
24 	kfree(zctx);
25 }
26 
27 static void *zstd_create(struct zcomp_params *params)
28 {
29 	zstd_parameters prm;
30 	struct zstd_ctx *ctx;
31 	size_t sz;
32 
33 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
34 	if (!ctx)
35 		return NULL;
36 
37 	if (params->level != ZCOMP_PARAM_NO_LEVEL)
38 		ctx->level = params->level;
39 	else
40 		ctx->level = zstd_default_clevel();
41 
42 	prm = zstd_get_params(ctx->level, PAGE_SIZE);
43 	sz = zstd_cctx_workspace_bound(&prm.cParams);
44 	ctx->cctx_mem = vzalloc(sz);
45 	if (!ctx->cctx_mem)
46 		goto error;
47 
48 	ctx->cctx = zstd_init_cctx(ctx->cctx_mem, sz);
49 	if (!ctx->cctx)
50 		goto error;
51 
52 	sz = zstd_dctx_workspace_bound();
53 	ctx->dctx_mem = vzalloc(sz);
54 	if (!ctx->dctx_mem)
55 		goto error;
56 
57 	ctx->dctx = zstd_init_dctx(ctx->dctx_mem, sz);
58 	if (!ctx->dctx)
59 		goto error;
60 
61 	return ctx;
62 
63 error:
64 	zstd_destroy(ctx);
65 	return NULL;
66 }
67 
68 static int zstd_compress(void *ctx, const unsigned char *src, size_t src_len,
69 			 unsigned char *dst, size_t *dst_len)
70 {
71 	struct zstd_ctx *zctx = ctx;
72 	const zstd_parameters prm = zstd_get_params(zctx->level, PAGE_SIZE);
73 	size_t ret;
74 
75 	ret = zstd_compress_cctx(zctx->cctx, dst, *dst_len,
76 				 src, src_len, &prm);
77 	if (zstd_is_error(ret))
78 		return -EINVAL;
79 	*dst_len = ret;
80 	return 0;
81 }
82 
83 static int zstd_decompress(void *ctx, const unsigned char *src, size_t src_len,
84 			   unsigned char *dst, size_t dst_len)
85 {
86 	struct zstd_ctx *zctx = ctx;
87 	size_t ret;
88 
89 	ret = zstd_decompress_dctx(zctx->dctx, dst, dst_len, src, src_len);
90 	if (zstd_is_error(ret))
91 		return -EINVAL;
92 	return 0;
93 }
94 
95 const struct zcomp_ops backend_zstd = {
96 	.compress	= zstd_compress,
97 	.decompress	= zstd_decompress,
98 	.create_ctx	= zstd_create,
99 	.destroy_ctx	= zstd_destroy,
100 	.name		= "zstd",
101 };
102