xref: /freebsd/contrib/bearssl/test/test_math.c (revision dbfb4063ae95b956a2b0021c37c9a8be4c2e4393)
1 /*
2  * Copyright (c) 2016 Thomas Pornin <pornin@bolet.org>
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining
5  * a copy of this software and associated documentation files (the
6  * "Software"), to deal in the Software without restriction, including
7  * without limitation the rights to use, copy, modify, merge, publish,
8  * distribute, sublicense, and/or sell copies of the Software, and to
9  * permit persons to whom the Software is furnished to do so, subject to
10  * the following conditions:
11  *
12  * The above copyright notice and this permission notice shall be
13  * included in all copies or substantial portions of the Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
19  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
20  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <stdarg.h>
29 #include <time.h>
30 
31 #include <gmp.h>
32 
33 #include "bearssl.h"
34 #include "inner.h"
35 
36 /*
37  * Pointers to implementations.
38  */
39 typedef struct {
40 	uint32_t word_size;
41 	void (*zero)(uint32_t *x, uint32_t bit_len);
42 	void (*decode)(uint32_t *x, const void *src, size_t len);
43 	uint32_t (*decode_mod)(uint32_t *x,
44 		const void *src, size_t len, const uint32_t *m);
45 	void (*reduce)(uint32_t *x, const uint32_t *a, const uint32_t *m);
46 	void (*decode_reduce)(uint32_t *x,
47 		const void *src, size_t len, const uint32_t *m);
48 	void (*encode)(void *dst, size_t len, const uint32_t *x);
49 	uint32_t (*add)(uint32_t *a, const uint32_t *b, uint32_t ctl);
50 	uint32_t (*sub)(uint32_t *a, const uint32_t *b, uint32_t ctl);
51 	uint32_t (*ninv)(uint32_t x);
52 	void (*montymul)(uint32_t *d, const uint32_t *x, const uint32_t *y,
53 		const uint32_t *m, uint32_t m0i);
54 	void (*to_monty)(uint32_t *x, const uint32_t *m);
55 	void (*from_monty)(uint32_t *x, const uint32_t *m, uint32_t m0i);
56 	void (*modpow)(uint32_t *x, const unsigned char *e, size_t elen,
57 		const uint32_t *m, uint32_t m0i, uint32_t *t1, uint32_t *t2);
58 } int_impl;
59 
60 static const int_impl i31_impl = {
61 	31,
62 	&br_i31_zero,
63 	&br_i31_decode,
64 	&br_i31_decode_mod,
65 	&br_i31_reduce,
66 	&br_i31_decode_reduce,
67 	&br_i31_encode,
68 	&br_i31_add,
69 	&br_i31_sub,
70 	&br_i31_ninv31,
71 	&br_i31_montymul,
72 	&br_i31_to_monty,
73 	&br_i31_from_monty,
74 	&br_i31_modpow
75 };
76 static const int_impl i32_impl = {
77 	32,
78 	&br_i32_zero,
79 	&br_i32_decode,
80 	&br_i32_decode_mod,
81 	&br_i32_reduce,
82 	&br_i32_decode_reduce,
83 	&br_i32_encode,
84 	&br_i32_add,
85 	&br_i32_sub,
86 	&br_i32_ninv32,
87 	&br_i32_montymul,
88 	&br_i32_to_monty,
89 	&br_i32_from_monty,
90 	&br_i32_modpow
91 };
92 
93 static const int_impl *impl;
94 
95 static gmp_randstate_t RNG;
96 
97 /*
98  * Get a random prime of length 'size' bits. This function also guarantees
99  * that x-1 is not a multiple of 65537.
100  */
101 static void
102 rand_prime(mpz_t x, int size)
103 {
104 	for (;;) {
105 		mpz_urandomb(x, RNG, size - 1);
106 		mpz_setbit(x, 0);
107 		mpz_setbit(x, size - 1);
108 		if (mpz_probab_prime_p(x, 50)) {
109 			mpz_sub_ui(x, x, 1);
110 			if (mpz_divisible_ui_p(x, 65537)) {
111 				continue;
112 			}
113 			mpz_add_ui(x, x, 1);
114 			return;
115 		}
116 	}
117 }
118 
119 /*
120  * Print out a GMP integer (for debug).
121  */
122 static void
123 print_z(mpz_t z)
124 {
125 	unsigned char zb[1000];
126 	size_t zlen, k;
127 
128 	mpz_export(zb, &zlen, 1, 1, 0, 0, z);
129 	if (zlen == 0) {
130 		printf(" 00");
131 		return;
132 	}
133 	if ((zlen & 3) != 0) {
134 		k = 4 - (zlen & 3);
135 		memmove(zb + k, zb, zlen);
136 		memset(zb, 0, k);
137 		zlen += k;
138 	}
139 	for (k = 0; k < zlen; k += 4) {
140 		printf(" %02X%02X%02X%02X",
141 			zb[k], zb[k + 1], zb[k + 2], zb[k + 3]);
142 	}
143 }
144 
145 /*
146  * Print out an i31 or i32 integer (for debug).
147  */
148 static void
149 print_u(uint32_t *x)
150 {
151 	size_t k;
152 
153 	if (x[0] == 0) {
154 		printf(" 00000000 (0, 0)");
155 		return;
156 	}
157 	for (k = (x[0] + 31) >> 5; k > 0; k --) {
158 		printf(" %08lX", (unsigned long)x[k]);
159 	}
160 	printf(" (%u, %u)", (unsigned)(x[0] >> 5), (unsigned)(x[0] & 31));
161 }
162 
163 /*
164  * Check that an i31/i32 number and a GMP number are equal.
165  */
166 static void
167 check_eqz(uint32_t *x, mpz_t z)
168 {
169 	unsigned char xb[1000];
170 	unsigned char zb[1000];
171 	size_t xlen, zlen;
172 	int good;
173 
174 	xlen = ((x[0] + 31) & ~(uint32_t)31) >> 3;
175 	impl->encode(xb, xlen, x);
176 	mpz_export(zb, &zlen, 1, 1, 0, 0, z);
177 	good = 1;
178 	if (xlen < zlen) {
179 		good = 0;
180 	} else if (xlen > zlen) {
181 		size_t u;
182 
183 		for (u = xlen; u > zlen; u --) {
184 			if (xb[xlen - u] != 0) {
185 				good = 0;
186 				break;
187 			}
188 		}
189 	}
190 	good = good && memcmp(xb + xlen - zlen, zb, zlen) == 0;
191 	if (!good) {
192 		size_t u;
193 
194 		printf("Mismatch:\n");
195 		printf("  x = ");
196 		print_u(x);
197 		printf("\n");
198 		printf("  ex = ");
199 		for (u = 0; u < xlen; u ++) {
200 			printf("%02X", xb[u]);
201 		}
202 		printf("\n");
203 		printf("  z = ");
204 		print_z(z);
205 		printf("\n");
206 		exit(EXIT_FAILURE);
207 	}
208 }
209 
210 /* obsolete
211 static void
212 mp_to_br(uint32_t *mx, uint32_t x_bitlen, mpz_t x)
213 {
214 	uint32_t x_ebitlen;
215 	size_t xlen;
216 
217 	if (mpz_sizeinbase(x, 2) > x_bitlen) {
218 		abort();
219 	}
220 	x_ebitlen = ((x_bitlen / 31) << 5) + (x_bitlen % 31);
221 	br_i31_zero(mx, x_ebitlen);
222 	mpz_export(mx + 1, &xlen, -1, sizeof *mx, 0, 1, x);
223 }
224 */
225 
226 static void
227 test_modint(void)
228 {
229 	int i, j, k;
230 	mpz_t p, a, b, v, t1;
231 
232 	printf("Test modular integers: ");
233 	fflush(stdout);
234 
235 	gmp_randinit_mt(RNG);
236 	mpz_init(p);
237 	mpz_init(a);
238 	mpz_init(b);
239 	mpz_init(v);
240 	mpz_init(t1);
241 	mpz_set_ui(t1, (unsigned long)time(NULL));
242 	gmp_randseed(RNG, t1);
243 	for (k = 2; k <= 128; k ++) {
244 		for (i = 0; i < 10; i ++) {
245 			unsigned char ep[100], ea[100], eb[100], ev[100];
246 			size_t plen, alen, blen, vlen;
247 			uint32_t mp[40], ma[40], mb[40], mv[60], mx[100];
248 			uint32_t mt1[40], mt2[40], mt3[40];
249 			uint32_t ctl;
250 			uint32_t mp0i;
251 
252 			rand_prime(p, k);
253 			mpz_urandomm(a, RNG, p);
254 			mpz_urandomm(b, RNG, p);
255 			mpz_urandomb(v, RNG, k + 60);
256 			if (mpz_sgn(b) == 0) {
257 				mpz_set_ui(b, 1);
258 			}
259 			mpz_export(ep, &plen, 1, 1, 0, 0, p);
260 			mpz_export(ea, &alen, 1, 1, 0, 0, a);
261 			mpz_export(eb, &blen, 1, 1, 0, 0, b);
262 			mpz_export(ev, &vlen, 1, 1, 0, 0, v);
263 
264 			impl->decode(mp, ep, plen);
265 			if (impl->decode_mod(ma, ea, alen, mp) != 1) {
266 				printf("Decode error\n");
267 				printf("  ea = ");
268 				print_z(a);
269 				printf("\n");
270 				printf("  p = ");
271 				print_u(mp);
272 				printf("\n");
273 				exit(EXIT_FAILURE);
274 			}
275 			mp0i = impl->ninv(mp[1]);
276 			if (impl->decode_mod(mb, eb, blen, mp) != 1) {
277 				printf("Decode error\n");
278 				printf("  eb = ");
279 				print_z(b);
280 				printf("\n");
281 				printf("  p = ");
282 				print_u(mp);
283 				printf("\n");
284 				exit(EXIT_FAILURE);
285 			}
286 			impl->decode(mv, ev, vlen);
287 			check_eqz(mp, p);
288 			check_eqz(ma, a);
289 			check_eqz(mb, b);
290 			check_eqz(mv, v);
291 
292 			impl->decode_mod(ma, ea, alen, mp);
293 			impl->decode_mod(mb, eb, blen, mp);
294 			ctl = impl->add(ma, mb, 1);
295 			ctl |= impl->sub(ma, mp, 0) ^ (uint32_t)1;
296 			impl->sub(ma, mp, ctl);
297 			mpz_add(t1, a, b);
298 			mpz_mod(t1, t1, p);
299 			check_eqz(ma, t1);
300 
301 			impl->decode_mod(ma, ea, alen, mp);
302 			impl->decode_mod(mb, eb, blen, mp);
303 			impl->add(ma, mp, impl->sub(ma, mb, 1));
304 			mpz_sub(t1, a, b);
305 			mpz_mod(t1, t1, p);
306 			check_eqz(ma, t1);
307 
308 			impl->decode_reduce(ma, ev, vlen, mp);
309 			mpz_mod(t1, v, p);
310 			check_eqz(ma, t1);
311 
312 			impl->decode(mv, ev, vlen);
313 			impl->reduce(ma, mv, mp);
314 			mpz_mod(t1, v, p);
315 			check_eqz(ma, t1);
316 
317 			impl->decode_mod(ma, ea, alen, mp);
318 			impl->to_monty(ma, mp);
319 			mpz_mul_2exp(t1, a, ((k + impl->word_size - 1)
320 				/ impl->word_size) * impl->word_size);
321 			mpz_mod(t1, t1, p);
322 			check_eqz(ma, t1);
323 			impl->from_monty(ma, mp, mp0i);
324 			check_eqz(ma, a);
325 
326 			impl->decode_mod(ma, ea, alen, mp);
327 			impl->decode_mod(mb, eb, blen, mp);
328 			impl->to_monty(ma, mp);
329 			impl->montymul(mt1, ma, mb, mp, mp0i);
330 			mpz_mul(t1, a, b);
331 			mpz_mod(t1, t1, p);
332 			check_eqz(mt1, t1);
333 
334 			impl->decode_mod(ma, ea, alen, mp);
335 			impl->modpow(ma, ev, vlen, mp, mp0i, mt1, mt2);
336 			mpz_powm(t1, a, v, p);
337 			check_eqz(ma, t1);
338 
339 			/*
340 			br_modint_decode(ma, mp, ea, alen);
341 			br_modint_decode(mb, mp, eb, blen);
342 			if (!br_modint_div(ma, mb, mp, mt1, mt2, mt3)) {
343 				fprintf(stderr, "division failed\n");
344 				exit(EXIT_FAILURE);
345 			}
346 			mpz_sub_ui(t1, p, 2);
347 			mpz_powm(t1, b, t1, p);
348 			mpz_mul(t1, a, t1);
349 			mpz_mod(t1, t1, p);
350 			check_eqz(ma, t1);
351 
352 			br_modint_decode(ma, mp, ea, alen);
353 			br_modint_decode(mb, mp, eb, blen);
354 			for (j = 0; j <= (2 * k + 5); j ++) {
355 				br_int_add(mx, j, ma, mb);
356 				mpz_add(t1, a, b);
357 				mpz_tdiv_r_2exp(t1, t1, j);
358 				check_eqz(mx, t1);
359 
360 				br_int_mul(mx, j, ma, mb);
361 				mpz_mul(t1, a, b);
362 				mpz_tdiv_r_2exp(t1, t1, j);
363 				check_eqz(mx, t1);
364 			}
365 			*/
366 		}
367 		printf(".");
368 		fflush(stdout);
369 	}
370 	mpz_clear(p);
371 	mpz_clear(a);
372 	mpz_clear(b);
373 	mpz_clear(v);
374 	mpz_clear(t1);
375 
376 	printf(" done.\n");
377 	fflush(stdout);
378 }
379 
380 #if 0
381 static void
382 test_RSA_core(void)
383 {
384 	int i, j, k;
385 	mpz_t n, e, d, p, q, dp, dq, iq, t1, t2, phi;
386 
387 	printf("Test RSA core: ");
388 	fflush(stdout);
389 
390 	gmp_randinit_mt(RNG);
391 	mpz_init(n);
392 	mpz_init(e);
393 	mpz_init(d);
394 	mpz_init(p);
395 	mpz_init(q);
396 	mpz_init(dp);
397 	mpz_init(dq);
398 	mpz_init(iq);
399 	mpz_init(t1);
400 	mpz_init(t2);
401 	mpz_init(phi);
402 	mpz_set_ui(t1, (unsigned long)time(NULL));
403 	gmp_randseed(RNG, t1);
404 
405 	/*
406 	 * To test corner cases, we want to try RSA keys such that the
407 	 * lengths of both factors can be arbitrary modulo 2^32. Factors
408 	 * p and q need not be of the same length; p can be greater than
409 	 * q and q can be greater than p.
410 	 *
411 	 * To keep computation time reasonable, we use p and q factors of
412 	 * less than 128 bits; this is way too small for secure RSA,
413 	 * but enough to exercise all code paths (since we work only with
414 	 * 32-bit words).
415 	 */
416 	for (i = 64; i <= 96; i ++) {
417 		rand_prime(p, i);
418 		for (j = i - 33; j <= i + 33; j ++) {
419 			uint32_t mp[40], mq[40], mdp[40], mdq[40], miq[40];
420 
421 			/*
422 			 * Generate a RSA key pair, with p of length i bits,
423 			 * and q of length j bits.
424 			 */
425 			do {
426 				rand_prime(q, j);
427 			} while (mpz_cmp(p, q) == 0);
428 			mpz_mul(n, p, q);
429 			mpz_set_ui(e, 65537);
430 			mpz_sub_ui(t1, p, 1);
431 			mpz_sub_ui(t2, q, 1);
432 			mpz_mul(phi, t1, t2);
433 			mpz_invert(d, e, phi);
434 			mpz_mod(dp, d, t1);
435 			mpz_mod(dq, d, t2);
436 			mpz_invert(iq, q, p);
437 
438 			/*
439 			 * Convert the key pair elements to BearSSL arrays.
440 			 */
441 			mp_to_br(mp, mpz_sizeinbase(p, 2), p);
442 			mp_to_br(mq, mpz_sizeinbase(q, 2), q);
443 			mp_to_br(mdp, mpz_sizeinbase(dp, 2), dp);
444 			mp_to_br(mdq, mpz_sizeinbase(dq, 2), dq);
445 			mp_to_br(miq, mp[0], iq);
446 
447 			/*
448 			 * Compute and check ten public/private operations.
449 			 */
450 			for (k = 0; k < 10; k ++) {
451 				uint32_t mx[40];
452 
453 				mpz_urandomm(t1, RNG, n);
454 				mpz_powm(t2, t1, e, n);
455 				mp_to_br(mx, mpz_sizeinbase(n, 2), t2);
456 				br_rsa_private_core(mx, mp, mq, mdp, mdq, miq);
457 				check_eqz(mx, t1);
458 			}
459 		}
460 		printf(".");
461 		fflush(stdout);
462 	}
463 
464 	printf(" done.\n");
465 	fflush(stdout);
466 }
467 #endif
468 
469 int
470 main(void)
471 {
472 	printf("===== i32 ======\n");
473 	impl = &i32_impl;
474 	test_modint();
475 	printf("===== i31 ======\n");
476 	impl = &i31_impl;
477 	test_modint();
478 	/*
479 	test_RSA_core();
480 	*/
481 	return 0;
482 }
483