1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Cryptographic API. 4 * 5 * Copyright (c) 2017-present, Facebook, Inc. 6 */ 7 #include <linux/crypto.h> 8 #include <linux/init.h> 9 #include <linux/interrupt.h> 10 #include <linux/mm.h> 11 #include <linux/module.h> 12 #include <linux/net.h> 13 #include <linux/overflow.h> 14 #include <linux/vmalloc.h> 15 #include <linux/zstd.h> 16 #include <crypto/internal/acompress.h> 17 #include <crypto/scatterwalk.h> 18 19 20 #define ZSTD_DEF_LEVEL 3 21 #define ZSTD_MAX_WINDOWLOG 18 22 #define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG) 23 24 struct zstd_ctx { 25 zstd_cctx *cctx; 26 zstd_dctx *dctx; 27 size_t wksp_size; 28 zstd_parameters params; 29 u8 wksp[] __aligned(8) __counted_by(wksp_size); 30 }; 31 32 static DEFINE_MUTEX(zstd_stream_lock); 33 34 static void *zstd_alloc_stream(void) 35 { 36 zstd_parameters params; 37 struct zstd_ctx *ctx; 38 size_t wksp_size; 39 40 params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE); 41 42 wksp_size = max(zstd_cstream_workspace_bound(¶ms.cParams), 43 zstd_dstream_workspace_bound(ZSTD_MAX_SIZE)); 44 if (!wksp_size) 45 return ERR_PTR(-EINVAL); 46 47 ctx = kvmalloc(struct_size(ctx, wksp, wksp_size), GFP_KERNEL); 48 if (!ctx) 49 return ERR_PTR(-ENOMEM); 50 51 ctx->params = params; 52 ctx->wksp_size = wksp_size; 53 54 return ctx; 55 } 56 57 static void zstd_free_stream(void *ctx) 58 { 59 kvfree(ctx); 60 } 61 62 static struct crypto_acomp_streams zstd_streams = { 63 .alloc_ctx = zstd_alloc_stream, 64 .free_ctx = zstd_free_stream, 65 }; 66 67 static int zstd_init(struct crypto_acomp *acomp_tfm) 68 { 69 int ret = 0; 70 71 mutex_lock(&zstd_stream_lock); 72 ret = crypto_acomp_alloc_streams(&zstd_streams); 73 mutex_unlock(&zstd_stream_lock); 74 75 return ret; 76 } 77 78 static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx, 79 const void *src, void *dst, unsigned int *dlen) 80 { 81 size_t out_len; 82 83 ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size); 84 if (!ctx->cctx) 85 return -EINVAL; 86 87 out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen, 88 &ctx->params); 89 if (zstd_is_error(out_len)) 90 return -EINVAL; 91 92 *dlen = out_len; 93 94 return 0; 95 } 96 97 static int zstd_compress(struct acomp_req *req) 98 { 99 struct crypto_acomp_stream *s; 100 unsigned int pos, scur, dcur; 101 unsigned int total_out = 0; 102 bool data_available = true; 103 zstd_out_buffer outbuf; 104 struct acomp_walk walk; 105 zstd_in_buffer inbuf; 106 struct zstd_ctx *ctx; 107 size_t pending_bytes; 108 size_t num_bytes; 109 int ret; 110 111 s = crypto_acomp_lock_stream_bh(&zstd_streams); 112 ctx = s->ctx; 113 114 ret = acomp_walk_virt(&walk, req, true); 115 if (ret) 116 goto out; 117 118 ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size); 119 if (!ctx->cctx) { 120 ret = -EINVAL; 121 goto out; 122 } 123 124 do { 125 dcur = acomp_walk_next_dst(&walk); 126 if (!dcur) { 127 ret = -ENOSPC; 128 goto out; 129 } 130 131 outbuf.pos = 0; 132 outbuf.dst = (u8 *)walk.dst.virt.addr; 133 outbuf.size = dcur; 134 135 do { 136 scur = acomp_walk_next_src(&walk); 137 if (dcur == req->dlen && scur == req->slen) { 138 ret = zstd_compress_one(req, ctx, walk.src.virt.addr, 139 walk.dst.virt.addr, &total_out); 140 acomp_walk_done_src(&walk, scur); 141 acomp_walk_done_dst(&walk, dcur); 142 goto out; 143 } 144 145 if (scur) { 146 inbuf.pos = 0; 147 inbuf.src = walk.src.virt.addr; 148 inbuf.size = scur; 149 } else { 150 data_available = false; 151 break; 152 } 153 154 num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf); 155 if (ZSTD_isError(num_bytes)) { 156 ret = -EIO; 157 goto out; 158 } 159 160 pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf); 161 if (ZSTD_isError(pending_bytes)) { 162 ret = -EIO; 163 goto out; 164 } 165 acomp_walk_done_src(&walk, inbuf.pos); 166 } while (dcur != outbuf.pos); 167 168 total_out += outbuf.pos; 169 acomp_walk_done_dst(&walk, dcur); 170 } while (data_available); 171 172 pos = outbuf.pos; 173 num_bytes = zstd_end_stream(ctx->cctx, &outbuf); 174 if (ZSTD_isError(num_bytes)) 175 ret = -EIO; 176 else 177 total_out += (outbuf.pos - pos); 178 179 out: 180 if (ret) 181 req->dlen = 0; 182 else 183 req->dlen = total_out; 184 185 crypto_acomp_unlock_stream_bh(s); 186 187 return ret; 188 } 189 190 static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx, 191 const void *src, void *dst, unsigned int *dlen) 192 { 193 size_t out_len; 194 195 ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size); 196 if (!ctx->dctx) 197 return -EINVAL; 198 199 out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen); 200 if (zstd_is_error(out_len)) 201 return -EINVAL; 202 203 *dlen = out_len; 204 205 return 0; 206 } 207 208 static int zstd_decompress(struct acomp_req *req) 209 { 210 struct crypto_acomp_stream *s; 211 unsigned int total_out = 0; 212 unsigned int scur, dcur; 213 zstd_out_buffer outbuf; 214 struct acomp_walk walk; 215 zstd_in_buffer inbuf; 216 struct zstd_ctx *ctx; 217 size_t pending_bytes; 218 int ret; 219 220 s = crypto_acomp_lock_stream_bh(&zstd_streams); 221 ctx = s->ctx; 222 223 ret = acomp_walk_virt(&walk, req, true); 224 if (ret) 225 goto out; 226 227 ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size); 228 if (!ctx->dctx) { 229 ret = -EINVAL; 230 goto out; 231 } 232 233 do { 234 scur = acomp_walk_next_src(&walk); 235 if (scur) { 236 inbuf.pos = 0; 237 inbuf.size = scur; 238 inbuf.src = walk.src.virt.addr; 239 } else { 240 break; 241 } 242 243 do { 244 dcur = acomp_walk_next_dst(&walk); 245 if (dcur == req->dlen && scur == req->slen) { 246 ret = zstd_decompress_one(req, ctx, walk.src.virt.addr, 247 walk.dst.virt.addr, &total_out); 248 acomp_walk_done_dst(&walk, dcur); 249 acomp_walk_done_src(&walk, scur); 250 goto out; 251 } 252 253 if (!dcur) { 254 ret = -ENOSPC; 255 goto out; 256 } 257 258 outbuf.pos = 0; 259 outbuf.dst = (u8 *)walk.dst.virt.addr; 260 outbuf.size = dcur; 261 262 pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf); 263 if (ZSTD_isError(pending_bytes)) { 264 ret = -EIO; 265 goto out; 266 } 267 268 total_out += outbuf.pos; 269 270 acomp_walk_done_dst(&walk, outbuf.pos); 271 } while (inbuf.pos != scur); 272 273 acomp_walk_done_src(&walk, scur); 274 } while (ret == 0); 275 276 out: 277 if (ret) 278 req->dlen = 0; 279 else 280 req->dlen = total_out; 281 282 crypto_acomp_unlock_stream_bh(s); 283 284 return ret; 285 } 286 287 static struct acomp_alg zstd_acomp = { 288 .base = { 289 .cra_name = "zstd", 290 .cra_driver_name = "zstd-generic", 291 .cra_flags = CRYPTO_ALG_REQ_VIRT, 292 .cra_module = THIS_MODULE, 293 }, 294 .init = zstd_init, 295 .compress = zstd_compress, 296 .decompress = zstd_decompress, 297 }; 298 299 static int __init zstd_mod_init(void) 300 { 301 return crypto_register_acomp(&zstd_acomp); 302 } 303 304 static void __exit zstd_mod_fini(void) 305 { 306 crypto_unregister_acomp(&zstd_acomp); 307 crypto_acomp_free_streams(&zstd_streams); 308 } 309 310 module_init(zstd_mod_init); 311 module_exit(zstd_mod_fini); 312 313 MODULE_LICENSE("GPL"); 314 MODULE_DESCRIPTION("Zstd Compression Algorithm"); 315 MODULE_ALIAS_CRYPTO("zstd"); 316