1 /* 2 * Cryptographic API. 3 * 4 * Copyright (c) 2017-present, Facebook, Inc. 5 * 6 * This program is free software; you can redistribute it and/or modify it 7 * under the terms of the GNU General Public License version 2 as published by 8 * the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, but WITHOUT 11 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 12 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for 13 * more details. 14 */ 15 #include <linux/crypto.h> 16 #include <linux/init.h> 17 #include <linux/interrupt.h> 18 #include <linux/mm.h> 19 #include <linux/module.h> 20 #include <linux/net.h> 21 #include <linux/vmalloc.h> 22 #include <linux/zstd.h> 23 #include <crypto/internal/scompress.h> 24 25 26 #define ZSTD_DEF_LEVEL 3 27 28 struct zstd_ctx { 29 ZSTD_CCtx *cctx; 30 ZSTD_DCtx *dctx; 31 void *cwksp; 32 void *dwksp; 33 }; 34 35 static ZSTD_parameters zstd_params(void) 36 { 37 return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0); 38 } 39 40 static int zstd_comp_init(struct zstd_ctx *ctx) 41 { 42 int ret = 0; 43 const ZSTD_parameters params = zstd_params(); 44 const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams); 45 46 ctx->cwksp = vzalloc(wksp_size); 47 if (!ctx->cwksp) { 48 ret = -ENOMEM; 49 goto out; 50 } 51 52 ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size); 53 if (!ctx->cctx) { 54 ret = -EINVAL; 55 goto out_free; 56 } 57 out: 58 return ret; 59 out_free: 60 vfree(ctx->cwksp); 61 goto out; 62 } 63 64 static int zstd_decomp_init(struct zstd_ctx *ctx) 65 { 66 int ret = 0; 67 const size_t wksp_size = ZSTD_DCtxWorkspaceBound(); 68 69 ctx->dwksp = vzalloc(wksp_size); 70 if (!ctx->dwksp) { 71 ret = -ENOMEM; 72 goto out; 73 } 74 75 ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size); 76 if (!ctx->dctx) { 77 ret = -EINVAL; 78 goto out_free; 79 } 80 out: 81 return ret; 82 out_free: 83 vfree(ctx->dwksp); 84 goto out; 85 } 86 87 static void zstd_comp_exit(struct zstd_ctx *ctx) 88 { 89 vfree(ctx->cwksp); 90 ctx->cwksp = NULL; 91 ctx->cctx = NULL; 92 } 93 94 static void zstd_decomp_exit(struct zstd_ctx *ctx) 95 { 96 vfree(ctx->dwksp); 97 ctx->dwksp = NULL; 98 ctx->dctx = NULL; 99 } 100 101 static int __zstd_init(void *ctx) 102 { 103 int ret; 104 105 ret = zstd_comp_init(ctx); 106 if (ret) 107 return ret; 108 ret = zstd_decomp_init(ctx); 109 if (ret) 110 zstd_comp_exit(ctx); 111 return ret; 112 } 113 114 static void *zstd_alloc_ctx(struct crypto_scomp *tfm) 115 { 116 int ret; 117 struct zstd_ctx *ctx; 118 119 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); 120 if (!ctx) 121 return ERR_PTR(-ENOMEM); 122 123 ret = __zstd_init(ctx); 124 if (ret) { 125 kfree(ctx); 126 return ERR_PTR(ret); 127 } 128 129 return ctx; 130 } 131 132 static int zstd_init(struct crypto_tfm *tfm) 133 { 134 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); 135 136 return __zstd_init(ctx); 137 } 138 139 static void __zstd_exit(void *ctx) 140 { 141 zstd_comp_exit(ctx); 142 zstd_decomp_exit(ctx); 143 } 144 145 static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx) 146 { 147 __zstd_exit(ctx); 148 kzfree(ctx); 149 } 150 151 static void zstd_exit(struct crypto_tfm *tfm) 152 { 153 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); 154 155 __zstd_exit(ctx); 156 } 157 158 static int __zstd_compress(const u8 *src, unsigned int slen, 159 u8 *dst, unsigned int *dlen, void *ctx) 160 { 161 size_t out_len; 162 struct zstd_ctx *zctx = ctx; 163 const ZSTD_parameters params = zstd_params(); 164 165 out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params); 166 if (ZSTD_isError(out_len)) 167 return -EINVAL; 168 *dlen = out_len; 169 return 0; 170 } 171 172 static int zstd_compress(struct crypto_tfm *tfm, const u8 *src, 173 unsigned int slen, u8 *dst, unsigned int *dlen) 174 { 175 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); 176 177 return __zstd_compress(src, slen, dst, dlen, ctx); 178 } 179 180 static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src, 181 unsigned int slen, u8 *dst, unsigned int *dlen, 182 void *ctx) 183 { 184 return __zstd_compress(src, slen, dst, dlen, ctx); 185 } 186 187 static int __zstd_decompress(const u8 *src, unsigned int slen, 188 u8 *dst, unsigned int *dlen, void *ctx) 189 { 190 size_t out_len; 191 struct zstd_ctx *zctx = ctx; 192 193 out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen); 194 if (ZSTD_isError(out_len)) 195 return -EINVAL; 196 *dlen = out_len; 197 return 0; 198 } 199 200 static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src, 201 unsigned int slen, u8 *dst, unsigned int *dlen) 202 { 203 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm); 204 205 return __zstd_decompress(src, slen, dst, dlen, ctx); 206 } 207 208 static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src, 209 unsigned int slen, u8 *dst, unsigned int *dlen, 210 void *ctx) 211 { 212 return __zstd_decompress(src, slen, dst, dlen, ctx); 213 } 214 215 static struct crypto_alg alg = { 216 .cra_name = "zstd", 217 .cra_flags = CRYPTO_ALG_TYPE_COMPRESS, 218 .cra_ctxsize = sizeof(struct zstd_ctx), 219 .cra_module = THIS_MODULE, 220 .cra_init = zstd_init, 221 .cra_exit = zstd_exit, 222 .cra_u = { .compress = { 223 .coa_compress = zstd_compress, 224 .coa_decompress = zstd_decompress } } 225 }; 226 227 static struct scomp_alg scomp = { 228 .alloc_ctx = zstd_alloc_ctx, 229 .free_ctx = zstd_free_ctx, 230 .compress = zstd_scompress, 231 .decompress = zstd_sdecompress, 232 .base = { 233 .cra_name = "zstd", 234 .cra_driver_name = "zstd-scomp", 235 .cra_module = THIS_MODULE, 236 } 237 }; 238 239 static int __init zstd_mod_init(void) 240 { 241 int ret; 242 243 ret = crypto_register_alg(&alg); 244 if (ret) 245 return ret; 246 247 ret = crypto_register_scomp(&scomp); 248 if (ret) 249 crypto_unregister_alg(&alg); 250 251 return ret; 252 } 253 254 static void __exit zstd_mod_fini(void) 255 { 256 crypto_unregister_alg(&alg); 257 crypto_unregister_scomp(&scomp); 258 } 259 260 subsys_initcall(zstd_mod_init); 261 module_exit(zstd_mod_fini); 262 263 MODULE_LICENSE("GPL"); 264 MODULE_DESCRIPTION("Zstd Compression Algorithm"); 265 MODULE_ALIAS_CRYPTO("zstd"); 266