xref: /linux/drivers/crypto/intel/keembay/keembay-ocs-ecc.c (revision fbf5df34a4dbcd09d433dd4f0916bf9b2ddb16de)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Intel Keem Bay OCS ECC Crypto Driver.
4  *
5  * Copyright (C) 2019-2021 Intel Corporation
6  */
7 
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9 
10 #include <crypto/ecc_curve.h>
11 #include <crypto/ecdh.h>
12 #include <crypto/engine.h>
13 #include <crypto/internal/ecc.h>
14 #include <crypto/internal/kpp.h>
15 #include <crypto/kpp.h>
16 #include <crypto/rng.h>
17 #include <linux/clk.h>
18 #include <linux/completion.h>
19 #include <linux/err.h>
20 #include <linux/fips.h>
21 #include <linux/interrupt.h>
22 #include <linux/io.h>
23 #include <linux/iopoll.h>
24 #include <linux/irq.h>
25 #include <linux/kernel.h>
26 #include <linux/module.h>
27 #include <linux/of.h>
28 #include <linux/platform_device.h>
29 #include <linux/scatterlist.h>
30 #include <linux/string.h>
31 
32 #define DRV_NAME			"keembay-ocs-ecc"
33 
34 #define KMB_OCS_ECC_PRIORITY		350
35 
36 #define HW_OFFS_OCS_ECC_COMMAND		0x00000000
37 #define HW_OFFS_OCS_ECC_STATUS		0x00000004
38 #define HW_OFFS_OCS_ECC_DATA_IN		0x00000080
39 #define HW_OFFS_OCS_ECC_CX_DATA_OUT	0x00000100
40 #define HW_OFFS_OCS_ECC_CY_DATA_OUT	0x00000180
41 #define HW_OFFS_OCS_ECC_ISR		0x00000400
42 #define HW_OFFS_OCS_ECC_IER		0x00000404
43 
44 #define HW_OCS_ECC_ISR_INT_STATUS_DONE	BIT(0)
45 #define HW_OCS_ECC_COMMAND_INS_BP	BIT(0)
46 
47 #define HW_OCS_ECC_COMMAND_START_VAL	BIT(0)
48 
49 #define OCS_ECC_OP_SIZE_384		BIT(8)
50 #define OCS_ECC_OP_SIZE_256		0
51 
52 /* ECC Instruction : for ECC_COMMAND */
53 #define OCS_ECC_INST_WRITE_AX		(0x1 << HW_OCS_ECC_COMMAND_INS_BP)
54 #define OCS_ECC_INST_WRITE_AY		(0x2 << HW_OCS_ECC_COMMAND_INS_BP)
55 #define OCS_ECC_INST_WRITE_BX_D		(0x3 << HW_OCS_ECC_COMMAND_INS_BP)
56 #define OCS_ECC_INST_WRITE_BY_L		(0x4 << HW_OCS_ECC_COMMAND_INS_BP)
57 #define OCS_ECC_INST_WRITE_P		(0x5 << HW_OCS_ECC_COMMAND_INS_BP)
58 #define OCS_ECC_INST_WRITE_A		(0x6 << HW_OCS_ECC_COMMAND_INS_BP)
59 #define OCS_ECC_INST_CALC_D_IDX_A	(0x8 << HW_OCS_ECC_COMMAND_INS_BP)
60 #define OCS_ECC_INST_CALC_A_POW_B_MODP	(0xB << HW_OCS_ECC_COMMAND_INS_BP)
61 #define OCS_ECC_INST_CALC_A_MUL_B_MODP	(0xC  << HW_OCS_ECC_COMMAND_INS_BP)
62 #define OCS_ECC_INST_CALC_A_ADD_B_MODP	(0xD << HW_OCS_ECC_COMMAND_INS_BP)
63 
64 #define ECC_ENABLE_INTR			1
65 
66 #define POLL_USEC			100
67 #define TIMEOUT_USEC			10000
68 
69 #define KMB_ECC_VLI_MAX_DIGITS		ECC_CURVE_NIST_P384_DIGITS
70 #define KMB_ECC_VLI_MAX_BYTES		(KMB_ECC_VLI_MAX_DIGITS \
71 					 << ECC_DIGITS_TO_BYTES_SHIFT)
72 
73 #define POW_CUBE			3
74 
75 /**
76  * struct ocs_ecc_dev - ECC device context
77  * @list: List of device contexts
78  * @dev: OCS ECC device
79  * @base_reg: IO base address of OCS ECC
80  * @engine: Crypto engine for the device
81  * @irq_done: IRQ done completion.
82  * @irq: IRQ number
83  */
84 struct ocs_ecc_dev {
85 	struct list_head list;
86 	struct device *dev;
87 	void __iomem *base_reg;
88 	struct crypto_engine *engine;
89 	struct completion irq_done;
90 	int irq;
91 };
92 
93 /**
94  * struct ocs_ecc_ctx - Transformation context.
95  * @ecc_dev:	 The ECC driver associated with this context.
96  * @curve:	 The elliptic curve used by this transformation.
97  * @private_key: The private key.
98  */
99 struct ocs_ecc_ctx {
100 	struct ocs_ecc_dev *ecc_dev;
101 	const struct ecc_curve *curve;
102 	u64 private_key[KMB_ECC_VLI_MAX_DIGITS];
103 };
104 
105 /* Driver data. */
106 struct ocs_ecc_drv {
107 	struct list_head dev_list;
108 	spinlock_t lock;	/* Protects dev_list. */
109 };
110 
111 /* Global variable holding the list of OCS ECC devices (only one expected). */
112 static struct ocs_ecc_drv ocs_ecc = {
113 	.dev_list = LIST_HEAD_INIT(ocs_ecc.dev_list),
114 	.lock = __SPIN_LOCK_UNLOCKED(ocs_ecc.lock),
115 };
116 
117 /* Get OCS ECC tfm context from kpp_request. */
118 static inline struct ocs_ecc_ctx *kmb_ocs_ecc_tctx(struct kpp_request *req)
119 {
120 	return kpp_tfm_ctx(crypto_kpp_reqtfm(req));
121 }
122 
123 /* Converts number of digits to number of bytes. */
124 static inline unsigned int digits_to_bytes(unsigned int n)
125 {
126 	return n << ECC_DIGITS_TO_BYTES_SHIFT;
127 }
128 
129 /*
130  * Wait for ECC idle i.e when an operation (other than write operations)
131  * is done.
132  */
133 static inline int ocs_ecc_wait_idle(struct ocs_ecc_dev *dev)
134 {
135 	u32 value;
136 
137 	return readl_poll_timeout((dev->base_reg + HW_OFFS_OCS_ECC_STATUS),
138 				  value,
139 				  !(value & HW_OCS_ECC_ISR_INT_STATUS_DONE),
140 				  POLL_USEC, TIMEOUT_USEC);
141 }
142 
143 static void ocs_ecc_cmd_start(struct ocs_ecc_dev *ecc_dev, u32 op_size)
144 {
145 	iowrite32(op_size | HW_OCS_ECC_COMMAND_START_VAL,
146 		  ecc_dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
147 }
148 
149 /* Direct write of u32 buffer to ECC engine with associated instruction. */
150 static void ocs_ecc_write_cmd_and_data(struct ocs_ecc_dev *dev,
151 				       u32 op_size,
152 				       u32 inst,
153 				       const void *data_in,
154 				       size_t data_size)
155 {
156 	iowrite32(op_size | inst, dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
157 
158 	/* MMIO Write src uint32 to dst. */
159 	memcpy_toio(dev->base_reg + HW_OFFS_OCS_ECC_DATA_IN, data_in,
160 		    data_size);
161 }
162 
163 /* Start OCS ECC operation and wait for its completion. */
164 static int ocs_ecc_trigger_op(struct ocs_ecc_dev *ecc_dev, u32 op_size,
165 			      u32 inst)
166 {
167 	reinit_completion(&ecc_dev->irq_done);
168 
169 	iowrite32(ECC_ENABLE_INTR, ecc_dev->base_reg + HW_OFFS_OCS_ECC_IER);
170 	iowrite32(op_size | inst, ecc_dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
171 
172 	return wait_for_completion_interruptible(&ecc_dev->irq_done);
173 }
174 
175 /**
176  * ocs_ecc_read_cx_out() - Read the CX data output buffer.
177  * @dev:	The OCS ECC device to read from.
178  * @cx_out:	The buffer where to store the CX value. Must be at least
179  *		@byte_count byte long.
180  * @byte_count:	The amount of data to read.
181  */
182 static inline void ocs_ecc_read_cx_out(struct ocs_ecc_dev *dev, void *cx_out,
183 				       size_t byte_count)
184 {
185 	memcpy_fromio(cx_out, dev->base_reg + HW_OFFS_OCS_ECC_CX_DATA_OUT,
186 		      byte_count);
187 }
188 
189 /**
190  * ocs_ecc_read_cy_out() - Read the CX data output buffer.
191  * @dev:	The OCS ECC device to read from.
192  * @cy_out:	The buffer where to store the CY value. Must be at least
193  *		@byte_count byte long.
194  * @byte_count:	The amount of data to read.
195  */
196 static inline void ocs_ecc_read_cy_out(struct ocs_ecc_dev *dev, void *cy_out,
197 				       size_t byte_count)
198 {
199 	memcpy_fromio(cy_out, dev->base_reg + HW_OFFS_OCS_ECC_CY_DATA_OUT,
200 		      byte_count);
201 }
202 
203 static struct ocs_ecc_dev *kmb_ocs_ecc_find_dev(struct ocs_ecc_ctx *tctx)
204 {
205 	if (tctx->ecc_dev)
206 		return tctx->ecc_dev;
207 
208 	spin_lock(&ocs_ecc.lock);
209 
210 	/* Only a single OCS device available. */
211 	tctx->ecc_dev = list_first_entry(&ocs_ecc.dev_list, struct ocs_ecc_dev,
212 					 list);
213 
214 	spin_unlock(&ocs_ecc.lock);
215 
216 	return tctx->ecc_dev;
217 }
218 
219 /* Do point multiplication using OCS ECC HW. */
220 static int kmb_ecc_point_mult(struct ocs_ecc_dev *ecc_dev,
221 			      struct ecc_point *result,
222 			      const struct ecc_point *point,
223 			      u64 *scalar,
224 			      const struct ecc_curve *curve)
225 {
226 	u8 sca[KMB_ECC_VLI_MAX_BYTES]; /* Use the maximum data size. */
227 	u32 op_size = (curve->g.ndigits > ECC_CURVE_NIST_P256_DIGITS) ?
228 		      OCS_ECC_OP_SIZE_384 : OCS_ECC_OP_SIZE_256;
229 	size_t nbytes = digits_to_bytes(curve->g.ndigits);
230 	int rc = 0;
231 
232 	/* Generate random nbytes for Simple and Differential SCA protection. */
233 	rc = crypto_stdrng_get_bytes(sca, nbytes);
234 	if (rc)
235 		return rc;
236 
237 	/* Wait engine to be idle before starting new operation. */
238 	rc = ocs_ecc_wait_idle(ecc_dev);
239 	if (rc)
240 		return rc;
241 
242 	/* Send ecc_start pulse as well as indicating operation size. */
243 	ocs_ecc_cmd_start(ecc_dev, op_size);
244 
245 	/* Write ax param; Base point (Gx). */
246 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AX,
247 				   point->x, nbytes);
248 
249 	/* Write ay param; Base point (Gy). */
250 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AY,
251 				   point->y, nbytes);
252 
253 	/*
254 	 * Write the private key into DATA_IN reg.
255 	 *
256 	 * Since DATA_IN register is used to write different values during the
257 	 * computation private Key value is overwritten with
258 	 * side-channel-resistance value.
259 	 */
260 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_BX_D,
261 				   scalar, nbytes);
262 
263 	/* Write operand by/l. */
264 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_BY_L,
265 				   sca, nbytes);
266 	memzero_explicit(sca, sizeof(sca));
267 
268 	/* Write p = curve prime(GF modulus). */
269 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_P,
270 				   curve->p, nbytes);
271 
272 	/* Write a = curve coefficient. */
273 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_A,
274 				   curve->a, nbytes);
275 
276 	/* Make hardware perform the multiplication. */
277 	rc = ocs_ecc_trigger_op(ecc_dev, op_size, OCS_ECC_INST_CALC_D_IDX_A);
278 	if (rc)
279 		return rc;
280 
281 	/* Read result. */
282 	ocs_ecc_read_cx_out(ecc_dev, result->x, nbytes);
283 	ocs_ecc_read_cy_out(ecc_dev, result->y, nbytes);
284 
285 	return 0;
286 }
287 
288 /**
289  * kmb_ecc_do_scalar_op() - Perform Scalar operation using OCS ECC HW.
290  * @ecc_dev:	The OCS ECC device to use.
291  * @scalar_out:	Where to store the output scalar.
292  * @scalar_a:	Input scalar operand 'a'.
293  * @scalar_b:	Input scalar operand 'b'
294  * @curve:	The curve on which the operation is performed.
295  * @ndigits:	The size of the operands (in digits).
296  * @inst:	The operation to perform (as an OCS ECC instruction).
297  *
298  * Return:	0 on success, negative error code otherwise.
299  */
300 static int kmb_ecc_do_scalar_op(struct ocs_ecc_dev *ecc_dev, u64 *scalar_out,
301 				const u64 *scalar_a, const u64 *scalar_b,
302 				const struct ecc_curve *curve,
303 				unsigned int ndigits, const u32 inst)
304 {
305 	u32 op_size = (ndigits > ECC_CURVE_NIST_P256_DIGITS) ?
306 		      OCS_ECC_OP_SIZE_384 : OCS_ECC_OP_SIZE_256;
307 	size_t nbytes = digits_to_bytes(ndigits);
308 	int rc;
309 
310 	/* Wait engine to be idle before starting new operation. */
311 	rc = ocs_ecc_wait_idle(ecc_dev);
312 	if (rc)
313 		return rc;
314 
315 	/* Send ecc_start pulse as well as indicating operation size. */
316 	ocs_ecc_cmd_start(ecc_dev, op_size);
317 
318 	/* Write ax param (Base point (Gx).*/
319 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AX,
320 				   scalar_a, nbytes);
321 
322 	/* Write ay param Base point (Gy).*/
323 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AY,
324 				   scalar_b, nbytes);
325 
326 	/* Write p = curve prime(GF modulus).*/
327 	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_P,
328 				   curve->p, nbytes);
329 
330 	/* Give instruction A.B or A+B to ECC engine. */
331 	rc = ocs_ecc_trigger_op(ecc_dev, op_size, inst);
332 	if (rc)
333 		return rc;
334 
335 	ocs_ecc_read_cx_out(ecc_dev, scalar_out, nbytes);
336 
337 	if (vli_is_zero(scalar_out, ndigits))
338 		return -EINVAL;
339 
340 	return 0;
341 }
342 
343 /* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */
344 static int kmb_ocs_ecc_is_pubkey_valid_partial(struct ocs_ecc_dev *ecc_dev,
345 					       const struct ecc_curve *curve,
346 					       struct ecc_point *pk)
347 {
348 	u64 xxx[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
349 	u64 yy[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
350 	u64 w[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
351 	int rc;
352 
353 	if (WARN_ON(pk->ndigits != curve->g.ndigits))
354 		return -EINVAL;
355 
356 	/* Check 1: Verify key is not the zero point. */
357 	if (ecc_point_is_zero(pk))
358 		return -EINVAL;
359 
360 	/* Check 2: Verify key is in the range [0, p-1]. */
361 	if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1)
362 		return -EINVAL;
363 
364 	if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1)
365 		return -EINVAL;
366 
367 	/* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */
368 
369 	 /* y^2 */
370 	/* Compute y^2 -> store in yy */
371 	rc = kmb_ecc_do_scalar_op(ecc_dev, yy, pk->y, pk->y, curve, pk->ndigits,
372 				  OCS_ECC_INST_CALC_A_MUL_B_MODP);
373 	if (rc)
374 		goto exit;
375 
376 	/* x^3 */
377 	/* Assigning w = 3, used for calculating x^3. */
378 	w[0] = POW_CUBE;
379 	/* Load the next stage.*/
380 	rc = kmb_ecc_do_scalar_op(ecc_dev, xxx, pk->x, w, curve, pk->ndigits,
381 				  OCS_ECC_INST_CALC_A_POW_B_MODP);
382 	if (rc)
383 		goto exit;
384 
385 	/* Do a*x -> store in w. */
386 	rc = kmb_ecc_do_scalar_op(ecc_dev, w, curve->a, pk->x, curve,
387 				  pk->ndigits,
388 				  OCS_ECC_INST_CALC_A_MUL_B_MODP);
389 	if (rc)
390 		goto exit;
391 
392 	/* Do ax + b == w + b; store in w. */
393 	rc = kmb_ecc_do_scalar_op(ecc_dev, w, w, curve->b, curve,
394 				  pk->ndigits,
395 				  OCS_ECC_INST_CALC_A_ADD_B_MODP);
396 	if (rc)
397 		goto exit;
398 
399 	/* x^3 + ax + b == x^3 + w -> store in w. */
400 	rc = kmb_ecc_do_scalar_op(ecc_dev, w, xxx, w, curve, pk->ndigits,
401 				  OCS_ECC_INST_CALC_A_ADD_B_MODP);
402 	if (rc)
403 		goto exit;
404 
405 	/* Compare y^2 == x^3 + a·x + b. */
406 	rc = vli_cmp(yy, w, pk->ndigits);
407 	if (rc)
408 		rc = -EINVAL;
409 
410 exit:
411 	memzero_explicit(xxx, sizeof(xxx));
412 	memzero_explicit(yy, sizeof(yy));
413 	memzero_explicit(w, sizeof(w));
414 
415 	return rc;
416 }
417 
418 /* SP800-56A section 5.6.2.3.3 full verification */
419 static int kmb_ocs_ecc_is_pubkey_valid_full(struct ocs_ecc_dev *ecc_dev,
420 					    const struct ecc_curve *curve,
421 					    struct ecc_point *pk)
422 {
423 	struct ecc_point *nQ;
424 	int rc;
425 
426 	/* Checks 1 through 3 */
427 	rc = kmb_ocs_ecc_is_pubkey_valid_partial(ecc_dev, curve, pk);
428 	if (rc)
429 		return rc;
430 
431 	/* Check 4: Verify that nQ is the zero point. */
432 	nQ = ecc_alloc_point(pk->ndigits);
433 	if (!nQ)
434 		return -ENOMEM;
435 
436 	rc = kmb_ecc_point_mult(ecc_dev, nQ, pk, curve->n, curve);
437 	if (rc)
438 		goto exit;
439 
440 	if (!ecc_point_is_zero(nQ))
441 		rc = -EINVAL;
442 
443 exit:
444 	ecc_free_point(nQ);
445 
446 	return rc;
447 }
448 
449 static int kmb_ecc_is_key_valid(const struct ecc_curve *curve,
450 				const u64 *private_key, size_t private_key_len)
451 {
452 	size_t ndigits = curve->g.ndigits;
453 	u64 one[KMB_ECC_VLI_MAX_DIGITS] = {1};
454 	u64 res[KMB_ECC_VLI_MAX_DIGITS];
455 
456 	if (private_key_len != digits_to_bytes(ndigits))
457 		return -EINVAL;
458 
459 	if (!private_key)
460 		return -EINVAL;
461 
462 	/* Make sure the private key is in the range [2, n-3]. */
463 	if (vli_cmp(one, private_key, ndigits) != -1)
464 		return -EINVAL;
465 
466 	vli_sub(res, curve->n, one, ndigits);
467 	vli_sub(res, res, one, ndigits);
468 	if (vli_cmp(res, private_key, ndigits) != 1)
469 		return -EINVAL;
470 
471 	return 0;
472 }
473 
474 /*
475  * ECC private keys are generated using the method of extra random bits,
476  * equivalent to that described in FIPS 186-4, Appendix B.4.1.
477  *
478  * d = (c mod(n–1)) + 1    where c is a string of random bits, 64 bits longer
479  *                         than requested
480  * 0 <= c mod(n-1) <= n-2  and implies that
481  * 1 <= d <= n-1
482  *
483  * This method generates a private key uniformly distributed in the range
484  * [1, n-1].
485  */
486 static int kmb_ecc_gen_privkey(const struct ecc_curve *curve, u64 *privkey)
487 {
488 	size_t nbytes = digits_to_bytes(curve->g.ndigits);
489 	u64 priv[KMB_ECC_VLI_MAX_DIGITS];
490 	size_t nbits;
491 	int rc;
492 
493 	nbits = vli_num_bits(curve->n, curve->g.ndigits);
494 
495 	/* Check that N is included in Table 1 of FIPS 186-4, section 6.1.1 */
496 	if (nbits < 160 || curve->g.ndigits > ARRAY_SIZE(priv))
497 		return -EINVAL;
498 
499 	/*
500 	 * FIPS 186-4 recommends that the private key should be obtained from a
501 	 * RBG with a security strength equal to or greater than the security
502 	 * strength associated with N.
503 	 *
504 	 * The maximum security strength identified by NIST SP800-57pt1r4 for
505 	 * ECC is 256 (N >= 512).
506 	 *
507 	 * This condition is met by stdrng because it selects a favored DRBG
508 	 * with a security strength of 256.
509 	 */
510 	rc = crypto_stdrng_get_bytes(priv, nbytes);
511 	if (rc)
512 		goto cleanup;
513 
514 	rc = kmb_ecc_is_key_valid(curve, priv, nbytes);
515 	if (rc)
516 		goto cleanup;
517 
518 	ecc_swap_digits(priv, privkey, curve->g.ndigits);
519 
520 cleanup:
521 	memzero_explicit(&priv, sizeof(priv));
522 
523 	return rc;
524 }
525 
526 static int kmb_ocs_ecdh_set_secret(struct crypto_kpp *tfm, const void *buf,
527 				   unsigned int len)
528 {
529 	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
530 	struct ecdh params;
531 	int rc = 0;
532 
533 	rc = crypto_ecdh_decode_key(buf, len, &params);
534 	if (rc)
535 		goto cleanup;
536 
537 	/* Ensure key size is not bigger then expected. */
538 	if (params.key_size > digits_to_bytes(tctx->curve->g.ndigits)) {
539 		rc = -EINVAL;
540 		goto cleanup;
541 	}
542 
543 	/* Auto-generate private key is not provided. */
544 	if (!params.key || !params.key_size) {
545 		rc = kmb_ecc_gen_privkey(tctx->curve, tctx->private_key);
546 		goto cleanup;
547 	}
548 
549 	rc = kmb_ecc_is_key_valid(tctx->curve, (const u64 *)params.key,
550 				  params.key_size);
551 	if (rc)
552 		goto cleanup;
553 
554 	ecc_swap_digits((const u64 *)params.key, tctx->private_key,
555 			tctx->curve->g.ndigits);
556 cleanup:
557 	memzero_explicit(&params, sizeof(params));
558 
559 	if (rc)
560 		tctx->curve = NULL;
561 
562 	return rc;
563 }
564 
565 /* Compute shared secret. */
566 static int kmb_ecc_do_shared_secret(struct ocs_ecc_ctx *tctx,
567 				    struct kpp_request *req)
568 {
569 	struct ocs_ecc_dev *ecc_dev = tctx->ecc_dev;
570 	const struct ecc_curve *curve = tctx->curve;
571 	u64 shared_secret[KMB_ECC_VLI_MAX_DIGITS];
572 	u64 pubk_buf[KMB_ECC_VLI_MAX_DIGITS * 2];
573 	size_t copied, nbytes, pubk_len;
574 	struct ecc_point *pk, *result;
575 	int rc;
576 
577 	nbytes = digits_to_bytes(curve->g.ndigits);
578 
579 	/* Public key is a point, thus it has two coordinates */
580 	pubk_len = 2 * nbytes;
581 
582 	/* Copy public key from SG list to pubk_buf. */
583 	copied = sg_copy_to_buffer(req->src,
584 				   sg_nents_for_len(req->src, pubk_len),
585 				   pubk_buf, pubk_len);
586 	if (copied != pubk_len)
587 		return -EINVAL;
588 
589 	/* Allocate and initialize public key point. */
590 	pk = ecc_alloc_point(curve->g.ndigits);
591 	if (!pk)
592 		return -ENOMEM;
593 
594 	ecc_swap_digits(pubk_buf, pk->x, curve->g.ndigits);
595 	ecc_swap_digits(&pubk_buf[curve->g.ndigits], pk->y, curve->g.ndigits);
596 
597 	/*
598 	 * Check the public key for following
599 	 * Check 1: Verify key is not the zero point.
600 	 * Check 2: Verify key is in the range [1, p-1].
601 	 * Check 3: Verify that y^2 == (x^3 + a·x + b) mod p
602 	 */
603 	rc = kmb_ocs_ecc_is_pubkey_valid_partial(ecc_dev, curve, pk);
604 	if (rc)
605 		goto exit_free_pk;
606 
607 	/* Allocate point for storing computed shared secret. */
608 	result = ecc_alloc_point(pk->ndigits);
609 	if (!result) {
610 		rc = -ENOMEM;
611 		goto exit_free_pk;
612 	}
613 
614 	/* Calculate the shared secret.*/
615 	rc = kmb_ecc_point_mult(ecc_dev, result, pk, tctx->private_key, curve);
616 	if (rc)
617 		goto exit_free_result;
618 
619 	if (ecc_point_is_zero(result)) {
620 		rc = -EFAULT;
621 		goto exit_free_result;
622 	}
623 
624 	/* Copy shared secret from point to buffer. */
625 	ecc_swap_digits(result->x, shared_secret, result->ndigits);
626 
627 	/* Request might ask for less bytes than what we have. */
628 	nbytes = min_t(size_t, nbytes, req->dst_len);
629 
630 	copied = sg_copy_from_buffer(req->dst,
631 				     sg_nents_for_len(req->dst, nbytes),
632 				     shared_secret, nbytes);
633 
634 	if (copied != nbytes)
635 		rc = -EINVAL;
636 
637 	memzero_explicit(shared_secret, sizeof(shared_secret));
638 
639 exit_free_result:
640 	ecc_free_point(result);
641 
642 exit_free_pk:
643 	ecc_free_point(pk);
644 
645 	return rc;
646 }
647 
648 /* Compute public key. */
649 static int kmb_ecc_do_public_key(struct ocs_ecc_ctx *tctx,
650 				 struct kpp_request *req)
651 {
652 	const struct ecc_curve *curve = tctx->curve;
653 	u64 pubk_buf[KMB_ECC_VLI_MAX_DIGITS * 2];
654 	struct ecc_point *pk;
655 	size_t pubk_len;
656 	size_t copied;
657 	int rc;
658 
659 	/* Public key is a point, so it has double the digits. */
660 	pubk_len = 2 * digits_to_bytes(curve->g.ndigits);
661 
662 	pk = ecc_alloc_point(curve->g.ndigits);
663 	if (!pk)
664 		return -ENOMEM;
665 
666 	/* Public Key(pk) = priv * G. */
667 	rc = kmb_ecc_point_mult(tctx->ecc_dev, pk, &curve->g, tctx->private_key,
668 				curve);
669 	if (rc)
670 		goto exit;
671 
672 	/* SP800-56A rev 3 5.6.2.1.3 key check */
673 	if (kmb_ocs_ecc_is_pubkey_valid_full(tctx->ecc_dev, curve, pk)) {
674 		rc = -EAGAIN;
675 		goto exit;
676 	}
677 
678 	/* Copy public key from point to buffer. */
679 	ecc_swap_digits(pk->x, pubk_buf, pk->ndigits);
680 	ecc_swap_digits(pk->y, &pubk_buf[pk->ndigits], pk->ndigits);
681 
682 	/* Copy public key to req->dst. */
683 	copied = sg_copy_from_buffer(req->dst,
684 				     sg_nents_for_len(req->dst, pubk_len),
685 				     pubk_buf, pubk_len);
686 
687 	if (copied != pubk_len)
688 		rc = -EINVAL;
689 
690 exit:
691 	ecc_free_point(pk);
692 
693 	return rc;
694 }
695 
696 static int kmb_ocs_ecc_do_one_request(struct crypto_engine *engine,
697 				      void *areq)
698 {
699 	struct kpp_request *req = container_of(areq, struct kpp_request, base);
700 	struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
701 	struct ocs_ecc_dev *ecc_dev = tctx->ecc_dev;
702 	int rc;
703 
704 	if (req->src)
705 		rc = kmb_ecc_do_shared_secret(tctx, req);
706 	else
707 		rc = kmb_ecc_do_public_key(tctx, req);
708 
709 	crypto_finalize_kpp_request(ecc_dev->engine, req, rc);
710 
711 	return 0;
712 }
713 
714 static int kmb_ocs_ecdh_generate_public_key(struct kpp_request *req)
715 {
716 	struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
717 	const struct ecc_curve *curve = tctx->curve;
718 
719 	/* Ensure kmb_ocs_ecdh_set_secret() has been successfully called. */
720 	if (!tctx->curve)
721 		return -EINVAL;
722 
723 	/* Ensure dst is present. */
724 	if (!req->dst)
725 		return -EINVAL;
726 
727 	/* Check the request dst is big enough to hold the public key. */
728 	if (req->dst_len < (2 * digits_to_bytes(curve->g.ndigits)))
729 		return -EINVAL;
730 
731 	/* 'src' is not supposed to be present when generate pubk is called. */
732 	if (req->src)
733 		return -EINVAL;
734 
735 	return crypto_transfer_kpp_request_to_engine(tctx->ecc_dev->engine,
736 						     req);
737 }
738 
739 static int kmb_ocs_ecdh_compute_shared_secret(struct kpp_request *req)
740 {
741 	struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
742 	const struct ecc_curve *curve = tctx->curve;
743 
744 	/* Ensure kmb_ocs_ecdh_set_secret() has been successfully called. */
745 	if (!tctx->curve)
746 		return -EINVAL;
747 
748 	/* Ensure dst is present. */
749 	if (!req->dst)
750 		return -EINVAL;
751 
752 	/* Ensure src is present. */
753 	if (!req->src)
754 		return -EINVAL;
755 
756 	/*
757 	 * req->src is expected to the (other-side) public key, so its length
758 	 * must be 2 * coordinate size (in bytes).
759 	 */
760 	if (req->src_len != 2 * digits_to_bytes(curve->g.ndigits))
761 		return -EINVAL;
762 
763 	return crypto_transfer_kpp_request_to_engine(tctx->ecc_dev->engine,
764 						     req);
765 }
766 
767 static int kmb_ecc_tctx_init(struct ocs_ecc_ctx *tctx, unsigned int curve_id)
768 {
769 	memset(tctx, 0, sizeof(*tctx));
770 
771 	tctx->ecc_dev = kmb_ocs_ecc_find_dev(tctx);
772 
773 	if (IS_ERR(tctx->ecc_dev)) {
774 		pr_err("Failed to find the device : %ld\n",
775 		       PTR_ERR(tctx->ecc_dev));
776 		return PTR_ERR(tctx->ecc_dev);
777 	}
778 
779 	tctx->curve = ecc_get_curve(curve_id);
780 	if (!tctx->curve)
781 		return -EOPNOTSUPP;
782 
783 	return 0;
784 }
785 
786 static int kmb_ocs_ecdh_nist_p256_init_tfm(struct crypto_kpp *tfm)
787 {
788 	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
789 
790 	return kmb_ecc_tctx_init(tctx, ECC_CURVE_NIST_P256);
791 }
792 
793 static int kmb_ocs_ecdh_nist_p384_init_tfm(struct crypto_kpp *tfm)
794 {
795 	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
796 
797 	return kmb_ecc_tctx_init(tctx, ECC_CURVE_NIST_P384);
798 }
799 
800 static void kmb_ocs_ecdh_exit_tfm(struct crypto_kpp *tfm)
801 {
802 	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
803 
804 	memzero_explicit(tctx->private_key, sizeof(*tctx->private_key));
805 }
806 
807 static unsigned int kmb_ocs_ecdh_max_size(struct crypto_kpp *tfm)
808 {
809 	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
810 
811 	/* Public key is made of two coordinates, so double the digits. */
812 	return digits_to_bytes(tctx->curve->g.ndigits) * 2;
813 }
814 
815 static struct kpp_engine_alg ocs_ecdh_p256 = {
816 	.base.set_secret = kmb_ocs_ecdh_set_secret,
817 	.base.generate_public_key = kmb_ocs_ecdh_generate_public_key,
818 	.base.compute_shared_secret = kmb_ocs_ecdh_compute_shared_secret,
819 	.base.init = kmb_ocs_ecdh_nist_p256_init_tfm,
820 	.base.exit = kmb_ocs_ecdh_exit_tfm,
821 	.base.max_size = kmb_ocs_ecdh_max_size,
822 	.base.base = {
823 		.cra_name = "ecdh-nist-p256",
824 		.cra_driver_name = "ecdh-nist-p256-keembay-ocs",
825 		.cra_priority = KMB_OCS_ECC_PRIORITY,
826 		.cra_module = THIS_MODULE,
827 		.cra_ctxsize = sizeof(struct ocs_ecc_ctx),
828 	},
829 	.op.do_one_request = kmb_ocs_ecc_do_one_request,
830 };
831 
832 static struct kpp_engine_alg ocs_ecdh_p384 = {
833 	.base.set_secret = kmb_ocs_ecdh_set_secret,
834 	.base.generate_public_key = kmb_ocs_ecdh_generate_public_key,
835 	.base.compute_shared_secret = kmb_ocs_ecdh_compute_shared_secret,
836 	.base.init = kmb_ocs_ecdh_nist_p384_init_tfm,
837 	.base.exit = kmb_ocs_ecdh_exit_tfm,
838 	.base.max_size = kmb_ocs_ecdh_max_size,
839 	.base.base = {
840 		.cra_name = "ecdh-nist-p384",
841 		.cra_driver_name = "ecdh-nist-p384-keembay-ocs",
842 		.cra_priority = KMB_OCS_ECC_PRIORITY,
843 		.cra_module = THIS_MODULE,
844 		.cra_ctxsize = sizeof(struct ocs_ecc_ctx),
845 	},
846 	.op.do_one_request = kmb_ocs_ecc_do_one_request,
847 };
848 
849 static irqreturn_t ocs_ecc_irq_handler(int irq, void *dev_id)
850 {
851 	struct ocs_ecc_dev *ecc_dev = dev_id;
852 	u32 status;
853 
854 	/*
855 	 * Read the status register and write it back to clear the
856 	 * DONE_INT_STATUS bit.
857 	 */
858 	status = ioread32(ecc_dev->base_reg + HW_OFFS_OCS_ECC_ISR);
859 	iowrite32(status, ecc_dev->base_reg + HW_OFFS_OCS_ECC_ISR);
860 
861 	if (!(status & HW_OCS_ECC_ISR_INT_STATUS_DONE))
862 		return IRQ_NONE;
863 
864 	complete(&ecc_dev->irq_done);
865 
866 	return IRQ_HANDLED;
867 }
868 
869 static int kmb_ocs_ecc_probe(struct platform_device *pdev)
870 {
871 	struct device *dev = &pdev->dev;
872 	struct ocs_ecc_dev *ecc_dev;
873 	int rc;
874 
875 	ecc_dev = devm_kzalloc(dev, sizeof(*ecc_dev), GFP_KERNEL);
876 	if (!ecc_dev)
877 		return -ENOMEM;
878 
879 	ecc_dev->dev = dev;
880 
881 	platform_set_drvdata(pdev, ecc_dev);
882 
883 	INIT_LIST_HEAD(&ecc_dev->list);
884 	init_completion(&ecc_dev->irq_done);
885 
886 	/* Get base register address. */
887 	ecc_dev->base_reg = devm_platform_ioremap_resource(pdev, 0);
888 	if (IS_ERR(ecc_dev->base_reg)) {
889 		dev_err(dev, "Failed to get base address\n");
890 		rc = PTR_ERR(ecc_dev->base_reg);
891 		goto list_del;
892 	}
893 
894 	/* Get and request IRQ */
895 	ecc_dev->irq = platform_get_irq(pdev, 0);
896 	if (ecc_dev->irq < 0) {
897 		rc = ecc_dev->irq;
898 		goto list_del;
899 	}
900 
901 	rc = devm_request_threaded_irq(dev, ecc_dev->irq, ocs_ecc_irq_handler,
902 				       NULL, 0, "keembay-ocs-ecc", ecc_dev);
903 	if (rc < 0) {
904 		dev_err(dev, "Could not request IRQ\n");
905 		goto list_del;
906 	}
907 
908 	/* Add device to the list of OCS ECC devices. */
909 	spin_lock(&ocs_ecc.lock);
910 	list_add_tail(&ecc_dev->list, &ocs_ecc.dev_list);
911 	spin_unlock(&ocs_ecc.lock);
912 
913 	/* Initialize crypto engine. */
914 	ecc_dev->engine = crypto_engine_alloc_init(dev, 1);
915 	if (!ecc_dev->engine) {
916 		dev_err(dev, "Could not allocate crypto engine\n");
917 		rc = -ENOMEM;
918 		goto list_del;
919 	}
920 
921 	rc = crypto_engine_start(ecc_dev->engine);
922 	if (rc) {
923 		dev_err(dev, "Could not start crypto engine\n");
924 		goto cleanup;
925 	}
926 
927 	/* Register the KPP algo. */
928 	rc = crypto_engine_register_kpp(&ocs_ecdh_p256);
929 	if (rc) {
930 		dev_err(dev,
931 			"Could not register OCS algorithms with Crypto API\n");
932 		goto cleanup;
933 	}
934 
935 	rc = crypto_engine_register_kpp(&ocs_ecdh_p384);
936 	if (rc) {
937 		dev_err(dev,
938 			"Could not register OCS algorithms with Crypto API\n");
939 		goto ocs_ecdh_p384_error;
940 	}
941 
942 	return 0;
943 
944 ocs_ecdh_p384_error:
945 	crypto_engine_unregister_kpp(&ocs_ecdh_p256);
946 
947 cleanup:
948 	crypto_engine_exit(ecc_dev->engine);
949 
950 list_del:
951 	spin_lock(&ocs_ecc.lock);
952 	list_del(&ecc_dev->list);
953 	spin_unlock(&ocs_ecc.lock);
954 
955 	return rc;
956 }
957 
958 static void kmb_ocs_ecc_remove(struct platform_device *pdev)
959 {
960 	struct ocs_ecc_dev *ecc_dev;
961 
962 	ecc_dev = platform_get_drvdata(pdev);
963 
964 	crypto_engine_unregister_kpp(&ocs_ecdh_p384);
965 	crypto_engine_unregister_kpp(&ocs_ecdh_p256);
966 
967 	spin_lock(&ocs_ecc.lock);
968 	list_del(&ecc_dev->list);
969 	spin_unlock(&ocs_ecc.lock);
970 
971 	crypto_engine_exit(ecc_dev->engine);
972 }
973 
974 /* Device tree driver match. */
975 static const struct of_device_id kmb_ocs_ecc_of_match[] = {
976 	{
977 		.compatible = "intel,keembay-ocs-ecc",
978 	},
979 	{}
980 };
981 
982 /* The OCS driver is a platform device. */
983 static struct platform_driver kmb_ocs_ecc_driver = {
984 	.probe = kmb_ocs_ecc_probe,
985 	.remove = kmb_ocs_ecc_remove,
986 	.driver = {
987 			.name = DRV_NAME,
988 			.of_match_table = kmb_ocs_ecc_of_match,
989 		},
990 };
991 module_platform_driver(kmb_ocs_ecc_driver);
992 
993 MODULE_LICENSE("GPL");
994 MODULE_DESCRIPTION("Intel Keem Bay OCS ECC Driver");
995 MODULE_ALIAS_CRYPTO("ecdh-nist-p256");
996 MODULE_ALIAS_CRYPTO("ecdh-nist-p384");
997 MODULE_ALIAS_CRYPTO("ecdh-nist-p256-keembay-ocs");
998 MODULE_ALIAS_CRYPTO("ecdh-nist-p384-keembay-ocs");
999