xref: /linux/drivers/crypto/intel/qat/qat_common/qat_comp_algs.c (revision fbf5df34a4dbcd09d433dd4f0916bf9b2ddb16de)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright(c) 2022 Intel Corporation */
3 #include <linux/crypto.h>
4 #include <crypto/acompress.h>
5 #include <crypto/internal/acompress.h>
6 #include <crypto/scatterwalk.h>
7 #include <linux/dma-mapping.h>
8 #include <linux/workqueue.h>
9 #include <linux/zstd.h>
10 #include "adf_accel_devices.h"
11 #include "adf_common_drv.h"
12 #include "adf_dc.h"
13 #include "qat_bl.h"
14 #include "qat_comp_req.h"
15 #include "qat_compression.h"
16 #include "qat_algs_send.h"
17 #include "qat_comp_zstd_utils.h"
18 
19 #define QAT_ZSTD_SCRATCH_SIZE		524288
20 #define QAT_ZSTD_MAX_BLOCK_SIZE		65535
21 #define QAT_ZSTD_MAX_CONTENT_SIZE	4096
22 #define QAT_LZ4S_MIN_INPUT_SIZE		8192
23 #define QAT_LZ4S_MAX_OUTPUT_SIZE	QAT_ZSTD_SCRATCH_SIZE
24 #define QAT_MAX_SEQUENCES		(128 * 1024)
25 
26 static DEFINE_MUTEX(algs_lock);
27 static unsigned int active_devs_deflate;
28 static unsigned int active_devs_lz4s;
29 static unsigned int active_devs_zstd;
30 
31 struct qat_zstd_scratch {
32 	size_t		cctx_buffer_size;
33 	void		*lz4s;
34 	void		*literals;
35 	void		*out_seqs;
36 	void		*workspace;
37 	ZSTD_CCtx	*ctx;
38 };
39 
40 static void *qat_zstd_alloc_scratch(void)
41 {
42 	struct qat_zstd_scratch *scratch;
43 	ZSTD_parameters params;
44 	size_t cctx_size;
45 	ZSTD_CCtx *ctx;
46 	size_t zret;
47 	int ret;
48 
49 	ret = -ENOMEM;
50 	scratch = kzalloc_obj(*scratch);
51 	if (!scratch)
52 		return ERR_PTR(ret);
53 
54 	scratch->lz4s = kvmalloc(QAT_ZSTD_SCRATCH_SIZE, GFP_KERNEL);
55 	if (!scratch->lz4s)
56 		goto error;
57 
58 	scratch->literals = kvmalloc(QAT_ZSTD_SCRATCH_SIZE, GFP_KERNEL);
59 	if (!scratch->literals)
60 		goto error;
61 
62 	scratch->out_seqs = kvcalloc(QAT_MAX_SEQUENCES, sizeof(ZSTD_Sequence),
63 				     GFP_KERNEL);
64 	if (!scratch->out_seqs)
65 		goto error;
66 
67 	params = zstd_get_params(zstd_max_clevel(), QAT_ZSTD_SCRATCH_SIZE);
68 	cctx_size = zstd_cctx_workspace_bound(&params.cParams);
69 
70 	scratch->workspace = kvmalloc(cctx_size, GFP_KERNEL | __GFP_ZERO);
71 	if (!scratch->workspace)
72 		goto error;
73 
74 	ret = -EINVAL;
75 	ctx = zstd_init_cctx(scratch->workspace, cctx_size);
76 	if (!ctx)
77 		goto error;
78 
79 	scratch->ctx = ctx;
80 	scratch->cctx_buffer_size = cctx_size;
81 
82 	zret = zstd_cctx_set_param(ctx, ZSTD_c_blockDelimiters, ZSTD_sf_explicitBlockDelimiters);
83 	if (zstd_is_error(zret))
84 		goto error;
85 
86 	return scratch;
87 
88 error:
89 	kvfree(scratch->lz4s);
90 	kvfree(scratch->literals);
91 	kvfree(scratch->out_seqs);
92 	kvfree(scratch->workspace);
93 	kfree(scratch);
94 	return ERR_PTR(ret);
95 }
96 
97 static void qat_zstd_free_scratch(void *ctx)
98 {
99 	struct qat_zstd_scratch *scratch = ctx;
100 
101 	if (!scratch)
102 		return;
103 
104 	kvfree(scratch->lz4s);
105 	kvfree(scratch->literals);
106 	kvfree(scratch->out_seqs);
107 	kvfree(scratch->workspace);
108 	kfree(scratch);
109 }
110 
111 static struct crypto_acomp_streams qat_zstd_streams = {
112 	.alloc_ctx = qat_zstd_alloc_scratch,
113 	.free_ctx = qat_zstd_free_scratch,
114 };
115 
116 enum direction {
117 	DECOMPRESSION = 0,
118 	COMPRESSION = 1,
119 };
120 
121 struct qat_compression_req;
122 
123 struct qat_callback_params {
124 	unsigned int produced;
125 	unsigned int dlen;
126 	bool plain;
127 };
128 
129 struct qat_compression_ctx {
130 	u8 comp_ctx[QAT_COMP_CTX_SIZE];
131 	struct qat_compression_instance *inst;
132 	int (*qat_comp_callback)(struct qat_compression_req *qat_req, void *resp,
133 				 struct qat_callback_params *params);
134 	struct crypto_acomp *ftfm;
135 };
136 
137 struct qat_compression_req {
138 	u8 req[QAT_COMP_REQ_SIZE];
139 	struct qat_compression_ctx *qat_compression_ctx;
140 	struct acomp_req *acompress_req;
141 	struct qat_request_buffs buf;
142 	enum direction dir;
143 	int actual_dlen;
144 	struct qat_alg_req alg_req;
145 };
146 
147 static int qat_alg_send_dc_message(struct qat_compression_req *qat_req,
148 				   struct qat_compression_instance *inst,
149 				   struct crypto_async_request *base)
150 {
151 	struct qat_alg_req *alg_req = &qat_req->alg_req;
152 
153 	alg_req->fw_req = (u32 *)&qat_req->req;
154 	alg_req->tx_ring = inst->dc_tx;
155 	alg_req->base = base;
156 	alg_req->backlog = &inst->backlog;
157 
158 	return qat_alg_send_message(alg_req);
159 }
160 
161 static void qat_comp_generic_callback(struct qat_compression_req *qat_req,
162 				      void *resp)
163 {
164 	struct acomp_req *areq = qat_req->acompress_req;
165 	struct qat_compression_ctx *ctx = qat_req->qat_compression_ctx;
166 	struct adf_accel_dev *accel_dev = ctx->inst->accel_dev;
167 	struct crypto_acomp *tfm = crypto_acomp_reqtfm(areq);
168 	struct qat_compression_instance *inst = ctx->inst;
169 	struct qat_callback_params params = { };
170 	int consumed, produced;
171 	s8 cmp_err, xlt_err;
172 	int res = -EBADMSG;
173 	int status;
174 	u8 cnv;
175 
176 	status = qat_comp_get_cmp_status(resp);
177 	status |= qat_comp_get_xlt_status(resp);
178 	cmp_err = qat_comp_get_cmp_err(resp);
179 	xlt_err = qat_comp_get_xlt_err(resp);
180 
181 	consumed = qat_comp_get_consumed_ctr(resp);
182 	produced = qat_comp_get_produced_ctr(resp);
183 
184 	/* Cache parameters for algorithm specific callback */
185 	params.produced = produced;
186 	params.dlen = areq->dlen;
187 
188 	dev_dbg(&GET_DEV(accel_dev),
189 		"[%s][%s][%s] slen = %8d dlen = %8d consumed = %8d produced = %8d cmp_err = %3d xlt_err = %3d",
190 		crypto_tfm_alg_driver_name(crypto_acomp_tfm(tfm)),
191 		qat_req->dir == COMPRESSION ? "comp  " : "decomp",
192 		status ? "ERR" : "OK ",
193 		areq->slen, areq->dlen, consumed, produced, cmp_err, xlt_err);
194 
195 	if (unlikely(status != ICP_QAT_FW_COMN_STATUS_FLAG_OK)) {
196 		if (cmp_err == ERR_CODE_OVERFLOW_ERROR || xlt_err == ERR_CODE_OVERFLOW_ERROR)
197 			res = -E2BIG;
198 
199 		areq->dlen = 0;
200 		goto end;
201 	}
202 
203 	if (qat_req->dir == COMPRESSION) {
204 		cnv = qat_comp_get_cmp_cnv_flag(resp);
205 		if (unlikely(!cnv)) {
206 			dev_err(&GET_DEV(accel_dev),
207 				"Verified compression not supported\n");
208 			areq->dlen = 0;
209 			goto end;
210 		}
211 
212 		if (unlikely(produced > qat_req->actual_dlen)) {
213 			memset(inst->dc_data->ovf_buff, 0,
214 			       inst->dc_data->ovf_buff_sz);
215 			dev_dbg(&GET_DEV(accel_dev),
216 				"Actual buffer overflow: produced=%d, dlen=%d\n",
217 				produced, qat_req->actual_dlen);
218 
219 			res = -E2BIG;
220 			areq->dlen = 0;
221 			goto end;
222 		}
223 
224 		params.plain = !!qat_comp_get_cmp_uncomp_flag(resp);
225 	}
226 
227 	res = 0;
228 	areq->dlen = produced;
229 
230 	if (ctx->qat_comp_callback)
231 		res = ctx->qat_comp_callback(qat_req, resp, &params);
232 
233 end:
234 	qat_bl_free_bufl(accel_dev, &qat_req->buf);
235 	acomp_request_complete(areq, res);
236 	qat_alg_send_backlog(qat_req->alg_req.backlog);
237 }
238 
239 void qat_comp_alg_callback(void *resp)
240 {
241 	struct qat_compression_req *qat_req =
242 			(void *)(__force long)qat_comp_get_opaque(resp);
243 
244 	qat_comp_generic_callback(qat_req, resp);
245 }
246 
247 static int qat_comp_alg_init_tfm(struct crypto_acomp *acomp_tfm, int alg)
248 {
249 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
250 	struct crypto_tfm *tfm = crypto_acomp_tfm(acomp_tfm);
251 	struct qat_compression_instance *inst;
252 	int node, ret;
253 
254 	if (tfm->node == NUMA_NO_NODE)
255 		node = numa_node_id();
256 	else
257 		node = tfm->node;
258 
259 	memset(ctx, 0, sizeof(*ctx));
260 	inst = qat_compression_get_instance_node(node, alg);
261 	if (!inst)
262 		return -EINVAL;
263 	ctx->inst = inst;
264 
265 	ret = qat_comp_build_ctx(inst->accel_dev, ctx->comp_ctx, alg);
266 	if (ret) {
267 		qat_compression_put_instance(inst);
268 		memset(ctx, 0, sizeof(*ctx));
269 	}
270 
271 	return ret;
272 }
273 
274 static int qat_comp_alg_deflate_init_tfm(struct crypto_acomp *acomp_tfm)
275 {
276 	return qat_comp_alg_init_tfm(acomp_tfm, QAT_DEFLATE);
277 }
278 
279 static void qat_comp_alg_exit_tfm(struct crypto_acomp *acomp_tfm)
280 {
281 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
282 
283 	qat_compression_put_instance(ctx->inst);
284 	memset(ctx, 0, sizeof(*ctx));
285 }
286 
287 static int qat_comp_alg_compress_decompress(struct acomp_req *areq, enum direction dir,
288 					    unsigned int shdr, unsigned int sftr,
289 					    unsigned int dhdr, unsigned int dftr)
290 {
291 	struct qat_compression_req *qat_req = acomp_request_ctx(areq);
292 	struct crypto_acomp *acomp_tfm = crypto_acomp_reqtfm(areq);
293 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
294 	struct qat_compression_instance *inst = ctx->inst;
295 	gfp_t f = qat_algs_alloc_flags(&areq->base);
296 	struct qat_sgl_to_bufl_params params = {0};
297 	int slen = areq->slen - shdr - sftr;
298 	int dlen = areq->dlen - dhdr - dftr;
299 	dma_addr_t sfbuf, dfbuf;
300 	u8 *req = qat_req->req;
301 	size_t ovf_buff_sz;
302 	int ret;
303 
304 	params.sskip = shdr;
305 	params.dskip = dhdr;
306 
307 	if (!areq->src || !slen)
308 		return -EINVAL;
309 
310 	if (!areq->dst || !dlen)
311 		return -EINVAL;
312 
313 	if (dir == COMPRESSION) {
314 		params.extra_dst_buff = inst->dc_data->ovf_buff_p;
315 		ovf_buff_sz = inst->dc_data->ovf_buff_sz;
316 		params.sz_extra_dst_buff = ovf_buff_sz;
317 	}
318 
319 	ret = qat_bl_sgl_to_bufl(ctx->inst->accel_dev, areq->src, areq->dst,
320 				 &qat_req->buf, &params, f);
321 	if (unlikely(ret))
322 		return ret;
323 
324 	sfbuf = qat_req->buf.blp;
325 	dfbuf = qat_req->buf.bloutp;
326 	qat_req->qat_compression_ctx = ctx;
327 	qat_req->acompress_req = areq;
328 	qat_req->dir = dir;
329 
330 	if (dir == COMPRESSION) {
331 		qat_req->actual_dlen = dlen;
332 		dlen += ovf_buff_sz;
333 		qat_comp_create_compression_req(ctx->comp_ctx, req,
334 						(u64)(__force long)sfbuf, slen,
335 						(u64)(__force long)dfbuf, dlen,
336 						(u64)(__force long)qat_req);
337 	} else {
338 		qat_comp_create_decompression_req(ctx->comp_ctx, req,
339 						  (u64)(__force long)sfbuf, slen,
340 						  (u64)(__force long)dfbuf, dlen,
341 						  (u64)(__force long)qat_req);
342 	}
343 
344 	ret = qat_alg_send_dc_message(qat_req, inst, &areq->base);
345 	if (ret == -ENOSPC)
346 		qat_bl_free_bufl(inst->accel_dev, &qat_req->buf);
347 
348 	return ret;
349 }
350 
351 static int qat_comp_alg_compress(struct acomp_req *req)
352 {
353 	return qat_comp_alg_compress_decompress(req, COMPRESSION, 0, 0, 0, 0);
354 }
355 
356 static int qat_comp_alg_decompress(struct acomp_req *req)
357 {
358 	return qat_comp_alg_compress_decompress(req, DECOMPRESSION, 0, 0, 0, 0);
359 }
360 
361 static int qat_comp_alg_zstd_decompress(struct acomp_req *req)
362 {
363 	struct crypto_acomp *acomp_tfm = crypto_acomp_reqtfm(req);
364 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
365 	struct acomp_req *nreq = acomp_request_ctx(req);
366 	zstd_frame_header header;
367 	void *buffer;
368 	size_t zret;
369 	int ret;
370 
371 	buffer = kmap_local_page(sg_page(req->src)) + req->src->offset;
372 	zret = zstd_get_frame_header(&header, buffer, req->src->length);
373 	kunmap_local(buffer);
374 	if (zret) {
375 		dev_err(&GET_DEV(ctx->inst->accel_dev),
376 			"ZSTD-compressed data has an incomplete frame header\n");
377 		return -EINVAL;
378 	}
379 
380 	if (header.windowSize > QAT_ZSTD_MAX_BLOCK_SIZE ||
381 	    header.frameContentSize >= QAT_ZSTD_MAX_CONTENT_SIZE) {
382 		dev_dbg(&GET_DEV(ctx->inst->accel_dev), "Window size=0x%llx\n",
383 			header.windowSize);
384 
385 		memcpy(nreq, req, sizeof(*req));
386 		acomp_request_set_tfm(nreq, ctx->ftfm);
387 
388 		ret = crypto_acomp_decompress(nreq);
389 		req->dlen = nreq->dlen;
390 
391 		return ret;
392 	}
393 
394 	return qat_comp_alg_compress_decompress(req, DECOMPRESSION, 0, 0, 0, 0);
395 }
396 
397 static int qat_comp_lz4s_zstd_callback(struct qat_compression_req *qat_req, void *resp,
398 				       struct qat_callback_params *params)
399 {
400 	struct qat_compression_ctx *qat_ctx = qat_req->qat_compression_ctx;
401 	struct acomp_req *areq = qat_req->acompress_req;
402 	struct qat_zstd_scratch *scratch;
403 	struct crypto_acomp_stream *s;
404 	unsigned int lit_len = 0;
405 	ZSTD_Sequence *out_seqs;
406 	void *lz4s, *zstd;
407 	size_t comp_size;
408 	ZSTD_CCtx *ctx;
409 	void *literals;
410 	int seq_count;
411 	int ret = 0;
412 
413 	if (params->produced + QAT_ZSTD_LIT_COPY_LEN > QAT_ZSTD_SCRATCH_SIZE) {
414 		dev_dbg(&GET_DEV(qat_ctx->inst->accel_dev),
415 			"LZ4s-ZSTD: produced size (%u) + COPY_SIZE > QAT_ZSTD_SCRATCH_SIZE (%u)\n",
416 			params->produced, QAT_ZSTD_SCRATCH_SIZE);
417 		areq->dlen = 0;
418 		return -E2BIG;
419 	}
420 
421 	s = crypto_acomp_lock_stream_bh(&qat_zstd_streams);
422 	scratch = s->ctx;
423 
424 	lz4s = scratch->lz4s;
425 	zstd = lz4s;  /* Output buffer is same as lz4s */
426 	out_seqs = scratch->out_seqs;
427 	ctx = scratch->ctx;
428 	literals = scratch->literals;
429 
430 	if (likely(!params->plain)) {
431 		if (likely(sg_nents(areq->dst) == 1)) {
432 			zstd = sg_virt(areq->dst);
433 			lz4s = zstd;
434 		} else {
435 			memcpy_from_sglist(lz4s, areq->dst, 0, params->produced);
436 		}
437 
438 		seq_count = qat_alg_dec_lz4s(out_seqs, QAT_MAX_SEQUENCES, lz4s,
439 					     params->produced, literals, &lit_len);
440 		if (seq_count < 0) {
441 			ret = seq_count;
442 			comp_size = 0;
443 			goto out;
444 		}
445 	} else {
446 		out_seqs[0].litLength = areq->slen;
447 		out_seqs[0].offset = 0;
448 		out_seqs[0].matchLength = 0;
449 
450 		seq_count = 1;
451 	}
452 
453 	comp_size = zstd_compress_sequences_and_literals(ctx, zstd, params->dlen,
454 							 out_seqs, seq_count,
455 							 literals, lit_len,
456 							 QAT_ZSTD_SCRATCH_SIZE,
457 							 areq->slen);
458 	if (zstd_is_error(comp_size)) {
459 		if (comp_size == ZSTD_error_cannotProduce_uncompressedBlock)
460 			ret = -E2BIG;
461 		else
462 			ret = -EOPNOTSUPP;
463 
464 		comp_size = 0;
465 		goto out;
466 	}
467 
468 	if (comp_size > params->dlen) {
469 		dev_dbg(&GET_DEV(qat_ctx->inst->accel_dev),
470 			"LZ4s-ZSTD: compressed_size (%u) > output buffer size (%u)\n",
471 			(unsigned int)comp_size, params->dlen);
472 		ret = -EOVERFLOW;
473 		goto out;
474 	}
475 
476 	if (unlikely(sg_nents(areq->dst) != 1))
477 		memcpy_to_sglist(areq->dst, 0, zstd, comp_size);
478 
479 out:
480 	areq->dlen = comp_size;
481 	crypto_acomp_unlock_stream_bh(s);
482 
483 	return ret;
484 }
485 
486 static int qat_comp_alg_lz4s_zstd_init_tfm(struct crypto_acomp *acomp_tfm)
487 {
488 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
489 	struct crypto_tfm *tfm = crypto_acomp_tfm(acomp_tfm);
490 	int reqsize;
491 	int ret;
492 
493 	/* qat_comp_alg_init_tfm() wipes out the ctx */
494 	ret = qat_comp_alg_init_tfm(acomp_tfm, QAT_LZ4S);
495 	if (ret)
496 		return ret;
497 
498 	ctx->ftfm = crypto_alloc_acomp_node("zstd", 0, CRYPTO_ALG_NEED_FALLBACK,
499 					    tfm->node);
500 	if (IS_ERR(ctx->ftfm)) {
501 		qat_comp_alg_exit_tfm(acomp_tfm);
502 		return PTR_ERR(ctx->ftfm);
503 	}
504 
505 	reqsize = max(sizeof(struct qat_compression_req),
506 		      sizeof(struct acomp_req) + crypto_acomp_reqsize(ctx->ftfm));
507 
508 	acomp_tfm->reqsize = reqsize;
509 
510 	ctx->qat_comp_callback = qat_comp_lz4s_zstd_callback;
511 
512 	return 0;
513 }
514 
515 static int qat_comp_alg_zstd_init_tfm(struct crypto_acomp *acomp_tfm)
516 {
517 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
518 	struct crypto_tfm *tfm = crypto_acomp_tfm(acomp_tfm);
519 	int reqsize;
520 	int ret;
521 
522 	/* qat_comp_alg_init_tfm() wipes out the ctx */
523 	ret = qat_comp_alg_init_tfm(acomp_tfm, QAT_ZSTD);
524 	if (ret)
525 		return ret;
526 
527 	ctx->ftfm = crypto_alloc_acomp_node("zstd", 0, CRYPTO_ALG_NEED_FALLBACK,
528 					    tfm->node);
529 	if (IS_ERR(ctx->ftfm)) {
530 		qat_comp_alg_exit_tfm(acomp_tfm);
531 		return PTR_ERR(ctx->ftfm);
532 	}
533 
534 	reqsize = max(sizeof(struct qat_compression_req),
535 		      sizeof(struct acomp_req) + crypto_acomp_reqsize(ctx->ftfm));
536 
537 	acomp_tfm->reqsize = reqsize;
538 
539 	return 0;
540 }
541 
542 static void qat_comp_alg_zstd_exit_tfm(struct crypto_acomp *acomp_tfm)
543 {
544 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
545 
546 	if (ctx->ftfm)
547 		crypto_free_acomp(ctx->ftfm);
548 
549 	qat_comp_alg_exit_tfm(acomp_tfm);
550 }
551 
552 static int qat_comp_alg_lz4s_zstd_compress(struct acomp_req *req)
553 {
554 	struct crypto_acomp *acomp_tfm = crypto_acomp_reqtfm(req);
555 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
556 	struct acomp_req *nreq = acomp_request_ctx(req);
557 	int ret;
558 
559 	if (req->slen >= QAT_LZ4S_MIN_INPUT_SIZE && req->dlen >= QAT_LZ4S_MIN_INPUT_SIZE &&
560 	    req->slen <= QAT_LZ4S_MAX_OUTPUT_SIZE && req->dlen <= QAT_LZ4S_MAX_OUTPUT_SIZE)
561 		return qat_comp_alg_compress(req);
562 
563 	memcpy(nreq, req, sizeof(*req));
564 	acomp_request_set_tfm(nreq, ctx->ftfm);
565 
566 	ret = crypto_acomp_compress(nreq);
567 	req->dlen = nreq->dlen;
568 
569 	return ret;
570 }
571 
572 static int qat_comp_alg_sw_decompress(struct acomp_req *req)
573 {
574 	struct crypto_acomp *acomp_tfm = crypto_acomp_reqtfm(req);
575 	struct qat_compression_ctx *ctx = acomp_tfm_ctx(acomp_tfm);
576 	struct acomp_req *nreq = acomp_request_ctx(req);
577 	int ret;
578 
579 	memcpy(nreq, req, sizeof(*req));
580 	acomp_request_set_tfm(nreq, ctx->ftfm);
581 
582 	ret = crypto_acomp_decompress(nreq);
583 	req->dlen = nreq->dlen;
584 
585 	return ret;
586 }
587 
588 static struct acomp_alg qat_acomp_deflate[] = { {
589 	.base = {
590 		.cra_name = "deflate",
591 		.cra_driver_name = "qat_deflate",
592 		.cra_priority = 4001,
593 		.cra_flags = CRYPTO_ALG_ASYNC | CRYPTO_ALG_ALLOCATES_MEMORY,
594 		.cra_ctxsize = sizeof(struct qat_compression_ctx),
595 		.cra_reqsize = sizeof(struct qat_compression_req),
596 		.cra_module = THIS_MODULE,
597 	},
598 	.init = qat_comp_alg_deflate_init_tfm,
599 	.exit = qat_comp_alg_exit_tfm,
600 	.compress = qat_comp_alg_compress,
601 	.decompress = qat_comp_alg_decompress,
602 }};
603 
604 static struct acomp_alg qat_acomp_zstd_lz4s = {
605 	.base = {
606 		.cra_name = "zstd",
607 		.cra_driver_name = "qat_zstd",
608 		.cra_priority = 4001,
609 		.cra_flags = CRYPTO_ALG_ASYNC | CRYPTO_ALG_ALLOCATES_MEMORY |
610 			     CRYPTO_ALG_NEED_FALLBACK,
611 		.cra_reqsize = sizeof(struct qat_compression_req),
612 		.cra_ctxsize = sizeof(struct qat_compression_ctx),
613 		.cra_module = THIS_MODULE,
614 	},
615 	.init = qat_comp_alg_lz4s_zstd_init_tfm,
616 	.exit = qat_comp_alg_zstd_exit_tfm,
617 	.compress = qat_comp_alg_lz4s_zstd_compress,
618 	.decompress = qat_comp_alg_sw_decompress,
619 };
620 
621 static struct acomp_alg qat_acomp_zstd_native = {
622 	.base = {
623 		.cra_name = "zstd",
624 		.cra_driver_name = "qat_zstd",
625 		.cra_priority = 4001,
626 		.cra_flags = CRYPTO_ALG_ASYNC | CRYPTO_ALG_ALLOCATES_MEMORY |
627 			     CRYPTO_ALG_NEED_FALLBACK,
628 		.cra_reqsize = sizeof(struct qat_compression_req),
629 		.cra_ctxsize = sizeof(struct qat_compression_ctx),
630 		.cra_module = THIS_MODULE,
631 	},
632 	.init = qat_comp_alg_zstd_init_tfm,
633 	.exit = qat_comp_alg_zstd_exit_tfm,
634 	.compress = qat_comp_alg_compress,
635 	.decompress = qat_comp_alg_zstd_decompress,
636 };
637 
638 static int qat_comp_algs_register_deflate(void)
639 {
640 	int ret = 0;
641 
642 	mutex_lock(&algs_lock);
643 	if (++active_devs_deflate == 1) {
644 		ret = crypto_register_acomps(qat_acomp_deflate,
645 					     ARRAY_SIZE(qat_acomp_deflate));
646 		if (ret)
647 			active_devs_deflate--;
648 	}
649 	mutex_unlock(&algs_lock);
650 
651 	return ret;
652 }
653 
654 static void qat_comp_algs_unregister_deflate(void)
655 {
656 	mutex_lock(&algs_lock);
657 	if (--active_devs_deflate == 0)
658 		crypto_unregister_acomps(qat_acomp_deflate, ARRAY_SIZE(qat_acomp_deflate));
659 	mutex_unlock(&algs_lock);
660 }
661 
662 static int qat_comp_algs_register_lz4s(void)
663 {
664 	int ret = 0;
665 
666 	mutex_lock(&algs_lock);
667 	if (++active_devs_lz4s == 1) {
668 		ret = crypto_acomp_alloc_streams(&qat_zstd_streams);
669 		if (ret) {
670 			active_devs_lz4s--;
671 			goto unlock;
672 		}
673 
674 		ret = crypto_register_acomp(&qat_acomp_zstd_lz4s);
675 		if (ret) {
676 			crypto_acomp_free_streams(&qat_zstd_streams);
677 			active_devs_lz4s--;
678 		}
679 	}
680 unlock:
681 	mutex_unlock(&algs_lock);
682 
683 	return ret;
684 }
685 
686 static void qat_comp_algs_unregister_lz4s(void)
687 {
688 	mutex_lock(&algs_lock);
689 	if (--active_devs_lz4s == 0) {
690 		crypto_unregister_acomp(&qat_acomp_zstd_lz4s);
691 		crypto_acomp_free_streams(&qat_zstd_streams);
692 	}
693 	mutex_unlock(&algs_lock);
694 }
695 
696 static int qat_comp_algs_register_zstd(void)
697 {
698 	int ret = 0;
699 
700 	mutex_lock(&algs_lock);
701 	if (++active_devs_zstd == 1) {
702 		ret = crypto_register_acomp(&qat_acomp_zstd_native);
703 		if (ret)
704 			active_devs_zstd--;
705 	}
706 	mutex_unlock(&algs_lock);
707 
708 	return ret;
709 }
710 
711 static void qat_comp_algs_unregister_zstd(void)
712 {
713 	mutex_lock(&algs_lock);
714 	if (--active_devs_zstd == 0)
715 		crypto_unregister_acomp(&qat_acomp_zstd_native);
716 	mutex_unlock(&algs_lock);
717 }
718 
719 int qat_comp_algs_register(u32 caps)
720 {
721 	int ret;
722 
723 	ret = qat_comp_algs_register_deflate();
724 	if (ret)
725 		return ret;
726 
727 	if (caps & ADF_ACCEL_CAPABILITIES_EXT_ZSTD_LZ4S) {
728 		ret = qat_comp_algs_register_lz4s();
729 		if (ret)
730 			goto err_unregister_deflate;
731 	}
732 
733 	if (caps & ADF_ACCEL_CAPABILITIES_EXT_ZSTD) {
734 		ret = qat_comp_algs_register_zstd();
735 		if (ret)
736 			goto err_unregister_lz4s;
737 	}
738 
739 	return ret;
740 
741 err_unregister_lz4s:
742 	if (caps & ADF_ACCEL_CAPABILITIES_EXT_ZSTD_LZ4S)
743 		qat_comp_algs_unregister_lz4s();
744 err_unregister_deflate:
745 	qat_comp_algs_unregister_deflate();
746 
747 	return ret;
748 }
749 
750 void qat_comp_algs_unregister(u32 caps)
751 {
752 	qat_comp_algs_unregister_deflate();
753 
754 	if (caps & ADF_ACCEL_CAPABILITIES_EXT_ZSTD_LZ4S)
755 		qat_comp_algs_unregister_lz4s();
756 
757 	if (caps & ADF_ACCEL_CAPABILITIES_EXT_ZSTD)
758 		qat_comp_algs_unregister_zstd();
759 }
760