xref: /linux/drivers/crypto/starfive/jh7110-rsa.c (revision 6af91e3d2cfc8bb579b1aa2d22cd91f8c34acdf6)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * StarFive Public Key Algo acceleration driver
4  *
5  * Copyright (c) 2022 StarFive Technology
6  */
7 
8 #include <linux/crypto.h>
9 #include <linux/iopoll.h>
10 #include <crypto/akcipher.h>
11 #include <crypto/algapi.h>
12 #include <crypto/internal/akcipher.h>
13 #include <crypto/internal/rsa.h>
14 #include <crypto/scatterwalk.h>
15 
16 #include "jh7110-cryp.h"
17 
18 #define STARFIVE_PKA_REGS_OFFSET	0x400
19 #define STARFIVE_PKA_CACR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x0)
20 #define STARFIVE_PKA_CASR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x4)
21 #define STARFIVE_PKA_CAAR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x8)
22 #define STARFIVE_PKA_CAER_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x108)
23 #define STARFIVE_PKA_CANR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x208)
24 
25 /* R ^ 2 mod N and N0' */
26 #define CRYPTO_CMD_PRE			0x0
27 /* A * R mod N   ==> A */
28 #define CRYPTO_CMD_ARN			0x5
29 /* A * E * R mod N ==> A */
30 #define CRYPTO_CMD_AERN			0x6
31 /* A * A * R mod N ==> A */
32 #define CRYPTO_CMD_AARN			0x7
33 
34 #define STARFIVE_RSA_RESET		0x2
35 
36 static inline int starfive_pka_wait_done(struct starfive_cryp_ctx *ctx)
37 {
38 	struct starfive_cryp_dev *cryp = ctx->cryp;
39 	u32 status;
40 
41 	return readl_relaxed_poll_timeout(cryp->base + STARFIVE_PKA_CASR_OFFSET, status,
42 					  status & STARFIVE_PKA_DONE, 10, 100000);
43 }
44 
45 static void starfive_rsa_free_key(struct starfive_rsa_key *key)
46 {
47 	if (!key->key_sz)
48 		return;
49 
50 	kfree_sensitive(key->d);
51 	kfree_sensitive(key->e);
52 	kfree_sensitive(key->n);
53 	memset(key, 0, sizeof(*key));
54 }
55 
56 static unsigned int starfive_rsa_get_nbit(u8 *pa, u32 snum, int key_sz)
57 {
58 	u32 i;
59 	u8 value;
60 
61 	i = snum >> 3;
62 
63 	value = pa[key_sz - i - 1];
64 	value >>= snum & 0x7;
65 	value &= 0x1;
66 
67 	return value;
68 }
69 
70 static int starfive_rsa_montgomery_form(struct starfive_cryp_ctx *ctx,
71 					u32 *out, u32 *in, u8 mont,
72 					u32 *mod, int bit_len)
73 {
74 	struct starfive_cryp_dev *cryp = ctx->cryp;
75 	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
76 	int count = (ALIGN(rctx->total, 4) / 4) - 1;
77 	int loop;
78 	u32 temp;
79 	u8 opsize;
80 
81 	opsize = (bit_len - 1) >> 5;
82 	rctx->csr.pka.v = 0;
83 
84 	writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
85 
86 	for (loop = 0; loop <= opsize; loop++)
87 		writel(mod[opsize - loop], cryp->base + STARFIVE_PKA_CANR_OFFSET + loop * 4);
88 
89 	if (mont) {
90 		rctx->csr.pka.v = 0;
91 		rctx->csr.pka.cln_done = 1;
92 		rctx->csr.pka.opsize = opsize;
93 		rctx->csr.pka.exposize = opsize;
94 		rctx->csr.pka.cmd = CRYPTO_CMD_PRE;
95 		rctx->csr.pka.start = 1;
96 		rctx->csr.pka.not_r2 = 1;
97 		rctx->csr.pka.ie = 1;
98 
99 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
100 
101 		if (starfive_pka_wait_done(ctx))
102 			return -ETIMEDOUT;
103 
104 		for (loop = 0; loop <= opsize; loop++)
105 			writel(in[opsize - loop], cryp->base + STARFIVE_PKA_CAAR_OFFSET + loop * 4);
106 
107 		writel(0x1000000, cryp->base + STARFIVE_PKA_CAER_OFFSET);
108 
109 		for (loop = 1; loop <= opsize; loop++)
110 			writel(0, cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
111 
112 		rctx->csr.pka.v = 0;
113 		rctx->csr.pka.cln_done = 1;
114 		rctx->csr.pka.opsize = opsize;
115 		rctx->csr.pka.exposize = opsize;
116 		rctx->csr.pka.cmd = CRYPTO_CMD_AERN;
117 		rctx->csr.pka.start = 1;
118 		rctx->csr.pka.ie = 1;
119 
120 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
121 
122 		if (starfive_pka_wait_done(ctx))
123 			return -ETIMEDOUT;
124 	} else {
125 		rctx->csr.pka.v = 0;
126 		rctx->csr.pka.cln_done = 1;
127 		rctx->csr.pka.opsize = opsize;
128 		rctx->csr.pka.exposize = opsize;
129 		rctx->csr.pka.cmd = CRYPTO_CMD_PRE;
130 		rctx->csr.pka.start = 1;
131 		rctx->csr.pka.pre_expf = 1;
132 		rctx->csr.pka.ie = 1;
133 
134 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
135 
136 		if (starfive_pka_wait_done(ctx))
137 			return -ETIMEDOUT;
138 
139 		for (loop = 0; loop <= count; loop++)
140 			writel(in[count - loop], cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
141 
142 		/*pad with 0 up to opsize*/
143 		for (loop = count + 1; loop <= opsize; loop++)
144 			writel(0, cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
145 
146 		rctx->csr.pka.v = 0;
147 		rctx->csr.pka.cln_done = 1;
148 		rctx->csr.pka.opsize = opsize;
149 		rctx->csr.pka.exposize = opsize;
150 		rctx->csr.pka.cmd = CRYPTO_CMD_ARN;
151 		rctx->csr.pka.start = 1;
152 		rctx->csr.pka.ie = 1;
153 
154 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
155 
156 		if (starfive_pka_wait_done(ctx))
157 			return -ETIMEDOUT;
158 	}
159 
160 	for (loop = 0; loop <= opsize; loop++) {
161 		temp = readl(cryp->base + STARFIVE_PKA_CAAR_OFFSET + 0x4 * loop);
162 		out[opsize - loop] = temp;
163 	}
164 
165 	return 0;
166 }
167 
168 static int starfive_rsa_cpu_start(struct starfive_cryp_ctx *ctx, u32 *result,
169 				  u8 *de, u32 *n, int key_sz)
170 {
171 	struct starfive_cryp_dev *cryp = ctx->cryp;
172 	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
173 	struct starfive_rsa_key *key = &ctx->rsa_key;
174 	u32 temp;
175 	int ret = 0;
176 	int opsize, mlen, loop;
177 	unsigned int *mta;
178 
179 	opsize = (key_sz - 1) >> 2;
180 
181 	mta = kmalloc(key_sz, GFP_KERNEL);
182 	if (!mta)
183 		return -ENOMEM;
184 
185 	ret = starfive_rsa_montgomery_form(ctx, mta, (u32 *)rctx->rsa_data,
186 					   0, n, key_sz << 3);
187 	if (ret) {
188 		dev_err_probe(cryp->dev, ret, "Conversion to Montgomery failed");
189 		goto rsa_err;
190 	}
191 
192 	for (loop = 0; loop <= opsize; loop++)
193 		writel(mta[opsize - loop],
194 		       cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
195 
196 	for (loop = key->bitlen - 1; loop > 0; loop--) {
197 		mlen = starfive_rsa_get_nbit(de, loop - 1, key_sz);
198 
199 		rctx->csr.pka.v = 0;
200 		rctx->csr.pka.cln_done = 1;
201 		rctx->csr.pka.opsize = opsize;
202 		rctx->csr.pka.exposize = opsize;
203 		rctx->csr.pka.cmd = CRYPTO_CMD_AARN;
204 		rctx->csr.pka.start = 1;
205 		rctx->csr.pka.ie = 1;
206 
207 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
208 
209 		ret = -ETIMEDOUT;
210 		if (starfive_pka_wait_done(ctx))
211 			goto rsa_err;
212 
213 		if (mlen) {
214 			rctx->csr.pka.v = 0;
215 			rctx->csr.pka.cln_done = 1;
216 			rctx->csr.pka.opsize = opsize;
217 			rctx->csr.pka.exposize = opsize;
218 			rctx->csr.pka.cmd = CRYPTO_CMD_AERN;
219 			rctx->csr.pka.start = 1;
220 			rctx->csr.pka.ie = 1;
221 
222 			writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
223 
224 			if (starfive_pka_wait_done(ctx))
225 				goto rsa_err;
226 		}
227 	}
228 
229 	for (loop = 0; loop <= opsize; loop++) {
230 		temp = readl(cryp->base + STARFIVE_PKA_CAAR_OFFSET + 0x4 * loop);
231 		result[opsize - loop] = temp;
232 	}
233 
234 	ret = starfive_rsa_montgomery_form(ctx, result, result, 1, n, key_sz << 3);
235 	if (ret)
236 		dev_err_probe(cryp->dev, ret, "Conversion from Montgomery failed");
237 rsa_err:
238 	kfree(mta);
239 	return ret;
240 }
241 
242 static int starfive_rsa_start(struct starfive_cryp_ctx *ctx, u8 *result,
243 			      u8 *de, u8 *n, int key_sz)
244 {
245 	return starfive_rsa_cpu_start(ctx, (u32 *)result, de, (u32 *)n, key_sz);
246 }
247 
248 static int starfive_rsa_enc_core(struct starfive_cryp_ctx *ctx, int enc)
249 {
250 	struct starfive_cryp_dev *cryp = ctx->cryp;
251 	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
252 	struct starfive_rsa_key *key = &ctx->rsa_key;
253 	int ret = 0, shift = 0;
254 
255 	writel(STARFIVE_RSA_RESET, cryp->base + STARFIVE_PKA_CACR_OFFSET);
256 
257 	if (!IS_ALIGNED(rctx->total, sizeof(u32))) {
258 		shift = sizeof(u32) - (rctx->total & 0x3);
259 		memset(rctx->rsa_data, 0, shift);
260 	}
261 
262 	rctx->total = sg_copy_to_buffer(rctx->in_sg, sg_nents(rctx->in_sg),
263 					rctx->rsa_data + shift, rctx->total);
264 
265 	if (enc) {
266 		key->bitlen = key->e_bitlen;
267 		ret = starfive_rsa_start(ctx, rctx->rsa_data, key->e,
268 					 key->n, key->key_sz);
269 	} else {
270 		key->bitlen = key->d_bitlen;
271 		ret = starfive_rsa_start(ctx, rctx->rsa_data, key->d,
272 					 key->n, key->key_sz);
273 	}
274 
275 	if (ret)
276 		goto err_rsa_crypt;
277 
278 	sg_copy_buffer(rctx->out_sg, sg_nents(rctx->out_sg),
279 		       rctx->rsa_data, key->key_sz, 0, 0);
280 
281 err_rsa_crypt:
282 	writel(STARFIVE_RSA_RESET, cryp->base + STARFIVE_PKA_CACR_OFFSET);
283 	return ret;
284 }
285 
286 static int starfive_rsa_enc(struct akcipher_request *req)
287 {
288 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
289 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
290 	struct starfive_cryp_dev *cryp = ctx->cryp;
291 	struct starfive_rsa_key *key = &ctx->rsa_key;
292 	struct starfive_cryp_request_ctx *rctx = akcipher_request_ctx(req);
293 	int ret;
294 
295 	if (!key->key_sz) {
296 		akcipher_request_set_tfm(req, ctx->akcipher_fbk);
297 		ret = crypto_akcipher_encrypt(req);
298 		akcipher_request_set_tfm(req, tfm);
299 		return ret;
300 	}
301 
302 	if (unlikely(!key->n || !key->e))
303 		return -EINVAL;
304 
305 	if (req->dst_len < key->key_sz)
306 		return dev_err_probe(cryp->dev, -EOVERFLOW,
307 				     "Output buffer length less than parameter n\n");
308 
309 	rctx->in_sg = req->src;
310 	rctx->out_sg = req->dst;
311 	rctx->total = req->src_len;
312 	ctx->rctx = rctx;
313 
314 	return starfive_rsa_enc_core(ctx, 1);
315 }
316 
317 static int starfive_rsa_dec(struct akcipher_request *req)
318 {
319 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
320 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
321 	struct starfive_cryp_dev *cryp = ctx->cryp;
322 	struct starfive_rsa_key *key = &ctx->rsa_key;
323 	struct starfive_cryp_request_ctx *rctx = akcipher_request_ctx(req);
324 	int ret;
325 
326 	if (!key->key_sz) {
327 		akcipher_request_set_tfm(req, ctx->akcipher_fbk);
328 		ret = crypto_akcipher_decrypt(req);
329 		akcipher_request_set_tfm(req, tfm);
330 		return ret;
331 	}
332 
333 	if (unlikely(!key->n || !key->d))
334 		return -EINVAL;
335 
336 	if (req->dst_len < key->key_sz)
337 		return dev_err_probe(cryp->dev, -EOVERFLOW,
338 				     "Output buffer length less than parameter n\n");
339 
340 	rctx->in_sg = req->src;
341 	rctx->out_sg = req->dst;
342 	ctx->rctx = rctx;
343 	rctx->total = req->src_len;
344 
345 	return starfive_rsa_enc_core(ctx, 0);
346 }
347 
348 static int starfive_rsa_set_n(struct starfive_rsa_key *rsa_key,
349 			      const char *value, size_t vlen)
350 {
351 	const char *ptr = value;
352 	unsigned int bitslen;
353 	int ret;
354 
355 	while (!*ptr && vlen) {
356 		ptr++;
357 		vlen--;
358 	}
359 	rsa_key->key_sz = vlen;
360 	bitslen = rsa_key->key_sz << 3;
361 
362 	/* check valid key size */
363 	if (bitslen & 0x1f)
364 		return -EINVAL;
365 
366 	ret = -ENOMEM;
367 	rsa_key->n = kmemdup(ptr, rsa_key->key_sz, GFP_KERNEL);
368 	if (!rsa_key->n)
369 		goto err;
370 
371 	return 0;
372  err:
373 	rsa_key->key_sz = 0;
374 	rsa_key->n = NULL;
375 	starfive_rsa_free_key(rsa_key);
376 	return ret;
377 }
378 
379 static int starfive_rsa_set_e(struct starfive_rsa_key *rsa_key,
380 			      const char *value, size_t vlen)
381 {
382 	const char *ptr = value;
383 	unsigned char pt;
384 	int loop;
385 
386 	while (!*ptr && vlen) {
387 		ptr++;
388 		vlen--;
389 	}
390 	pt = *ptr;
391 
392 	if (!rsa_key->key_sz || !vlen || vlen > rsa_key->key_sz) {
393 		rsa_key->e = NULL;
394 		return -EINVAL;
395 	}
396 
397 	rsa_key->e = kzalloc(rsa_key->key_sz, GFP_KERNEL);
398 	if (!rsa_key->e)
399 		return -ENOMEM;
400 
401 	for (loop = 8; loop > 0; loop--) {
402 		if (pt >> (loop - 1))
403 			break;
404 	}
405 
406 	rsa_key->e_bitlen = (vlen - 1) * 8 + loop;
407 
408 	memcpy(rsa_key->e + (rsa_key->key_sz - vlen), ptr, vlen);
409 
410 	return 0;
411 }
412 
413 static int starfive_rsa_set_d(struct starfive_rsa_key *rsa_key,
414 			      const char *value, size_t vlen)
415 {
416 	const char *ptr = value;
417 	unsigned char pt;
418 	int loop;
419 	int ret;
420 
421 	while (!*ptr && vlen) {
422 		ptr++;
423 		vlen--;
424 	}
425 	pt = *ptr;
426 
427 	ret = -EINVAL;
428 	if (!rsa_key->key_sz || !vlen || vlen > rsa_key->key_sz)
429 		goto err;
430 
431 	ret = -ENOMEM;
432 	rsa_key->d = kzalloc(rsa_key->key_sz, GFP_KERNEL);
433 	if (!rsa_key->d)
434 		goto err;
435 
436 	for (loop = 8; loop > 0; loop--) {
437 		if (pt >> (loop - 1))
438 			break;
439 	}
440 
441 	rsa_key->d_bitlen = (vlen - 1) * 8 + loop;
442 
443 	memcpy(rsa_key->d + (rsa_key->key_sz - vlen), ptr, vlen);
444 
445 	return 0;
446  err:
447 	rsa_key->d = NULL;
448 	return ret;
449 }
450 
451 static int starfive_rsa_setkey(struct crypto_akcipher *tfm, const void *key,
452 			       unsigned int keylen, bool private)
453 {
454 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
455 	struct rsa_key raw_key = {NULL};
456 	struct starfive_rsa_key *rsa_key = &ctx->rsa_key;
457 	int ret;
458 
459 	if (private)
460 		ret = rsa_parse_priv_key(&raw_key, key, keylen);
461 	else
462 		ret = rsa_parse_pub_key(&raw_key, key, keylen);
463 	if (ret < 0)
464 		goto err;
465 
466 	starfive_rsa_free_key(rsa_key);
467 
468 	/* Use fallback for mod > 256 + 1 byte prefix */
469 	if (raw_key.n_sz > STARFIVE_RSA_MAX_KEYSZ + 1)
470 		return 0;
471 
472 	ret = starfive_rsa_set_n(rsa_key, raw_key.n, raw_key.n_sz);
473 	if (ret)
474 		return ret;
475 
476 	ret = starfive_rsa_set_e(rsa_key, raw_key.e, raw_key.e_sz);
477 	if (ret)
478 		goto err;
479 
480 	if (private) {
481 		ret = starfive_rsa_set_d(rsa_key, raw_key.d, raw_key.d_sz);
482 		if (ret)
483 			goto err;
484 	}
485 
486 	if (!rsa_key->n || !rsa_key->e) {
487 		ret = -EINVAL;
488 		goto err;
489 	}
490 
491 	if (private && !rsa_key->d) {
492 		ret = -EINVAL;
493 		goto err;
494 	}
495 
496 	return 0;
497  err:
498 	starfive_rsa_free_key(rsa_key);
499 	return ret;
500 }
501 
502 static int starfive_rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key,
503 				    unsigned int keylen)
504 {
505 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
506 	int ret;
507 
508 	ret = crypto_akcipher_set_pub_key(ctx->akcipher_fbk, key, keylen);
509 	if (ret)
510 		return ret;
511 
512 	return starfive_rsa_setkey(tfm, key, keylen, false);
513 }
514 
515 static int starfive_rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key,
516 				     unsigned int keylen)
517 {
518 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
519 	int ret;
520 
521 	ret = crypto_akcipher_set_priv_key(ctx->akcipher_fbk, key, keylen);
522 	if (ret)
523 		return ret;
524 
525 	return starfive_rsa_setkey(tfm, key, keylen, true);
526 }
527 
528 static unsigned int starfive_rsa_max_size(struct crypto_akcipher *tfm)
529 {
530 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
531 
532 	if (ctx->rsa_key.key_sz)
533 		return ctx->rsa_key.key_sz;
534 
535 	return crypto_akcipher_maxsize(ctx->akcipher_fbk);
536 }
537 
538 static int starfive_rsa_init_tfm(struct crypto_akcipher *tfm)
539 {
540 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
541 
542 	ctx->cryp = starfive_cryp_find_dev(ctx);
543 	if (!ctx->cryp)
544 		return -ENODEV;
545 
546 	ctx->akcipher_fbk = crypto_alloc_akcipher("rsa-generic", 0, 0);
547 	if (IS_ERR(ctx->akcipher_fbk))
548 		return PTR_ERR(ctx->akcipher_fbk);
549 
550 	akcipher_set_reqsize(tfm, sizeof(struct starfive_cryp_request_ctx) +
551 			     sizeof(struct crypto_akcipher) + 32);
552 
553 	return 0;
554 }
555 
556 static void starfive_rsa_exit_tfm(struct crypto_akcipher *tfm)
557 {
558 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
559 	struct starfive_rsa_key *key = (struct starfive_rsa_key *)&ctx->rsa_key;
560 
561 	crypto_free_akcipher(ctx->akcipher_fbk);
562 	starfive_rsa_free_key(key);
563 }
564 
565 static struct akcipher_alg starfive_rsa = {
566 	.encrypt = starfive_rsa_enc,
567 	.decrypt = starfive_rsa_dec,
568 	.sign = starfive_rsa_dec,
569 	.verify = starfive_rsa_enc,
570 	.set_pub_key = starfive_rsa_set_pub_key,
571 	.set_priv_key = starfive_rsa_set_priv_key,
572 	.max_size = starfive_rsa_max_size,
573 	.init = starfive_rsa_init_tfm,
574 	.exit = starfive_rsa_exit_tfm,
575 	.base = {
576 		.cra_name = "rsa",
577 		.cra_driver_name = "starfive-rsa",
578 		.cra_flags = CRYPTO_ALG_TYPE_AKCIPHER |
579 			     CRYPTO_ALG_NEED_FALLBACK,
580 		.cra_priority = 3000,
581 		.cra_module = THIS_MODULE,
582 		.cra_ctxsize = sizeof(struct starfive_cryp_ctx),
583 	},
584 };
585 
586 int starfive_rsa_register_algs(void)
587 {
588 	return crypto_register_akcipher(&starfive_rsa);
589 }
590 
591 void starfive_rsa_unregister_algs(void)
592 {
593 	crypto_unregister_akcipher(&starfive_rsa);
594 }
595