1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Cryptographic API. 4 * 5 * Copyright (c) 2017-present, Facebook, Inc. 6 */ 7 #include <linux/crypto.h> 8 #include <linux/init.h> 9 #include <linux/interrupt.h> 10 #include <linux/mm.h> 11 #include <linux/module.h> 12 #include <linux/net.h> 13 #include <linux/vmalloc.h> 14 #include <linux/zstd.h> 15 #include <crypto/internal/acompress.h> 16 #include <crypto/scatterwalk.h> 17 18 19 #define ZSTD_DEF_LEVEL 3 20 #define ZSTD_MAX_WINDOWLOG 18 21 #define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG) 22 23 struct zstd_ctx { 24 zstd_cctx *cctx; 25 zstd_dctx *dctx; 26 size_t wksp_size; 27 zstd_parameters params; 28 u8 wksp[] __aligned(8); 29 }; 30 31 static DEFINE_MUTEX(zstd_stream_lock); 32 33 static void *zstd_alloc_stream(void) 34 { 35 zstd_parameters params; 36 struct zstd_ctx *ctx; 37 size_t wksp_size; 38 39 params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE); 40 41 wksp_size = max_t(size_t, 42 zstd_cstream_workspace_bound(¶ms.cParams), 43 zstd_dstream_workspace_bound(ZSTD_MAX_SIZE)); 44 if (!wksp_size) 45 return ERR_PTR(-EINVAL); 46 47 ctx = kvmalloc(sizeof(*ctx) + wksp_size, GFP_KERNEL); 48 if (!ctx) 49 return ERR_PTR(-ENOMEM); 50 51 ctx->params = params; 52 ctx->wksp_size = wksp_size; 53 54 return ctx; 55 } 56 57 static void zstd_free_stream(void *ctx) 58 { 59 kvfree(ctx); 60 } 61 62 static struct crypto_acomp_streams zstd_streams = { 63 .alloc_ctx = zstd_alloc_stream, 64 .free_ctx = zstd_free_stream, 65 }; 66 67 static int zstd_init(struct crypto_acomp *acomp_tfm) 68 { 69 int ret = 0; 70 71 mutex_lock(&zstd_stream_lock); 72 ret = crypto_acomp_alloc_streams(&zstd_streams); 73 mutex_unlock(&zstd_stream_lock); 74 75 return ret; 76 } 77 78 static void zstd_exit(struct crypto_acomp *acomp_tfm) 79 { 80 crypto_acomp_free_streams(&zstd_streams); 81 } 82 83 static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx, 84 const void *src, void *dst, unsigned int *dlen) 85 { 86 unsigned int out_len; 87 88 ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size); 89 if (!ctx->cctx) 90 return -EINVAL; 91 92 out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen, 93 &ctx->params); 94 if (zstd_is_error(out_len)) 95 return -EINVAL; 96 97 *dlen = out_len; 98 99 return 0; 100 } 101 102 static int zstd_compress(struct acomp_req *req) 103 { 104 struct crypto_acomp_stream *s; 105 unsigned int pos, scur, dcur; 106 unsigned int total_out = 0; 107 bool data_available = true; 108 zstd_out_buffer outbuf; 109 struct acomp_walk walk; 110 zstd_in_buffer inbuf; 111 struct zstd_ctx *ctx; 112 size_t pending_bytes; 113 size_t num_bytes; 114 int ret; 115 116 s = crypto_acomp_lock_stream_bh(&zstd_streams); 117 ctx = s->ctx; 118 119 ret = acomp_walk_virt(&walk, req, true); 120 if (ret) 121 goto out; 122 123 ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size); 124 if (!ctx->cctx) { 125 ret = -EINVAL; 126 goto out; 127 } 128 129 do { 130 dcur = acomp_walk_next_dst(&walk); 131 if (!dcur) { 132 ret = -ENOSPC; 133 goto out; 134 } 135 136 outbuf.pos = 0; 137 outbuf.dst = (u8 *)walk.dst.virt.addr; 138 outbuf.size = dcur; 139 140 do { 141 scur = acomp_walk_next_src(&walk); 142 if (dcur == req->dlen && scur == req->slen) { 143 ret = zstd_compress_one(req, ctx, walk.src.virt.addr, 144 walk.dst.virt.addr, &total_out); 145 acomp_walk_done_src(&walk, scur); 146 acomp_walk_done_dst(&walk, dcur); 147 goto out; 148 } 149 150 if (scur) { 151 inbuf.pos = 0; 152 inbuf.src = walk.src.virt.addr; 153 inbuf.size = scur; 154 } else { 155 data_available = false; 156 break; 157 } 158 159 num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf); 160 if (ZSTD_isError(num_bytes)) { 161 ret = -EIO; 162 goto out; 163 } 164 165 pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf); 166 if (ZSTD_isError(pending_bytes)) { 167 ret = -EIO; 168 goto out; 169 } 170 acomp_walk_done_src(&walk, inbuf.pos); 171 } while (dcur != outbuf.pos); 172 173 total_out += outbuf.pos; 174 acomp_walk_done_dst(&walk, dcur); 175 } while (data_available); 176 177 pos = outbuf.pos; 178 num_bytes = zstd_end_stream(ctx->cctx, &outbuf); 179 if (ZSTD_isError(num_bytes)) 180 ret = -EIO; 181 else 182 total_out += (outbuf.pos - pos); 183 184 out: 185 if (ret) 186 req->dlen = 0; 187 else 188 req->dlen = total_out; 189 190 crypto_acomp_unlock_stream_bh(s); 191 192 return ret; 193 } 194 195 static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx, 196 const void *src, void *dst, unsigned int *dlen) 197 { 198 size_t out_len; 199 200 ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size); 201 if (!ctx->dctx) 202 return -EINVAL; 203 204 out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen); 205 if (zstd_is_error(out_len)) 206 return -EINVAL; 207 208 *dlen = out_len; 209 210 return 0; 211 } 212 213 static int zstd_decompress(struct acomp_req *req) 214 { 215 struct crypto_acomp_stream *s; 216 unsigned int total_out = 0; 217 unsigned int scur, dcur; 218 zstd_out_buffer outbuf; 219 struct acomp_walk walk; 220 zstd_in_buffer inbuf; 221 struct zstd_ctx *ctx; 222 size_t pending_bytes; 223 int ret; 224 225 s = crypto_acomp_lock_stream_bh(&zstd_streams); 226 ctx = s->ctx; 227 228 ret = acomp_walk_virt(&walk, req, true); 229 if (ret) 230 goto out; 231 232 ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size); 233 if (!ctx->dctx) { 234 ret = -EINVAL; 235 goto out; 236 } 237 238 do { 239 scur = acomp_walk_next_src(&walk); 240 if (scur) { 241 inbuf.pos = 0; 242 inbuf.size = scur; 243 inbuf.src = walk.src.virt.addr; 244 } else { 245 break; 246 } 247 248 do { 249 dcur = acomp_walk_next_dst(&walk); 250 if (dcur == req->dlen && scur == req->slen) { 251 ret = zstd_decompress_one(req, ctx, walk.src.virt.addr, 252 walk.dst.virt.addr, &total_out); 253 acomp_walk_done_dst(&walk, dcur); 254 acomp_walk_done_src(&walk, scur); 255 goto out; 256 } 257 258 if (!dcur) { 259 ret = -ENOSPC; 260 goto out; 261 } 262 263 outbuf.pos = 0; 264 outbuf.dst = (u8 *)walk.dst.virt.addr; 265 outbuf.size = dcur; 266 267 pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf); 268 if (ZSTD_isError(pending_bytes)) { 269 ret = -EIO; 270 goto out; 271 } 272 273 total_out += outbuf.pos; 274 275 acomp_walk_done_dst(&walk, outbuf.pos); 276 } while (inbuf.pos != scur); 277 278 acomp_walk_done_src(&walk, scur); 279 } while (ret == 0); 280 281 out: 282 if (ret) 283 req->dlen = 0; 284 else 285 req->dlen = total_out; 286 287 crypto_acomp_unlock_stream_bh(s); 288 289 return ret; 290 } 291 292 static struct acomp_alg zstd_acomp = { 293 .base = { 294 .cra_name = "zstd", 295 .cra_driver_name = "zstd-generic", 296 .cra_flags = CRYPTO_ALG_REQ_VIRT, 297 .cra_module = THIS_MODULE, 298 }, 299 .init = zstd_init, 300 .exit = zstd_exit, 301 .compress = zstd_compress, 302 .decompress = zstd_decompress, 303 }; 304 305 static int __init zstd_mod_init(void) 306 { 307 return crypto_register_acomp(&zstd_acomp); 308 } 309 310 static void __exit zstd_mod_fini(void) 311 { 312 crypto_unregister_acomp(&zstd_acomp); 313 } 314 315 module_init(zstd_mod_init); 316 module_exit(zstd_mod_fini); 317 318 MODULE_LICENSE("GPL"); 319 MODULE_DESCRIPTION("Zstd Compression Algorithm"); 320 MODULE_ALIAS_CRYPTO("zstd"); 321