xref: /freebsd/lib/libmp/mpasbn.c (revision 6990ffd8a95caaba6858ad44ff1b3157d1efba8f)
1 /*
2  * Copyright (c) 2001 Dima Dorfman.
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  */
26 
27 /*
28  * This is the traditional Berkeley MP library implemented in terms of
29  * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
30  * is meant to be as compatible with the latter as feasible.
31  *
32  * There seems to be a lack of documentation for the Berkeley MP
33  * interface.  All I could find was libgmp documentation (which didn't
34  * talk about the semantics of the functions) and an old SunOS 4.1
35  * manual page from 1989.  The latter wasn't very detailed, either,
36  * but at least described what the function's arguments were.  In
37  * general the interface seems to be archaic, somewhat poorly
38  * designed, and poorly, if at all, documented.  It is considered
39  * harmful.
40  *
41  * Miscellaneous notes on this implementation:
42  *
43  *  - The SunOS manual page mentioned above indicates that if an error
44  *  occurs, the library should "produce messages and core images."
45  *  Given that most of the functions don't have return values (and
46  *  thus no sane way of alerting the caller to an error), this seems
47  *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
48  *  respectively, then abort().
49  *
50  *  - All the functions which take an argument to be "filled in"
51  *  assume that the argument has been initialized by one of the *tom()
52  *  routines before being passed to it.  I never saw this documented
53  *  anywhere, but this seems to be consistent with the way this
54  *  library is used.
55  *
56  *  - msqrt() is the only routine which had to be implemented which
57  *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
58  *  It was implemented by hand using Newton's recursive formula.
59  *  Doing it this way, although more error-prone, has the positive
60  *  sideaffect of testing a lot of other functions; if msqrt()
61  *  produces the correct results, most of the other routines will as
62  *  well.
63  *
64  *  - Internal-use-only routines (i.e., those defined here statically
65  *  and not in mp.h) have an underscore prepended to their name (this
66  *  is more for aesthetical reasons than technical).  All such
67  *  routines take an extra argument, 'msg', that denotes what they
68  *  should call themselves in an error message.  This is so a user
69  *  doesn't get an error message from a function they didn't call.
70  */
71 
72 #ifndef lint
73 static const char rcsid[] =
74   "$FreeBSD$";
75 #endif /* not lint */
76 
77 #include <ctype.h>
78 #include <err.h>
79 #include <errno.h>
80 #include <stdio.h>
81 #include <stdlib.h>
82 #include <string.h>
83 
84 #include <openssl/bn.h>
85 #include <openssl/crypto.h>
86 #include <openssl/err.h>
87 
88 #include "mp.h"
89 
90 #define MPERR(s)	do { warn s; abort(); } while (0)
91 #define MPERRX(s)	do { warnx s; abort(); } while (0)
92 #define BN_ERRCHECK(msg, expr) do {		\
93 	if (!(expr)) _bnerr(msg);		\
94 } while (0)
95 
96 static void _bnerr(const char *);
97 static MINT *_dtom(const char *, const char *);
98 static MINT *_itom(const char *, short);
99 static void _madd(const char *, const MINT *, const MINT *, MINT *);
100 static int _mcmpa(const char *, const MINT *, const MINT *);
101 static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *);
102 static void _mfree(const char *, MINT *);
103 static void _moveb(const char *, const BIGNUM *, MINT *);
104 static void _movem(const char *, const MINT *, MINT *);
105 static void _msub(const char *, const MINT *, const MINT *, MINT *);
106 static char *_mtod(const char *, const MINT *);
107 static char *_mtox(const char *, const MINT *);
108 static void _mult(const char *, const MINT *, const MINT *, MINT *);
109 static void _sdiv(const char *, const MINT *, short, MINT *, short *);
110 static MINT *_xtom(const char *, const char *);
111 
112 /*
113  * Report an error from one of the BN_* functions using MPERRX.
114  */
115 static void
116 _bnerr(const char *msg)
117 {
118 
119 	ERR_load_crypto_strings();
120 	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
121 }
122 
123 /*
124  * Convert a decimal string to an MINT.
125  */
126 static MINT *
127 _dtom(const char *msg, const char *s)
128 {
129 	MINT *mp;
130 
131 	mp = malloc(sizeof(*mp));
132 	if (mp == NULL)
133 		MPERR(("%s", msg));
134 	mp->bn = BN_new();
135 	if (mp->bn == NULL)
136 		_bnerr(msg);
137 	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
138 	return (mp);
139 }
140 
141 /*
142  * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
143  */
144 void
145 gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
146 {
147 	BIGNUM b;
148 	BN_CTX c;
149 
150 	BN_CTX_init(&c);
151 	BN_init(&b);
152 	BN_ERRCHECK("gcd", BN_gcd(&b, mp1->bn, mp2->bn, &c));
153 	_moveb("gcd", &b, rmp);
154 	BN_free(&b);
155 	BN_CTX_free(&c);
156 }
157 
158 /*
159  * Make an MINT out of a short integer.  Return value must be mfree()'d.
160  */
161 static MINT *
162 _itom(const char *msg, short n)
163 {
164 	MINT *mp;
165 	char *s;
166 
167 	asprintf(&s, "%x", n);
168 	if (s == NULL)
169 		MPERR(("%s", msg));
170 	mp = _xtom(msg, s);
171 	free(s);
172 	return (mp);
173 }
174 
175 MINT *
176 itom(short n)
177 {
178 
179 	return (_itom("itom", n));
180 }
181 
182 /*
183  * Compute rmp=mp1+mp2.
184  */
185 static void
186 _madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
187 {
188 	BIGNUM b;
189 
190 	BN_init(&b);
191 	BN_ERRCHECK(msg, BN_add(&b, mp1->bn, mp2->bn));
192 	_moveb(msg, &b, rmp);
193 	BN_free(&b);
194 }
195 
196 void
197 madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
198 {
199 
200 	_madd("madd", mp1, mp2, rmp);
201 }
202 
203 /*
204  * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
205  */
206 int
207 mcmp(const MINT *mp1, const MINT *mp2)
208 {
209 
210 	return (BN_cmp(mp1->bn, mp2->bn));
211 }
212 
213 /*
214  * Same as mcmp but compares absolute values.
215  */
216 static int
217 _mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
218 {
219 
220 	return (BN_ucmp(mp1->bn, mp2->bn));
221 }
222 
223 /*
224  * Compute qmp=nmp/dmp and rmp=nmp%dmp.
225  */
226 static void
227 _mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
228 {
229 	BIGNUM q, r;
230 	BN_CTX c;
231 
232 	BN_CTX_init(&c);
233 	BN_init(&r);
234 	BN_init(&q);
235 	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, &c));
236 	_moveb(msg, &q, qmp);
237 	_moveb(msg, &r, rmp);
238 	BN_free(&q);
239 	BN_free(&r);
240 	BN_CTX_free(&c);
241 }
242 
243 void
244 mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
245 {
246 
247 	_mdiv("mdiv", nmp, dmp, qmp, rmp);
248 }
249 
250 /*
251  * Free memory associated with an MINT.
252  */
253 static void
254 _mfree(const char *msg __unused, MINT *mp)
255 {
256 
257 	BN_clear(mp->bn);
258 	BN_free(mp->bn);
259 	free(mp);
260 }
261 
262 void
263 mfree(MINT *mp)
264 {
265 
266 	_mfree("mfree", mp);
267 }
268 
269 /*
270  * Read an integer from standard input and stick the result in mp.
271  * The input is treated to be in base 10.  This must be the silliest
272  * API in existence; why can't the program read in a string and call
273  * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
274  * exported.)
275  */
276 void
277 min(MINT *mp)
278 {
279 	MINT *rmp;
280 	char *line, *nline;
281 	size_t linelen;
282 
283 	line = fgetln(stdin, &linelen);
284 	if (line == NULL)
285 		MPERR(("min"));
286 	nline = malloc(linelen);
287 	if (nline == NULL)
288 		MPERR(("min"));
289 	strncpy(nline, line, linelen);
290 	nline[linelen] = '\0';
291 	rmp = _dtom("min", nline);
292 	_movem("min", rmp, mp);
293 	_mfree("min", rmp);
294 	free(nline);
295 }
296 
297 /*
298  * Print the value of mp to standard output in base 10.  See blurb
299  * above min() for why this is so useless.
300  */
301 void
302 mout(const MINT *mp)
303 {
304 	char *s;
305 
306 	s = _mtod("mout", mp);
307 	printf("%s", s);
308 	free(s);
309 }
310 
311 /*
312  * Set the value of tmp to the value of smp (i.e., tmp=smp).
313  */
314 void
315 move(const MINT *smp, MINT *tmp)
316 {
317 
318 	_movem("move", smp, tmp);
319 }
320 
321 
322 /*
323  * Internal routine to set the value of tmp to that of sbp.
324  */
325 static void
326 _moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
327 {
328 
329 	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
330 }
331 
332 /*
333  * Internal routine to set the value of tmp to that of smp.
334  */
335 static void
336 _movem(const char *msg, const MINT *smp, MINT *tmp)
337 {
338 
339 	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
340 }
341 
342 /*
343  * Compute the square root of nmp and put the result in xmp.  The
344  * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
345  *
346  * Note that the OpenSSL BIGNUM library does not have a square root
347  * function, so this had to be implemented by hand using Newton's
348  * recursive formula:
349  *
350  *		x = (x + (n / x)) / 2
351  *
352  * where x is the square root of the positive number n.  In the
353  * beginning, x should be a reasonable guess, but the value 1,
354  * although suboptimal, works, too; this is that is used below.
355  */
356 void
357 msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
358 {
359 	MINT *tolerance;
360 	MINT *ox, *x;
361 	MINT *z1, *z2, *z3;
362 	short i;
363 
364 	tolerance = _itom("msqrt", 1);
365 	x = _itom("msqrt", 1);
366 	ox = _itom("msqrt", 0);
367 	z1 = _itom("msqrt", 0);
368 	z2 = _itom("msqrt", 0);
369 	z3 = _itom("msqrt", 0);
370 	do {
371 		_movem("msqrt", x, ox);
372 		_mdiv("msqrt", nmp, x, z1, z2);
373 		_madd("msqrt", x, z1, z2);
374 		_sdiv("msqrt", z2, 2, x, &i);
375 		_msub("msqrt", ox, x, z3);
376 	} while (_mcmpa("msqrt", z3, tolerance) == 1);
377 	_movem("msqrt", x, xmp);
378 	_mult("msqrt", x, x, z1);
379 	_msub("msqrt", nmp, z1, z2);
380 	_movem("msqrt", z2, rmp);
381 	_mfree("msqrt", tolerance);
382 	_mfree("msqrt", ox);
383 	_mfree("msqrt", x);
384 	_mfree("msqrt", z1);
385 	_mfree("msqrt", z2);
386 	_mfree("msqrt", z3);
387 }
388 
389 /*
390  * Compute rmp=mp1-mp2.
391  */
392 static void
393 _msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
394 {
395 	BIGNUM b;
396 
397 	BN_init(&b);
398 	BN_ERRCHECK(msg, BN_sub(&b, mp1->bn, mp2->bn));
399 	_moveb(msg, &b, rmp);
400 	BN_free(&b);
401 }
402 
403 void
404 msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
405 {
406 
407 	_msub("msub", mp1, mp2, rmp);
408 }
409 
410 /*
411  * Return a decimal representation of mp.  Return value must be
412  * free()'d.
413  */
414 static char *
415 _mtod(const char *msg, const MINT *mp)
416 {
417 	char *s, *s2;
418 
419 	s = BN_bn2dec(mp->bn);
420 	if (s == NULL)
421 		_bnerr(msg);
422 	asprintf(&s2, "%s", s);
423 	if (s2 == NULL)
424 		MPERR(("%s", msg));
425 	OPENSSL_free(s);
426 	return (s2);
427 }
428 
429 /*
430  * Return a hexadecimal representation of mp.  Return value must be
431  * free()'d.
432  */
433 static char *
434 _mtox(const char *msg, const MINT *mp)
435 {
436 	char *p, *s, *s2;
437 	int len;
438 
439 	s = BN_bn2hex(mp->bn);
440 	if (s == NULL)
441 		_bnerr(msg);
442 	asprintf(&s2, "%s", s);
443 	if (s2 == NULL)
444 		MPERR(("%s", msg));
445 	OPENSSL_free(s);
446 
447 	/*
448 	 * This is a kludge for libgmp compatibility.  The latter's
449 	 * implementation of this function returns lower-case letters,
450 	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
451 	 * newkey(1)) are sensitive to this.  Although it's probably
452 	 * their fault, it's nice to be compatible.
453 	 */
454 	len = strlen(s2);
455 	for (p = s2; p < s2 + len; p++)
456 		*p = tolower(*p);
457 
458 	return (s2);
459 }
460 
461 char *
462 mtox(const MINT *mp)
463 {
464 
465 	return (_mtox("mtox", mp));
466 }
467 
468 /*
469  * Compute rmp=mp1*mp2.
470  */
471 static void
472 _mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
473 {
474 	BIGNUM b;
475 	BN_CTX c;
476 
477 	BN_CTX_init(&c);
478 	BN_init(&b);
479 	BN_ERRCHECK(msg, BN_mul(&b, mp1->bn, mp2->bn, &c));
480 	_moveb(msg, &b, rmp);
481 	BN_free(&b);
482 	BN_CTX_free(&c);
483 }
484 
485 void
486 mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
487 {
488 
489 	_mult("mult", mp1, mp2, rmp);
490 }
491 
492 /*
493  * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
494  * means 'raise to power', not 'bitwise XOR'.)
495  */
496 void
497 pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
498 {
499 	BIGNUM b;
500 	BN_CTX c;
501 
502 	BN_CTX_init(&c);
503 	BN_init(&b);
504 	BN_ERRCHECK("pow", BN_mod_exp(&b, bmp->bn, emp->bn, mmp->bn, &c));
505 	_moveb("pow", &b, rmp);
506 	BN_free(&b);
507 	BN_CTX_free(&c);
508 }
509 
510 /*
511  * Compute rmp=bmp^e.  (See note above pow().)
512  */
513 void
514 rpow(const MINT *bmp, short e, MINT *rmp)
515 {
516 	MINT *emp;
517 	BIGNUM b;
518 	BN_CTX c;
519 
520 	BN_CTX_init(&c);
521 	BN_init(&b);
522 	emp = _itom("rpow", e);
523 	BN_ERRCHECK("rpow", BN_exp(&b, bmp->bn, emp->bn, &c));
524 	_moveb("rpow", &b, rmp);
525 	_mfree("rpow", emp);
526 	BN_free(&b);
527 	BN_CTX_free(&c);
528 }
529 
530 /*
531  * Compute qmp=nmp/d and ro=nmp%d.
532  */
533 static void
534 _sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro)
535 {
536 	MINT *dmp, *rmp;
537 	BIGNUM q, r;
538 	BN_CTX c;
539 	char *s;
540 
541 	BN_CTX_init(&c);
542 	BN_init(&q);
543 	BN_init(&r);
544 	dmp = _itom(msg, d);
545 	rmp = _itom(msg, 0);
546 	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, &c));
547 	_moveb(msg, &q, qmp);
548 	_moveb(msg, &r, rmp);
549 	s = _mtox(msg, rmp);
550 	errno = 0;
551 	*ro = strtol(s, NULL, 16);
552 	if (errno != 0)
553 		MPERR(("%s underflow or overflow", msg));
554 	free(s);
555 	_mfree(msg, dmp);
556 	_mfree(msg, rmp);
557 	BN_free(&r);
558 	BN_free(&q);
559 	BN_CTX_free(&c);
560 }
561 
562 void
563 sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
564 {
565 
566 	_sdiv("sdiv", nmp, d, qmp, ro);
567 }
568 
569 /*
570  * Convert a hexadecimal string to an MINT.
571  */
572 static MINT *
573 _xtom(const char *msg, const char *s)
574 {
575 	MINT *mp;
576 
577 	mp = malloc(sizeof(*mp));
578 	if (mp == NULL)
579 		MPERR(("%s", msg));
580 	mp->bn = BN_new();
581 	if (mp->bn == NULL)
582 		_bnerr(msg);
583 	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
584 	return (mp);
585 }
586 
587 MINT *
588 xtom(const char *s)
589 {
590 
591 	return (_xtom("xtom", s));
592 }
593