xref: /linux/drivers/crypto/starfive/jh7110-rsa.c (revision 4e73826089ce899357580bbf6e0afe4e6f9900b7)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * StarFive Public Key Algo acceleration driver
4  *
5  * Copyright (c) 2022 StarFive Technology
6  */
7 
8 #include <linux/crypto.h>
9 #include <linux/iopoll.h>
10 #include <crypto/akcipher.h>
11 #include <crypto/algapi.h>
12 #include <crypto/internal/akcipher.h>
13 #include <crypto/internal/rsa.h>
14 #include <crypto/scatterwalk.h>
15 
16 #include "jh7110-cryp.h"
17 
18 #define STARFIVE_PKA_REGS_OFFSET	0x400
19 #define STARFIVE_PKA_CACR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x0)
20 #define STARFIVE_PKA_CASR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x4)
21 #define STARFIVE_PKA_CAAR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x8)
22 #define STARFIVE_PKA_CAER_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x108)
23 #define STARFIVE_PKA_CANR_OFFSET	(STARFIVE_PKA_REGS_OFFSET + 0x208)
24 
25 /* R ^ 2 mod N and N0' */
26 #define CRYPTO_CMD_PRE			0x0
27 /* A * R mod N   ==> A */
28 #define CRYPTO_CMD_ARN			0x5
29 /* A * E * R mod N ==> A */
30 #define CRYPTO_CMD_AERN			0x6
31 /* A * A * R mod N ==> A */
32 #define CRYPTO_CMD_AARN			0x7
33 
34 #define STARFIVE_RSA_MAX_KEYSZ		256
35 #define STARFIVE_RSA_RESET		0x2
36 
37 static inline int starfive_pka_wait_done(struct starfive_cryp_ctx *ctx)
38 {
39 	struct starfive_cryp_dev *cryp = ctx->cryp;
40 	u32 status;
41 
42 	return readl_relaxed_poll_timeout(cryp->base + STARFIVE_PKA_CASR_OFFSET, status,
43 					  status & STARFIVE_PKA_DONE, 10, 100000);
44 }
45 
46 static void starfive_rsa_free_key(struct starfive_rsa_key *key)
47 {
48 	kfree_sensitive(key->d);
49 	kfree_sensitive(key->e);
50 	kfree_sensitive(key->n);
51 	memset(key, 0, sizeof(*key));
52 }
53 
54 static unsigned int starfive_rsa_get_nbit(u8 *pa, u32 snum, int key_sz)
55 {
56 	u32 i;
57 	u8 value;
58 
59 	i = snum >> 3;
60 
61 	value = pa[key_sz - i - 1];
62 	value >>= snum & 0x7;
63 	value &= 0x1;
64 
65 	return value;
66 }
67 
68 static int starfive_rsa_montgomery_form(struct starfive_cryp_ctx *ctx,
69 					u32 *out, u32 *in, u8 mont,
70 					u32 *mod, int bit_len)
71 {
72 	struct starfive_cryp_dev *cryp = ctx->cryp;
73 	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
74 	int count = rctx->total / sizeof(u32) - 1;
75 	int loop;
76 	u32 temp;
77 	u8 opsize;
78 
79 	opsize = (bit_len - 1) >> 5;
80 	rctx->csr.pka.v = 0;
81 
82 	writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
83 
84 	for (loop = 0; loop <= opsize; loop++)
85 		writel(mod[opsize - loop], cryp->base + STARFIVE_PKA_CANR_OFFSET + loop * 4);
86 
87 	if (mont) {
88 		rctx->csr.pka.v = 0;
89 		rctx->csr.pka.cln_done = 1;
90 		rctx->csr.pka.opsize = opsize;
91 		rctx->csr.pka.exposize = opsize;
92 		rctx->csr.pka.cmd = CRYPTO_CMD_PRE;
93 		rctx->csr.pka.start = 1;
94 		rctx->csr.pka.not_r2 = 1;
95 		rctx->csr.pka.ie = 1;
96 
97 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
98 
99 		if (starfive_pka_wait_done(ctx))
100 			return -ETIMEDOUT;
101 
102 		for (loop = 0; loop <= opsize; loop++)
103 			writel(in[opsize - loop], cryp->base + STARFIVE_PKA_CAAR_OFFSET + loop * 4);
104 
105 		writel(0x1000000, cryp->base + STARFIVE_PKA_CAER_OFFSET);
106 
107 		for (loop = 1; loop <= opsize; loop++)
108 			writel(0, cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
109 
110 		rctx->csr.pka.v = 0;
111 		rctx->csr.pka.cln_done = 1;
112 		rctx->csr.pka.opsize = opsize;
113 		rctx->csr.pka.exposize = opsize;
114 		rctx->csr.pka.cmd = CRYPTO_CMD_AERN;
115 		rctx->csr.pka.start = 1;
116 		rctx->csr.pka.ie = 1;
117 
118 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
119 
120 		if (starfive_pka_wait_done(ctx))
121 			return -ETIMEDOUT;
122 	} else {
123 		rctx->csr.pka.v = 0;
124 		rctx->csr.pka.cln_done = 1;
125 		rctx->csr.pka.opsize = opsize;
126 		rctx->csr.pka.exposize = opsize;
127 		rctx->csr.pka.cmd = CRYPTO_CMD_PRE;
128 		rctx->csr.pka.start = 1;
129 		rctx->csr.pka.pre_expf = 1;
130 		rctx->csr.pka.ie = 1;
131 
132 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
133 
134 		if (starfive_pka_wait_done(ctx))
135 			return -ETIMEDOUT;
136 
137 		for (loop = 0; loop <= count; loop++)
138 			writel(in[count - loop], cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
139 
140 		/*pad with 0 up to opsize*/
141 		for (loop = count + 1; loop <= opsize; loop++)
142 			writel(0, cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
143 
144 		rctx->csr.pka.v = 0;
145 		rctx->csr.pka.cln_done = 1;
146 		rctx->csr.pka.opsize = opsize;
147 		rctx->csr.pka.exposize = opsize;
148 		rctx->csr.pka.cmd = CRYPTO_CMD_ARN;
149 		rctx->csr.pka.start = 1;
150 		rctx->csr.pka.ie = 1;
151 
152 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
153 
154 		if (starfive_pka_wait_done(ctx))
155 			return -ETIMEDOUT;
156 	}
157 
158 	for (loop = 0; loop <= opsize; loop++) {
159 		temp = readl(cryp->base + STARFIVE_PKA_CAAR_OFFSET + 0x4 * loop);
160 		out[opsize - loop] = temp;
161 	}
162 
163 	return 0;
164 }
165 
166 static int starfive_rsa_cpu_start(struct starfive_cryp_ctx *ctx, u32 *result,
167 				  u8 *de, u32 *n, int key_sz)
168 {
169 	struct starfive_cryp_dev *cryp = ctx->cryp;
170 	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
171 	struct starfive_rsa_key *key = &ctx->rsa_key;
172 	u32 temp;
173 	int ret = 0;
174 	int opsize, mlen, loop;
175 	unsigned int *mta;
176 
177 	opsize = (key_sz - 1) >> 2;
178 
179 	mta = kmalloc(key_sz, GFP_KERNEL);
180 	if (!mta)
181 		return -ENOMEM;
182 
183 	ret = starfive_rsa_montgomery_form(ctx, mta, (u32 *)rctx->rsa_data,
184 					   0, n, key_sz << 3);
185 	if (ret) {
186 		dev_err_probe(cryp->dev, ret, "Conversion to Montgomery failed");
187 		goto rsa_err;
188 	}
189 
190 	for (loop = 0; loop <= opsize; loop++)
191 		writel(mta[opsize - loop],
192 		       cryp->base + STARFIVE_PKA_CAER_OFFSET + loop * 4);
193 
194 	for (loop = key->bitlen - 1; loop > 0; loop--) {
195 		mlen = starfive_rsa_get_nbit(de, loop - 1, key_sz);
196 
197 		rctx->csr.pka.v = 0;
198 		rctx->csr.pka.cln_done = 1;
199 		rctx->csr.pka.opsize = opsize;
200 		rctx->csr.pka.exposize = opsize;
201 		rctx->csr.pka.cmd = CRYPTO_CMD_AARN;
202 		rctx->csr.pka.start = 1;
203 		rctx->csr.pka.ie = 1;
204 
205 		writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
206 
207 		ret = -ETIMEDOUT;
208 		if (starfive_pka_wait_done(ctx))
209 			goto rsa_err;
210 
211 		if (mlen) {
212 			rctx->csr.pka.v = 0;
213 			rctx->csr.pka.cln_done = 1;
214 			rctx->csr.pka.opsize = opsize;
215 			rctx->csr.pka.exposize = opsize;
216 			rctx->csr.pka.cmd = CRYPTO_CMD_AERN;
217 			rctx->csr.pka.start = 1;
218 			rctx->csr.pka.ie = 1;
219 
220 			writel(rctx->csr.pka.v, cryp->base + STARFIVE_PKA_CACR_OFFSET);
221 
222 			if (starfive_pka_wait_done(ctx))
223 				goto rsa_err;
224 		}
225 	}
226 
227 	for (loop = 0; loop <= opsize; loop++) {
228 		temp = readl(cryp->base + STARFIVE_PKA_CAAR_OFFSET + 0x4 * loop);
229 		result[opsize - loop] = temp;
230 	}
231 
232 	ret = starfive_rsa_montgomery_form(ctx, result, result, 1, n, key_sz << 3);
233 	if (ret)
234 		dev_err_probe(cryp->dev, ret, "Conversion from Montgomery failed");
235 rsa_err:
236 	kfree(mta);
237 	return ret;
238 }
239 
240 static int starfive_rsa_start(struct starfive_cryp_ctx *ctx, u8 *result,
241 			      u8 *de, u8 *n, int key_sz)
242 {
243 	return starfive_rsa_cpu_start(ctx, (u32 *)result, de, (u32 *)n, key_sz);
244 }
245 
246 static int starfive_rsa_enc_core(struct starfive_cryp_ctx *ctx, int enc)
247 {
248 	struct starfive_cryp_dev *cryp = ctx->cryp;
249 	struct starfive_cryp_request_ctx *rctx = ctx->rctx;
250 	struct starfive_rsa_key *key = &ctx->rsa_key;
251 	int ret = 0;
252 
253 	writel(STARFIVE_RSA_RESET, cryp->base + STARFIVE_PKA_CACR_OFFSET);
254 
255 	rctx->total = sg_copy_to_buffer(rctx->in_sg, rctx->nents,
256 					rctx->rsa_data, rctx->total);
257 
258 	if (enc) {
259 		key->bitlen = key->e_bitlen;
260 		ret = starfive_rsa_start(ctx, rctx->rsa_data, key->e,
261 					 key->n, key->key_sz);
262 	} else {
263 		key->bitlen = key->d_bitlen;
264 		ret = starfive_rsa_start(ctx, rctx->rsa_data, key->d,
265 					 key->n, key->key_sz);
266 	}
267 
268 	if (ret)
269 		goto err_rsa_crypt;
270 
271 	sg_copy_buffer(rctx->out_sg, sg_nents(rctx->out_sg),
272 		       rctx->rsa_data, key->key_sz, 0, 0);
273 
274 err_rsa_crypt:
275 	writel(STARFIVE_RSA_RESET, cryp->base + STARFIVE_PKA_CACR_OFFSET);
276 	kfree(rctx->rsa_data);
277 	return ret;
278 }
279 
280 static int starfive_rsa_enc(struct akcipher_request *req)
281 {
282 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
283 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
284 	struct starfive_cryp_dev *cryp = ctx->cryp;
285 	struct starfive_rsa_key *key = &ctx->rsa_key;
286 	struct starfive_cryp_request_ctx *rctx = akcipher_request_ctx(req);
287 	int ret;
288 
289 	if (!key->key_sz) {
290 		akcipher_request_set_tfm(req, ctx->akcipher_fbk);
291 		ret = crypto_akcipher_encrypt(req);
292 		akcipher_request_set_tfm(req, tfm);
293 		return ret;
294 	}
295 
296 	if (unlikely(!key->n || !key->e))
297 		return -EINVAL;
298 
299 	if (req->dst_len < key->key_sz)
300 		return dev_err_probe(cryp->dev, -EOVERFLOW,
301 				     "Output buffer length less than parameter n\n");
302 
303 	rctx->in_sg = req->src;
304 	rctx->out_sg = req->dst;
305 	rctx->total = req->src_len;
306 	rctx->nents = sg_nents(rctx->in_sg);
307 	ctx->rctx = rctx;
308 
309 	return starfive_rsa_enc_core(ctx, 1);
310 }
311 
312 static int starfive_rsa_dec(struct akcipher_request *req)
313 {
314 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
315 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
316 	struct starfive_cryp_dev *cryp = ctx->cryp;
317 	struct starfive_rsa_key *key = &ctx->rsa_key;
318 	struct starfive_cryp_request_ctx *rctx = akcipher_request_ctx(req);
319 	int ret;
320 
321 	if (!key->key_sz) {
322 		akcipher_request_set_tfm(req, ctx->akcipher_fbk);
323 		ret = crypto_akcipher_decrypt(req);
324 		akcipher_request_set_tfm(req, tfm);
325 		return ret;
326 	}
327 
328 	if (unlikely(!key->n || !key->d))
329 		return -EINVAL;
330 
331 	if (req->dst_len < key->key_sz)
332 		return dev_err_probe(cryp->dev, -EOVERFLOW,
333 				     "Output buffer length less than parameter n\n");
334 
335 	rctx->in_sg = req->src;
336 	rctx->out_sg = req->dst;
337 	ctx->rctx = rctx;
338 	rctx->total = req->src_len;
339 
340 	return starfive_rsa_enc_core(ctx, 0);
341 }
342 
343 static int starfive_rsa_set_n(struct starfive_rsa_key *rsa_key,
344 			      const char *value, size_t vlen)
345 {
346 	const char *ptr = value;
347 	unsigned int bitslen;
348 	int ret;
349 
350 	while (!*ptr && vlen) {
351 		ptr++;
352 		vlen--;
353 	}
354 	rsa_key->key_sz = vlen;
355 	bitslen = rsa_key->key_sz << 3;
356 
357 	/* check valid key size */
358 	if (bitslen & 0x1f)
359 		return -EINVAL;
360 
361 	ret = -ENOMEM;
362 	rsa_key->n = kmemdup(ptr, rsa_key->key_sz, GFP_KERNEL);
363 	if (!rsa_key->n)
364 		goto err;
365 
366 	return 0;
367  err:
368 	rsa_key->key_sz = 0;
369 	rsa_key->n = NULL;
370 	starfive_rsa_free_key(rsa_key);
371 	return ret;
372 }
373 
374 static int starfive_rsa_set_e(struct starfive_rsa_key *rsa_key,
375 			      const char *value, size_t vlen)
376 {
377 	const char *ptr = value;
378 	unsigned char pt;
379 	int loop;
380 
381 	while (!*ptr && vlen) {
382 		ptr++;
383 		vlen--;
384 	}
385 	pt = *ptr;
386 
387 	if (!rsa_key->key_sz || !vlen || vlen > rsa_key->key_sz) {
388 		rsa_key->e = NULL;
389 		return -EINVAL;
390 	}
391 
392 	rsa_key->e = kzalloc(rsa_key->key_sz, GFP_KERNEL);
393 	if (!rsa_key->e)
394 		return -ENOMEM;
395 
396 	for (loop = 8; loop > 0; loop--) {
397 		if (pt >> (loop - 1))
398 			break;
399 	}
400 
401 	rsa_key->e_bitlen = (vlen - 1) * 8 + loop;
402 
403 	memcpy(rsa_key->e + (rsa_key->key_sz - vlen), ptr, vlen);
404 
405 	return 0;
406 }
407 
408 static int starfive_rsa_set_d(struct starfive_rsa_key *rsa_key,
409 			      const char *value, size_t vlen)
410 {
411 	const char *ptr = value;
412 	unsigned char pt;
413 	int loop;
414 	int ret;
415 
416 	while (!*ptr && vlen) {
417 		ptr++;
418 		vlen--;
419 	}
420 	pt = *ptr;
421 
422 	ret = -EINVAL;
423 	if (!rsa_key->key_sz || !vlen || vlen > rsa_key->key_sz)
424 		goto err;
425 
426 	ret = -ENOMEM;
427 	rsa_key->d = kzalloc(rsa_key->key_sz, GFP_KERNEL);
428 	if (!rsa_key->d)
429 		goto err;
430 
431 	for (loop = 8; loop > 0; loop--) {
432 		if (pt >> (loop - 1))
433 			break;
434 	}
435 
436 	rsa_key->d_bitlen = (vlen - 1) * 8 + loop;
437 
438 	memcpy(rsa_key->d + (rsa_key->key_sz - vlen), ptr, vlen);
439 
440 	return 0;
441  err:
442 	rsa_key->d = NULL;
443 	return ret;
444 }
445 
446 static int starfive_rsa_setkey(struct crypto_akcipher *tfm, const void *key,
447 			       unsigned int keylen, bool private)
448 {
449 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
450 	struct rsa_key raw_key = {NULL};
451 	struct starfive_rsa_key *rsa_key = &ctx->rsa_key;
452 	int ret;
453 
454 	if (private)
455 		ret = rsa_parse_priv_key(&raw_key, key, keylen);
456 	else
457 		ret = rsa_parse_pub_key(&raw_key, key, keylen);
458 	if (ret < 0)
459 		goto err;
460 
461 	starfive_rsa_free_key(rsa_key);
462 
463 	/* Use fallback for mod > 256 + 1 byte prefix */
464 	if (raw_key.n_sz > STARFIVE_RSA_MAX_KEYSZ + 1)
465 		return 0;
466 
467 	ret = starfive_rsa_set_n(rsa_key, raw_key.n, raw_key.n_sz);
468 	if (ret)
469 		return ret;
470 
471 	ret = starfive_rsa_set_e(rsa_key, raw_key.e, raw_key.e_sz);
472 	if (ret)
473 		goto err;
474 
475 	if (private) {
476 		ret = starfive_rsa_set_d(rsa_key, raw_key.d, raw_key.d_sz);
477 		if (ret)
478 			goto err;
479 	}
480 
481 	if (!rsa_key->n || !rsa_key->e) {
482 		ret = -EINVAL;
483 		goto err;
484 	}
485 
486 	if (private && !rsa_key->d) {
487 		ret = -EINVAL;
488 		goto err;
489 	}
490 
491 	return 0;
492  err:
493 	starfive_rsa_free_key(rsa_key);
494 	return ret;
495 }
496 
497 static int starfive_rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key,
498 				    unsigned int keylen)
499 {
500 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
501 	int ret;
502 
503 	ret = crypto_akcipher_set_pub_key(ctx->akcipher_fbk, key, keylen);
504 	if (ret)
505 		return ret;
506 
507 	return starfive_rsa_setkey(tfm, key, keylen, false);
508 }
509 
510 static int starfive_rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key,
511 				     unsigned int keylen)
512 {
513 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
514 	int ret;
515 
516 	ret = crypto_akcipher_set_priv_key(ctx->akcipher_fbk, key, keylen);
517 	if (ret)
518 		return ret;
519 
520 	return starfive_rsa_setkey(tfm, key, keylen, true);
521 }
522 
523 static unsigned int starfive_rsa_max_size(struct crypto_akcipher *tfm)
524 {
525 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
526 
527 	if (ctx->rsa_key.key_sz)
528 		return ctx->rsa_key.key_sz;
529 
530 	return crypto_akcipher_maxsize(ctx->akcipher_fbk);
531 }
532 
533 static int starfive_rsa_init_tfm(struct crypto_akcipher *tfm)
534 {
535 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
536 
537 	ctx->akcipher_fbk = crypto_alloc_akcipher("rsa-generic", 0, 0);
538 	if (IS_ERR(ctx->akcipher_fbk))
539 		return PTR_ERR(ctx->akcipher_fbk);
540 
541 	ctx->cryp = starfive_cryp_find_dev(ctx);
542 	if (!ctx->cryp) {
543 		crypto_free_akcipher(ctx->akcipher_fbk);
544 		return -ENODEV;
545 	}
546 
547 	akcipher_set_reqsize(tfm, sizeof(struct starfive_cryp_request_ctx) +
548 			     sizeof(struct crypto_akcipher) + 32);
549 
550 	return 0;
551 }
552 
553 static void starfive_rsa_exit_tfm(struct crypto_akcipher *tfm)
554 {
555 	struct starfive_cryp_ctx *ctx = akcipher_tfm_ctx(tfm);
556 	struct starfive_rsa_key *key = (struct starfive_rsa_key *)&ctx->rsa_key;
557 
558 	crypto_free_akcipher(ctx->akcipher_fbk);
559 	starfive_rsa_free_key(key);
560 }
561 
562 static struct akcipher_alg starfive_rsa = {
563 	.encrypt = starfive_rsa_enc,
564 	.decrypt = starfive_rsa_dec,
565 	.sign = starfive_rsa_dec,
566 	.verify = starfive_rsa_enc,
567 	.set_pub_key = starfive_rsa_set_pub_key,
568 	.set_priv_key = starfive_rsa_set_priv_key,
569 	.max_size = starfive_rsa_max_size,
570 	.init = starfive_rsa_init_tfm,
571 	.exit = starfive_rsa_exit_tfm,
572 	.base = {
573 		.cra_name = "rsa",
574 		.cra_driver_name = "starfive-rsa",
575 		.cra_flags = CRYPTO_ALG_TYPE_AKCIPHER |
576 			     CRYPTO_ALG_NEED_FALLBACK,
577 		.cra_priority = 3000,
578 		.cra_module = THIS_MODULE,
579 		.cra_ctxsize = sizeof(struct starfive_cryp_ctx),
580 	},
581 };
582 
583 int starfive_rsa_register_algs(void)
584 {
585 	return crypto_register_akcipher(&starfive_rsa);
586 }
587 
588 void starfive_rsa_unregister_algs(void)
589 {
590 	crypto_unregister_akcipher(&starfive_rsa);
591 }
592