xref: /freebsd/crypto/libecc/src/examples/sig/rsa/rsa.c (revision f0865ec9906d5a18fa2a3b61381f22ce16e606ad)
1 /*
2  *  Copyright (C) 2021 - This file is part of libecc project
3  *
4  *  Authors:
5  *      Ryad BENADJILA <ryadbenadjila@gmail.com>
6  *      Arnaud EBALARD <arnaud.ebalard@ssi.gouv.fr>
7  *
8  *  This software is licensed under a dual BSD and GPL v2 license.
9  *  See LICENSE file at the root folder of the project.
10  */
11 #include "rsa.h"
12 #include "rsa_tests.h"
13 
14 
15 /* We include the rand external dependency because we have to generate
16  * some random data for the padding.
17  */
18 #include <libecc/external_deps/rand.h>
19 /* We include the printf external dependency for printf output */
20 #include <libecc/external_deps/print.h>
21 /* We include our common helpers */
22 #include "../common/common.h"
23 
24 
25 /*
26  * The purpose of this example is to implement the RSA
27  * related algorithms as per RFC 8017 and ISO/IEC 9796-2 based
28  * on libecc arithmetic primitives.
29  *
30  * XXX: Please be aware that libecc has been designed for Elliptic
31  * Curve cryptography, and as so the arithmetic primitives are
32  * not optimized for big numbers >= 1024 bits usually used for RSA.
33  * Additionnaly, a hard limit of our NN values makes it impossible
34  * to exceed ~5300 bits in the best case (words of size 64 bits).
35  *
36  * All in all, please see this as a proof of concept of implementing
37  * RFC 8017 rather than a production code. Use it at your own risk!
38  *
39  * !! DISCLAIMER !!
40  * ================
41  * Although some efforts have been put on providing a clean code and although many of
42  * the underlying arithmetic primitives are constant time, only basic efforts have
43  * been deployed to prevent advanced side channels (e.g. to protect the private values
44  * against elaborate microarchitectural side-channels and so on). The modular exponentation
45  * uses a Montgomery Ladder, and message blinding is performed to mitigate basic SCA.
46  * Please note that the modular exponentation is NOT constant time wrt the MSB of
47  * the private exponent, which should be OK in the general case as this leak is less
48  * critical than for DSA and ECDSA nonces in scalar multiplication (raising HNP issues
49  * in these last cases).
50  * Optionally, when BLINDING=1 is activated, exponent blinding is used by adding a
51  * "small" (128 bits) multiple of the "order" (this is left as optional because of
52  * the big impacts on performance), somehow limiting the modular exponentiation MSB
53  * issue at the expense of performance.
54  *
55  * Padding oracles (Bleichenbacher, Manger) in RSA PKCS#1 v1.5 and RSA-OAEP decryption
56  * primitives are taken into account, although no absolute guarantee can be made on this
57  * (and mostly: these oracles also heavily depend on what the upper layer callers do).
58  *
59  * Fault injection attacks "a la Bellcore" are protected using a sanity check that
60  * the exponentiation to the public exponent provides the same input as the operation
61  * using the private exponent.
62  *
63  * !!NOTE: only the *_hardened* suffixed APIs are protected, the non suffixed ones are
64  * *NOT* protected. This is mainly due to the fact that the protections use the public
65  * key while the RFC APIs handling private operations only take the private key as
66  * input. Hence, please *USE* the *_hardened* APIs if unsure about your usage context!
67  *
68  * Also, as for all other libecc primitives, beware of randomness sources. By default,
69  * the library uses the OS random sources (e.g. "/dev/urandom"), but the user
70  * is encouraged to adapt the ../external_deps/rand.c source file to combine
71  * multiple sources and add entropy there depending on the context where this
72  * code is integrated. The security level of all the cryptographic primitives
73  * heavily relies on random sources quality.
74  *
75  * All-in-all, this piece of code can be useful in some contexts, or risky to
76  * use in other sensitive ones where advanced side-channels or fault attacks
77  * have to be considered. Use this RSA code knowingly and at your own risk!
78  *
79  */
80 
rsa_import_pub_key(rsa_pub_key * pub,const u8 * n,u16 nlen,const u8 * e,u16 elen)81 int rsa_import_pub_key(rsa_pub_key *pub, const u8 *n,
82                        u16 nlen, const u8 *e, u16 elen)
83 {
84 	int ret;
85 
86 	MUST_HAVE((pub != NULL), ret, err);
87 
88 	/* Import our big numbers */
89 	ret = nn_init_from_buf(&(pub->n), n, nlen); EG(ret, err);
90 	ret = nn_init_from_buf(&(pub->e), e, elen);
91 
92 err:
93 	if(ret && (pub != NULL)){
94 		IGNORE_RET_VAL(local_memset(pub, 0, sizeof(rsa_pub_key)));
95 	}
96 
97 	return ret;
98 }
99 
rsa_import_simple_priv_key(rsa_priv_key * priv,const u8 * n,u16 nlen,const u8 * d,u16 dlen,const u8 * p,u16 plen,const u8 * q,u16 qlen)100 int rsa_import_simple_priv_key(rsa_priv_key *priv,
101                                const u8 *n, u16 nlen, const u8 *d, u16 dlen,
102 			       const u8 *p, u16 plen, const u8 *q, u16 qlen)
103 {
104 	int ret;
105 
106 	MUST_HAVE((priv != NULL), ret, err);
107 
108 	MUST_HAVE(((p != NULL) && (q != NULL)) || ((p == NULL) && (q == NULL)), ret, err);
109 
110 	/* Import our big numbers */
111 	if((p == NULL) || (q == NULL)){
112 		priv->type = RSA_SIMPLE;
113 		ret = nn_init_from_buf(&(priv->key.s.n), n, nlen); EG(ret, err);
114 		ret = nn_init_from_buf(&(priv->key.s.d), d, dlen); EG(ret, err);
115 	}
116 	else{
117 		priv->type = RSA_SIMPLE_PQ;
118 		ret = nn_init_from_buf(&(priv->key.s_pq.n), n, nlen); EG(ret, err);
119 		ret = nn_init_from_buf(&(priv->key.s_pq.d), d, dlen); EG(ret, err);
120 		ret = nn_init_from_buf(&(priv->key.s_pq.p), p, plen); EG(ret, err);
121 		ret = nn_init_from_buf(&(priv->key.s_pq.q), q, qlen); EG(ret, err);
122 	}
123 
124 err:
125 	if(ret && (priv != NULL)){
126 		IGNORE_RET_VAL(local_memset(priv, 0, sizeof(rsa_priv_key)));
127 	}
128 
129 	return ret;
130 }
131 
rsa_import_crt_priv_key(rsa_priv_key * priv,const u8 * p,u16 plen,const u8 * q,u16 qlen,const u8 * dP,u16 dPlen,const u8 * dQ,u16 dQlen,const u8 * qInv,u16 qInvlen,const u8 ** coeffs,u16 * coeffslens,u8 u)132 int rsa_import_crt_priv_key(rsa_priv_key *priv,
133                             const u8 *p, u16 plen,
134                             const u8 *q, u16 qlen,
135                             const u8 *dP, u16 dPlen,
136                             const u8 *dQ, u16 dQlen,
137                             const u8 *qInv, u16 qInvlen,
138                             const u8 **coeffs, u16 *coeffslens, u8 u)
139 {
140 	int ret;
141 
142 	MUST_HAVE((priv != NULL), ret, err);
143 
144 	priv->type = RSA_CRT;
145 	/* Import our big numbers */
146 	ret = nn_init_from_buf(&(priv->key.crt.p), p, plen); EG(ret, err);
147 	ret = nn_init_from_buf(&(priv->key.crt.q), q, qlen); EG(ret, err);
148 	ret = nn_init_from_buf(&(priv->key.crt.dP), dP, dPlen); EG(ret, err);
149 	ret = nn_init_from_buf(&(priv->key.crt.dQ), dQ, dQlen); EG(ret, err);
150 	ret = nn_init_from_buf(&(priv->key.crt.qInv), qInv, qInvlen); EG(ret, err);
151 
152 	priv->key.crt.u = 0;
153 
154 	/* Import the optional coefficients if necessary */
155 	if(coeffs != NULL){
156 		unsigned int i;
157 
158 		MUST_HAVE((coeffslens != NULL), ret, err);
159 		MUST_HAVE((u > 0) && (u < MAX_CRT_COEFFS), ret, err);
160 
161 		priv->key.crt.u = u;
162 
163 		for(i = 0; i < (3*u); i += 3){
164 			rsa_priv_key_crt_coeffs *cur = &(priv->key.crt.coeffs[(i / 3)]);
165 
166 			ret = nn_init_from_buf(&(cur->r), coeffs[i],     coeffslens[i]);     EG(ret, err);
167 			ret = nn_init_from_buf(&(cur->d), coeffs[i + 1], coeffslens[i + 1]); EG(ret, err);
168 			ret = nn_init_from_buf(&(cur->t), coeffs[i + 2], coeffslens[i + 2]); EG(ret, err);
169 		}
170 	}
171 
172 err:
173 	if(ret && (priv != NULL)){
174 		IGNORE_RET_VAL(local_memset(priv, 0, sizeof(rsa_priv_key)));
175 	}
176 	return ret;
177 }
178 
179 /* I2OSP - Integer-to-Octet-String primitive
180  * (as decribed in section 4.1 of RFC 8017)
181  */
rsa_i2osp(nn_src_t x,u8 * buf,u32 buflen)182 int rsa_i2osp(nn_src_t x, u8 *buf, u32 buflen)
183 {
184 	int ret;
185 
186 	/* Size check */
187 	MUST_HAVE((buflen <= 0xffff), ret, err);
188 	ret = _i2osp(x, buf, (u16)buflen);
189 
190 err:
191 	return ret;
192 }
193 
194 /* OS2IP - Octet-String-to-Integer primitive
195  * (as decribed in section 4.2 of RFC 8017)
196  */
rsa_os2ip(nn_t x,const u8 * buf,u32 buflen)197 int rsa_os2ip(nn_t x, const u8 *buf, u32 buflen)
198 {
199 	int ret;
200 
201 	/* Size check */
202 	MUST_HAVE((buflen <= 0xffff), ret, err);
203 	ret = _os2ip(x, buf, (u16)buflen);
204 
205 err:
206 	return ret;
207 }
208 
209 /* The raw RSAEP function as defined in RFC 8017 section 5.1.1
210  *     Input: an RSA public key and a big int message
211  *     Output: a big int ciphertext
212  *     Assumption:  RSA public key K is valid
213  */
rsaep(const rsa_pub_key * pub,nn_src_t m,nn_t c)214 int rsaep(const rsa_pub_key *pub, nn_src_t m, nn_t c)
215 {
216 	int ret, cmp;
217 	nn_src_t n, e;
218 
219 	/* Sanity checks */
220 	MUST_HAVE((pub != NULL), ret, err);
221 
222 	/* Make things more readable */
223 	n = &(pub->n);
224 	e = &(pub->e);
225 
226 	/* Sanity checks */
227 	ret = nn_check_initialized(n); EG(ret, err);
228 	ret = nn_check_initialized(e); EG(ret, err);
229 
230 	/* Check that m is indeed in [0, n-1], trigger an error if not */
231 	MUST_HAVE((!nn_cmp(m, n, &cmp)) && (cmp < 0), ret, err);
232 
233 	/* Compute c = m^e mod n
234 	 * NOTE: we use our internal *insecure* modular exponentation as we
235 	 * are handling public key and data.
236 	 */
237 	ret = _nn_mod_pow_insecure(c, m, e, n);
238 
239 err:
240 	PTR_NULLIFY(n);
241 	PTR_NULLIFY(e);
242 
243 	return ret;
244 }
245 
246 #ifdef USE_SIG_BLINDING
247 #define RSA_EXPONENT_BLINDING_SIZE 128
248 /*
249  * Blind an exponent with a "small" multiple (of size "bits") of the input mod or (mod-1).
250  * We use a relatively small multiple mainly because of potential big performance impacts on
251  * modular exponentiation.
252  */
_rsa_blind_exponent(nn_src_t e,nn_src_t mod,nn_t out,bitcnt_t bits,u8 dec)253 ATTRIBUTE_WARN_UNUSED_RET static int _rsa_blind_exponent(nn_src_t e, nn_src_t mod, nn_t out, bitcnt_t bits, u8 dec)
254 {
255 	int ret, check;
256 	nn b;
257 	b.magic = WORD(0);
258 
259 	ret = nn_init(&b, 0); EG(ret, err);
260 	ret = nn_init(out, 0); EG(ret, err);
261 
262 	ret = nn_one(out); EG(ret, err);
263 	ret = nn_lshift(out, out, bits); EG(ret, err);
264 	ret = nn_iszero(out, &check); EG(ret, err);
265 	/* Check for overflow */
266 	MUST_HAVE(!check, ret, err);
267 
268 	/* Get a random value of "bits" count */
269 	ret = nn_get_random_mod(&b, out); EG(ret, err);
270 
271 	if(dec){
272 		ret = nn_copy(out, mod); EG(ret, err);
273 		ret = nn_dec(out, out); EG(ret, err);
274 		ret = nn_mul(&b, &b, out); EG(ret, err);
275 	}
276 	else{
277 		ret = nn_mul(&b, &b, mod); EG(ret, err);
278 	}
279 
280 	ret = nn_add(out, e, &b);
281 
282 err:
283 	nn_uninit(&b);
284 
285 	return ret;
286 }
287 #endif
288 
289 /* The raw RSADP function as defined in RFC 8017 section 5.1.2
290  *     Input: an RSA private key 'priv' and a big int ciphertext 'c'
291  *     Output: a big int clear message 'm'
292  *     Assumption:  RSA private key 'priv' is valid
293  */
rsadp_crt_coeffs(const rsa_priv_key * priv,nn_src_t c,nn_t m,u8 u)294 ATTRIBUTE_WARN_UNUSED_RET static int rsadp_crt_coeffs(const rsa_priv_key *priv, nn_src_t c, nn_t m, u8 u)
295 {
296 	int ret;
297 	unsigned int i;
298 	nn_src_t r_i, d_i, t_i, r_i_1;
299 	nn m_i, h, R;
300 	m_i.magic = h.magic = R.magic = WORD(0);
301 
302 	/* Sanity check on u */
303 	MUST_HAVE((u < MAX_CRT_COEFFS), ret, err);
304 
305 	ret = nn_init(&m_i, 0); EG(ret, err);
306 	ret = nn_init(&h, 0); EG(ret, err);
307 	ret = nn_init(&R, 0); EG(ret, err);
308 	/* NOTE: this is an internal function, sanity checks on priv and u have
309 	 * been performed by the callers.
310 	 */
311 	/* R = r_1 */
312 	ret = nn_copy(&R, &(priv->key.crt.coeffs[0].r)); EG(ret, err);
313 	/* Loop  */
314 	for(i = 1; i < u; i++){
315 		r_i_1 = &(priv->key.crt.coeffs[i-1].r);
316 		r_i = &(priv->key.crt.coeffs[i].r);
317 		d_i = &(priv->key.crt.coeffs[i].d);
318 		t_i = &(priv->key.crt.coeffs[i].t);
319 
320 		/* Sanity checks */
321 		ret = nn_check_initialized(r_i_1); EG(ret, err);
322 		ret = nn_check_initialized(r_i); EG(ret, err);
323 		ret = nn_check_initialized(d_i); EG(ret, err);
324 		ret = nn_check_initialized(t_i); EG(ret, err);
325 
326 		/* m_i = c^(d_i) mod r_i */
327 #ifdef USE_SIG_BLINDING
328 		ret = _rsa_blind_exponent(d_i, r_i, &h, (bitcnt_t)RSA_EXPONENT_BLINDING_SIZE, 1); EG(ret, err);
329 		ret = nn_mod_pow(&m_i, c, &h, r_i); EG(ret, err);
330 #else
331 		ret = nn_mod_pow(&m_i, c, d_i, r_i); EG(ret, err);
332 #endif
333 		/* R = R * r_(i-1) */
334 		ret = nn_mul(&R, &R, r_i_1); EG(ret, err);
335 		/*  h = (m_i - m) * t_i mod r_i */
336 		ret = nn_mod(&h, m, r_i); EG(ret, err);
337 		ret = nn_mod_sub(&h, &m_i, &h, r_i); EG(ret, err);
338 		ret = nn_mod_mul(&h, &h, t_i, r_i); EG(ret, err);
339 		/* m = m + R * h */
340 		ret = nn_mul(&h, &R, &h); EG(ret, err);
341 		ret = nn_add(m, m, &h); EG(ret, err);
342 	}
343 
344 err:
345 	nn_uninit(&m_i);
346 	nn_uninit(&h);
347 	nn_uninit(&R);
348 
349 	PTR_NULLIFY(r_i);
350 	PTR_NULLIFY(d_i);
351 	PTR_NULLIFY(t_i);
352 	PTR_NULLIFY(r_i_1);
353 
354 	return ret;
355 }
356 
rsadp_crt(const rsa_priv_key * priv,nn_src_t c,nn_t m)357 ATTRIBUTE_WARN_UNUSED_RET static int rsadp_crt(const rsa_priv_key *priv, nn_src_t c, nn_t m)
358 {
359 	int ret;
360 	nn_src_t p, q, dP, dQ, qInv;
361 	nn m_1, m_2, h, msb_fixed;
362 	u8 u;
363 	m_1.magic = m_2.magic = h.magic = WORD(0);
364 
365 	ret = nn_init(&m_1, 0); EG(ret, err);
366 	ret = nn_init(&m_2, 0); EG(ret, err);
367 	ret = nn_init(&h, 0); EG(ret, err);
368 	ret = nn_init(&msb_fixed, 0); EG(ret, err);
369 
370 	/* Make things more readable */
371 	p    = &(priv->key.crt.p);
372 	q    = &(priv->key.crt.q);
373 	dP   = &(priv->key.crt.dP);
374 	dQ   = &(priv->key.crt.dQ);
375 	qInv = &(priv->key.crt.qInv);
376 	u    = priv->key.crt.u;
377 
378 	/* Sanity checks */
379 	ret = nn_check_initialized(p); EG(ret, err);
380 	ret = nn_check_initialized(q); EG(ret, err);
381 	ret = nn_check_initialized(dP); EG(ret, err);
382 	ret = nn_check_initialized(dQ); EG(ret, err);
383 	ret = nn_check_initialized(qInv); EG(ret, err);
384 
385 	/* m_1 = c^dP mod p */
386 #ifdef USE_SIG_BLINDING
387 	ret = _rsa_blind_exponent(dP, p, &h, (bitcnt_t)RSA_EXPONENT_BLINDING_SIZE, 1); EG(ret, err);
388 	ret = nn_mod_pow(&m_1, c, &h, p); EG(ret, err);
389 #else
390 	ret = nn_mod_pow(&m_1, c, dP, p); EG(ret, err);
391 #endif
392 	/* m_2 = c^dQ mod q */
393 #ifdef USE_SIG_BLINDING
394 	ret = _rsa_blind_exponent(dQ, q, &h, (bitcnt_t)RSA_EXPONENT_BLINDING_SIZE, 1); EG(ret, err);
395 	ret = nn_mod_pow(&m_2, c, &h, q); EG(ret, err);
396 #else
397 	ret = nn_mod_pow(&m_2, c, dQ, q); EG(ret, err);
398 #endif
399 	/* h = (m_1 - m_2) * qInv mod p */
400 	ret = nn_mod(&h, &m_2, p); EG(ret, err);
401 	ret = nn_mod_sub(&h, &m_1, &h, p); EG(ret, err);
402 	ret = nn_mod_mul(&h, &h, qInv, p); EG(ret, err);
403 	/* m = m_2 + q * h */
404 	ret = nn_mul(m, &h, q); EG(ret, err);
405 	ret = nn_add(m, &m_2, m); EG(ret, err);
406 
407 	if(u > 1){
408 		ret = rsadp_crt_coeffs(priv, c, m, u);
409 	}
410 
411 err:
412 	nn_uninit(&m_1);
413 	nn_uninit(&m_2);
414 	nn_uninit(&h);
415 
416 	PTR_NULLIFY(p);
417 	PTR_NULLIFY(q);
418 	PTR_NULLIFY(dP);
419 	PTR_NULLIFY(dQ);
420 	PTR_NULLIFY(qInv);
421 
422 	return ret;
423 }
424 
rsadp_nocrt(const rsa_priv_key * priv,nn_src_t c,nn_t m)425 ATTRIBUTE_WARN_UNUSED_RET static int rsadp_nocrt(const rsa_priv_key *priv, nn_src_t c, nn_t m)
426 {
427 	int ret, cmp;
428 	nn_src_t n, d, p, q;
429 #ifdef USE_SIG_BLINDING
430 	nn b1, b2;
431 	b1.magic = b2.magic = WORD(0);
432 #endif
433 	/* Make things more readable */
434 	if(priv->type == RSA_SIMPLE){
435 		n = &(priv->key.s.n);
436 		d = &(priv->key.s.d);
437 	}
438 	else if(priv->type == RSA_SIMPLE_PQ){
439 		n = &(priv->key.s_pq.n);
440 		d = &(priv->key.s_pq.d);
441 	}
442 	else{
443 		ret = -1;
444 		goto err;
445 	}
446 	/* Sanity checks */
447 	ret = nn_check_initialized(n); EG(ret, err);
448 	ret = nn_check_initialized(d); EG(ret, err);
449 	/* Check that c is indeed in [0, n-1], trigger an error if not */
450 	MUST_HAVE((!nn_cmp(c, n, &cmp)) && (cmp < 0), ret, err);
451 
452 	/* Compute m = c^d mod n */
453 #ifdef USE_SIG_BLINDING
454 	/* When we are asked to use exponent blinding, we MUST have a RSA_SIMPLE_PQ
455 	 * type key in order to be able to compute our Phi(n) = (p-1)(q-1) and perform
456 	 * the blinding.
457 	 */
458 	if(priv->type == RSA_SIMPLE_PQ){
459 		p = &(priv->key.s_pq.p);
460 		q = &(priv->key.s_pq.q);
461 		ret = nn_init(&b1, 0); EG(ret, err);
462 		ret = nn_init(&b2, 0); EG(ret, err);
463 		ret = nn_dec(&b1, p); EG(ret, err);
464 		ret = nn_dec(&b2, q); EG(ret, err);
465 		ret = nn_mul(&b1, &b1, &b2); EG(ret, err);
466 		ret = _rsa_blind_exponent(d, &b1, &b2, (bitcnt_t)RSA_EXPONENT_BLINDING_SIZE, 0); EG(ret, err);
467 		ret = nn_mod_pow(m, c, &b2, n); EG(ret, err);
468 	}
469 	else{
470 		ret = -1;
471 		goto err;
472 	}
473 #else
474 	FORCE_USED_VAR(p);
475 	FORCE_USED_VAR(q);
476 	ret = nn_mod_pow(m, c, d, n);
477 #endif
478 
479 err:
480 #ifdef USE_SIG_BLINDING
481 	nn_uninit(&b1);
482 	nn_uninit(&b2);
483 #endif
484 	PTR_NULLIFY(n);
485 	PTR_NULLIFY(d);
486 	PTR_NULLIFY(p);
487 	PTR_NULLIFY(q);
488 
489 	return ret;
490 }
491 
rsadp(const rsa_priv_key * priv,nn_src_t c,nn_t m)492 int rsadp(const rsa_priv_key *priv, nn_src_t c, nn_t m)
493 {
494 	int ret;
495 
496 	/* Sanity checks */
497 	MUST_HAVE((priv != NULL), ret, err);
498 
499 	/* Do we have a simple or a CRT key? */
500 	if((priv->type == RSA_SIMPLE) || (priv->type == RSA_SIMPLE_PQ)){
501 		ret = rsadp_nocrt(priv, c, m); EG(ret, err);
502 	}
503 	else if(priv->type == RSA_CRT){
504 		ret = rsadp_crt(priv, c, m); EG(ret, err);
505 	}
506 	else{
507 		ret = -1;
508 		goto err;
509 	}
510 
511 err:
512 	return ret;
513 }
514 
515 /*
516  * The "hardened" version of rsadp that uses message blinding as well
517  * as output check for Bellcore style fault attacks.
518  *
519  */
rsadp_hardened(const rsa_priv_key * priv,const rsa_pub_key * pub,nn_src_t c,nn_t m)520 int rsadp_hardened(const rsa_priv_key *priv, const rsa_pub_key *pub, nn_src_t c, nn_t m)
521 {
522 	int ret, check;
523 	nn_src_t n, e;
524 	nn b, binv;
525 	b.magic = binv.magic = WORD(0);
526 
527 	/* Make things more readable */
528 	n = &(pub->n);
529 	e = &(pub->e);
530 
531 	/* Sanity checks */
532 	MUST_HAVE((priv != NULL) && (pub != NULL), ret, err);
533 
534 	/* Blind the message: get a random value for b prime with n
535 	 * and compute its modular inverse.
536 	 */
537 	ret = nn_init(&b, 0); EG(ret, err);
538 	ret = nn_init(&binv, 0); EG(ret, err);
539 	ret = -1;
540 	while(ret){
541 		ret = nn_get_random_mod(&b, n); EG(ret, err);
542 		ret = nn_modinv(&binv, &b, n);
543 	}
544 	/* Exponentiate the blinder to the public value */
545 	ret = _nn_mod_pow_insecure(m, &b, e, n); EG(ret, err);
546 	/* Perform message blinding */
547 	ret = nn_mod_mul(&b, m, c, n); EG(ret, err);
548 
549 	/* Perform rsadp on the blinded message */
550 	ret = rsadp(priv, &b, m); EG(ret, err);
551 
552 	/* Unblind the result */
553 	ret = nn_mod_mul(m, m, &binv, n); EG(ret, err);
554 
555 	/* Now perform the public operation to check the result.
556 	 * This is useful against some fault attacks (Bellcore style).
557 	 */
558 	ret = rsaep(pub, m, &b); EG(ret, err);
559 	ret = nn_cmp(c, &b, &check); EG(ret, err);
560 	MUST_HAVE((check == 0), ret, err);
561 
562 err:
563 	nn_uninit(&b);
564 	nn_uninit(&binv);
565 
566 	PTR_NULLIFY(n);
567 	PTR_NULLIFY(e);
568 
569 	return ret;
570 }
571 
572 /* The raw RSASP1 function as defined in RFC 8017 section 5.2.1
573  *     Input: an RSA private key 'priv' and a big int message 'm'
574  *     Output: a big int signature 's'
575  *     Assumption:  RSA private key 'priv' is valid
576  */
rsasp1(const rsa_priv_key * priv,nn_src_t m,nn_t s)577 int rsasp1(const rsa_priv_key *priv, nn_src_t m, nn_t s)
578 {
579 	return rsadp(priv, m, s);
580 }
581 
582 /*
583  * The "hardened" version of rsasp1 that uses message blinding as well
584  * as optional exponent blinding.
585  *
586  */
rsasp1_hardened(const rsa_priv_key * priv,const rsa_pub_key * pub,nn_src_t m,nn_t s)587 int rsasp1_hardened(const rsa_priv_key *priv, const rsa_pub_key *pub, nn_src_t m, nn_t s)
588 {
589 	return rsadp_hardened(priv, pub, m, s);
590 }
591 
592 
593 /* The raw RSAVP1 function as defined in RFC 8017 section 5.2.2
594  *     Input: an RSA public key 'pub' and a big int signature 's'
595  *     Output: a big int ciphertext 'm'
596  *     Assumption:  RSA public key 'pub' is valid
597  */
rsavp1(const rsa_pub_key * pub,nn_src_t s,nn_t m)598 int rsavp1(const rsa_pub_key *pub, nn_src_t s, nn_t m)
599 {
600 	return rsaep(pub, s, m);
601 }
602 
rsa_digestinfo_from_hash(gen_hash_alg_type gen_hash_type,u8 * digestinfo,u32 * digestinfo_len)603 ATTRIBUTE_WARN_UNUSED_RET static int rsa_digestinfo_from_hash(gen_hash_alg_type gen_hash_type, u8 *digestinfo, u32 *digestinfo_len)
604 {
605 	int ret;
606 
607 	/* Sanity check */
608 	MUST_HAVE((digestinfo_len != NULL), ret, err);
609 
610 	switch(gen_hash_type){
611 		case HASH_MD2:{
612 			const u8 _digestinfo[] = { 0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a,
613 						   0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x02,
614 						   0x05, 0x00, 0x04, 0x10 };
615 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
616 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
617 			(*digestinfo_len) = sizeof(_digestinfo);
618 			break;
619 		}
620 		case HASH_MD4:{
621 			const u8 _digestinfo[] = { 0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a,
622 						   0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x04,
623 						   0x05, 0x00, 0x04, 0x10 };
624 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
625 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
626 			(*digestinfo_len) = sizeof(_digestinfo);
627 			break;
628 		}
629 		case HASH_MD5:{
630 			const u8 _digestinfo[] = { 0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a,
631 						   0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05,
632 						   0x05, 0x00, 0x04, 0x10 };
633 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
634 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
635 			(*digestinfo_len) = sizeof(_digestinfo);
636 			break;
637 		}
638 		case HASH_SHA0:{
639 			const u8 _digestinfo[] = { 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b,
640 						   0x0e, 0x03, 0x02, 0x12, 0x05, 0x00, 0x04,
641 						   0x14 };
642 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
643 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
644 			(*digestinfo_len) = sizeof(_digestinfo);
645 			break;
646 		}
647 		case HASH_SHA1:{
648 			const u8 _digestinfo[] = { 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b,
649 						   0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04,
650 						   0x14 };
651 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
652 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
653 			(*digestinfo_len) = sizeof(_digestinfo);
654 			break;
655 		}
656 		case HASH_SHA224:{
657 			const u8 _digestinfo[] = { 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60,
658 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
659 						   0x04, 0x05, 0x00, 0x04, 0x1c };
660 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
661 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
662 			(*digestinfo_len) = sizeof(_digestinfo);
663 			break;
664 		}
665 		case HASH_SHA256:{
666 			const u8 _digestinfo[] = { 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60,
667 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
668 						   0x01, 0x05, 0x00, 0x04, 0x20 };
669 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
670 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
671 			(*digestinfo_len) = sizeof(_digestinfo);
672 			break;
673 		}
674 		case HASH_SHA384:{
675 			const u8 _digestinfo[] = { 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60,
676 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
677 						   0x02, 0x05, 0x00, 0x04, 0x30 };
678 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
679 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
680 			(*digestinfo_len) = sizeof(_digestinfo);
681 			break;
682 		}
683 		case HASH_SHA512:{
684 			const u8 _digestinfo[] = { 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60,
685 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
686 						   0x03, 0x05, 0x00, 0x04, 0x40 };
687 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
688 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
689 			(*digestinfo_len) = sizeof(_digestinfo);
690 			break;
691 		}
692 		case HASH_SHA512_224:{
693 			const u8 _digestinfo[] = { 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60,
694 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
695 						   0x05, 0x05, 0x00, 0x04, 0x1c };
696 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
697 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
698 			(*digestinfo_len) = sizeof(_digestinfo);
699 			break;
700 		}
701 		case HASH_SHA512_256:{
702 			const u8 _digestinfo[] = { 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60,
703 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
704 						   0x06, 0x05, 0x00, 0x04, 0x20 };
705 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
706 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
707 			(*digestinfo_len) = sizeof(_digestinfo);
708 			break;
709 		}
710 		case HASH_RIPEMD160:{
711 			const u8 _digestinfo[] = { 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b,
712 						   0x24, 0x03, 0x02, 0x01, 0x05, 0x00, 0x04,
713 						   0x14 };
714 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
715 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
716 			(*digestinfo_len) = sizeof(_digestinfo);
717 			break;
718 		}
719 		/* The following SHA-3 oids have been taken from
720 		 *     https://www.ietf.org/archive/id/draft-jivsov-openpgp-sha3-01.txt
721 		 *
722 		 * The specific case of SHA3-224 is infered from the OID of SHA3-224 although
723 		 * not standardized.
724 		 */
725 		case HASH_SHA3_224:{
726 			const u8 _digestinfo[] = { 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60,
727 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
728 						   0x07, 0x05, 0x00, 0x04, 0x1c };
729 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
730 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
731 			(*digestinfo_len) = sizeof(_digestinfo);
732 			break;
733 		}
734 		case HASH_SHA3_256:{
735 			const u8 _digestinfo[] = { 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60,
736 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
737 						   0x08, 0x05, 0x00, 0x04, 0x20 };
738 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
739 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
740 			(*digestinfo_len) = sizeof(_digestinfo);
741 			break;
742 		}
743 		case HASH_SHA3_384:{
744 			const u8 _digestinfo[] = { 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60,
745 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
746 						   0x09, 0x05, 0x00, 0x04, 0x30 };
747 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
748 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
749 			(*digestinfo_len) = sizeof(_digestinfo);
750 			break;
751 		}
752 		case HASH_SHA3_512:{
753 			const u8 _digestinfo[] = { 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60,
754 						   0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02,
755 						   0x0a ,0x05, 0x00, 0x04, 0x40 };
756 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
757 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
758 			(*digestinfo_len) = sizeof(_digestinfo);
759 			break;
760 		}
761 		/* For SM3, the "RSA Signing with SM3" OID is taken from:
762 		 *     http://gmssl.org/docs/oid.html
763 		 */
764 		case HASH_SM3:{
765 			const u8 _digestinfo[] = { 0x30, 0x30, 0x30, 0x0d, 0x06, 0x08, 0x2A,
766 						   0x81, 0x1c, 0xcf, 0x55, 0x01, 0x83, 0x78,
767 						   0x05, 0x00, 0x04, 0x20 };
768 			MUST_HAVE(((*digestinfo_len) >= sizeof(_digestinfo)), ret, err);
769 			ret = local_memcpy(digestinfo, _digestinfo, sizeof(_digestinfo)); EG(ret, err);
770 			(*digestinfo_len) = sizeof(_digestinfo);
771 			break;
772 		}
773 		default:{
774 			ret = -1;
775 			goto err;
776 		}
777 	}
778 
779 err:
780 	return ret;
781 }
782 
783 /* GF1 as a mask generation function as described in RFC 8017 Appendix B.2.1
784  *     z is the 'seed', and zlen its length
785  */
_mgf1(const u8 * z,u16 zlen,u8 * mask,u64 masklen,gen_hash_alg_type mgf_hash_type)786 ATTRIBUTE_WARN_UNUSED_RET static int _mgf1(const u8 *z, u16 zlen,
787 					   u8 *mask, u64 masklen,
788 					   gen_hash_alg_type mgf_hash_type)
789 {
790 	int ret;
791 	u8 hlen, block_size;
792 	u32 c, ceil;
793 	u8 C[4];
794 	const u8 *input[3] = { z, C, NULL };
795 	u32 ilens[3] = { zlen, 4, 0 };
796 	u8 digest[MAX_DIGEST_SIZE];
797 
798 	/* Zeroize local variables */
799 	ret = local_memset(C, 0, sizeof(C)); EG(ret, err);
800 	ret = local_memset(digest, 0, sizeof(digest)); EG(ret, err);
801 
802 	/* Sanity checks */
803 	MUST_HAVE((z != NULL) && (mask != NULL), ret, err);
804 
805 	ret = gen_hash_get_hash_sizes(mgf_hash_type, &hlen, &block_size); EG(ret, err);
806 	MUST_HAVE((hlen <= MAX_DIGEST_SIZE), ret, err);
807 
808 	/* masklen must be < 2**32 * hlen */
809 	MUST_HAVE((masklen < ((u64)hlen * ((u64)0x1 << 32))), ret, err);
810 	ceil = (u32)(masklen / hlen) + !!(masklen % hlen);
811 
812 	for(c = 0; c < ceil; c++){
813 		/* 3.A: C = I2OSP (counter, 4) */
814 		C[0] = (u8)((c >> 24) & 0xff);
815 		C[1] = (u8)((c >> 16) & 0xff);
816 		C[2] = (u8)((c >>  8) & 0xff);
817 		C[3] = (u8)((c >>  0) & 0xff);
818 
819 		/* 3.B + 4. */
820 		if ((masklen % hlen) && (c == (ceil - 1))) { /* need last chunk smaller than hlen */
821 			ret = gen_hash_hfunc_scattered(input, ilens, digest, mgf_hash_type); EG(ret, err);
822 			ret = local_memcpy(&mask[c * hlen], digest, (u32)(masklen % hlen)); EG(ret, err);
823 		} else {                                     /* common case, i.e. complete chunk */
824 			ret = gen_hash_hfunc_scattered(input, ilens, &mask[c * hlen], mgf_hash_type); EG(ret, err);
825 		}
826 	}
827 err:
828 	return ret;
829 }
830 
831 /* EMSA-PSS-ENCODE encoding as described in RFC 8017 section 9.1.1
832  * NOTE: we enforce MGF1 as a mask generation function
833  */
emsa_pss_encode(const u8 * m,u32 mlen,u8 * em,u32 embits,u16 * eminlen,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type,u32 saltlen,const u8 * forced_salt)834 int emsa_pss_encode(const u8 *m, u32 mlen, u8 *em, u32 embits,
835                     u16 *eminlen, gen_hash_alg_type gen_hash_type, gen_hash_alg_type mgf_hash_type,
836                     u32 saltlen, const u8 *forced_salt)
837 {
838 	int ret;
839 	u8 hlen, block_size;
840 	u8 mhash[MAX_DIGEST_SIZE];
841 	u8 h[MAX_DIGEST_SIZE];
842 	u8 zeroes[8];
843 	/* Reasonable sizes:
844 	 * NOTE: for the cases where the salt exceeds this size, we return an error
845 	 * alhough this should not happen if our underlying libecc supports the current
846 	 * modulus size.
847 	 */
848 	u8 salt[NN_USABLE_MAX_BYTE_LEN];
849 	u8 *dbmask = em;
850 	const u8 *input[2] = { m, NULL };
851 	u32 ilens[2] = { mlen, 0 };
852 	u32 emlen, dblen, pslen;
853 	unsigned int i;
854 	u8 mask;
855 	const u8 *input_[4] = { zeroes, mhash, salt, NULL };
856 	u32 ilens_[4];
857 
858 	/* Zeroize local variables */
859 	ret = local_memset(mhash, 0, sizeof(mhash)); EG(ret, err);
860 	ret = local_memset(h, 0, sizeof(h)); EG(ret, err);
861 	ret = local_memset(salt, 0, sizeof(salt)); EG(ret, err);
862 	ret = local_memset(zeroes, 0, sizeof(zeroes)); EG(ret, err);
863 	ret = local_memset(ilens_, 0, sizeof(ilens_)); EG(ret, err);
864 
865 	/* Sanity checks */
866 	MUST_HAVE((m != NULL) && (em != NULL) && (eminlen != NULL), ret, err);
867 
868 	/* We only allow salt up to a certain size */
869 	MUST_HAVE((saltlen <= sizeof(salt)), ret, err);
870 	emlen = BYTECEIL(embits);
871 	MUST_HAVE((emlen < (u32)((u32)0x1 << 16)), ret, err);
872 
873 	/* Check that we have enough room for the output */
874 	MUST_HAVE(((*eminlen) >= emlen), ret, err);
875 
876 	/* Get the used hash information */
877 	ret = gen_hash_get_hash_sizes(gen_hash_type, &hlen, &block_size); EG(ret, err);
878 	MUST_HAVE((hlen <= MAX_DIGEST_SIZE), ret, err);
879 
880 	/* emBits at least 8hLen + 8sLen + 9 */
881 	MUST_HAVE((embits >= ((8*(u32)hlen) + (8*(u32)saltlen) + 9)), ret, err);
882 
883 	/*  If emLen < hLen + sLen + 2, output "encoding error" and stop. */
884 	MUST_HAVE((emlen >= ((u32)hlen + (u32)saltlen + 2)), ret, err);
885 
886 	/* mHash = Hash(M) */
887 	ret = gen_hash_hfunc_scattered(input, ilens, mhash, gen_hash_type); EG(ret, err);
888 
889 	/*  Generate a random octet string salt of length sLen; if sLen = 0
890 	 *  then salt is the empty string.
891 	 *  M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
892 	 *  H = Hash(M'),
893 	 */
894 	if(forced_salt != NULL){
895 		/* We are given a forced salt, use it */
896 		ret = local_memcpy(salt, forced_salt, saltlen); EG(ret, err);
897 	}
898 	else{
899 		/* We only support generating salts of size <= 2**16 */
900 		MUST_HAVE((saltlen <= 0xffff), ret, err);
901 		/* Get random salt */
902 		ret = get_random(salt, (u16)saltlen); EG(ret, err);
903 	}
904 	ilens_[0] = sizeof(zeroes);
905 	ilens_[1] = hlen;
906 	ilens_[2] = saltlen;
907 	ilens_[3] = 0;
908 	ret = gen_hash_hfunc_scattered(input_, ilens_, h, gen_hash_type); EG(ret, err);
909 
910 	/* dbMask = MGF(H, emLen - hLen - 1)
911 	 * NOTE: dbmask points to &em[0]
912 	 */
913 	dblen = (emlen - hlen - 1);
914 	pslen = (dblen - saltlen - 1); /* padding string PS len */
915 	ret = _mgf1(h, hlen, dbmask, dblen, mgf_hash_type); EG(ret, err);
916 
917         /*
918          * maskedb = (PS || 0x01 || salt) xor dbmask. We compute maskeddb directly
919          * in dbmask.
920          */
921 
922         /* 1) PS is made of 0 so xoring it with first pslen bytes of dbmask is a NOP */
923 
924         /*
925          * 2) the byte after padding string is 0x01. Do the xor with the associated
926          *    byte in dbmask
927          */
928         dbmask[pslen] ^= 0x01;
929 
930         /* 3) xor the salt with the end of dbmask */
931         for (i = 0; i < saltlen; i++){
932                 dbmask[dblen - saltlen + i] ^= salt[i];
933         }
934 
935 	/* Set the leftmost 8emLen - emBits bits of the leftmost octet
936 	 * in maskedDB to zero.
937 	 */
938 	mask = 0;
939 	for(i = 0; i < (8 - ((8*emlen) - embits)); i++){
940 		mask = (u8)(mask | (0x1 << i));
941 	}
942 	dbmask[0] &= mask;
943 	/* EM = maskedDB || H || 0xbc */
944 	ret = local_memcpy(&em[dblen], h, hlen); EG(ret, err);
945 	em[emlen - 1] = 0xbc;
946 	(*eminlen) = (u16)emlen;
947 
948 err:
949 	return ret;
950 }
951 
952 /* EMSA-PSS-VERIFY verification as described in RFC 8017 section 9.1.2
953  * NOTE: we enforce MGF1 as a mask generation function
954  */
emsa_pss_verify(const u8 * m,u32 mlen,const u8 * em,u32 embits,u16 emlen,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type,u32 saltlen)955 int emsa_pss_verify(const u8 *m, u32 mlen, const u8 *em,
956                     u32 embits, u16 emlen,
957 		    gen_hash_alg_type gen_hash_type, gen_hash_alg_type mgf_hash_type,
958                     u32 saltlen)
959 {
960 	int ret, cmp;
961 	u8 hlen, block_size;
962 	u8 mhash[MAX_DIGEST_SIZE];
963 	u8 h_[MAX_DIGEST_SIZE];
964 	u8 zeroes[8];
965 	const u8 *input[2] = { m, NULL };
966 	u32 ilens[2] = { mlen, 0 };
967 	unsigned int i;
968 	u8 mask;
969 	u16 _emlen;
970 	/*
971 	 * NOTE: the NN_USABLE_MAX_BYTE_LEN should be a reasonable size here.
972 	 */
973 	u8 dbmask[NN_USABLE_MAX_BYTE_LEN];
974 	u8 *db;
975 	const u8 *h, *salt, *maskeddb = em;
976 	u32 dblen;
977 	const u8 *input_[4];
978 	u32 ilens_[4];
979 
980 	/* Zeroize local variables */
981 	ret = local_memset(mhash, 0, sizeof(mhash)); EG(ret, err);
982 	ret = local_memset(h_, 0, sizeof(h_)); EG(ret, err);
983 	ret = local_memset(dbmask, 0, sizeof(dbmask)); EG(ret, err);
984 	ret = local_memset(zeroes, 0, sizeof(zeroes)); EG(ret, err);
985 	ret = local_memset(input_, 0, sizeof(input_)); EG(ret, err);
986 	ret = local_memset(ilens_, 0, sizeof(ilens_)); EG(ret, err);
987 
988 	/* Sanity checks */
989 	MUST_HAVE((m != NULL) && (em != NULL), ret, err);
990 
991 	/* Get the used hash information */
992 	ret = gen_hash_get_hash_sizes(gen_hash_type, &hlen, &block_size); EG(ret, err);
993 	MUST_HAVE((hlen <= MAX_DIGEST_SIZE), ret, err);
994 
995 	/* Let mHash = Hash(M), an octet string of length hLen */
996 	ret = gen_hash_hfunc_scattered(input, ilens, mhash, gen_hash_type); EG(ret, err);
997 
998 	/* emBits at least 8hLen + 8sLen + 9 */
999 	MUST_HAVE((embits >= ((8*(u32)hlen) + (8*(u32)saltlen) + 9)), ret, err);
1000 
1001 	/* Check that emLen == \ceil(emBits/8) */
1002 	MUST_HAVE((((embits / 8) + 1) < (u32)((u32)0x1 << 16)), ret, err);
1003 	_emlen = ((embits % 8) == 0) ? (u16)(embits / 8) : (u16)((embits / 8) + 1);
1004 	MUST_HAVE((_emlen == emlen), ret, err);
1005 
1006 	/* If emLen < hLen + sLen + 2, output "inconsistent" and stop */
1007 	MUST_HAVE((emlen >= ((u32)hlen + (u32)saltlen + 2)), ret, err);
1008 
1009 	/* If the rightmost octet of EM does not have hexadecimal value 0xbc, output "inconsistent" and stop */
1010 	MUST_HAVE((em[emlen - 1] == 0xbc), ret, err);
1011 
1012 	/* If the leftmost 8emLen - emBits bits of the leftmost octet in maskedDB are not all equal to zero,
1013 	 * output "inconsistent" and stop
1014 	 * NOTE: maskeddb points to &em[0]
1015 	 */
1016 	mask = 0;
1017 	for(i = 0; i < (8 - ((unsigned int)(8*emlen) - embits)); i++){
1018 		mask = (u8)(mask | (0x1 << i));
1019 	}
1020 	MUST_HAVE(((maskeddb[0] & (~mask)) == 0), ret, err);
1021 
1022 	/* dbMask = MGF(H, emLen - hLen - 1) */
1023 	dblen = (u32)(emlen - hlen - 1);
1024 	h = &em[dblen];
1025 	MUST_HAVE(((u16)dblen <= sizeof(dbmask)), ret, err); /* sanity check for overflow */
1026 	ret = _mgf1(h, hlen, dbmask, dblen, mgf_hash_type); EG(ret, err);
1027 	/* DB = maskedDB \xor dbMask */
1028 	db = &dbmask[0];
1029 	for(i = 0; i < (u16)dblen; i++){
1030 		db[i] = (dbmask[i] ^ maskeddb[i]);
1031 	}
1032 	/* Set the leftmost 8emLen - emBits bits of the leftmost octet in DB to zero */
1033 	db[0] &= mask;
1034 
1035 	/*
1036 	 * If the emLen - hLen - sLen - 2 leftmost octets of DB are not
1037          * zero or if the octet at position emLen - hLen - sLen - 1 (the
1038          * leftmost position is "position 1") does not have hexadecimal
1039          * value 0x01, output "inconsistent" and stop.
1040 	 */
1041 	for(i = 0; i < (u16)(dblen - saltlen - 1); i++){
1042 		MUST_HAVE((db[i] == 0x00), ret, err);
1043 	}
1044 	MUST_HAVE((db[dblen - saltlen - 1] == 0x01), ret, err);
1045 
1046 	/* Let salt be the last sLen octets of DB */
1047 	salt = &db[dblen - saltlen];
1048 	/*
1049 	 * Let H' = Hash(M'), an octet string of length hLen with
1050 	 *     M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
1051 	 */
1052 	/* Fill input_ */
1053 	input_[0] = zeroes;
1054 	input_[1] = mhash;
1055 	input_[2] = salt;
1056 	input_[3] = NULL;
1057 	/* Fill ilens_ */
1058 	ilens_[0] = sizeof(zeroes);
1059 	ilens_[1] = hlen;
1060 	ilens_[2] = saltlen;
1061 	ilens_[3] = 0;
1062 	/* Hash */
1063 	ret = gen_hash_hfunc_scattered(input_, ilens_, h_, gen_hash_type); EG(ret, err);
1064 
1065 	/* If H = H', output "consistent".  Otherwise, output "inconsistent" */
1066 	ret = are_equal(h, h_, hlen, &cmp); EG(ret, err);
1067 	if(!cmp){
1068 		ret = -1;
1069 	}
1070 
1071 err:
1072 	return ret;
1073 }
1074 
1075 /* EMSA-PKCS1-v1_5 encoding as described in RFC 8017 section 9.2
1076  */
emsa_pkcs1_v1_5_encode(const u8 * m,u32 mlen,u8 * em,u16 emlen,gen_hash_alg_type gen_hash_type)1077 int emsa_pkcs1_v1_5_encode(const u8 *m, u32 mlen, u8 *em, u16 emlen,
1078                            gen_hash_alg_type gen_hash_type)
1079 {
1080 	int ret;
1081 	const u8 *input[2] = { m, NULL };
1082 	u32 ilens[2] = { mlen, 0 };
1083 	u8 digest_size, block_size;
1084 	u8 digest[MAX_DIGEST_SIZE];
1085 	u32 digestinfo_len = 0;
1086 	u32 tlen = 0;
1087 
1088 	/* Zeroize local variables */
1089 	ret = local_memset(digest, 0, sizeof(digest)); EG(ret, err);
1090 
1091 	/* Compute H = Hash(M) */
1092 	ret = gen_hash_get_hash_sizes(gen_hash_type, &digest_size, &block_size); EG(ret, err);
1093 	MUST_HAVE((digest_size <= MAX_DIGEST_SIZE), ret, err);
1094 	ret = gen_hash_hfunc_scattered(input, ilens, digest, gen_hash_type); EG(ret, err);
1095 
1096 	/* Now encode:
1097 	 *
1098          *     DigestInfo ::= SEQUENCE {
1099          *         digestAlgorithm AlgorithmIdentifier,
1100          *         digest OCTET STRING
1101          *     }
1102 	 */
1103 	digestinfo_len = emlen;
1104 	/* NOTE: the rsa_digestinfo_from_hash returns the size of DigestInfo *WITHOUT* the
1105 	 * appended raw hash, tlen is the real size of the complete encoded DigestInfo.
1106 	 */
1107 	ret = rsa_digestinfo_from_hash(gen_hash_type, em, &digestinfo_len); EG(ret, err);
1108 	tlen = (digestinfo_len + digest_size);
1109 
1110 	/* If emLen < tLen + 11, output "intended encoded message length too short" and stop */
1111 	MUST_HAVE((emlen >= (tlen + 11)), ret, err);
1112 
1113 	/* Copy T at the end of em */
1114 	digestinfo_len = emlen;
1115 	ret = rsa_digestinfo_from_hash(gen_hash_type, &em[emlen - tlen], &digestinfo_len); EG(ret, err);
1116 	ret = local_memcpy(&em[emlen - tlen + digestinfo_len], digest, digest_size); EG(ret, err);
1117 
1118 	/*
1119 	 * Format 0x00 || 0x01 || PS || 0x00 before
1120 	 */
1121 	em[0] = 0x00;
1122 	em[1] = 0x01;
1123 	em[emlen - tlen - 1] = 0x00;
1124 	ret = local_memset(&em[2], 0xff, emlen - tlen - 3);
1125 
1126 err:
1127 	return ret;
1128 }
1129 
1130 /****************************************************************/
1131 /******** Encryption schemes *************************************/
1132 /* The RSAES-PKCS1-V1_5-ENCRYPT algorithm as described in RFC 8017 section 7.2.1
1133  *
1134  */
rsaes_pkcs1_v1_5_encrypt(const rsa_pub_key * pub,const u8 * m,u32 mlen,u8 * c,u32 * clen,u32 modbits,const u8 * forced_seed,u32 seedlen)1135 int rsaes_pkcs1_v1_5_encrypt(const rsa_pub_key *pub, const u8 *m, u32 mlen,
1136                              u8 *c, u32 *clen, u32 modbits,
1137                              const u8 *forced_seed, u32 seedlen)
1138 {
1139 	int ret;
1140 	u32 k;
1141 	u8 *em = c;
1142 	unsigned int i;
1143 	nn m_, c_;
1144 	m_.magic = c_.magic = WORD(0);
1145 
1146 	MUST_HAVE((clen != NULL) && (c != NULL) && (m != NULL), ret, err);
1147 
1148 	k = BYTECEIL(modbits);
1149 
1150 	/* Check on lengths */
1151 	MUST_HAVE((k >= 11), ret, err);
1152 	MUST_HAVE((mlen <= (k - 11)), ret, err);
1153 	MUST_HAVE(((*clen) >= k), ret, err);
1154 
1155 	/* EME-PKCS1-v1_5 encoding EM = 0x00 || 0x02 || PS || 0x00 || M */
1156 	em[0] = 0x00;
1157 	em[1] = 0x02;
1158 	if(forced_seed == NULL){
1159 		for(i = 0; i < (k - mlen - 3); i++){
1160 			u8 rand_byte = 0;
1161 			while (!rand_byte) {
1162 				ret = get_random(&rand_byte, 1); EG(ret, err);
1163 			}
1164 			em[2 + i] = rand_byte;
1165 		}
1166 	}
1167 	else{
1168 		MUST_HAVE((seedlen == (k - mlen - 3)), ret, err);
1169 		/* Check that the forced seed does not contain 0x00 */
1170 		for(i = 0; i < seedlen; i++){
1171 			MUST_HAVE((forced_seed[i] != 0), ret, err);
1172 		}
1173 		ret = local_memcpy(&em[2], forced_seed, seedlen); EG(ret, err);
1174 	}
1175 	em[k - mlen - 1] = 0x00;
1176 	ret = local_memcpy(&em[k - mlen], m, mlen); EG(ret, err);
1177 
1178 	/* RSA encryption */
1179 	/*   m = OS2IP (EM) */
1180 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1181 	ret = rsa_os2ip(&m_, em, (u16)k); EG(ret, err);
1182 	/*   c = RSAEP ((n, e), m) */
1183 	ret = rsaep(pub, &m_, &c_); EG(ret, err);
1184 	/*   C = I2OSP (c, k) */
1185 	ret = rsa_i2osp(&c_, c, (u16)k); EG(ret, err);
1186 	(*clen) = (u16)k;
1187 
1188 err:
1189 	nn_uninit(&m_);
1190 	nn_uninit(&c_);
1191 	/* Zeroify in case of error */
1192 	if(ret && (clen != NULL)){
1193 		IGNORE_RET_VAL(local_memset(c, 0, (*clen)));
1194 	}
1195 
1196 	return ret;
1197 }
1198 
1199 /* The RSAES-PKCS1-V1_5-DECRYPT algorithm as described in RFC 8017 section 7.2.2
1200  *
1201  */
_rsaes_pkcs1_v1_5_decrypt(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * c,u32 clen,u8 * m,u32 * mlen,u32 modbits)1202 ATTRIBUTE_WARN_UNUSED_RET static int _rsaes_pkcs1_v1_5_decrypt(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *c, u32 clen,
1203                              u8 *m, u32 *mlen, u32 modbits)
1204 {
1205 	int ret;
1206 	unsigned int i, pos;
1207 	u32 k;
1208 	u8 r;
1209 	u8 *em = m;
1210 	nn m_, c_;
1211 	m_.magic = c_.magic = WORD(0);
1212 
1213 	MUST_HAVE((mlen != NULL) && (c != NULL) && (m != NULL), ret, err);
1214 
1215 	k = BYTECEIL(modbits);
1216 
1217 	/* Check on lengths */
1218 	MUST_HAVE((clen == k) && (k >= 11), ret, err);
1219 	MUST_HAVE(((*mlen) >= k), ret, err);
1220 
1221 	/* RSA decryption */
1222 	/*   c = OS2IP (C) */
1223 	ret = rsa_os2ip(&c_, c, clen); EG(ret, err);
1224 	/*   m = RSADP ((n, d), c) */
1225 	if(pub != NULL){
1226 		ret = rsadp_hardened(priv, pub, &c_, &m_); EG(ret, err);
1227 	}
1228 	else{
1229 		ret = rsadp(priv, &c_, &m_); EG(ret, err);
1230 	}
1231 	/*   EM = I2OSP (m, k) */
1232 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1233 	ret = rsa_i2osp(&m_, em, (u16)k); EG(ret, err);
1234 
1235 	/* EME-PKCS1-v1_5 decoding: EM = 0x00 || 0x02 || PS || 0x00 || M */
1236 	/* NOTE: we try our best to do the following in constant time to
1237 	 * limit padding oracles here (see Bleichenbacher attacks).
1238 	 */
1239 	ret = (1 - (!!(em[0] == 0x00) & !!(em[1] == 0x02)));
1240 	pos = 0;
1241 	/* Handle the first zero octet after PS in constant time */
1242 	for(i = 2; i < k; i++){
1243 		unsigned int mask = !!(em[i] == 0x00) & !!(pos == 0);
1244 		pos = (mask * i) + ((1 - mask) * pos);
1245 	}
1246 	ret |= !(pos >= (2 + 8)); /* PS length is at least 8 (also implying we found a 0x00) */
1247 	pos = (pos == 0) ? pos : (pos + 1);
1248 	/* We get a random value between 2 and k if we have an error so that
1249 	 * we put a random value in pos.
1250 	 */
1251         ret |= get_random((u8*)&i, 4);
1252 	/* Get a random value r for later loop dummy operations */
1253 	ret |= get_random(&r, 1);
1254 	/* Update pos with random value in case of error to progress
1255 	 * nominally with the algorithm
1256 	 */
1257 	pos = (ret) ? ((i % (k - 2)) + 2) : pos;
1258 	for(i = 2; i < k; i++){
1259 		u8 r_;
1260 		unsigned int idx;
1261 		/* Replace m by a random value in case of error */
1262 		idx = ((i < pos) ? 0x00 : (i - pos));
1263 		r ^= (u8)i;
1264 		r_ = (u8)((u8)(!!ret) * r);
1265 		m[idx] = (em[i] ^ r_);
1266 	}
1267 	(*mlen) = (u16)(k - pos);
1268 	/* Hide return value details to avoid information leak */
1269 	ret = -(!!ret);
1270 
1271 err:
1272 	nn_uninit(&m_);
1273 	nn_uninit(&c_);
1274 
1275 	return ret;
1276 }
1277 
1278 /*
1279  * Basic version without much SCA/faults protections.
1280  */
rsaes_pkcs1_v1_5_decrypt(const rsa_priv_key * priv,const u8 * c,u32 clen,u8 * m,u32 * mlen,u32 modbits)1281 int rsaes_pkcs1_v1_5_decrypt(const rsa_priv_key *priv, const u8 *c, u32 clen,
1282                              u8 *m, u32 *mlen, u32 modbits)
1283 {
1284 	return _rsaes_pkcs1_v1_5_decrypt(priv, NULL, c, clen, m, mlen, modbits);
1285 }
1286 
1287 /*
1288  * Hardened version with some SCA/faults protections.
1289  */
rsaes_pkcs1_v1_5_decrypt_hardened(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * c,u32 clen,u8 * m,u32 * mlen,u32 modbits)1290 int rsaes_pkcs1_v1_5_decrypt_hardened(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *c, u32 clen,
1291                              u8 *m, u32 *mlen, u32 modbits)
1292 {
1293 	return _rsaes_pkcs1_v1_5_decrypt(priv, pub, c, clen, m, mlen, modbits);
1294 }
1295 
1296 /* The RSAES-OAEP-ENCRYPT algorithm as described in RFC 8017 section 7.1.1
1297  *
1298  */
rsaes_oaep_encrypt(const rsa_pub_key * pub,const u8 * m,u32 mlen,u8 * c,u32 * clen,u32 modbits,const u8 * label,u32 label_len,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type,const u8 * forced_seed,u32 seedlen)1299 int rsaes_oaep_encrypt(const rsa_pub_key *pub, const u8 *m, u32 mlen,
1300                        u8 *c, u32 *clen, u32 modbits, const u8 *label, u32 label_len,
1301                        gen_hash_alg_type gen_hash_type, gen_hash_alg_type mgf_hash_type,
1302 		       const u8 *forced_seed, u32 seedlen)
1303 {
1304 	int ret;
1305 	u32 k, pslen, khlen;
1306 	unsigned int i;
1307 	u8 hlen, block_size;
1308 	u8 *em = c;
1309 	/* Reasonable sizes */
1310 	u8 seed[MAX_DIGEST_SIZE];
1311         /*
1312          * NOTE: the NN_USABLE_MAX_BYTE_LEN should be a reasonable size here.
1313          */
1314 	u8 dbmask[NN_USABLE_MAX_BYTE_LEN];
1315 	u8 db[NN_USABLE_MAX_BYTE_LEN];
1316 	u8 *seedmask = dbmask, *maskedseed = NULL, *maskeddb = NULL;
1317 	const u8 *input[2] = { c, NULL };
1318 	u32 ilens[2] = { 0, 0 };
1319 	nn m_, c_;
1320 	m_.magic = c_.magic = WORD(0);
1321 
1322 	/* Zeroize local variables */
1323 	ret = local_memset(seed, 0, sizeof(seed)); EG(ret, err);
1324 	ret = local_memset(db, 0, sizeof(db)); EG(ret, err);
1325 	ret = local_memset(dbmask, 0, sizeof(dbmask)); EG(ret, err);
1326 
1327 	MUST_HAVE((clen != NULL) && (c != NULL) && (m != NULL), ret, err);
1328 
1329 	k = BYTECEIL(modbits);
1330 
1331 	ret = gen_hash_get_hash_sizes(gen_hash_type, &hlen, &block_size); EG(ret, err);
1332 	MUST_HAVE((hlen <= MAX_DIGEST_SIZE), ret, err);
1333 
1334 	/* Check on lengths */
1335 	MUST_HAVE(((u32)k >= ((2 * (u32)hlen) + 2)), ret, err);
1336 	MUST_HAVE(((mlen ) <= ((u32)k - (2 * (u32)hlen) - 2)), ret, err);
1337 	MUST_HAVE(((*clen) >= k), ret, err);
1338 
1339 	/* EME-OAEP encoding: DB = lHash || PS || 0x01 || M */
1340 	/* and then EM = 0x00 || maskedSeed || maskedDB */
1341 	maskedseed = &em[1];
1342 	maskeddb   = &em[hlen + 1];
1343 	MUST_HAVE(((k - hlen - 1) <= sizeof(db)), ret, err);
1344 	if(label == NULL){
1345 		MUST_HAVE((label_len == 0), ret, err);
1346 	}
1347 	else{
1348 		input[0] = label;
1349 		ilens[0] = label_len;
1350 	}
1351 	ret = gen_hash_hfunc_scattered(input, ilens, &db[0], gen_hash_type); EG(ret, err);
1352 	/*
1353 	 * 2.b. Generate a padding string PS consisting of k - mLen - 2hLen -
1354 	 * 2 zero octets. The length of PS may be zero.
1355 	 *
1356 	 * DB = lHash || PS || 0x01 || M. Hence, PS starts at octet hlen in DB
1357 	 */
1358 	pslen = (k - mlen - (u32)(2 * hlen) - 2);
1359 	for(i = 0; i < pslen; i++){
1360 		db[hlen + i] = 0x00;
1361 	}
1362 	/* 0x01 || M */
1363 	db[hlen + pslen] = 0x01;
1364 	for(i = 0 ; i < mlen; i++){
1365 		db[hlen + pslen + 1 + i] = m[i];
1366 	}
1367 	/* Generate a random octet string seed of length hLen */
1368 	MUST_HAVE((hlen <= sizeof(seed)), ret, err);
1369 	if(forced_seed != NULL){
1370 		MUST_HAVE((seedlen == hlen), ret, err);
1371 		ret = local_memcpy(seed, forced_seed, seedlen); EG(ret, err);
1372 	}
1373 	else{
1374 		ret = get_random(seed, hlen); EG(ret, err);
1375 	}
1376 	/* Let dbMask = MGF(seed, k - hLen - 1)*/
1377 	khlen = (k - hlen - 1);
1378 	MUST_HAVE((khlen <= sizeof(dbmask)), ret, err);
1379 	ret = _mgf1(seed, hlen, dbmask, khlen, mgf_hash_type); EG(ret, err);
1380 	/* Let maskedDB = DB \xor dbMask */
1381 	for(i = 0; i < khlen; i++){
1382 		maskeddb[i] = (db[i] ^ dbmask[i]);
1383 	}
1384 	/* Let seedMask = MGF(maskedDB, hLen) */
1385 	MUST_HAVE((khlen < (u32)((u32)0x1 << 16)), ret, err);
1386 	ret = _mgf1(maskeddb, (u16)khlen, seedmask, hlen, mgf_hash_type); EG(ret, err);
1387 	/* Let maskedSeed = seed \xor seedMask */
1388 	for(i = 0; i < hlen; i++){
1389 		maskedseed[i] = (seed[i] ^ seedmask[i]);
1390 	}
1391 	/* EM = 0x00 || maskedSeed || maskedDB should be filled */
1392 	em[0] = 0x00;
1393 
1394 	/* RSA encryption */
1395 	/*   m = OS2IP (EM) */
1396 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1397 	ret = rsa_os2ip(&m_, em, (u16)k); EG(ret, err);
1398 	/*   c = RSAEP ((n, e), m) */
1399 	ret = rsaep(pub, &m_, &c_); EG(ret, err);
1400 	/*   C = I2OSP (c, k) */
1401 	ret = rsa_i2osp(&c_, c, (u16)k); EG(ret, err);
1402 	(*clen) = (u16)k;
1403 
1404 err:
1405 	nn_uninit(&m_);
1406 	nn_uninit(&c_);
1407 	/* Zeroify in case of error */
1408 	if(ret && (clen != NULL)){
1409 		IGNORE_RET_VAL(local_memset(c, 0, (*clen)));
1410 	}
1411 
1412 	return ret;
1413 }
1414 
1415 /* The RSAES-OAEP-DECRYPT algorithm as described in RFC 8017 section 7.1.2
1416  *
1417  */
_rsaes_oaep_decrypt(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * c,u32 clen,u8 * m,u32 * mlen,u32 modbits,const u8 * label,u32 label_len,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type)1418 ATTRIBUTE_WARN_UNUSED_RET static int _rsaes_oaep_decrypt(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *c, u32 clen,
1419                        u8 *m, u32 *mlen, u32 modbits,
1420                        const u8 *label, u32 label_len, gen_hash_alg_type gen_hash_type,
1421 		       gen_hash_alg_type mgf_hash_type)
1422 {
1423 	int ret, cmp;
1424 	u32 k, khlen;
1425 	unsigned int i, pos;
1426 	u8 hlen, block_size;
1427 	u8 *em = m;
1428 	u8 r;
1429 	/* Reasonable sizes */
1430 	u8 lhash[MAX_DIGEST_SIZE];
1431 	u8 seedmask[MAX_DIGEST_SIZE];
1432         /*
1433          * NOTE: the NN_USABLE_MAX_BYTE_LEN should be a reasonable size here.
1434          */
1435 	u8 dbmask[NN_USABLE_MAX_BYTE_LEN];
1436 	u8 *seed = seedmask, *maskedseed = NULL, *maskeddb = NULL, *db = NULL;
1437 	const u8 *input[2] = { c, NULL };
1438 	u32 ilens[2] = { 0, 0 };
1439 	nn m_, c_;
1440 	m_.magic = c_.magic = WORD(0);
1441 
1442 	/* Zeroize local variables */
1443 	ret = local_memset(lhash, 0, sizeof(lhash)); EG(ret, err);
1444 	ret = local_memset(seedmask, 0, sizeof(seedmask)); EG(ret, err);
1445 	ret = local_memset(dbmask, 0, sizeof(dbmask)); EG(ret, err);
1446 
1447 	MUST_HAVE((c != NULL) && (m != NULL), ret, err);
1448 
1449 	k = BYTECEIL(modbits);
1450 
1451 	ret = gen_hash_get_hash_sizes(gen_hash_type, &hlen, &block_size); EG(ret, err);
1452 	MUST_HAVE((hlen <= MAX_DIGEST_SIZE), ret, err);
1453 
1454 	/* Check on lengths */
1455 	MUST_HAVE((clen == k), ret, err);
1456 	MUST_HAVE(((u32)k >= ((2 * (u32)hlen) + 2)), ret, err);
1457 
1458 	/* RSA decryption */
1459 	/*   c = OS2IP (C) */
1460 	ret = rsa_os2ip(&c_, c, clen); EG(ret, err);
1461 	/*   m = RSADP ((n, d), c) */
1462 	if(pub != NULL){
1463 		ret = rsadp_hardened(priv, pub, &c_, &m_); EG(ret, err);
1464 	}
1465 	else{
1466 		ret = rsadp(priv, &c_, &m_); EG(ret, err);
1467 	}
1468 	/*   EM = I2OSP (m, k) */
1469 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1470 	ret = rsa_i2osp(&m_, em, (u16)k); EG(ret, err);
1471 
1472 	/* EME-OAEP decoding */
1473 	/* lHash = Hash(L) */
1474 	if(label == NULL){
1475 		MUST_HAVE((label_len == 0), ret, err);
1476 	}
1477 	else{
1478 		input[0] = label;
1479 		ilens[0] = label_len;
1480 	}
1481 	ret = gen_hash_hfunc_scattered(input, ilens, lhash, gen_hash_type); EG(ret, err);
1482 	/*  EM = Y || maskedSeed || maskedDB */
1483 	maskedseed = &em[1];
1484 	maskeddb   = &em[hlen + 1];
1485 	/* seedMask = MGF(maskedDB, hLen) */
1486 	khlen = (k - hlen - 1);
1487 	MUST_HAVE((khlen < (u32)((u32)0x1 << 16)), ret, err);
1488 	ret = _mgf1(maskeddb, (u16)khlen, seedmask, hlen, mgf_hash_type); EG(ret, err);
1489 	/* Let maskedSeed = seed \xor seedMask */
1490 	for(i = 0; i < hlen; i++){
1491 		seed[i] = (maskedseed[i] ^ seedmask[i]);
1492 	}
1493 	/* dbMask = MGF(seed, k - hLen - 1) */
1494 	MUST_HAVE((khlen <= sizeof(dbmask)), ret, err);
1495 	ret = _mgf1(seed, hlen, dbmask, khlen, mgf_hash_type); EG(ret, err);
1496 	/* Let DB = maskedDB \xor dbMask */
1497 	db = dbmask;
1498 	for(i = 0; i < khlen; i++){
1499 		db[i] = (maskeddb[i] ^ dbmask[i]);
1500 	}
1501 	/* DB = lHash' || PS || 0x01 || M */
1502 	/* NOTE: we try our best to do the following in constant time to
1503 	 * limit padding oracles here (see Manger attacks).
1504 	 */
1505 	/* Y must be != 0 */
1506 	ret = em[0];
1507 	/* Isolate and compare lHash' to lHash */
1508 	ret |= are_equal(&db[0], lhash, hlen, &cmp);
1509 	ret |= ((~cmp) & 0x1);
1510 	/* Find 0x01 separator in constant time */
1511 	pos = 0;
1512 	for(i = hlen; i < khlen; i++){
1513 		u8 r_;
1514 		pos = ((db[i] == 0x01) && (pos == 0)) ? i : pos;
1515 		r_ = (pos == 0) ? db[i] : 0;
1516 		ret |= r_; /* Capture non zero PS */
1517 	}
1518 	pos = (pos == 0) ? pos : (pos + 1);
1519 	/* We get a random value between 2 and k if we have an error so that
1520 	 * we put a random value in pos.
1521 	 */
1522         ret |= get_random((u8*)&i, 4);
1523 	/* Get a random value r for later loop dummy operations */
1524 	ret |= get_random(&r, 1);
1525 	/* Update pos with random value in case of error to progress
1526 	 * nominally with the algorithm
1527 	 */
1528 	pos = (ret) ? ((i % (khlen - hlen)) + hlen) : pos;
1529 	/* Copy the result */
1530 	for(i = hlen; i < khlen; i++){
1531 		u8 r_;
1532 		unsigned int idx;
1533 		/* Replace m by a random value in case of error */
1534 		idx = (i < pos) ? 0x00 : (i - pos);
1535 		r ^= (u8)i;
1536 		r_ = (u8)((u8)(!!ret) * r);
1537 		m[idx] = (db[i] ^ r_);
1538 	}
1539 	(*mlen) = (u16)(k - hlen - 1 - pos);
1540 	/* Hide return value details to avoid information leak */
1541 	ret = -(!!ret);
1542 
1543 err:
1544 	nn_uninit(&m_);
1545 	nn_uninit(&c_);
1546 
1547 	return ret;
1548 }
1549 
1550 /*
1551  * Basic version without much SCA/faults protections.
1552  */
rsaes_oaep_decrypt(const rsa_priv_key * priv,const u8 * c,u32 clen,u8 * m,u32 * mlen,u32 modbits,const u8 * label,u32 label_len,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type)1553 int rsaes_oaep_decrypt(const rsa_priv_key *priv, const u8 *c, u32 clen,
1554                        u8 *m, u32 *mlen, u32 modbits,
1555                        const u8 *label, u32 label_len, gen_hash_alg_type gen_hash_type,
1556 		       gen_hash_alg_type mgf_hash_type)
1557 {
1558 	return _rsaes_oaep_decrypt(priv, NULL, c, clen, m, mlen, modbits, label, label_len, gen_hash_type, mgf_hash_type);
1559 }
1560 
1561 /*
1562  * Hardened version with some SCA/faults protections.
1563  */
rsaes_oaep_decrypt_hardened(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * c,u32 clen,u8 * m,u32 * mlen,u32 modbits,const u8 * label,u32 label_len,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type)1564 int rsaes_oaep_decrypt_hardened(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *c, u32 clen,
1565                        u8 *m, u32 *mlen, u32 modbits,
1566                        const u8 *label, u32 label_len, gen_hash_alg_type gen_hash_type,
1567 		       gen_hash_alg_type mgf_hash_type)
1568 {
1569 	return _rsaes_oaep_decrypt(priv, pub, c, clen, m, mlen, modbits, label, label_len, gen_hash_type, mgf_hash_type);
1570 }
1571 
1572 /****************************************************************/
1573 /******** Signature schemes *************************************/
1574 /* The RSASSA-PKCS1-V1_5-SIGN signature algorithm as described in RFC 8017 section 8.2.1
1575  *
1576  */
_rsassa_pkcs1_v1_5_sign(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * m,u32 mlen,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type)1577 ATTRIBUTE_WARN_UNUSED_RET static int _rsassa_pkcs1_v1_5_sign(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *m, u32 mlen,
1578                            u8 *s, u16 *slen, u32 modbits, gen_hash_alg_type gen_hash_type)
1579 {
1580 	int ret;
1581 	u8 *em = s;
1582 	u32 k;
1583 	nn m_, s_;
1584 	m_.magic = s_.magic = WORD(0);
1585 
1586 	/* Checks on sizes */
1587 	MUST_HAVE((slen != NULL), ret, err);
1588 
1589 	k = BYTECEIL(modbits);
1590 
1591 	/* Only accept reasonable sizes */
1592 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1593 	/* Sanity check on size */
1594 	MUST_HAVE(((*slen) >= k), ret, err);
1595 
1596 	/* EM = EMSA-PKCS1-V1_5-ENCODE (M, k) */
1597 	ret = emsa_pkcs1_v1_5_encode(m, mlen, em, (u16)k, gen_hash_type); EG(ret, err);
1598 
1599 	/* m = OS2IP (EM) */
1600 	ret = rsa_os2ip(&m_, em, (u16)k); EG(ret, err);
1601 	/* s = RSASP1 (K, m) */
1602 	if(pub != NULL){
1603 		ret = rsasp1_hardened(priv, pub, &m_, &s_); EG(ret, err);
1604 	}
1605 	else{
1606 		ret = rsasp1(priv, &m_, &s_); EG(ret, err);
1607 	}
1608 	/* S = I2OSP (s, k) */
1609 	ret = rsa_i2osp(&s_, s, (u16)k);
1610 	(*slen) = (u16)k;
1611 
1612 err:
1613 	nn_uninit(&m_);
1614 	nn_uninit(&s_);
1615 	/* Zeroify in case of error */
1616 	if(ret && (slen != NULL)){
1617 		IGNORE_RET_VAL(local_memset(s, 0, (*slen)));
1618 	}
1619 
1620 	return ret;
1621 }
1622 
1623 /*
1624  * Basic version without much SCA/faults protections.
1625  */
rsassa_pkcs1_v1_5_sign(const rsa_priv_key * priv,const u8 * m,u32 mlen,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type)1626 int rsassa_pkcs1_v1_5_sign(const rsa_priv_key *priv, const u8 *m, u32 mlen,
1627                            u8 *s, u16 *slen, u32 modbits, gen_hash_alg_type gen_hash_type)
1628 {
1629 	return _rsassa_pkcs1_v1_5_sign(priv, NULL, m, mlen, s, slen, modbits, gen_hash_type);
1630 }
1631 
1632 /*
1633  * Hardened version with some SCA/faults protections.
1634  */
rsassa_pkcs1_v1_5_sign_hardened(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * m,u32 mlen,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type)1635 int rsassa_pkcs1_v1_5_sign_hardened(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *m, u32 mlen,
1636                            u8 *s, u16 *slen, u32 modbits, gen_hash_alg_type gen_hash_type)
1637 {
1638 	return _rsassa_pkcs1_v1_5_sign(priv, pub, m, mlen, s, slen, modbits, gen_hash_type);
1639 }
1640 
1641 /* The RSASSA-PKCS1-V1_5-VERIFY verification algorithm as described in RFC 8017 section 8.2.2
1642  *
1643  */
rsassa_pkcs1_v1_5_verify(const rsa_pub_key * pub,const u8 * m,u32 mlen,const u8 * s,u16 slen,u32 modbits,gen_hash_alg_type gen_hash_type)1644 int rsassa_pkcs1_v1_5_verify(const rsa_pub_key *pub, const u8 *m, u32 mlen,
1645                              const u8 *s, u16 slen, u32 modbits, gen_hash_alg_type gen_hash_type)
1646 {
1647 	int ret, cmp;
1648 	/* Get a large enough buffer to hold the result */
1649         /*
1650          * NOTE: the NN_USABLE_MAX_BYTE_LEN should be a reasonable size here.
1651          */
1652 	u8 em[NN_USABLE_MAX_BYTE_LEN];
1653 	u8 em_[NN_USABLE_MAX_BYTE_LEN];
1654 	u32 k;
1655 	nn m_, s_;
1656 	m_.magic = s_.magic = WORD(0);
1657 
1658 	/* Zeroize local variables */
1659 	ret = local_memset(em, 0, sizeof(em)); EG(ret, err);
1660 	ret = local_memset(em_, 0, sizeof(em_)); EG(ret, err);
1661 
1662 	k = BYTECEIL(modbits);
1663 	/* Only accept reasonable sizes */
1664 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1665 
1666 	/* Length checking: If the length of the signature S is not k
1667          * octets, output "invalid signature" and stop.
1668 	 */
1669 	MUST_HAVE(((u16)k == slen), ret, err);
1670 
1671 	/* s = OS2IP (S) */
1672 	ret = rsa_os2ip(&s_, s, slen); EG(ret, err);
1673 	/* m = RSAVP1 ((n, e), s) */
1674 	ret = rsavp1(pub, &s_, &m_); EG(ret, err);
1675 	/* EM = I2OSP (m, k) */
1676 	MUST_HAVE((slen <= sizeof(em)), ret, err);
1677 	ret = rsa_i2osp(&m_, em, slen); EG(ret, err);
1678 	/* EM' = EMSA-PKCS1-V1_5-ENCODE (M, k) */
1679 	MUST_HAVE((k <= sizeof(em_)), ret, err);
1680 	ret = emsa_pkcs1_v1_5_encode(m, mlen, em_, (u16)k, gen_hash_type); EG(ret, err);
1681 
1682 	/* Compare */
1683 	ret = are_equal(em, em_, (u16)k, &cmp); EG(ret, err);
1684 	if(!cmp){
1685 		ret = -1;
1686 	}
1687 err:
1688 	nn_uninit(&m_);
1689 	nn_uninit(&s_);
1690 
1691 	return ret;
1692 }
1693 
1694 /* The RSASSA-PSS-SIGN signature algorithm as described in RFC 8017 section 8.1.1
1695  *
1696  */
_rsassa_pss_sign(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * m,u32 mlen,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type,u32 saltlen,const u8 * forced_salt)1697 ATTRIBUTE_WARN_UNUSED_RET static int _rsassa_pss_sign(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *m, u32 mlen,
1698                     u8 *s, u16 *slen, u32 modbits,
1699 		    gen_hash_alg_type gen_hash_type, gen_hash_alg_type mgf_hash_type,
1700                     u32 saltlen, const u8 *forced_salt)
1701 {
1702 	int ret;
1703 	u8 *em = s;
1704 	u16 emsize;
1705 	u32 k;
1706 	nn m_, s_;
1707 	m_.magic = s_.magic = WORD(0);
1708 
1709 	MUST_HAVE((slen != NULL), ret, err);
1710 
1711 	MUST_HAVE((modbits > 1), ret, err);
1712 
1713 	k = BYTECEIL(modbits);
1714 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1715 
1716 	/* Sanity check on size */
1717 	MUST_HAVE(((*slen) >= k), ret, err);
1718 
1719 	/* EM = EMSA-PSS-ENCODE (M, modBits - 1) */
1720 	emsize = (*slen);
1721 	ret = emsa_pss_encode(m, mlen, em, (modbits - 1), &emsize, gen_hash_type, mgf_hash_type, saltlen, forced_salt); EG(ret, err);
1722 
1723 	/* Note that the octet length of EM will be one less than k if modBits - 1 is divisible by 8 and equal to k otherwise */
1724 	MUST_HAVE(emsize == BYTECEIL(modbits - 1), ret, err);
1725 
1726 	/* m = OS2IP (EM) */
1727 	ret = rsa_os2ip(&m_, em, (u16)emsize); EG(ret, err);
1728 	/* s = RSASP1 (K, m) */
1729 	if(pub != NULL){
1730 		ret = rsasp1_hardened(priv, pub, &m_, &s_); EG(ret, err);
1731 	}
1732 	else{
1733 		ret = rsasp1(priv, &m_, &s_); EG(ret, err);
1734 	}
1735 	/* S = I2OSP (s, k) */
1736 	MUST_HAVE((k < ((u32)0x1 << 16)), ret, err);
1737 	ret = rsa_i2osp(&s_, s, (u16)k);
1738 	(*slen) = (u16)k;
1739 
1740 err:
1741 	nn_uninit(&m_);
1742 	nn_uninit(&s_);
1743 	/* Zeroify in case of error */
1744 	if(ret && (slen != NULL)){
1745 		IGNORE_RET_VAL(local_memset(s, 0, (*slen)));
1746 	}
1747 
1748 	return ret;
1749 }
1750 
1751 /*
1752  * Basic version without much SCA/faults protections.
1753  */
rsassa_pss_sign(const rsa_priv_key * priv,const u8 * m,u32 mlen,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type,u32 saltlen,const u8 * forced_salt)1754 int rsassa_pss_sign(const rsa_priv_key *priv, const u8 *m, u32 mlen,
1755                     u8 *s, u16 *slen, u32 modbits,
1756 		    gen_hash_alg_type gen_hash_type, gen_hash_alg_type mgf_hash_type,
1757                     u32 saltlen, const u8 *forced_salt)
1758 {
1759 	return _rsassa_pss_sign(priv, NULL, m, mlen, s, slen, modbits, gen_hash_type, mgf_hash_type, saltlen, forced_salt);
1760 }
1761 
1762 /*
1763  * Hardened version with some SCA/faults protections.
1764  */
rsassa_pss_sign_hardened(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * m,u32 mlen,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type,u32 saltlen,const u8 * forced_salt)1765 int rsassa_pss_sign_hardened(const rsa_priv_key *priv, const rsa_pub_key *pub, const u8 *m, u32 mlen,
1766                     u8 *s, u16 *slen, u32 modbits,
1767 		    gen_hash_alg_type gen_hash_type, gen_hash_alg_type mgf_hash_type,
1768                     u32 saltlen, const u8 *forced_salt)
1769 {
1770 	return _rsassa_pss_sign(priv, pub, m, mlen, s, slen, modbits, gen_hash_type, mgf_hash_type, saltlen, forced_salt);
1771 }
1772 
1773 
1774 /* The RSASSA-PSS-VERIFY verification algorithm as described in RFC 8017 section 8.1.2
1775  *
1776  */
rsassa_pss_verify(const rsa_pub_key * pub,const u8 * m,u32 mlen,const u8 * s,u16 slen,u32 modbits,gen_hash_alg_type gen_hash_type,gen_hash_alg_type mgf_hash_type,u32 saltlen)1777 int rsassa_pss_verify(const rsa_pub_key *pub, const u8 *m, u32 mlen,
1778                       const u8 *s, u16 slen, u32 modbits,
1779                       gen_hash_alg_type gen_hash_type, gen_hash_alg_type mgf_hash_type,
1780 		      u32 saltlen)
1781 {
1782 	int ret;
1783 	/* Get a large enough buffer to hold the result */
1784         /*
1785          * NOTE: the NN_USABLE_MAX_BYTE_LEN should be a reasonable size here.
1786          */
1787 	u8 em[NN_USABLE_MAX_BYTE_LEN];
1788 	u16 emlen;
1789 	u32 k;
1790 	nn m_, s_;
1791 	m_.magic = s_.magic = WORD(0);
1792 
1793 	/* Zeroize local variables */
1794 	ret = local_memset(em, 0, sizeof(em)); EG(ret, err);
1795 
1796 	MUST_HAVE((modbits > 1), ret, err);
1797 	k = BYTECEIL(modbits);
1798 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1799 
1800 	/* s = OS2IP (S) */
1801 	ret = rsa_os2ip(&s_, s, slen); EG(ret, err);
1802 	/* m = RSAVP1 ((n, e), s) */
1803 	ret = rsavp1(pub, &s_, &m_); EG(ret, err);
1804 	/* emLen = \ceil ((modBits - 1)/8) */
1805 	MUST_HAVE((((modbits - 1) / 8) + 1) < (u32)((u32)0x1 << 16), ret, err);
1806 	emlen = (((modbits - 1) % 8) == 0) ? (u16)((modbits - 1) / 8) : (u16)(((modbits - 1) / 8) + 1);
1807 
1808 	/* Note that emLen will be one less than k if modBits - 1 is divisible by 8 and equal to k otherwise */
1809 	MUST_HAVE(emlen == BYTECEIL(modbits - 1), ret, err);
1810 
1811 	/* EM = I2OSP (m, emLen) */
1812 	MUST_HAVE((emlen <= sizeof(em)), ret, err);
1813 	ret = rsa_i2osp(&m_, em, (u16)emlen); EG(ret, err);
1814 	/*  Result = EMSA-PSS-VERIFY (M, EM, modBits - 1) */
1815 	ret = emsa_pss_verify(m, mlen, em, (modbits - 1), emlen, gen_hash_type, mgf_hash_type, saltlen);
1816 
1817 err:
1818 	nn_uninit(&m_);
1819 	nn_uninit(&s_);
1820 
1821 	return ret;
1822 }
1823 
1824 /* The RSA signature algorithm using ISO/IEC 9796-2 padding scheme 1.
1825  * This is a signature with recovery.
1826  *
1827  * XXX: beware that this scheme is here for completeness, but is considered fragile
1828  * since practical attacks exist when the hash function is of relatively "small" size
1829  * (see http://www.crypto-uni.lu/jscoron/publications/iso97962joc.pdf).
1830  *
1831  * The ISO/IEC 9796-2 is also described in EMV Book 2 in the A.2.1 section:
1832  * "Digital Signature Scheme Giving Message Recovery".
1833  *
1834  */
_rsa_iso9796_2_sign_recover(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * m,u32 mlen,u32 * m1len,u32 * m2len,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type)1835 ATTRIBUTE_WARN_UNUSED_RET static int _rsa_iso9796_2_sign_recover(const rsa_priv_key *priv, const rsa_pub_key *pub,
1836 								 const u8 *m, u32 mlen, u32 *m1len, u32 *m2len, u8 *s, u16 *slen,
1837 								 u32 modbits, gen_hash_alg_type gen_hash_type)
1838 {
1839 	int ret;
1840 	u32 k, m1len_, m2len_;
1841 	u8 hlen, block_size;
1842 	gen_hash_context hctx;
1843 	nn m_, s_;
1844 	m_.magic = s_.magic = WORD(0);
1845 
1846 	MUST_HAVE((priv != NULL) && (m != NULL), ret, err);
1847 
1848 	MUST_HAVE((slen != NULL), ret, err);
1849 
1850 	MUST_HAVE((modbits > 1), ret, err);
1851 
1852 	k = BYTECEIL(modbits);
1853 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1854 
1855 	/* Get hash parameters */
1856 	ret = gen_hash_get_hash_sizes(gen_hash_type, &hlen, &block_size); EG(ret, err);
1857 	MUST_HAVE((hlen <= MAX_DIGEST_SIZE), ret, err);
1858 
1859 	/* Sanity check on sizes */
1860 	MUST_HAVE(((*slen) >= k), ret, err);
1861 	MUST_HAVE(k >= (u32)(2 + hlen), ret, err);
1862 
1863 	/* Compute our recoverable and non-recoverable parts */
1864 	m1len_ = (mlen >= (k - 2 - hlen)) ? (k - 2 - hlen) : mlen;
1865 	m2len_ = (mlen - m1len_);
1866 
1867 	/* Now hash the message */
1868 	ret = gen_hash_init(&hctx, gen_hash_type); EG(ret, err);
1869 	ret = gen_hash_update(&hctx, m, mlen, gen_hash_type); EG(ret, err);
1870 	ret = gen_hash_final(&hctx, &s[k - 1 - hlen], gen_hash_type); EG(ret, err);
1871 
1872 	/* Put M1 */
1873 	ret = local_memcpy(&s[1], m, m1len_); EG(ret, err);
1874 	if(m1len != NULL){
1875 		(*m1len) = m1len_;
1876 	}
1877 	if(m2len != NULL){
1878 		(*m2len) = m2len_;
1879 	}
1880 
1881 	/* Put the constants */
1882 	s[0]     = 0x6a;
1883 	s[k - 1] = 0xbc;
1884 
1885 	/* m = OS2IP (X) */
1886 	ret = rsa_os2ip(&m_, s, k); EG(ret, err);
1887 	/* s = RSASP1 (K, m) */
1888 	if(pub != NULL){
1889 		ret = rsasp1_hardened(priv, pub, &m_, &s_); EG(ret, err);
1890 	}
1891 	else{
1892 		ret = rsasp1(priv, &m_, &s_); EG(ret, err);
1893 	}
1894 	/* S = I2OSP (s, k) */
1895 	MUST_HAVE((k < ((u32)0x1 << 16)), ret, err);
1896 	ret = rsa_i2osp(&s_, s, (u16)k);
1897 	(*slen) = (u16)k;
1898 
1899 err:
1900 	nn_uninit(&m_);
1901 	nn_uninit(&s_);
1902 
1903 	if(ret && (m1len != 0)){
1904 		(*m1len) = 0;
1905 	}
1906 	if(ret && (m2len != 0)){
1907 		(*m2len) = 0;
1908 	}
1909 
1910 	return ret;
1911 }
1912 
1913 /*
1914  * Basic version without much SCA/faults protections.
1915  */
rsa_iso9796_2_sign_recover(const rsa_priv_key * priv,const u8 * m,u32 mlen,u32 * m1len,u32 * m2len,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type)1916 int rsa_iso9796_2_sign_recover(const rsa_priv_key *priv, const u8 *m, u32 mlen, u32 *m1len,
1917 			       u32 *m2len, u8 *s, u16 *slen,
1918 			       u32 modbits, gen_hash_alg_type gen_hash_type)
1919 {
1920 	return _rsa_iso9796_2_sign_recover(priv, NULL, m, mlen, m1len, m2len, s, slen, modbits, gen_hash_type);
1921 }
1922 
1923 /*
1924  * Hardened version with some SCA/faults protections.
1925  */
rsa_iso9796_2_sign_recover_hardened(const rsa_priv_key * priv,const rsa_pub_key * pub,const u8 * m,u32 mlen,u32 * m1len,u32 * m2len,u8 * s,u16 * slen,u32 modbits,gen_hash_alg_type gen_hash_type)1926 int rsa_iso9796_2_sign_recover_hardened(const rsa_priv_key *priv, const rsa_pub_key *pub,
1927 				        const u8 *m, u32 mlen, u32 *m1len, u32 *m2len, u8 *s, u16 *slen,
1928 				        u32 modbits, gen_hash_alg_type gen_hash_type)
1929 {
1930 	return _rsa_iso9796_2_sign_recover(priv, pub, m, mlen, m1len, m2len, s, slen, modbits, gen_hash_type);
1931 }
1932 
1933 /* The RSA verification algorithm using ISO/IEC 9796-2 padding scheme 1.
1934  * This is a verification with recovery.
1935  *
1936  * XXX: beware that this scheme is here for completeness, but is considered fragile
1937  * since practical attacks exist when the hash function is of relatively "small" size
1938  * (see http://www.crypto-uni.lu/jscoron/publications/iso97962joc.pdf).
1939  *
1940  * The ISO/IEC 9796-2 is also described in EMV Book 2 in the A.2.1 section:
1941  * "Digital Signature Scheme Giving Message Recovery".
1942  *
1943  */
rsa_iso9796_2_verify_recover(const rsa_pub_key * pub,const u8 * m2,u32 m2len,u8 * m1,u32 * m1len,const u8 * s,u16 slen,u32 modbits,gen_hash_alg_type gen_hash_type)1944 int rsa_iso9796_2_verify_recover(const rsa_pub_key *pub, const u8 *m2, u32 m2len, u8 *m1, u32 *m1len,
1945                                  const u8 *s, u16 slen, u32 modbits, gen_hash_alg_type gen_hash_type)
1946 {
1947 	int ret, cmp;
1948 	/* Get a large enough buffer to hold the result */
1949         /*
1950          * NOTE: the NN_USABLE_MAX_BYTE_LEN should be a reasonable size here.
1951          */
1952 	u8 X[NN_USABLE_MAX_BYTE_LEN];
1953 	u8 H[MAX_DIGEST_SIZE];
1954 	u32 k, m1len_;
1955 	u8 hlen, block_size;
1956 	gen_hash_context hctx;
1957 	nn m_, s_;
1958 	m_.magic = s_.magic = WORD(0);
1959 
1960 	MUST_HAVE((pub != NULL) && (m2 != NULL), ret, err);
1961 
1962 	/* Zeroize local variables */
1963 	ret = local_memset(X, 0, sizeof(X)); EG(ret, err);
1964 	ret = local_memset(H, 0, sizeof(H)); EG(ret, err);
1965 
1966 	k = BYTECEIL(modbits);
1967 	/* Only accept reasonable sizes */
1968 	MUST_HAVE((k < (u32)((u32)0x1 << 16)), ret, err);
1969 
1970 	ret = gen_hash_get_hash_sizes(gen_hash_type, &hlen, &block_size); EG(ret, err);
1971 	MUST_HAVE((hlen <= MAX_DIGEST_SIZE), ret, err);
1972 
1973 	/* Length checking: If the length of the signature S is not k
1974          * octets, output "invalid signature" and stop.
1975 	 */
1976 	MUST_HAVE(((u16)k == slen), ret, err);
1977 	MUST_HAVE((slen >= (hlen + 2)), ret, err);
1978 	m1len_ = (u32)(slen - (hlen + 2));
1979 
1980 	/* s = OS2IP (S) */
1981 	ret = rsa_os2ip(&s_, s, slen); EG(ret, err);
1982 	/* m = RSAVP1 ((n, e), s) */
1983 	ret = rsavp1(pub, &s_, &m_); EG(ret, err);
1984 	/* EM = I2OSP (m, k) */
1985 	MUST_HAVE((slen <= sizeof(X)), ret, err);
1986 	ret = rsa_i2osp(&m_, X, slen); EG(ret, err);
1987 
1988 	/* Split the message in B || m1 || H || E with
1989 	 * B = '6A', E = 'BC', and H the hash value */
1990 	if(m1len != NULL){
1991 		MUST_HAVE((*m1len) >= m1len_, ret, err);
1992 		(*m1len) = m1len_;
1993 	}
1994 	if((X[0] != 0x6a) || (X[slen - 1] != 0xbc)){
1995 		ret = -1;
1996 		goto err;
1997 	}
1998 
1999 	/* Compute the hash of m1 || m2 */
2000 	ret = gen_hash_init(&hctx, gen_hash_type); EG(ret, err);
2001 	ret = gen_hash_update(&hctx, &X[1], m1len_, gen_hash_type); EG(ret, err);
2002 	ret = gen_hash_update(&hctx, m2, m2len, gen_hash_type); EG(ret, err);
2003 	ret = gen_hash_final(&hctx, H, gen_hash_type); EG(ret, err);
2004 
2005 	/* Compare */
2006 	ret = are_equal(H, &X[1 + m1len_], (u16)hlen, &cmp); EG(ret, err);
2007 	if(!cmp){
2008 		ret = -1;
2009 	}
2010 	/* If comparison is OK, copy data */
2011 	if(m1 != NULL){
2012 		MUST_HAVE((m1len != NULL), ret, err);
2013 		ret = local_memcpy(m1, &X[1], (*m1len)); EG(ret, err);
2014 	}
2015 
2016 err:
2017 	nn_uninit(&m_);
2018 	nn_uninit(&s_);
2019 
2020 	if(ret && (m1len != 0)){
2021 		(*m1len) = 0;
2022 	}
2023 
2024 	return ret;
2025 }
2026 
2027 #ifdef RSA
2028 /* RSA PKCS#1 test vectors taken from:
2029  *     https://github.com/bdauvergne/python-pkcs1/tree/master/tests/data
2030  */
2031 #include "rsa_pkcs1_tests.h"
2032 
main(int argc,char * argv[])2033 int main(int argc, char *argv[])
2034 {
2035 	int ret = 0;
2036 	FORCE_USED_VAR(argc);
2037 	FORCE_USED_VAR(argv);
2038 
2039 	/* Sanity check on size for RSA.
2040 	 * NOTE: the double parentheses are here to handle -Wunreachable-code
2041 	 */
2042 	if((NN_USABLE_MAX_BIT_LEN) < (4096)){
2043 		ext_printf("Error: you seem to have compiled libecc with usable NN size < 4096, not suitable for RSA.\n");
2044 		ext_printf("  => Please recompile libecc with EXTRA_CFLAGS=\"-DUSER_NN_BIT_LEN=4096\"\n");
2045 		ext_printf("     This will increase usable NN for proper RSA up to 4096 bits.\n");
2046 		ext_printf("     Then recompile the current examples with the same EXTRA_CFLAGS=\"-DUSER_NN_BIT_LEN=4096\" flag and execute again!\n");
2047 		/* NOTE: ret = 0 here to pass self tests even if the library is not compatible */
2048 		ret = 0;
2049 		goto err;
2050 	}
2051 
2052 	ret = perform_rsa_tests(all_rsa_tests, sizeof(all_rsa_tests) / sizeof(rsa_test*));
2053 
2054 err:
2055 	return ret;
2056 }
2057 #endif
2058