xref: /freebsd/crypto/libecc/src/examples/sig/rsa/rsa_tests.h (revision f0865ec9906d5a18fa2a3b61381f22ce16e606ad)
1 /*
2  *  Copyright (C) 2021 - This file is part of libecc project
3  *
4  *  Authors:
5  *      Ryad BENADJILA <ryadbenadjila@gmail.com>
6  *      Arnaud EBALARD <arnaud.ebalard@ssi.gouv.fr>
7  *
8  *  This software is licensed under a dual BSD and GPL v2 license.
9  *  See LICENSE file at the root folder of the project.
10  */
11 #ifndef __RSA_TESTS_H__
12 #define __RSA_TESTS_H__
13 
14 /* Test suite for RSA PKCS#1 algorithms */
15 #include "rsa.h"
16 
17 typedef enum {
18 	RSA_PKCS1_v1_5_ENC = 0,
19 	RSA_PKCS1_v1_5_SIG = 1,
20 	RSA_OAEP_ENC = 2,
21 	RSA_PSS_SIG = 3,
22 } rsa_alg_type;
23 
24 typedef struct {
25 	const char *name;
26 	rsa_alg_type type;
27 	u32 modbits;
28         gen_hash_alg_type hash;
29 	const u8 *n;
30 	u16 nlen;
31 	const u8 *d;
32 	u16 dlen;
33 	const u8 *e;
34 	u16 elen;
35 	const u8 *p;
36 	u16 plen;
37 	const u8 *q;
38 	u16 qlen;
39 	const u8 *dP;
40 	u16 dPlen;
41 	const u8 *dQ;
42 	u16 dQlen;
43 	const u8 *qInv;
44 	u16 qInvlen;
45 	const u8 *m;
46 	u32 mlen;
47 	const u8 *res;
48 	u32 reslen;
49 	const u8 *salt;
50 	u32 saltlen;
51 } rsa_test;
52 
53 
perform_rsa_tests(const rsa_test ** tests,u32 num_tests)54 ATTRIBUTE_WARN_UNUSED_RET static inline int perform_rsa_tests(const rsa_test **tests, u32 num_tests)
55 {
56 	int ret = 0, cmp;
57 	unsigned int i;
58 
59 	for(i = 0; i < num_tests; i++){
60 		const rsa_test *t = tests[i];
61 		u32 modbits = t->modbits;
62 		rsa_pub_key pub;
63 		rsa_priv_key priv;
64 		rsa_priv_key priv_pq;
65 
66 		/* Import the keys */
67 		ret = rsa_import_pub_key(&pub, t->n, (u16)t->nlen, t->e, (u16)t->elen); EG(ret, err1);
68 		if(t->dP == NULL){
69 			const rsa_test *t_ = NULL;
70 			MUST_HAVE((num_tests > 1) && (i < (num_tests - 1)), ret, err);
71 			/* NOTE: we use the "next" CRT test to extract p and q */
72 			t_ = tests[i + 1];
73 			MUST_HAVE((t_->dP != NULL), ret, err);
74 			/* Import the RSA_SIMPLE private key with only d and n */
75 			ret = rsa_import_simple_priv_key(&priv, t->n, (u16)t->nlen, t->d, (u16)t->dlen, NULL, 0, NULL, 0); EG(ret, err1);
76 			/* Import the RSA_SIMPLE_PQ with d, n, p and q */
77 			ret = rsa_import_simple_priv_key(&priv_pq, t->n, (u16)t->nlen, t->d, (u16)t->dlen, t_->p, (u16)t_->plen, t_->q, (u16)t_->qlen); EG(ret, err1);
78 		}
79 		else{
80 			/* Import the RSA_CRT CRT key */
81 			ret = rsa_import_crt_priv_key(&priv, t->p, (u16)t->plen, t->q, (u16)t->qlen, t->dP, (u16)t->dPlen, t->dQ, (u16)t->dQlen, t->qInv, (u16)t->qInvlen, NULL, NULL, 0); EG(ret, err1);
82 		}
83 #ifdef USE_SIG_BLINDING
84 		/* We using exponent blinding, only RSA_SIMPLE_PQ are usable. We hence overwrite the key */
85 		ret = local_memcpy(&priv, &priv_pq, sizeof(rsa_priv_key)); EG(ret, err);
86 #endif
87 		/* Perform our operation */
88 		switch(t->type){
89 			case RSA_PKCS1_v1_5_ENC:{
90 				u8 cipher[NN_USABLE_MAX_BYTE_LEN];
91 				u32 clen;
92 				if(t->salt != NULL){
93 					clen = sizeof(cipher);
94 					ret = rsaes_pkcs1_v1_5_encrypt(&pub, t->m, t->mlen, cipher, &clen, modbits, t->salt, t->saltlen); EG(ret, err1);
95 					/* Check the result */
96 					MUST_HAVE((clen == t->reslen), ret, err1);
97 					ret = are_equal(t->res, cipher, t->reslen, &cmp); EG(ret, err1);
98 					MUST_HAVE(cmp, ret, err1);
99 				}
100 				/* Try to decrypt */
101 				clen = sizeof(cipher);
102 				ret = rsaes_pkcs1_v1_5_decrypt(&priv, t->res, t->reslen, cipher, &clen, modbits); EG(ret, err1);
103 				/* Check the result */
104 				MUST_HAVE((clen == t->mlen), ret, err1);
105 				ret = are_equal(t->m, cipher, t->mlen, &cmp); EG(ret, err1);
106 				MUST_HAVE(cmp, ret, err1);
107 				/* Try to decrypt with the hardened version */
108 				clen = sizeof(cipher);
109 				ret = rsaes_pkcs1_v1_5_decrypt_hardened(&priv, &pub, t->res, t->reslen, cipher, &clen, modbits); EG(ret, err1);
110 				/* Check the result */
111 				MUST_HAVE((clen == t->mlen), ret, err1);
112 				ret = are_equal(t->m, cipher, t->mlen, &cmp); EG(ret, err1);
113 				MUST_HAVE(cmp, ret, err1);
114 				break;
115 			}
116 			case RSA_OAEP_ENC:{
117 				u8 cipher[NN_USABLE_MAX_BYTE_LEN];
118 				u32 clen;
119 				if(t->salt != NULL){
120 					clen = sizeof(cipher);
121 					ret = rsaes_oaep_encrypt(&pub, t->m, t->mlen, cipher, &clen, modbits, NULL, 0, t->hash, t->hash, t->salt, t->saltlen); EG(ret, err1);
122 					/* Check the result */
123 					MUST_HAVE((clen == t->reslen), ret, err1);
124 					ret = are_equal(t->res, cipher, t->reslen, &cmp); EG(ret, err1);
125 					MUST_HAVE(cmp, ret, err1);
126 				}
127 				/* Try to decrypt */
128 				clen = sizeof(cipher);
129 				ret = rsaes_oaep_decrypt(&priv, t->res, t->reslen, cipher, &clen, modbits, NULL, 0, t->hash, t->hash); EG(ret, err1);
130 				/* Check the result */
131 				MUST_HAVE((clen == t->mlen), ret, err1);
132 				ret = are_equal(t->m, cipher, t->mlen, &cmp); EG(ret, err1);
133 				MUST_HAVE(cmp, ret, err1);
134 				/* Try to decrypt with the hardened version */
135 				clen = sizeof(cipher);
136 				ret = rsaes_oaep_decrypt_hardened(&priv, &pub, t->res, t->reslen, cipher, &clen, modbits, NULL, 0, t->hash, t->hash); EG(ret, err1);
137 				/* Check the result */
138 				MUST_HAVE((clen == t->mlen), ret, err1);
139 				ret = are_equal(t->m, cipher, t->mlen, &cmp); EG(ret, err1);
140 				MUST_HAVE(cmp, ret, err1);
141 				break;
142 			}
143 			case RSA_PKCS1_v1_5_SIG:{
144 				u8 sig[NN_USABLE_MAX_BYTE_LEN];
145 				u16 siglen = sizeof(sig);
146 				MUST_HAVE((t->reslen) <= 0xffff, ret, err1);
147 				ret = rsassa_pkcs1_v1_5_verify(&pub, t->m, t->mlen, t->res, (u16)(t->reslen), modbits, t->hash); EG(ret, err1);
148 				/* Try to sign */
149 				ret = rsassa_pkcs1_v1_5_sign(&priv, t->m, t->mlen, sig, &siglen, modbits, t->hash); EG(ret, err1);
150 				/* Check the result */
151 				MUST_HAVE((siglen == t->reslen), ret, err1);
152 				ret = are_equal(t->res, sig, t->reslen, &cmp); EG(ret, err1);
153 				MUST_HAVE(cmp, ret, err1);
154 				/* Try to sign with the hardened version */
155 				ret = rsassa_pkcs1_v1_5_sign_hardened(&priv, &pub, t->m, t->mlen, sig, &siglen, modbits, t->hash); EG(ret, err1);
156 				/* Check the result */
157 				MUST_HAVE((siglen == t->reslen), ret, err1);
158 				ret = are_equal(t->res, sig, t->reslen, &cmp); EG(ret, err1);
159 				MUST_HAVE(cmp, ret, err1);
160 				break;
161 			}
162 			case RSA_PSS_SIG:{
163 				if(t->salt == NULL){
164 					/* In case of NULL salt, default saltlen value is the digest size */
165 					u8 digestsize, blocksize;
166 					ret = gen_hash_get_hash_sizes(t->hash, &digestsize, &blocksize); EG(ret, err1);
167 					MUST_HAVE((t->reslen) <= 0xffff, ret, err1);
168 					ret = rsassa_pss_verify(&pub, t->m, t->mlen, t->res, (u16)(t->reslen), modbits, t->hash, t->hash, digestsize); EG(ret, err1);
169 				}
170 				else{
171 					MUST_HAVE((t->reslen) <= 0xffff, ret, err1);
172 					ret = rsassa_pss_verify(&pub, t->m, t->mlen, t->res, (u16)(t->reslen), modbits, t->hash, t->hash, t->saltlen); EG(ret, err1);
173 				}
174 				if(t->salt != NULL){
175 					/* Try to sign */
176 					u8 sig[NN_USABLE_MAX_BYTE_LEN];
177 					u16 siglen = sizeof(sig);
178 					ret = rsassa_pss_sign(&priv, t->m, t->mlen, sig, &siglen, modbits, t->hash, t->hash, t->saltlen, t->salt); EG(ret, err1);
179 					/* Check the result */
180 					MUST_HAVE((t->reslen) <= 0xffff, ret, err1);
181 					MUST_HAVE((siglen == (u16)(t->reslen)), ret, err1);
182 					ret = are_equal(t->res, sig, t->reslen, &cmp); EG(ret, err1);
183 					MUST_HAVE(cmp, ret, err1);
184 					/* Try to sign with the hardened version */
185 					ret = rsassa_pss_sign_hardened(&priv, &pub, t->m, t->mlen, sig, &siglen, modbits, t->hash, t->hash, t->saltlen, t->salt); EG(ret, err1);
186 					/* Check the result */
187 					MUST_HAVE((siglen == (u16)(t->reslen)), ret, err1);
188 					ret = are_equal(t->res, sig, t->reslen, &cmp); EG(ret, err1);
189 					MUST_HAVE(cmp, ret, err1);
190 				}
191 				break;
192 			}
193 			default:{
194 				ret = -1;
195 				break;
196 			}
197 		}
198 err1:
199 		if(ret){
200 			ext_printf("[-] Test %s failed (modbits = %" PRIu32 ")\n", t->name, t->modbits);
201 			goto err;
202 		}
203 		else{
204 			ext_printf("[+] Test %s passed (modbits = %" PRIu32 ")\n", t->name, t->modbits);
205 		}
206 	}
207 
208 	if(!ret){
209 		ext_printf("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\t=== [+] All RSA tests went OK! ===\n");
210 	}
211 err:
212 	return ret;
213 }
214 
215 #endif /* __RSA_TESTS_H__ */
216