xref: /linux/drivers/crypto/intel/qat/qat_common/qat_asym_algs.c (revision 6e7fd890f1d6ac83805409e9c346240de2705584)
1 // SPDX-License-Identifier: (BSD-3-Clause OR GPL-2.0-only)
2 /* Copyright(c) 2014 - 2020 Intel Corporation */
3 #include <linux/module.h>
4 #include <crypto/internal/rsa.h>
5 #include <crypto/internal/akcipher.h>
6 #include <crypto/akcipher.h>
7 #include <crypto/kpp.h>
8 #include <crypto/internal/kpp.h>
9 #include <crypto/dh.h>
10 #include <linux/dma-mapping.h>
11 #include <linux/fips.h>
12 #include <crypto/scatterwalk.h>
13 #include "icp_qat_fw_pke.h"
14 #include "adf_accel_devices.h"
15 #include "qat_algs_send.h"
16 #include "adf_transport.h"
17 #include "adf_common_drv.h"
18 #include "qat_crypto.h"
19 
20 static DEFINE_MUTEX(algs_lock);
21 static unsigned int active_devs;
22 
23 struct qat_rsa_input_params {
24 	union {
25 		struct {
26 			dma_addr_t m;
27 			dma_addr_t e;
28 			dma_addr_t n;
29 		} enc;
30 		struct {
31 			dma_addr_t c;
32 			dma_addr_t d;
33 			dma_addr_t n;
34 		} dec;
35 		struct {
36 			dma_addr_t c;
37 			dma_addr_t p;
38 			dma_addr_t q;
39 			dma_addr_t dp;
40 			dma_addr_t dq;
41 			dma_addr_t qinv;
42 		} dec_crt;
43 		u64 in_tab[8];
44 	};
45 } __packed __aligned(64);
46 
47 struct qat_rsa_output_params {
48 	union {
49 		struct {
50 			dma_addr_t c;
51 		} enc;
52 		struct {
53 			dma_addr_t m;
54 		} dec;
55 		u64 out_tab[8];
56 	};
57 } __packed __aligned(64);
58 
59 struct qat_rsa_ctx {
60 	char *n;
61 	char *e;
62 	char *d;
63 	char *p;
64 	char *q;
65 	char *dp;
66 	char *dq;
67 	char *qinv;
68 	dma_addr_t dma_n;
69 	dma_addr_t dma_e;
70 	dma_addr_t dma_d;
71 	dma_addr_t dma_p;
72 	dma_addr_t dma_q;
73 	dma_addr_t dma_dp;
74 	dma_addr_t dma_dq;
75 	dma_addr_t dma_qinv;
76 	unsigned int key_sz;
77 	bool crt_mode;
78 	struct qat_crypto_instance *inst;
79 } __packed __aligned(64);
80 
81 struct qat_dh_input_params {
82 	union {
83 		struct {
84 			dma_addr_t b;
85 			dma_addr_t xa;
86 			dma_addr_t p;
87 		} in;
88 		struct {
89 			dma_addr_t xa;
90 			dma_addr_t p;
91 		} in_g2;
92 		u64 in_tab[8];
93 	};
94 } __packed __aligned(64);
95 
96 struct qat_dh_output_params {
97 	union {
98 		dma_addr_t r;
99 		u64 out_tab[8];
100 	};
101 } __packed __aligned(64);
102 
103 struct qat_dh_ctx {
104 	char *g;
105 	char *xa;
106 	char *p;
107 	dma_addr_t dma_g;
108 	dma_addr_t dma_xa;
109 	dma_addr_t dma_p;
110 	unsigned int p_size;
111 	bool g2;
112 	struct qat_crypto_instance *inst;
113 	struct crypto_kpp *ftfm;
114 	bool fallback;
115 } __packed __aligned(64);
116 
117 struct qat_asym_request {
118 	union {
119 		struct qat_rsa_input_params rsa;
120 		struct qat_dh_input_params dh;
121 	} in;
122 	union {
123 		struct qat_rsa_output_params rsa;
124 		struct qat_dh_output_params dh;
125 	} out;
126 	dma_addr_t phy_in;
127 	dma_addr_t phy_out;
128 	char *src_align;
129 	char *dst_align;
130 	struct icp_qat_fw_pke_request req;
131 	union {
132 		struct qat_rsa_ctx *rsa;
133 		struct qat_dh_ctx *dh;
134 	} ctx;
135 	union {
136 		struct akcipher_request *rsa;
137 		struct kpp_request *dh;
138 	} areq;
139 	int err;
140 	void (*cb)(struct icp_qat_fw_pke_resp *resp);
141 	struct qat_alg_req alg_req;
142 } __aligned(64);
143 
144 static int qat_alg_send_asym_message(struct qat_asym_request *qat_req,
145 				     struct qat_crypto_instance *inst,
146 				     struct crypto_async_request *base)
147 {
148 	struct qat_alg_req *alg_req = &qat_req->alg_req;
149 
150 	alg_req->fw_req = (u32 *)&qat_req->req;
151 	alg_req->tx_ring = inst->pke_tx;
152 	alg_req->base = base;
153 	alg_req->backlog = &inst->backlog;
154 
155 	return qat_alg_send_message(alg_req);
156 }
157 
158 static void qat_dh_cb(struct icp_qat_fw_pke_resp *resp)
159 {
160 	struct qat_asym_request *req = (void *)(__force long)resp->opaque;
161 	struct kpp_request *areq = req->areq.dh;
162 	struct device *dev = &GET_DEV(req->ctx.dh->inst->accel_dev);
163 	int err = ICP_QAT_FW_PKE_RESP_PKE_STAT_GET(
164 				resp->pke_resp_hdr.comn_resp_flags);
165 
166 	err = (err == ICP_QAT_FW_COMN_STATUS_FLAG_OK) ? 0 : -EINVAL;
167 
168 	if (areq->src) {
169 		dma_unmap_single(dev, req->in.dh.in.b, req->ctx.dh->p_size,
170 				 DMA_TO_DEVICE);
171 		kfree_sensitive(req->src_align);
172 	}
173 
174 	areq->dst_len = req->ctx.dh->p_size;
175 	dma_unmap_single(dev, req->out.dh.r, req->ctx.dh->p_size,
176 			 DMA_FROM_DEVICE);
177 	if (req->dst_align) {
178 		scatterwalk_map_and_copy(req->dst_align, areq->dst, 0,
179 					 areq->dst_len, 1);
180 		kfree_sensitive(req->dst_align);
181 	}
182 
183 	dma_unmap_single(dev, req->phy_in, sizeof(struct qat_dh_input_params),
184 			 DMA_TO_DEVICE);
185 	dma_unmap_single(dev, req->phy_out,
186 			 sizeof(struct qat_dh_output_params),
187 			 DMA_TO_DEVICE);
188 
189 	kpp_request_complete(areq, err);
190 }
191 
192 #define PKE_DH_1536 0x390c1a49
193 #define PKE_DH_G2_1536 0x2e0b1a3e
194 #define PKE_DH_2048 0x4d0c1a60
195 #define PKE_DH_G2_2048 0x3e0b1a55
196 #define PKE_DH_3072 0x510c1a77
197 #define PKE_DH_G2_3072 0x3a0b1a6c
198 #define PKE_DH_4096 0x690c1a8e
199 #define PKE_DH_G2_4096 0x4a0b1a83
200 
201 static unsigned long qat_dh_fn_id(unsigned int len, bool g2)
202 {
203 	unsigned int bitslen = len << 3;
204 
205 	switch (bitslen) {
206 	case 1536:
207 		return g2 ? PKE_DH_G2_1536 : PKE_DH_1536;
208 	case 2048:
209 		return g2 ? PKE_DH_G2_2048 : PKE_DH_2048;
210 	case 3072:
211 		return g2 ? PKE_DH_G2_3072 : PKE_DH_3072;
212 	case 4096:
213 		return g2 ? PKE_DH_G2_4096 : PKE_DH_4096;
214 	default:
215 		return 0;
216 	}
217 }
218 
219 static int qat_dh_compute_value(struct kpp_request *req)
220 {
221 	struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
222 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
223 	struct qat_crypto_instance *inst = ctx->inst;
224 	struct device *dev = &GET_DEV(inst->accel_dev);
225 	struct qat_asym_request *qat_req =
226 			PTR_ALIGN(kpp_request_ctx(req), 64);
227 	struct icp_qat_fw_pke_request *msg = &qat_req->req;
228 	gfp_t flags = qat_algs_alloc_flags(&req->base);
229 	int n_input_params = 0;
230 	u8 *vaddr;
231 	int ret;
232 
233 	if (unlikely(!ctx->xa))
234 		return -EINVAL;
235 
236 	if (req->dst_len < ctx->p_size) {
237 		req->dst_len = ctx->p_size;
238 		return -EOVERFLOW;
239 	}
240 
241 	if (req->src_len > ctx->p_size)
242 		return -EINVAL;
243 
244 	memset(msg, '\0', sizeof(*msg));
245 	ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
246 					  ICP_QAT_FW_COMN_REQ_FLAG_SET);
247 
248 	msg->pke_hdr.cd_pars.func_id = qat_dh_fn_id(ctx->p_size,
249 						    !req->src && ctx->g2);
250 	if (unlikely(!msg->pke_hdr.cd_pars.func_id))
251 		return -EINVAL;
252 
253 	qat_req->cb = qat_dh_cb;
254 	qat_req->ctx.dh = ctx;
255 	qat_req->areq.dh = req;
256 	msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
257 	msg->pke_hdr.comn_req_flags =
258 		ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
259 					    QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
260 
261 	/*
262 	 * If no source is provided use g as base
263 	 */
264 	if (req->src) {
265 		qat_req->in.dh.in.xa = ctx->dma_xa;
266 		qat_req->in.dh.in.p = ctx->dma_p;
267 		n_input_params = 3;
268 	} else {
269 		if (ctx->g2) {
270 			qat_req->in.dh.in_g2.xa = ctx->dma_xa;
271 			qat_req->in.dh.in_g2.p = ctx->dma_p;
272 			n_input_params = 2;
273 		} else {
274 			qat_req->in.dh.in.b = ctx->dma_g;
275 			qat_req->in.dh.in.xa = ctx->dma_xa;
276 			qat_req->in.dh.in.p = ctx->dma_p;
277 			n_input_params = 3;
278 		}
279 	}
280 
281 	ret = -ENOMEM;
282 	if (req->src) {
283 		/*
284 		 * src can be of any size in valid range, but HW expects it to
285 		 * be the same as modulo p so in case it is different we need
286 		 * to allocate a new buf and copy src data.
287 		 * In other case we just need to map the user provided buffer.
288 		 * Also need to make sure that it is in contiguous buffer.
289 		 */
290 		if (sg_is_last(req->src) && req->src_len == ctx->p_size) {
291 			qat_req->src_align = NULL;
292 			vaddr = sg_virt(req->src);
293 		} else {
294 			int shift = ctx->p_size - req->src_len;
295 
296 			qat_req->src_align = kzalloc(ctx->p_size, flags);
297 			if (unlikely(!qat_req->src_align))
298 				return ret;
299 
300 			scatterwalk_map_and_copy(qat_req->src_align + shift,
301 						 req->src, 0, req->src_len, 0);
302 
303 			vaddr = qat_req->src_align;
304 		}
305 
306 		qat_req->in.dh.in.b = dma_map_single(dev, vaddr, ctx->p_size,
307 						     DMA_TO_DEVICE);
308 		if (unlikely(dma_mapping_error(dev, qat_req->in.dh.in.b)))
309 			goto unmap_src;
310 	}
311 	/*
312 	 * dst can be of any size in valid range, but HW expects it to be the
313 	 * same as modulo m so in case it is different we need to allocate a
314 	 * new buf and copy src data.
315 	 * In other case we just need to map the user provided buffer.
316 	 * Also need to make sure that it is in contiguous buffer.
317 	 */
318 	if (sg_is_last(req->dst) && req->dst_len == ctx->p_size) {
319 		qat_req->dst_align = NULL;
320 		vaddr = sg_virt(req->dst);
321 	} else {
322 		qat_req->dst_align = kzalloc(ctx->p_size, flags);
323 		if (unlikely(!qat_req->dst_align))
324 			goto unmap_src;
325 
326 		vaddr = qat_req->dst_align;
327 	}
328 	qat_req->out.dh.r = dma_map_single(dev, vaddr, ctx->p_size,
329 					   DMA_FROM_DEVICE);
330 	if (unlikely(dma_mapping_error(dev, qat_req->out.dh.r)))
331 		goto unmap_dst;
332 
333 	qat_req->in.dh.in_tab[n_input_params] = 0;
334 	qat_req->out.dh.out_tab[1] = 0;
335 	/* Mapping in.in.b or in.in_g2.xa is the same */
336 	qat_req->phy_in = dma_map_single(dev, &qat_req->in.dh,
337 					 sizeof(struct qat_dh_input_params),
338 					 DMA_TO_DEVICE);
339 	if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
340 		goto unmap_dst;
341 
342 	qat_req->phy_out = dma_map_single(dev, &qat_req->out.dh,
343 					  sizeof(struct qat_dh_output_params),
344 					  DMA_TO_DEVICE);
345 	if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
346 		goto unmap_in_params;
347 
348 	msg->pke_mid.src_data_addr = qat_req->phy_in;
349 	msg->pke_mid.dest_data_addr = qat_req->phy_out;
350 	msg->pke_mid.opaque = (u64)(__force long)qat_req;
351 	msg->input_param_count = n_input_params;
352 	msg->output_param_count = 1;
353 
354 	ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
355 	if (ret == -ENOSPC)
356 		goto unmap_all;
357 
358 	return ret;
359 
360 unmap_all:
361 	if (!dma_mapping_error(dev, qat_req->phy_out))
362 		dma_unmap_single(dev, qat_req->phy_out,
363 				 sizeof(struct qat_dh_output_params),
364 				 DMA_TO_DEVICE);
365 unmap_in_params:
366 	if (!dma_mapping_error(dev, qat_req->phy_in))
367 		dma_unmap_single(dev, qat_req->phy_in,
368 				 sizeof(struct qat_dh_input_params),
369 				 DMA_TO_DEVICE);
370 unmap_dst:
371 	if (!dma_mapping_error(dev, qat_req->out.dh.r))
372 		dma_unmap_single(dev, qat_req->out.dh.r, ctx->p_size,
373 				 DMA_FROM_DEVICE);
374 	kfree_sensitive(qat_req->dst_align);
375 unmap_src:
376 	if (req->src) {
377 		if (!dma_mapping_error(dev, qat_req->in.dh.in.b))
378 			dma_unmap_single(dev, qat_req->in.dh.in.b,
379 					 ctx->p_size,
380 					 DMA_TO_DEVICE);
381 		kfree_sensitive(qat_req->src_align);
382 	}
383 	return ret;
384 }
385 
386 static int qat_dh_generate_public_key(struct kpp_request *req)
387 {
388 	struct kpp_request *nreq = kpp_request_ctx(req);
389 	struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
390 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
391 
392 	if (ctx->fallback) {
393 		memcpy(nreq, req, sizeof(*req));
394 		kpp_request_set_tfm(nreq, ctx->ftfm);
395 		return crypto_kpp_generate_public_key(nreq);
396 	}
397 
398 	return qat_dh_compute_value(req);
399 }
400 
401 static int qat_dh_compute_shared_secret(struct kpp_request *req)
402 {
403 	struct kpp_request *nreq = kpp_request_ctx(req);
404 	struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
405 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
406 
407 	if (ctx->fallback) {
408 		memcpy(nreq, req, sizeof(*req));
409 		kpp_request_set_tfm(nreq, ctx->ftfm);
410 		return crypto_kpp_compute_shared_secret(nreq);
411 	}
412 
413 	return qat_dh_compute_value(req);
414 }
415 
416 static int qat_dh_check_params_length(unsigned int p_len)
417 {
418 	switch (p_len) {
419 	case 1536:
420 	case 2048:
421 	case 3072:
422 	case 4096:
423 		return 0;
424 	}
425 	return -EINVAL;
426 }
427 
428 static int qat_dh_set_params(struct qat_dh_ctx *ctx, struct dh *params)
429 {
430 	struct qat_crypto_instance *inst = ctx->inst;
431 	struct device *dev = &GET_DEV(inst->accel_dev);
432 
433 	ctx->p_size = params->p_size;
434 	ctx->p = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_p, GFP_KERNEL);
435 	if (!ctx->p)
436 		return -ENOMEM;
437 	memcpy(ctx->p, params->p, ctx->p_size);
438 
439 	/* If g equals 2 don't copy it */
440 	if (params->g_size == 1 && *(char *)params->g == 0x02) {
441 		ctx->g2 = true;
442 		return 0;
443 	}
444 
445 	ctx->g = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_g, GFP_KERNEL);
446 	if (!ctx->g)
447 		return -ENOMEM;
448 	memcpy(ctx->g + (ctx->p_size - params->g_size), params->g,
449 	       params->g_size);
450 
451 	return 0;
452 }
453 
454 static void qat_dh_clear_ctx(struct device *dev, struct qat_dh_ctx *ctx)
455 {
456 	if (ctx->g) {
457 		memset(ctx->g, 0, ctx->p_size);
458 		dma_free_coherent(dev, ctx->p_size, ctx->g, ctx->dma_g);
459 		ctx->g = NULL;
460 	}
461 	if (ctx->xa) {
462 		memset(ctx->xa, 0, ctx->p_size);
463 		dma_free_coherent(dev, ctx->p_size, ctx->xa, ctx->dma_xa);
464 		ctx->xa = NULL;
465 	}
466 	if (ctx->p) {
467 		memset(ctx->p, 0, ctx->p_size);
468 		dma_free_coherent(dev, ctx->p_size, ctx->p, ctx->dma_p);
469 		ctx->p = NULL;
470 	}
471 	ctx->p_size = 0;
472 	ctx->g2 = false;
473 }
474 
475 static int qat_dh_set_secret(struct crypto_kpp *tfm, const void *buf,
476 			     unsigned int len)
477 {
478 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
479 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
480 	struct dh params;
481 	int ret;
482 
483 	if (crypto_dh_decode_key(buf, len, &params) < 0)
484 		return -EINVAL;
485 
486 	if (qat_dh_check_params_length(params.p_size << 3)) {
487 		ctx->fallback = true;
488 		return crypto_kpp_set_secret(ctx->ftfm, buf, len);
489 	}
490 
491 	ctx->fallback = false;
492 
493 	/* Free old secret if any */
494 	qat_dh_clear_ctx(dev, ctx);
495 
496 	ret = qat_dh_set_params(ctx, &params);
497 	if (ret < 0)
498 		goto err_clear_ctx;
499 
500 	ctx->xa = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_xa,
501 				     GFP_KERNEL);
502 	if (!ctx->xa) {
503 		ret = -ENOMEM;
504 		goto err_clear_ctx;
505 	}
506 	memcpy(ctx->xa + (ctx->p_size - params.key_size), params.key,
507 	       params.key_size);
508 
509 	return 0;
510 
511 err_clear_ctx:
512 	qat_dh_clear_ctx(dev, ctx);
513 	return ret;
514 }
515 
516 static unsigned int qat_dh_max_size(struct crypto_kpp *tfm)
517 {
518 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
519 
520 	if (ctx->fallback)
521 		return crypto_kpp_maxsize(ctx->ftfm);
522 
523 	return ctx->p_size;
524 }
525 
526 static int qat_dh_init_tfm(struct crypto_kpp *tfm)
527 {
528 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
529 	struct qat_crypto_instance *inst =
530 			qat_crypto_get_instance_node(numa_node_id());
531 	const char *alg = kpp_alg_name(tfm);
532 	unsigned int reqsize;
533 
534 	if (!inst)
535 		return -EINVAL;
536 
537 	ctx->ftfm = crypto_alloc_kpp(alg, 0, CRYPTO_ALG_NEED_FALLBACK);
538 	if (IS_ERR(ctx->ftfm))
539 		return PTR_ERR(ctx->ftfm);
540 
541 	crypto_kpp_set_flags(ctx->ftfm, crypto_kpp_get_flags(tfm));
542 
543 	reqsize = max(sizeof(struct qat_asym_request) + 64,
544 		      sizeof(struct kpp_request) + crypto_kpp_reqsize(ctx->ftfm));
545 
546 	kpp_set_reqsize(tfm, reqsize);
547 
548 	ctx->p_size = 0;
549 	ctx->g2 = false;
550 	ctx->inst = inst;
551 	return 0;
552 }
553 
554 static void qat_dh_exit_tfm(struct crypto_kpp *tfm)
555 {
556 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
557 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
558 
559 	if (ctx->ftfm)
560 		crypto_free_kpp(ctx->ftfm);
561 
562 	qat_dh_clear_ctx(dev, ctx);
563 	qat_crypto_put_instance(ctx->inst);
564 }
565 
566 static void qat_rsa_cb(struct icp_qat_fw_pke_resp *resp)
567 {
568 	struct qat_asym_request *req = (void *)(__force long)resp->opaque;
569 	struct akcipher_request *areq = req->areq.rsa;
570 	struct device *dev = &GET_DEV(req->ctx.rsa->inst->accel_dev);
571 	int err = ICP_QAT_FW_PKE_RESP_PKE_STAT_GET(
572 				resp->pke_resp_hdr.comn_resp_flags);
573 
574 	err = (err == ICP_QAT_FW_COMN_STATUS_FLAG_OK) ? 0 : -EINVAL;
575 
576 	dma_unmap_single(dev, req->in.rsa.enc.m, req->ctx.rsa->key_sz,
577 			 DMA_TO_DEVICE);
578 
579 	kfree_sensitive(req->src_align);
580 
581 	areq->dst_len = req->ctx.rsa->key_sz;
582 	dma_unmap_single(dev, req->out.rsa.enc.c, req->ctx.rsa->key_sz,
583 			 DMA_FROM_DEVICE);
584 	if (req->dst_align) {
585 		scatterwalk_map_and_copy(req->dst_align, areq->dst, 0,
586 					 areq->dst_len, 1);
587 
588 		kfree_sensitive(req->dst_align);
589 	}
590 
591 	dma_unmap_single(dev, req->phy_in, sizeof(struct qat_rsa_input_params),
592 			 DMA_TO_DEVICE);
593 	dma_unmap_single(dev, req->phy_out,
594 			 sizeof(struct qat_rsa_output_params),
595 			 DMA_TO_DEVICE);
596 
597 	akcipher_request_complete(areq, err);
598 }
599 
600 void qat_alg_asym_callback(void *_resp)
601 {
602 	struct icp_qat_fw_pke_resp *resp = _resp;
603 	struct qat_asym_request *areq = (void *)(__force long)resp->opaque;
604 	struct qat_instance_backlog *backlog = areq->alg_req.backlog;
605 
606 	areq->cb(resp);
607 
608 	qat_alg_send_backlog(backlog);
609 }
610 
611 #define PKE_RSA_EP_512 0x1c161b21
612 #define PKE_RSA_EP_1024 0x35111bf7
613 #define PKE_RSA_EP_1536 0x4d111cdc
614 #define PKE_RSA_EP_2048 0x6e111dba
615 #define PKE_RSA_EP_3072 0x7d111ea3
616 #define PKE_RSA_EP_4096 0xa5101f7e
617 
618 static unsigned long qat_rsa_enc_fn_id(unsigned int len)
619 {
620 	unsigned int bitslen = len << 3;
621 
622 	switch (bitslen) {
623 	case 512:
624 		return PKE_RSA_EP_512;
625 	case 1024:
626 		return PKE_RSA_EP_1024;
627 	case 1536:
628 		return PKE_RSA_EP_1536;
629 	case 2048:
630 		return PKE_RSA_EP_2048;
631 	case 3072:
632 		return PKE_RSA_EP_3072;
633 	case 4096:
634 		return PKE_RSA_EP_4096;
635 	default:
636 		return 0;
637 	}
638 }
639 
640 #define PKE_RSA_DP1_512 0x1c161b3c
641 #define PKE_RSA_DP1_1024 0x35111c12
642 #define PKE_RSA_DP1_1536 0x4d111cf7
643 #define PKE_RSA_DP1_2048 0x6e111dda
644 #define PKE_RSA_DP1_3072 0x7d111ebe
645 #define PKE_RSA_DP1_4096 0xa5101f98
646 
647 static unsigned long qat_rsa_dec_fn_id(unsigned int len)
648 {
649 	unsigned int bitslen = len << 3;
650 
651 	switch (bitslen) {
652 	case 512:
653 		return PKE_RSA_DP1_512;
654 	case 1024:
655 		return PKE_RSA_DP1_1024;
656 	case 1536:
657 		return PKE_RSA_DP1_1536;
658 	case 2048:
659 		return PKE_RSA_DP1_2048;
660 	case 3072:
661 		return PKE_RSA_DP1_3072;
662 	case 4096:
663 		return PKE_RSA_DP1_4096;
664 	default:
665 		return 0;
666 	}
667 }
668 
669 #define PKE_RSA_DP2_512 0x1c131b57
670 #define PKE_RSA_DP2_1024 0x26131c2d
671 #define PKE_RSA_DP2_1536 0x45111d12
672 #define PKE_RSA_DP2_2048 0x59121dfa
673 #define PKE_RSA_DP2_3072 0x81121ed9
674 #define PKE_RSA_DP2_4096 0xb1111fb2
675 
676 static unsigned long qat_rsa_dec_fn_id_crt(unsigned int len)
677 {
678 	unsigned int bitslen = len << 3;
679 
680 	switch (bitslen) {
681 	case 512:
682 		return PKE_RSA_DP2_512;
683 	case 1024:
684 		return PKE_RSA_DP2_1024;
685 	case 1536:
686 		return PKE_RSA_DP2_1536;
687 	case 2048:
688 		return PKE_RSA_DP2_2048;
689 	case 3072:
690 		return PKE_RSA_DP2_3072;
691 	case 4096:
692 		return PKE_RSA_DP2_4096;
693 	default:
694 		return 0;
695 	}
696 }
697 
698 static int qat_rsa_enc(struct akcipher_request *req)
699 {
700 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
701 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
702 	struct qat_crypto_instance *inst = ctx->inst;
703 	struct device *dev = &GET_DEV(inst->accel_dev);
704 	struct qat_asym_request *qat_req =
705 			PTR_ALIGN(akcipher_request_ctx(req), 64);
706 	struct icp_qat_fw_pke_request *msg = &qat_req->req;
707 	gfp_t flags = qat_algs_alloc_flags(&req->base);
708 	u8 *vaddr;
709 	int ret;
710 
711 	if (unlikely(!ctx->n || !ctx->e))
712 		return -EINVAL;
713 
714 	if (req->dst_len < ctx->key_sz) {
715 		req->dst_len = ctx->key_sz;
716 		return -EOVERFLOW;
717 	}
718 
719 	if (req->src_len > ctx->key_sz)
720 		return -EINVAL;
721 
722 	memset(msg, '\0', sizeof(*msg));
723 	ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
724 					  ICP_QAT_FW_COMN_REQ_FLAG_SET);
725 	msg->pke_hdr.cd_pars.func_id = qat_rsa_enc_fn_id(ctx->key_sz);
726 	if (unlikely(!msg->pke_hdr.cd_pars.func_id))
727 		return -EINVAL;
728 
729 	qat_req->cb = qat_rsa_cb;
730 	qat_req->ctx.rsa = ctx;
731 	qat_req->areq.rsa = req;
732 	msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
733 	msg->pke_hdr.comn_req_flags =
734 		ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
735 					    QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
736 
737 	qat_req->in.rsa.enc.e = ctx->dma_e;
738 	qat_req->in.rsa.enc.n = ctx->dma_n;
739 	ret = -ENOMEM;
740 
741 	/*
742 	 * src can be of any size in valid range, but HW expects it to be the
743 	 * same as modulo n so in case it is different we need to allocate a
744 	 * new buf and copy src data.
745 	 * In other case we just need to map the user provided buffer.
746 	 * Also need to make sure that it is in contiguous buffer.
747 	 */
748 	if (sg_is_last(req->src) && req->src_len == ctx->key_sz) {
749 		qat_req->src_align = NULL;
750 		vaddr = sg_virt(req->src);
751 	} else {
752 		int shift = ctx->key_sz - req->src_len;
753 
754 		qat_req->src_align = kzalloc(ctx->key_sz, flags);
755 		if (unlikely(!qat_req->src_align))
756 			return ret;
757 
758 		scatterwalk_map_and_copy(qat_req->src_align + shift, req->src,
759 					 0, req->src_len, 0);
760 		vaddr = qat_req->src_align;
761 	}
762 
763 	qat_req->in.rsa.enc.m = dma_map_single(dev, vaddr, ctx->key_sz,
764 					       DMA_TO_DEVICE);
765 	if (unlikely(dma_mapping_error(dev, qat_req->in.rsa.enc.m)))
766 		goto unmap_src;
767 
768 	if (sg_is_last(req->dst) && req->dst_len == ctx->key_sz) {
769 		qat_req->dst_align = NULL;
770 		vaddr = sg_virt(req->dst);
771 	} else {
772 		qat_req->dst_align = kzalloc(ctx->key_sz, flags);
773 		if (unlikely(!qat_req->dst_align))
774 			goto unmap_src;
775 		vaddr = qat_req->dst_align;
776 	}
777 
778 	qat_req->out.rsa.enc.c = dma_map_single(dev, vaddr, ctx->key_sz,
779 						DMA_FROM_DEVICE);
780 	if (unlikely(dma_mapping_error(dev, qat_req->out.rsa.enc.c)))
781 		goto unmap_dst;
782 
783 	qat_req->in.rsa.in_tab[3] = 0;
784 	qat_req->out.rsa.out_tab[1] = 0;
785 	qat_req->phy_in = dma_map_single(dev, &qat_req->in.rsa,
786 					 sizeof(struct qat_rsa_input_params),
787 					 DMA_TO_DEVICE);
788 	if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
789 		goto unmap_dst;
790 
791 	qat_req->phy_out = dma_map_single(dev, &qat_req->out.rsa,
792 					  sizeof(struct qat_rsa_output_params),
793 					  DMA_TO_DEVICE);
794 	if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
795 		goto unmap_in_params;
796 
797 	msg->pke_mid.src_data_addr = qat_req->phy_in;
798 	msg->pke_mid.dest_data_addr = qat_req->phy_out;
799 	msg->pke_mid.opaque = (u64)(__force long)qat_req;
800 	msg->input_param_count = 3;
801 	msg->output_param_count = 1;
802 
803 	ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
804 	if (ret == -ENOSPC)
805 		goto unmap_all;
806 
807 	return ret;
808 
809 unmap_all:
810 	if (!dma_mapping_error(dev, qat_req->phy_out))
811 		dma_unmap_single(dev, qat_req->phy_out,
812 				 sizeof(struct qat_rsa_output_params),
813 				 DMA_TO_DEVICE);
814 unmap_in_params:
815 	if (!dma_mapping_error(dev, qat_req->phy_in))
816 		dma_unmap_single(dev, qat_req->phy_in,
817 				 sizeof(struct qat_rsa_input_params),
818 				 DMA_TO_DEVICE);
819 unmap_dst:
820 	if (!dma_mapping_error(dev, qat_req->out.rsa.enc.c))
821 		dma_unmap_single(dev, qat_req->out.rsa.enc.c,
822 				 ctx->key_sz, DMA_FROM_DEVICE);
823 	kfree_sensitive(qat_req->dst_align);
824 unmap_src:
825 	if (!dma_mapping_error(dev, qat_req->in.rsa.enc.m))
826 		dma_unmap_single(dev, qat_req->in.rsa.enc.m, ctx->key_sz,
827 				 DMA_TO_DEVICE);
828 	kfree_sensitive(qat_req->src_align);
829 	return ret;
830 }
831 
832 static int qat_rsa_dec(struct akcipher_request *req)
833 {
834 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
835 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
836 	struct qat_crypto_instance *inst = ctx->inst;
837 	struct device *dev = &GET_DEV(inst->accel_dev);
838 	struct qat_asym_request *qat_req =
839 			PTR_ALIGN(akcipher_request_ctx(req), 64);
840 	struct icp_qat_fw_pke_request *msg = &qat_req->req;
841 	gfp_t flags = qat_algs_alloc_flags(&req->base);
842 	u8 *vaddr;
843 	int ret;
844 
845 	if (unlikely(!ctx->n || !ctx->d))
846 		return -EINVAL;
847 
848 	if (req->dst_len < ctx->key_sz) {
849 		req->dst_len = ctx->key_sz;
850 		return -EOVERFLOW;
851 	}
852 
853 	if (req->src_len > ctx->key_sz)
854 		return -EINVAL;
855 
856 	memset(msg, '\0', sizeof(*msg));
857 	ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
858 					  ICP_QAT_FW_COMN_REQ_FLAG_SET);
859 	msg->pke_hdr.cd_pars.func_id = ctx->crt_mode ?
860 		qat_rsa_dec_fn_id_crt(ctx->key_sz) :
861 		qat_rsa_dec_fn_id(ctx->key_sz);
862 	if (unlikely(!msg->pke_hdr.cd_pars.func_id))
863 		return -EINVAL;
864 
865 	qat_req->cb = qat_rsa_cb;
866 	qat_req->ctx.rsa = ctx;
867 	qat_req->areq.rsa = req;
868 	msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
869 	msg->pke_hdr.comn_req_flags =
870 		ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
871 					    QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
872 
873 	if (ctx->crt_mode) {
874 		qat_req->in.rsa.dec_crt.p = ctx->dma_p;
875 		qat_req->in.rsa.dec_crt.q = ctx->dma_q;
876 		qat_req->in.rsa.dec_crt.dp = ctx->dma_dp;
877 		qat_req->in.rsa.dec_crt.dq = ctx->dma_dq;
878 		qat_req->in.rsa.dec_crt.qinv = ctx->dma_qinv;
879 	} else {
880 		qat_req->in.rsa.dec.d = ctx->dma_d;
881 		qat_req->in.rsa.dec.n = ctx->dma_n;
882 	}
883 	ret = -ENOMEM;
884 
885 	/*
886 	 * src can be of any size in valid range, but HW expects it to be the
887 	 * same as modulo n so in case it is different we need to allocate a
888 	 * new buf and copy src data.
889 	 * In other case we just need to map the user provided buffer.
890 	 * Also need to make sure that it is in contiguous buffer.
891 	 */
892 	if (sg_is_last(req->src) && req->src_len == ctx->key_sz) {
893 		qat_req->src_align = NULL;
894 		vaddr = sg_virt(req->src);
895 	} else {
896 		int shift = ctx->key_sz - req->src_len;
897 
898 		qat_req->src_align = kzalloc(ctx->key_sz, flags);
899 		if (unlikely(!qat_req->src_align))
900 			return ret;
901 
902 		scatterwalk_map_and_copy(qat_req->src_align + shift, req->src,
903 					 0, req->src_len, 0);
904 		vaddr = qat_req->src_align;
905 	}
906 
907 	qat_req->in.rsa.dec.c = dma_map_single(dev, vaddr, ctx->key_sz,
908 					       DMA_TO_DEVICE);
909 	if (unlikely(dma_mapping_error(dev, qat_req->in.rsa.dec.c)))
910 		goto unmap_src;
911 
912 	if (sg_is_last(req->dst) && req->dst_len == ctx->key_sz) {
913 		qat_req->dst_align = NULL;
914 		vaddr = sg_virt(req->dst);
915 	} else {
916 		qat_req->dst_align = kzalloc(ctx->key_sz, flags);
917 		if (unlikely(!qat_req->dst_align))
918 			goto unmap_src;
919 		vaddr = qat_req->dst_align;
920 	}
921 	qat_req->out.rsa.dec.m = dma_map_single(dev, vaddr, ctx->key_sz,
922 						DMA_FROM_DEVICE);
923 	if (unlikely(dma_mapping_error(dev, qat_req->out.rsa.dec.m)))
924 		goto unmap_dst;
925 
926 	if (ctx->crt_mode)
927 		qat_req->in.rsa.in_tab[6] = 0;
928 	else
929 		qat_req->in.rsa.in_tab[3] = 0;
930 	qat_req->out.rsa.out_tab[1] = 0;
931 	qat_req->phy_in = dma_map_single(dev, &qat_req->in.rsa,
932 					 sizeof(struct qat_rsa_input_params),
933 					 DMA_TO_DEVICE);
934 	if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
935 		goto unmap_dst;
936 
937 	qat_req->phy_out = dma_map_single(dev, &qat_req->out.rsa,
938 					  sizeof(struct qat_rsa_output_params),
939 					  DMA_TO_DEVICE);
940 	if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
941 		goto unmap_in_params;
942 
943 	msg->pke_mid.src_data_addr = qat_req->phy_in;
944 	msg->pke_mid.dest_data_addr = qat_req->phy_out;
945 	msg->pke_mid.opaque = (u64)(__force long)qat_req;
946 	if (ctx->crt_mode)
947 		msg->input_param_count = 6;
948 	else
949 		msg->input_param_count = 3;
950 
951 	msg->output_param_count = 1;
952 
953 	ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
954 	if (ret == -ENOSPC)
955 		goto unmap_all;
956 
957 	return ret;
958 
959 unmap_all:
960 	if (!dma_mapping_error(dev, qat_req->phy_out))
961 		dma_unmap_single(dev, qat_req->phy_out,
962 				 sizeof(struct qat_rsa_output_params),
963 				 DMA_TO_DEVICE);
964 unmap_in_params:
965 	if (!dma_mapping_error(dev, qat_req->phy_in))
966 		dma_unmap_single(dev, qat_req->phy_in,
967 				 sizeof(struct qat_rsa_input_params),
968 				 DMA_TO_DEVICE);
969 unmap_dst:
970 	if (!dma_mapping_error(dev, qat_req->out.rsa.dec.m))
971 		dma_unmap_single(dev, qat_req->out.rsa.dec.m,
972 				 ctx->key_sz, DMA_FROM_DEVICE);
973 	kfree_sensitive(qat_req->dst_align);
974 unmap_src:
975 	if (!dma_mapping_error(dev, qat_req->in.rsa.dec.c))
976 		dma_unmap_single(dev, qat_req->in.rsa.dec.c, ctx->key_sz,
977 				 DMA_TO_DEVICE);
978 	kfree_sensitive(qat_req->src_align);
979 	return ret;
980 }
981 
982 static int qat_rsa_set_n(struct qat_rsa_ctx *ctx, const char *value,
983 			 size_t vlen)
984 {
985 	struct qat_crypto_instance *inst = ctx->inst;
986 	struct device *dev = &GET_DEV(inst->accel_dev);
987 	const char *ptr = value;
988 	int ret;
989 
990 	while (!*ptr && vlen) {
991 		ptr++;
992 		vlen--;
993 	}
994 
995 	ctx->key_sz = vlen;
996 	ret = -EINVAL;
997 	/* invalid key size provided */
998 	if (!qat_rsa_enc_fn_id(ctx->key_sz))
999 		goto err;
1000 
1001 	ret = -ENOMEM;
1002 	ctx->n = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_n, GFP_KERNEL);
1003 	if (!ctx->n)
1004 		goto err;
1005 
1006 	memcpy(ctx->n, ptr, ctx->key_sz);
1007 	return 0;
1008 err:
1009 	ctx->key_sz = 0;
1010 	ctx->n = NULL;
1011 	return ret;
1012 }
1013 
1014 static int qat_rsa_set_e(struct qat_rsa_ctx *ctx, const char *value,
1015 			 size_t vlen)
1016 {
1017 	struct qat_crypto_instance *inst = ctx->inst;
1018 	struct device *dev = &GET_DEV(inst->accel_dev);
1019 	const char *ptr = value;
1020 
1021 	while (!*ptr && vlen) {
1022 		ptr++;
1023 		vlen--;
1024 	}
1025 
1026 	if (!ctx->key_sz || !vlen || vlen > ctx->key_sz) {
1027 		ctx->e = NULL;
1028 		return -EINVAL;
1029 	}
1030 
1031 	ctx->e = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_e, GFP_KERNEL);
1032 	if (!ctx->e)
1033 		return -ENOMEM;
1034 
1035 	memcpy(ctx->e + (ctx->key_sz - vlen), ptr, vlen);
1036 	return 0;
1037 }
1038 
1039 static int qat_rsa_set_d(struct qat_rsa_ctx *ctx, const char *value,
1040 			 size_t vlen)
1041 {
1042 	struct qat_crypto_instance *inst = ctx->inst;
1043 	struct device *dev = &GET_DEV(inst->accel_dev);
1044 	const char *ptr = value;
1045 	int ret;
1046 
1047 	while (!*ptr && vlen) {
1048 		ptr++;
1049 		vlen--;
1050 	}
1051 
1052 	ret = -EINVAL;
1053 	if (!ctx->key_sz || !vlen || vlen > ctx->key_sz)
1054 		goto err;
1055 
1056 	ret = -ENOMEM;
1057 	ctx->d = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_d, GFP_KERNEL);
1058 	if (!ctx->d)
1059 		goto err;
1060 
1061 	memcpy(ctx->d + (ctx->key_sz - vlen), ptr, vlen);
1062 	return 0;
1063 err:
1064 	ctx->d = NULL;
1065 	return ret;
1066 }
1067 
1068 static void qat_rsa_drop_leading_zeros(const char **ptr, unsigned int *len)
1069 {
1070 	while (!**ptr && *len) {
1071 		(*ptr)++;
1072 		(*len)--;
1073 	}
1074 }
1075 
1076 static void qat_rsa_setkey_crt(struct qat_rsa_ctx *ctx, struct rsa_key *rsa_key)
1077 {
1078 	struct qat_crypto_instance *inst = ctx->inst;
1079 	struct device *dev = &GET_DEV(inst->accel_dev);
1080 	const char *ptr;
1081 	unsigned int len;
1082 	unsigned int half_key_sz = ctx->key_sz / 2;
1083 
1084 	/* p */
1085 	ptr = rsa_key->p;
1086 	len = rsa_key->p_sz;
1087 	qat_rsa_drop_leading_zeros(&ptr, &len);
1088 	if (!len)
1089 		goto err;
1090 	ctx->p = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_p, GFP_KERNEL);
1091 	if (!ctx->p)
1092 		goto err;
1093 	memcpy(ctx->p + (half_key_sz - len), ptr, len);
1094 
1095 	/* q */
1096 	ptr = rsa_key->q;
1097 	len = rsa_key->q_sz;
1098 	qat_rsa_drop_leading_zeros(&ptr, &len);
1099 	if (!len)
1100 		goto free_p;
1101 	ctx->q = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_q, GFP_KERNEL);
1102 	if (!ctx->q)
1103 		goto free_p;
1104 	memcpy(ctx->q + (half_key_sz - len), ptr, len);
1105 
1106 	/* dp */
1107 	ptr = rsa_key->dp;
1108 	len = rsa_key->dp_sz;
1109 	qat_rsa_drop_leading_zeros(&ptr, &len);
1110 	if (!len)
1111 		goto free_q;
1112 	ctx->dp = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_dp,
1113 				     GFP_KERNEL);
1114 	if (!ctx->dp)
1115 		goto free_q;
1116 	memcpy(ctx->dp + (half_key_sz - len), ptr, len);
1117 
1118 	/* dq */
1119 	ptr = rsa_key->dq;
1120 	len = rsa_key->dq_sz;
1121 	qat_rsa_drop_leading_zeros(&ptr, &len);
1122 	if (!len)
1123 		goto free_dp;
1124 	ctx->dq = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_dq,
1125 				     GFP_KERNEL);
1126 	if (!ctx->dq)
1127 		goto free_dp;
1128 	memcpy(ctx->dq + (half_key_sz - len), ptr, len);
1129 
1130 	/* qinv */
1131 	ptr = rsa_key->qinv;
1132 	len = rsa_key->qinv_sz;
1133 	qat_rsa_drop_leading_zeros(&ptr, &len);
1134 	if (!len)
1135 		goto free_dq;
1136 	ctx->qinv = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_qinv,
1137 				       GFP_KERNEL);
1138 	if (!ctx->qinv)
1139 		goto free_dq;
1140 	memcpy(ctx->qinv + (half_key_sz - len), ptr, len);
1141 
1142 	ctx->crt_mode = true;
1143 	return;
1144 
1145 free_dq:
1146 	memset(ctx->dq, '\0', half_key_sz);
1147 	dma_free_coherent(dev, half_key_sz, ctx->dq, ctx->dma_dq);
1148 	ctx->dq = NULL;
1149 free_dp:
1150 	memset(ctx->dp, '\0', half_key_sz);
1151 	dma_free_coherent(dev, half_key_sz, ctx->dp, ctx->dma_dp);
1152 	ctx->dp = NULL;
1153 free_q:
1154 	memset(ctx->q, '\0', half_key_sz);
1155 	dma_free_coherent(dev, half_key_sz, ctx->q, ctx->dma_q);
1156 	ctx->q = NULL;
1157 free_p:
1158 	memset(ctx->p, '\0', half_key_sz);
1159 	dma_free_coherent(dev, half_key_sz, ctx->p, ctx->dma_p);
1160 	ctx->p = NULL;
1161 err:
1162 	ctx->crt_mode = false;
1163 }
1164 
1165 static void qat_rsa_clear_ctx(struct device *dev, struct qat_rsa_ctx *ctx)
1166 {
1167 	unsigned int half_key_sz = ctx->key_sz / 2;
1168 
1169 	/* Free the old key if any */
1170 	if (ctx->n)
1171 		dma_free_coherent(dev, ctx->key_sz, ctx->n, ctx->dma_n);
1172 	if (ctx->e)
1173 		dma_free_coherent(dev, ctx->key_sz, ctx->e, ctx->dma_e);
1174 	if (ctx->d) {
1175 		memset(ctx->d, '\0', ctx->key_sz);
1176 		dma_free_coherent(dev, ctx->key_sz, ctx->d, ctx->dma_d);
1177 	}
1178 	if (ctx->p) {
1179 		memset(ctx->p, '\0', half_key_sz);
1180 		dma_free_coherent(dev, half_key_sz, ctx->p, ctx->dma_p);
1181 	}
1182 	if (ctx->q) {
1183 		memset(ctx->q, '\0', half_key_sz);
1184 		dma_free_coherent(dev, half_key_sz, ctx->q, ctx->dma_q);
1185 	}
1186 	if (ctx->dp) {
1187 		memset(ctx->dp, '\0', half_key_sz);
1188 		dma_free_coherent(dev, half_key_sz, ctx->dp, ctx->dma_dp);
1189 	}
1190 	if (ctx->dq) {
1191 		memset(ctx->dq, '\0', half_key_sz);
1192 		dma_free_coherent(dev, half_key_sz, ctx->dq, ctx->dma_dq);
1193 	}
1194 	if (ctx->qinv) {
1195 		memset(ctx->qinv, '\0', half_key_sz);
1196 		dma_free_coherent(dev, half_key_sz, ctx->qinv, ctx->dma_qinv);
1197 	}
1198 
1199 	ctx->n = NULL;
1200 	ctx->e = NULL;
1201 	ctx->d = NULL;
1202 	ctx->p = NULL;
1203 	ctx->q = NULL;
1204 	ctx->dp = NULL;
1205 	ctx->dq = NULL;
1206 	ctx->qinv = NULL;
1207 	ctx->crt_mode = false;
1208 	ctx->key_sz = 0;
1209 }
1210 
1211 static int qat_rsa_setkey(struct crypto_akcipher *tfm, const void *key,
1212 			  unsigned int keylen, bool private)
1213 {
1214 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1215 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
1216 	struct rsa_key rsa_key;
1217 	int ret;
1218 
1219 	qat_rsa_clear_ctx(dev, ctx);
1220 
1221 	if (private)
1222 		ret = rsa_parse_priv_key(&rsa_key, key, keylen);
1223 	else
1224 		ret = rsa_parse_pub_key(&rsa_key, key, keylen);
1225 	if (ret < 0)
1226 		goto free;
1227 
1228 	ret = qat_rsa_set_n(ctx, rsa_key.n, rsa_key.n_sz);
1229 	if (ret < 0)
1230 		goto free;
1231 	ret = qat_rsa_set_e(ctx, rsa_key.e, rsa_key.e_sz);
1232 	if (ret < 0)
1233 		goto free;
1234 	if (private) {
1235 		ret = qat_rsa_set_d(ctx, rsa_key.d, rsa_key.d_sz);
1236 		if (ret < 0)
1237 			goto free;
1238 		qat_rsa_setkey_crt(ctx, &rsa_key);
1239 	}
1240 
1241 	if (!ctx->n || !ctx->e) {
1242 		/* invalid key provided */
1243 		ret = -EINVAL;
1244 		goto free;
1245 	}
1246 	if (private && !ctx->d) {
1247 		/* invalid private key provided */
1248 		ret = -EINVAL;
1249 		goto free;
1250 	}
1251 
1252 	return 0;
1253 free:
1254 	qat_rsa_clear_ctx(dev, ctx);
1255 	return ret;
1256 }
1257 
1258 static int qat_rsa_setpubkey(struct crypto_akcipher *tfm, const void *key,
1259 			     unsigned int keylen)
1260 {
1261 	return qat_rsa_setkey(tfm, key, keylen, false);
1262 }
1263 
1264 static int qat_rsa_setprivkey(struct crypto_akcipher *tfm, const void *key,
1265 			      unsigned int keylen)
1266 {
1267 	return qat_rsa_setkey(tfm, key, keylen, true);
1268 }
1269 
1270 static unsigned int qat_rsa_max_size(struct crypto_akcipher *tfm)
1271 {
1272 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1273 
1274 	return ctx->key_sz;
1275 }
1276 
1277 static int qat_rsa_init_tfm(struct crypto_akcipher *tfm)
1278 {
1279 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1280 	struct qat_crypto_instance *inst =
1281 			qat_crypto_get_instance_node(numa_node_id());
1282 
1283 	if (!inst)
1284 		return -EINVAL;
1285 
1286 	akcipher_set_reqsize(tfm, sizeof(struct qat_asym_request) + 64);
1287 
1288 	ctx->key_sz = 0;
1289 	ctx->inst = inst;
1290 	return 0;
1291 }
1292 
1293 static void qat_rsa_exit_tfm(struct crypto_akcipher *tfm)
1294 {
1295 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1296 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
1297 
1298 	qat_rsa_clear_ctx(dev, ctx);
1299 	qat_crypto_put_instance(ctx->inst);
1300 }
1301 
1302 static struct akcipher_alg rsa = {
1303 	.encrypt = qat_rsa_enc,
1304 	.decrypt = qat_rsa_dec,
1305 	.set_pub_key = qat_rsa_setpubkey,
1306 	.set_priv_key = qat_rsa_setprivkey,
1307 	.max_size = qat_rsa_max_size,
1308 	.init = qat_rsa_init_tfm,
1309 	.exit = qat_rsa_exit_tfm,
1310 	.base = {
1311 		.cra_name = "rsa",
1312 		.cra_driver_name = "qat-rsa",
1313 		.cra_priority = 1000,
1314 		.cra_module = THIS_MODULE,
1315 		.cra_ctxsize = sizeof(struct qat_rsa_ctx),
1316 	},
1317 };
1318 
1319 static struct kpp_alg dh = {
1320 	.set_secret = qat_dh_set_secret,
1321 	.generate_public_key = qat_dh_generate_public_key,
1322 	.compute_shared_secret = qat_dh_compute_shared_secret,
1323 	.max_size = qat_dh_max_size,
1324 	.init = qat_dh_init_tfm,
1325 	.exit = qat_dh_exit_tfm,
1326 	.base = {
1327 		.cra_name = "dh",
1328 		.cra_driver_name = "qat-dh",
1329 		.cra_priority = 1000,
1330 		.cra_module = THIS_MODULE,
1331 		.cra_ctxsize = sizeof(struct qat_dh_ctx),
1332 		.cra_flags = CRYPTO_ALG_NEED_FALLBACK,
1333 	},
1334 };
1335 
1336 int qat_asym_algs_register(void)
1337 {
1338 	int ret = 0;
1339 
1340 	mutex_lock(&algs_lock);
1341 	if (++active_devs == 1) {
1342 		rsa.base.cra_flags = 0;
1343 		ret = crypto_register_akcipher(&rsa);
1344 		if (ret)
1345 			goto unlock;
1346 		ret = crypto_register_kpp(&dh);
1347 	}
1348 unlock:
1349 	mutex_unlock(&algs_lock);
1350 	return ret;
1351 }
1352 
1353 void qat_asym_algs_unregister(void)
1354 {
1355 	mutex_lock(&algs_lock);
1356 	if (--active_devs == 0) {
1357 		crypto_unregister_akcipher(&rsa);
1358 		crypto_unregister_kpp(&dh);
1359 	}
1360 	mutex_unlock(&algs_lock);
1361 }
1362