xref: /freebsd/contrib/wpa/src/tls/libtommath.c (revision a90b9d0159070121c221b966469c3e36d912bf82)
1 /*
2  * Minimal code for RSA support from LibTomMath 0.41
3  * http://libtom.org/
4  * http://libtom.org/files/ltm-0.41.tar.bz2
5  * This library was released in public domain by Tom St Denis.
6  *
7  * The combination in this file may not use all of the optimized algorithms
8  * from LibTomMath and may be considerable slower than the LibTomMath with its
9  * default settings. The main purpose of having this version here is to make it
10  * easier to build bignum.c wrapper without having to install and build an
11  * external library.
12  *
13  * If CONFIG_INTERNAL_LIBTOMMATH is defined, bignum.c includes this
14  * libtommath.c file instead of using the external LibTomMath library.
15  */
16 
17 #ifndef CHAR_BIT
18 #define CHAR_BIT 8
19 #endif
20 
21 #define BN_MP_INVMOD_C
22 #define BN_S_MP_EXPTMOD_C /* Note: #undef in tommath_superclass.h; this would
23 			   * require BN_MP_EXPTMOD_FAST_C instead */
24 #define BN_S_MP_MUL_DIGS_C
25 #define BN_MP_INVMOD_SLOW_C
26 #define BN_S_MP_SQR_C
27 #define BN_S_MP_MUL_HIGH_DIGS_C /* Note: #undef in tommath_superclass.h; this
28 				 * would require other than mp_reduce */
29 
30 #ifdef LTM_FAST
31 
32 /* Use faster div at the cost of about 1 kB */
33 #define BN_MP_MUL_D_C
34 
35 /* Include faster exptmod (Montgomery) at the cost of about 2.5 kB in code */
36 #define BN_MP_EXPTMOD_FAST_C
37 #define BN_MP_MONTGOMERY_SETUP_C
38 #define BN_FAST_MP_MONTGOMERY_REDUCE_C
39 #define BN_MP_MONTGOMERY_CALC_NORMALIZATION_C
40 #define BN_MP_MUL_2_C
41 
42 /* Include faster sqr at the cost of about 0.5 kB in code */
43 #define BN_FAST_S_MP_SQR_C
44 
45 /* About 0.25 kB of code, but ~1.7kB of stack space! */
46 #define BN_FAST_S_MP_MUL_DIGS_C
47 
48 #else /* LTM_FAST */
49 
50 #define BN_MP_DIV_SMALL
51 #define BN_MP_INIT_MULTI_C
52 #define BN_MP_CLEAR_MULTI_C
53 #define BN_MP_ABS_C
54 #endif /* LTM_FAST */
55 
56 /* Current uses do not require support for negative exponent in exptmod, so we
57  * can save about 1.5 kB in leaving out invmod. */
58 #define LTM_NO_NEG_EXP
59 
60 /* from tommath.h */
61 
62 #define  OPT_CAST(x)
63 
64 #ifdef __x86_64__
65 typedef unsigned long mp_digit;
66 typedef unsigned long mp_word __attribute__((mode(TI)));
67 
68 #define DIGIT_BIT 60
69 #define MP_64BIT
70 #else
71 typedef unsigned long mp_digit;
72 typedef u64 mp_word;
73 
74 #define DIGIT_BIT          28
75 #define MP_28BIT
76 #endif
77 
78 
79 #define XMALLOC  os_malloc
80 #define XFREE    os_free
81 #define XREALLOC os_realloc
82 
83 
84 #define MP_MASK          ((((mp_digit)1)<<((mp_digit)DIGIT_BIT))-((mp_digit)1))
85 
86 #define MP_LT        -1   /* less than */
87 #define MP_EQ         0   /* equal to */
88 #define MP_GT         1   /* greater than */
89 
90 #define MP_ZPOS       0   /* positive integer */
91 #define MP_NEG        1   /* negative */
92 
93 #define MP_OKAY       0   /* ok result */
94 #define MP_MEM        -2  /* out of mem */
95 #define MP_VAL        -3  /* invalid input */
96 
97 #define MP_YES        1   /* yes response */
98 #define MP_NO         0   /* no response */
99 
100 typedef int           mp_err;
101 
102 /* define this to use lower memory usage routines (exptmods mostly) */
103 #define MP_LOW_MEM
104 
105 /* default precision */
106 #ifndef MP_PREC
107    #ifndef MP_LOW_MEM
108       #define MP_PREC                 32     /* default digits of precision */
109    #else
110       #define MP_PREC                 8      /* default digits of precision */
111    #endif
112 #endif
113 
114 /* size of comba arrays, should be at least 2 * 2**(BITS_PER_WORD - BITS_PER_DIGIT*2) */
115 #define MP_WARRAY               (1 << (sizeof(mp_word) * CHAR_BIT - 2 * DIGIT_BIT + 1))
116 
117 /* the infamous mp_int structure */
118 typedef struct  {
119     int used, alloc, sign;
120     mp_digit *dp;
121 } mp_int;
122 
123 
124 /* ---> Basic Manipulations <--- */
125 #define mp_iszero(a) (((a)->used == 0) ? MP_YES : MP_NO)
126 #define mp_iseven(a) (((a)->used > 0 && (((a)->dp[0] & 1) == 0)) ? MP_YES : MP_NO)
127 #define mp_isodd(a)  (((a)->used > 0 && (((a)->dp[0] & 1) == 1)) ? MP_YES : MP_NO)
128 
129 
130 /* prototypes for copied functions */
131 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
132 static int s_mp_exptmod(mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode);
133 static int s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs);
134 static int s_mp_sqr(mp_int * a, mp_int * b);
135 static int s_mp_mul_high_digs(mp_int * a, mp_int * b, mp_int * c, int digs);
136 
137 #ifdef BN_FAST_S_MP_MUL_DIGS_C
138 static int fast_s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs);
139 #endif
140 
141 #ifdef BN_MP_INIT_MULTI_C
142 static int mp_init_multi(mp_int *mp, ...);
143 #endif
144 #ifdef BN_MP_CLEAR_MULTI_C
145 static void mp_clear_multi(mp_int *mp, ...);
146 #endif
147 static int mp_lshd(mp_int * a, int b);
148 static void mp_set(mp_int * a, mp_digit b);
149 static void mp_clamp(mp_int * a);
150 static void mp_exch(mp_int * a, mp_int * b);
151 static void mp_rshd(mp_int * a, int b);
152 static void mp_zero(mp_int * a);
153 static int mp_mod_2d(mp_int * a, int b, mp_int * c);
154 static int mp_div_2d(mp_int * a, int b, mp_int * c, mp_int * d);
155 static int mp_init_copy(mp_int * a, mp_int * b);
156 static int mp_mul_2d(mp_int * a, int b, mp_int * c);
157 #ifndef LTM_NO_NEG_EXP
158 static int mp_div_2(mp_int * a, mp_int * b);
159 static int mp_invmod(mp_int * a, mp_int * b, mp_int * c);
160 static int mp_invmod_slow(mp_int * a, mp_int * b, mp_int * c);
161 #endif /* LTM_NO_NEG_EXP */
162 static int mp_copy(mp_int * a, mp_int * b);
163 static int mp_count_bits(mp_int * a);
164 static int mp_div(mp_int * a, mp_int * b, mp_int * c, mp_int * d);
165 static int mp_mod(mp_int * a, mp_int * b, mp_int * c);
166 static int mp_grow(mp_int * a, int size);
167 static int mp_cmp_mag(mp_int * a, mp_int * b);
168 #ifdef BN_MP_ABS_C
169 static int mp_abs(mp_int * a, mp_int * b);
170 #endif
171 static int mp_sqr(mp_int * a, mp_int * b);
172 static int mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d);
173 static int mp_reduce_2k_setup_l(mp_int *a, mp_int *d);
174 static int mp_2expt(mp_int * a, int b);
175 static int mp_reduce_setup(mp_int * a, mp_int * b);
176 static int mp_reduce(mp_int * x, mp_int * m, mp_int * mu);
177 static int mp_init_size(mp_int * a, int size);
178 #ifdef BN_MP_EXPTMOD_FAST_C
179 static int mp_exptmod_fast (mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode);
180 #endif /* BN_MP_EXPTMOD_FAST_C */
181 #ifdef BN_FAST_S_MP_SQR_C
182 static int fast_s_mp_sqr (mp_int * a, mp_int * b);
183 #endif /* BN_FAST_S_MP_SQR_C */
184 #ifdef BN_MP_MUL_D_C
185 static int mp_mul_d (mp_int * a, mp_digit b, mp_int * c);
186 #endif /* BN_MP_MUL_D_C */
187 
188 
189 
190 /* functions from bn_<func name>.c */
191 
192 
193 /* reverse an array, used for radix code */
bn_reverse(unsigned char * s,int len)194 static void bn_reverse (unsigned char *s, int len)
195 {
196   int     ix, iy;
197   unsigned char t;
198 
199   ix = 0;
200   iy = len - 1;
201   while (ix < iy) {
202     t     = s[ix];
203     s[ix] = s[iy];
204     s[iy] = t;
205     ++ix;
206     --iy;
207   }
208 }
209 
210 
211 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
s_mp_add(mp_int * a,mp_int * b,mp_int * c)212 static int s_mp_add (mp_int * a, mp_int * b, mp_int * c)
213 {
214   mp_int *x;
215   int     olduse, res, min, max;
216 
217   /* find sizes, we let |a| <= |b| which means we have to sort
218    * them.  "x" will point to the input with the most digits
219    */
220   if (a->used > b->used) {
221     min = b->used;
222     max = a->used;
223     x = a;
224   } else {
225     min = a->used;
226     max = b->used;
227     x = b;
228   }
229 
230   /* init result */
231   if (c->alloc < max + 1) {
232     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
233       return res;
234     }
235   }
236 
237   /* get old used digit count and set new one */
238   olduse = c->used;
239   c->used = max + 1;
240 
241   {
242     register mp_digit u, *tmpa, *tmpb, *tmpc;
243     register int i;
244 
245     /* alias for digit pointers */
246 
247     /* first input */
248     tmpa = a->dp;
249 
250     /* second input */
251     tmpb = b->dp;
252 
253     /* destination */
254     tmpc = c->dp;
255 
256     /* zero the carry */
257     u = 0;
258     for (i = 0; i < min; i++) {
259       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
260       *tmpc = *tmpa++ + *tmpb++ + u;
261 
262       /* U = carry bit of T[i] */
263       u = *tmpc >> ((mp_digit)DIGIT_BIT);
264 
265       /* take away carry bit from T[i] */
266       *tmpc++ &= MP_MASK;
267     }
268 
269     /* now copy higher words if any, that is in A+B
270      * if A or B has more digits add those in
271      */
272     if (min != max) {
273       for (; i < max; i++) {
274         /* T[i] = X[i] + U */
275         *tmpc = x->dp[i] + u;
276 
277         /* U = carry bit of T[i] */
278         u = *tmpc >> ((mp_digit)DIGIT_BIT);
279 
280         /* take away carry bit from T[i] */
281         *tmpc++ &= MP_MASK;
282       }
283     }
284 
285     /* add carry */
286     *tmpc++ = u;
287 
288     /* clear digits above oldused */
289     for (i = c->used; i < olduse; i++) {
290       *tmpc++ = 0;
291     }
292   }
293 
294   mp_clamp (c);
295   return MP_OKAY;
296 }
297 
298 
299 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
s_mp_sub(mp_int * a,mp_int * b,mp_int * c)300 static int s_mp_sub (mp_int * a, mp_int * b, mp_int * c)
301 {
302   int     olduse, res, min, max;
303 
304   /* find sizes */
305   min = b->used;
306   max = a->used;
307 
308   /* init result */
309   if (c->alloc < max) {
310     if ((res = mp_grow (c, max)) != MP_OKAY) {
311       return res;
312     }
313   }
314   olduse = c->used;
315   c->used = max;
316 
317   {
318     register mp_digit u, *tmpa, *tmpb, *tmpc;
319     register int i;
320 
321     /* alias for digit pointers */
322     tmpa = a->dp;
323     tmpb = b->dp;
324     tmpc = c->dp;
325 
326     /* set carry to zero */
327     u = 0;
328     for (i = 0; i < min; i++) {
329       /* T[i] = A[i] - B[i] - U */
330       *tmpc = *tmpa++ - *tmpb++ - u;
331 
332       /* U = carry bit of T[i]
333        * Note this saves performing an AND operation since
334        * if a carry does occur it will propagate all the way to the
335        * MSB.  As a result a single shift is enough to get the carry
336        */
337       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
338 
339       /* Clear carry from T[i] */
340       *tmpc++ &= MP_MASK;
341     }
342 
343     /* now copy higher words if any, e.g. if A has more digits than B  */
344     for (; i < max; i++) {
345       /* T[i] = A[i] - U */
346       *tmpc = *tmpa++ - u;
347 
348       /* U = carry bit of T[i] */
349       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
350 
351       /* Clear carry from T[i] */
352       *tmpc++ &= MP_MASK;
353     }
354 
355     /* clear digits above used (since we may not have grown result above) */
356     for (i = c->used; i < olduse; i++) {
357       *tmpc++ = 0;
358     }
359   }
360 
361   mp_clamp (c);
362   return MP_OKAY;
363 }
364 
365 
366 /* init a new mp_int */
mp_init(mp_int * a)367 static int mp_init (mp_int * a)
368 {
369   int i;
370 
371   /* allocate memory required and clear it */
372   a->dp = OPT_CAST(mp_digit) XMALLOC (sizeof (mp_digit) * MP_PREC);
373   if (a->dp == NULL) {
374     return MP_MEM;
375   }
376 
377   /* set the digits to zero */
378   for (i = 0; i < MP_PREC; i++) {
379       a->dp[i] = 0;
380   }
381 
382   /* set the used to zero, allocated digits to the default precision
383    * and sign to positive */
384   a->used  = 0;
385   a->alloc = MP_PREC;
386   a->sign  = MP_ZPOS;
387 
388   return MP_OKAY;
389 }
390 
391 
392 /* clear one (frees)  */
mp_clear(mp_int * a)393 static void mp_clear (mp_int * a)
394 {
395   int i;
396 
397   /* only do anything if a hasn't been freed previously */
398   if (a->dp != NULL) {
399     /* first zero the digits */
400     for (i = 0; i < a->used; i++) {
401         a->dp[i] = 0;
402     }
403 
404     /* free ram */
405     XFREE(a->dp);
406 
407     /* reset members to make debugging easier */
408     a->dp    = NULL;
409     a->alloc = a->used = 0;
410     a->sign  = MP_ZPOS;
411   }
412 }
413 
414 
415 /* high level addition (handles signs) */
mp_add(mp_int * a,mp_int * b,mp_int * c)416 static int mp_add (mp_int * a, mp_int * b, mp_int * c)
417 {
418   int     sa, sb, res;
419 
420   /* get sign of both inputs */
421   sa = a->sign;
422   sb = b->sign;
423 
424   /* handle two cases, not four */
425   if (sa == sb) {
426     /* both positive or both negative */
427     /* add their magnitudes, copy the sign */
428     c->sign = sa;
429     res = s_mp_add (a, b, c);
430   } else {
431     /* one positive, the other negative */
432     /* subtract the one with the greater magnitude from */
433     /* the one of the lesser magnitude.  The result gets */
434     /* the sign of the one with the greater magnitude. */
435     if (mp_cmp_mag (a, b) == MP_LT) {
436       c->sign = sb;
437       res = s_mp_sub (b, a, c);
438     } else {
439       c->sign = sa;
440       res = s_mp_sub (a, b, c);
441     }
442   }
443   return res;
444 }
445 
446 
447 /* high level subtraction (handles signs) */
mp_sub(mp_int * a,mp_int * b,mp_int * c)448 static int mp_sub (mp_int * a, mp_int * b, mp_int * c)
449 {
450   int     sa, sb, res;
451 
452   sa = a->sign;
453   sb = b->sign;
454 
455   if (sa != sb) {
456     /* subtract a negative from a positive, OR */
457     /* subtract a positive from a negative. */
458     /* In either case, ADD their magnitudes, */
459     /* and use the sign of the first number. */
460     c->sign = sa;
461     res = s_mp_add (a, b, c);
462   } else {
463     /* subtract a positive from a positive, OR */
464     /* subtract a negative from a negative. */
465     /* First, take the difference between their */
466     /* magnitudes, then... */
467     if (mp_cmp_mag (a, b) != MP_LT) {
468       /* Copy the sign from the first */
469       c->sign = sa;
470       /* The first has a larger or equal magnitude */
471       res = s_mp_sub (a, b, c);
472     } else {
473       /* The result has the *opposite* sign from */
474       /* the first number. */
475       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
476       /* The second has a larger magnitude */
477       res = s_mp_sub (b, a, c);
478     }
479   }
480   return res;
481 }
482 
483 
484 /* high level multiplication (handles sign) */
mp_mul(mp_int * a,mp_int * b,mp_int * c)485 static int mp_mul (mp_int * a, mp_int * b, mp_int * c)
486 {
487   int     res, neg;
488   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
489 
490   /* use Toom-Cook? */
491 #ifdef BN_MP_TOOM_MUL_C
492   if (MIN (a->used, b->used) >= TOOM_MUL_CUTOFF) {
493     res = mp_toom_mul(a, b, c);
494   } else
495 #endif
496 #ifdef BN_MP_KARATSUBA_MUL_C
497   /* use Karatsuba? */
498   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
499     res = mp_karatsuba_mul (a, b, c);
500   } else
501 #endif
502   {
503     /* can we use the fast multiplier?
504      *
505      * The fast multiplier can be used if the output will
506      * have less than MP_WARRAY digits and the number of
507      * digits won't affect carry propagation
508      */
509 #ifdef BN_FAST_S_MP_MUL_DIGS_C
510     int     digs = a->used + b->used + 1;
511 
512     if ((digs < MP_WARRAY) &&
513         MIN(a->used, b->used) <=
514         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
515       res = fast_s_mp_mul_digs (a, b, c, digs);
516     } else
517 #endif
518 #ifdef BN_S_MP_MUL_DIGS_C
519       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
520 #else
521 #error mp_mul could fail
522       res = MP_VAL;
523 #endif
524 
525   }
526   c->sign = (c->used > 0) ? neg : MP_ZPOS;
527   return res;
528 }
529 
530 
531 /* d = a * b (mod c) */
mp_mulmod(mp_int * a,mp_int * b,mp_int * c,mp_int * d)532 static int mp_mulmod (mp_int * a, mp_int * b, mp_int * c, mp_int * d)
533 {
534   int     res;
535   mp_int  t;
536 
537   if ((res = mp_init (&t)) != MP_OKAY) {
538     return res;
539   }
540 
541   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
542     mp_clear (&t);
543     return res;
544   }
545   res = mp_mod (&t, c, d);
546   mp_clear (&t);
547   return res;
548 }
549 
550 
551 /* c = a mod b, 0 <= c < b */
mp_mod(mp_int * a,mp_int * b,mp_int * c)552 static int mp_mod (mp_int * a, mp_int * b, mp_int * c)
553 {
554   mp_int  t;
555   int     res;
556 
557   if ((res = mp_init (&t)) != MP_OKAY) {
558     return res;
559   }
560 
561   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
562     mp_clear (&t);
563     return res;
564   }
565 
566   if (t.sign != b->sign) {
567     res = mp_add (b, &t, c);
568   } else {
569     res = MP_OKAY;
570     mp_exch (&t, c);
571   }
572 
573   mp_clear (&t);
574   return res;
575 }
576 
577 
578 /* this is a shell function that calls either the normal or Montgomery
579  * exptmod functions.  Originally the call to the montgomery code was
580  * embedded in the normal function but that wasted a lot of stack space
581  * for nothing (since 99% of the time the Montgomery code would be called)
582  */
mp_exptmod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y)583 static int mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y)
584 {
585   int dr;
586 
587   /* modulus P must be positive */
588   if (P->sign == MP_NEG) {
589      return MP_VAL;
590   }
591 
592   /* if exponent X is negative we have to recurse */
593   if (X->sign == MP_NEG) {
594 #ifdef LTM_NO_NEG_EXP
595         return MP_VAL;
596 #else /* LTM_NO_NEG_EXP */
597 #ifdef BN_MP_INVMOD_C
598      mp_int tmpG, tmpX;
599      int err;
600 
601      /* first compute 1/G mod P */
602      if ((err = mp_init(&tmpG)) != MP_OKAY) {
603         return err;
604      }
605      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
606         mp_clear(&tmpG);
607         return err;
608      }
609 
610      /* now get |X| */
611      if ((err = mp_init(&tmpX)) != MP_OKAY) {
612         mp_clear(&tmpG);
613         return err;
614      }
615      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
616         mp_clear_multi(&tmpG, &tmpX, NULL);
617         return err;
618      }
619 
620      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
621      err = mp_exptmod(&tmpG, &tmpX, P, Y);
622      mp_clear_multi(&tmpG, &tmpX, NULL);
623      return err;
624 #else
625 #error mp_exptmod would always fail
626      /* no invmod */
627      return MP_VAL;
628 #endif
629 #endif /* LTM_NO_NEG_EXP */
630   }
631 
632 /* modified diminished radix reduction */
633 #if defined(BN_MP_REDUCE_IS_2K_L_C) && defined(BN_MP_REDUCE_2K_L_C) && defined(BN_S_MP_EXPTMOD_C)
634   if (mp_reduce_is_2k_l(P) == MP_YES) {
635      return s_mp_exptmod(G, X, P, Y, 1);
636   }
637 #endif
638 
639 #ifdef BN_MP_DR_IS_MODULUS_C
640   /* is it a DR modulus? */
641   dr = mp_dr_is_modulus(P);
642 #else
643   /* default to no */
644   dr = 0;
645 #endif
646 
647 #ifdef BN_MP_REDUCE_IS_2K_C
648   /* if not, is it a unrestricted DR modulus? */
649   if (dr == 0) {
650      dr = mp_reduce_is_2k(P) << 1;
651   }
652 #endif
653 
654   /* if the modulus is odd or dr != 0 use the montgomery method */
655 #ifdef BN_MP_EXPTMOD_FAST_C
656   if (mp_isodd (P) == 1 || dr !=  0) {
657     return mp_exptmod_fast (G, X, P, Y, dr);
658   } else {
659 #endif
660 #ifdef BN_S_MP_EXPTMOD_C
661     /* otherwise use the generic Barrett reduction technique */
662     return s_mp_exptmod (G, X, P, Y, 0);
663 #else
664 #error mp_exptmod could fail
665     /* no exptmod for evens */
666     return MP_VAL;
667 #endif
668 #ifdef BN_MP_EXPTMOD_FAST_C
669   }
670 #endif
671   if (dr == 0) {
672     /* avoid compiler warnings about possibly unused variable */
673   }
674 }
675 
676 
677 /* compare two ints (signed)*/
mp_cmp(mp_int * a,mp_int * b)678 static int mp_cmp (mp_int * a, mp_int * b)
679 {
680   /* compare based on sign */
681   if (a->sign != b->sign) {
682      if (a->sign == MP_NEG) {
683         return MP_LT;
684      } else {
685         return MP_GT;
686      }
687   }
688 
689   /* compare digits */
690   if (a->sign == MP_NEG) {
691      /* if negative compare opposite direction */
692      return mp_cmp_mag(b, a);
693   } else {
694      return mp_cmp_mag(a, b);
695   }
696 }
697 
698 
699 /* compare a digit */
mp_cmp_d(mp_int * a,mp_digit b)700 static int mp_cmp_d(mp_int * a, mp_digit b)
701 {
702   /* compare based on sign */
703   if (a->sign == MP_NEG) {
704     return MP_LT;
705   }
706 
707   /* compare based on magnitude */
708   if (a->used > 1) {
709     return MP_GT;
710   }
711 
712   /* compare the only digit of a to b */
713   if (a->dp[0] > b) {
714     return MP_GT;
715   } else if (a->dp[0] < b) {
716     return MP_LT;
717   } else {
718     return MP_EQ;
719   }
720 }
721 
722 
723 #ifndef LTM_NO_NEG_EXP
724 /* hac 14.61, pp608 */
mp_invmod(mp_int * a,mp_int * b,mp_int * c)725 static int mp_invmod (mp_int * a, mp_int * b, mp_int * c)
726 {
727   /* b cannot be negative */
728   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
729     return MP_VAL;
730   }
731 
732 #ifdef BN_FAST_MP_INVMOD_C
733   /* if the modulus is odd we can use a faster routine instead */
734   if (mp_isodd (b) == 1) {
735     return fast_mp_invmod (a, b, c);
736   }
737 #endif
738 
739 #ifdef BN_MP_INVMOD_SLOW_C
740   return mp_invmod_slow(a, b, c);
741 #endif
742 
743 #ifndef BN_FAST_MP_INVMOD_C
744 #ifndef BN_MP_INVMOD_SLOW_C
745 #error mp_invmod would always fail
746 #endif
747 #endif
748   return MP_VAL;
749 }
750 #endif /* LTM_NO_NEG_EXP */
751 
752 
753 /* get the size for an unsigned equivalent */
mp_unsigned_bin_size(mp_int * a)754 static int mp_unsigned_bin_size (mp_int * a)
755 {
756   int     size = mp_count_bits (a);
757   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
758 }
759 
760 
761 #ifndef LTM_NO_NEG_EXP
762 /* hac 14.61, pp608 */
mp_invmod_slow(mp_int * a,mp_int * b,mp_int * c)763 static int mp_invmod_slow (mp_int * a, mp_int * b, mp_int * c)
764 {
765   mp_int  x, y, u, v, A, B, C, D;
766   int     res;
767 
768   /* b cannot be negative */
769   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
770     return MP_VAL;
771   }
772 
773   /* init temps */
774   if ((res = mp_init_multi(&x, &y, &u, &v,
775                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
776      return res;
777   }
778 
779   /* x = a, y = b */
780   if ((res = mp_mod(a, b, &x)) != MP_OKAY) {
781       goto LBL_ERR;
782   }
783   if ((res = mp_copy (b, &y)) != MP_OKAY) {
784     goto LBL_ERR;
785   }
786 
787   /* 2. [modified] if x,y are both even then return an error! */
788   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
789     res = MP_VAL;
790     goto LBL_ERR;
791   }
792 
793   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
794   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
795     goto LBL_ERR;
796   }
797   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
798     goto LBL_ERR;
799   }
800   mp_set (&A, 1);
801   mp_set (&D, 1);
802 
803 top:
804   /* 4.  while u is even do */
805   while (mp_iseven (&u) == 1) {
806     /* 4.1 u = u/2 */
807     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
808       goto LBL_ERR;
809     }
810     /* 4.2 if A or B is odd then */
811     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
812       /* A = (A+y)/2, B = (B-x)/2 */
813       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
814          goto LBL_ERR;
815       }
816       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
817          goto LBL_ERR;
818       }
819     }
820     /* A = A/2, B = B/2 */
821     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
822       goto LBL_ERR;
823     }
824     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
825       goto LBL_ERR;
826     }
827   }
828 
829   /* 5.  while v is even do */
830   while (mp_iseven (&v) == 1) {
831     /* 5.1 v = v/2 */
832     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
833       goto LBL_ERR;
834     }
835     /* 5.2 if C or D is odd then */
836     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
837       /* C = (C+y)/2, D = (D-x)/2 */
838       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
839          goto LBL_ERR;
840       }
841       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
842          goto LBL_ERR;
843       }
844     }
845     /* C = C/2, D = D/2 */
846     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
847       goto LBL_ERR;
848     }
849     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
850       goto LBL_ERR;
851     }
852   }
853 
854   /* 6.  if u >= v then */
855   if (mp_cmp (&u, &v) != MP_LT) {
856     /* u = u - v, A = A - C, B = B - D */
857     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
858       goto LBL_ERR;
859     }
860 
861     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
862       goto LBL_ERR;
863     }
864 
865     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
866       goto LBL_ERR;
867     }
868   } else {
869     /* v - v - u, C = C - A, D = D - B */
870     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
871       goto LBL_ERR;
872     }
873 
874     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
875       goto LBL_ERR;
876     }
877 
878     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
879       goto LBL_ERR;
880     }
881   }
882 
883   /* if not zero goto step 4 */
884   if (mp_iszero (&u) == 0)
885     goto top;
886 
887   /* now a = C, b = D, gcd == g*v */
888 
889   /* if v != 1 then there is no inverse */
890   if (mp_cmp_d (&v, 1) != MP_EQ) {
891     res = MP_VAL;
892     goto LBL_ERR;
893   }
894 
895   /* if its too low */
896   while (mp_cmp_d(&C, 0) == MP_LT) {
897       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
898          goto LBL_ERR;
899       }
900   }
901 
902   /* too big */
903   while (mp_cmp_mag(&C, b) != MP_LT) {
904       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
905          goto LBL_ERR;
906       }
907   }
908 
909   /* C is now the inverse */
910   mp_exch (&C, c);
911   res = MP_OKAY;
912 LBL_ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
913   return res;
914 }
915 #endif /* LTM_NO_NEG_EXP */
916 
917 
918 /* compare maginitude of two ints (unsigned) */
mp_cmp_mag(mp_int * a,mp_int * b)919 static int mp_cmp_mag (mp_int * a, mp_int * b)
920 {
921   int     n;
922   mp_digit *tmpa, *tmpb;
923 
924   /* compare based on # of non-zero digits */
925   if (a->used > b->used) {
926     return MP_GT;
927   }
928 
929   if (a->used < b->used) {
930     return MP_LT;
931   }
932 
933   /* alias for a */
934   tmpa = a->dp + (a->used - 1);
935 
936   /* alias for b */
937   tmpb = b->dp + (a->used - 1);
938 
939   /* compare based on digits  */
940   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
941     if (*tmpa > *tmpb) {
942       return MP_GT;
943     }
944 
945     if (*tmpa < *tmpb) {
946       return MP_LT;
947     }
948   }
949   return MP_EQ;
950 }
951 
952 
953 /* reads a unsigned char array, assumes the msb is stored first [big endian] */
mp_read_unsigned_bin(mp_int * a,const unsigned char * b,int c)954 static int mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
955 {
956   int     res;
957 
958   /* make sure there are at least two digits */
959   if (a->alloc < 2) {
960      if ((res = mp_grow(a, 2)) != MP_OKAY) {
961         return res;
962      }
963   }
964 
965   /* zero the int */
966   mp_zero (a);
967 
968   /* read the bytes in */
969   while (c-- > 0) {
970     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
971       return res;
972     }
973 
974 #ifndef MP_8BIT
975       a->dp[0] |= *b++;
976       a->used += 1;
977 #else
978       a->dp[0] = (*b & MP_MASK);
979       a->dp[1] |= ((*b++ >> 7U) & 1);
980       a->used += 2;
981 #endif
982   }
983   mp_clamp (a);
984   return MP_OKAY;
985 }
986 
987 
988 /* store in unsigned [big endian] format */
mp_to_unsigned_bin(mp_int * a,unsigned char * b)989 static int mp_to_unsigned_bin (mp_int * a, unsigned char *b)
990 {
991   int     x, res;
992   mp_int  t;
993 
994   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
995     return res;
996   }
997 
998   x = 0;
999   while (mp_iszero (&t) == 0) {
1000 #ifndef MP_8BIT
1001       b[x++] = (unsigned char) (t.dp[0] & 255);
1002 #else
1003       b[x++] = (unsigned char) (t.dp[0] | ((t.dp[1] & 0x01) << 7));
1004 #endif
1005     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
1006       mp_clear (&t);
1007       return res;
1008     }
1009   }
1010   bn_reverse (b, x);
1011   mp_clear (&t);
1012   return MP_OKAY;
1013 }
1014 
1015 
1016 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
mp_div_2d(mp_int * a,int b,mp_int * c,mp_int * d)1017 static int mp_div_2d (mp_int * a, int b, mp_int * c, mp_int * d)
1018 {
1019   mp_digit D, r, rr;
1020   int     x, res;
1021   mp_int  t;
1022 
1023 
1024   /* if the shift count is <= 0 then we do no work */
1025   if (b <= 0) {
1026     res = mp_copy (a, c);
1027     if (d != NULL) {
1028       mp_zero (d);
1029     }
1030     return res;
1031   }
1032 
1033   if ((res = mp_init (&t)) != MP_OKAY) {
1034     return res;
1035   }
1036 
1037   /* get the remainder */
1038   if (d != NULL) {
1039     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1040       mp_clear (&t);
1041       return res;
1042     }
1043   }
1044 
1045   /* copy */
1046   if ((res = mp_copy (a, c)) != MP_OKAY) {
1047     mp_clear (&t);
1048     return res;
1049   }
1050 
1051   /* shift by as many digits in the bit count */
1052   if (b >= (int)DIGIT_BIT) {
1053     mp_rshd (c, b / DIGIT_BIT);
1054   }
1055 
1056   /* shift any bit count < DIGIT_BIT */
1057   D = (mp_digit) (b % DIGIT_BIT);
1058   if (D != 0) {
1059     register mp_digit *tmpc, mask, shift;
1060 
1061     /* mask */
1062     mask = (((mp_digit)1) << D) - 1;
1063 
1064     /* shift for lsb */
1065     shift = DIGIT_BIT - D;
1066 
1067     /* alias */
1068     tmpc = c->dp + (c->used - 1);
1069 
1070     /* carry */
1071     r = 0;
1072     for (x = c->used - 1; x >= 0; x--) {
1073       /* get the lower  bits of this word in a temp */
1074       rr = *tmpc & mask;
1075 
1076       /* shift the current word and mix in the carry bits from the previous word */
1077       *tmpc = (*tmpc >> D) | (r << shift);
1078       --tmpc;
1079 
1080       /* set the carry to the carry bits of the current word found above */
1081       r = rr;
1082     }
1083   }
1084   mp_clamp (c);
1085   if (d != NULL) {
1086     mp_exch (&t, d);
1087   }
1088   mp_clear (&t);
1089   return MP_OKAY;
1090 }
1091 
1092 
mp_init_copy(mp_int * a,mp_int * b)1093 static int mp_init_copy (mp_int * a, mp_int * b)
1094 {
1095   int     res;
1096 
1097   if ((res = mp_init (a)) != MP_OKAY) {
1098     return res;
1099   }
1100   return mp_copy (b, a);
1101 }
1102 
1103 
1104 /* set to zero */
mp_zero(mp_int * a)1105 static void mp_zero (mp_int * a)
1106 {
1107   int       n;
1108   mp_digit *tmp;
1109 
1110   a->sign = MP_ZPOS;
1111   a->used = 0;
1112 
1113   tmp = a->dp;
1114   for (n = 0; n < a->alloc; n++) {
1115      *tmp++ = 0;
1116   }
1117 }
1118 
1119 
1120 /* copy, b = a */
mp_copy(mp_int * a,mp_int * b)1121 static int mp_copy (mp_int * a, mp_int * b)
1122 {
1123   int     res, n;
1124 
1125   /* if dst == src do nothing */
1126   if (a == b) {
1127     return MP_OKAY;
1128   }
1129 
1130   /* grow dest */
1131   if (b->alloc < a->used) {
1132      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1133         return res;
1134      }
1135   }
1136 
1137   /* zero b and copy the parameters over */
1138   {
1139     register mp_digit *tmpa, *tmpb;
1140 
1141     /* pointer aliases */
1142 
1143     /* source */
1144     tmpa = a->dp;
1145 
1146     /* destination */
1147     tmpb = b->dp;
1148 
1149     /* copy all the digits */
1150     for (n = 0; n < a->used; n++) {
1151       *tmpb++ = *tmpa++;
1152     }
1153 
1154     /* clear high digits */
1155     for (; n < b->used; n++) {
1156       *tmpb++ = 0;
1157     }
1158   }
1159 
1160   /* copy used count and sign */
1161   b->used = a->used;
1162   b->sign = a->sign;
1163   return MP_OKAY;
1164 }
1165 
1166 
1167 /* shift right a certain amount of digits */
mp_rshd(mp_int * a,int b)1168 static void mp_rshd (mp_int * a, int b)
1169 {
1170   int     x;
1171 
1172   /* if b <= 0 then ignore it */
1173   if (b <= 0) {
1174     return;
1175   }
1176 
1177   /* if b > used then simply zero it and return */
1178   if (a->used <= b) {
1179     mp_zero (a);
1180     return;
1181   }
1182 
1183   {
1184     register mp_digit *bottom, *top;
1185 
1186     /* shift the digits down */
1187 
1188     /* bottom */
1189     bottom = a->dp;
1190 
1191     /* top [offset into digits] */
1192     top = a->dp + b;
1193 
1194     /* this is implemented as a sliding window where
1195      * the window is b-digits long and digits from
1196      * the top of the window are copied to the bottom
1197      *
1198      * e.g.
1199 
1200      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
1201                  /\                   |      ---->
1202                   \-------------------/      ---->
1203      */
1204     for (x = 0; x < (a->used - b); x++) {
1205       *bottom++ = *top++;
1206     }
1207 
1208     /* zero the top digits */
1209     for (; x < a->used; x++) {
1210       *bottom++ = 0;
1211     }
1212   }
1213 
1214   /* remove excess digits */
1215   a->used -= b;
1216 }
1217 
1218 
1219 /* swap the elements of two integers, for cases where you can't simply swap the
1220  * mp_int pointers around
1221  */
mp_exch(mp_int * a,mp_int * b)1222 static void mp_exch (mp_int * a, mp_int * b)
1223 {
1224   mp_int  t;
1225 
1226   t  = *a;
1227   *a = *b;
1228   *b = t;
1229 }
1230 
1231 
1232 /* trim unused digits
1233  *
1234  * This is used to ensure that leading zero digits are
1235  * trimed and the leading "used" digit will be non-zero
1236  * Typically very fast.  Also fixes the sign if there
1237  * are no more leading digits
1238  */
mp_clamp(mp_int * a)1239 static void mp_clamp (mp_int * a)
1240 {
1241   /* decrease used while the most significant digit is
1242    * zero.
1243    */
1244   while (a->used > 0 && a->dp[a->used - 1] == 0) {
1245     --(a->used);
1246   }
1247 
1248   /* reset the sign flag if used == 0 */
1249   if (a->used == 0) {
1250     a->sign = MP_ZPOS;
1251   }
1252 }
1253 
1254 
1255 /* grow as required */
mp_grow(mp_int * a,int size)1256 static int mp_grow (mp_int * a, int size)
1257 {
1258   int     i;
1259   mp_digit *tmp;
1260 
1261   /* if the alloc size is smaller alloc more ram */
1262   if (a->alloc < size) {
1263     /* ensure there are always at least MP_PREC digits extra on top */
1264     size += (MP_PREC * 2) - (size % MP_PREC);
1265 
1266     /* reallocate the array a->dp
1267      *
1268      * We store the return in a temporary variable
1269      * in case the operation failed we don't want
1270      * to overwrite the dp member of a.
1271      */
1272     tmp = OPT_CAST(mp_digit) XREALLOC (a->dp, sizeof (mp_digit) * size);
1273     if (tmp == NULL) {
1274       /* reallocation failed but "a" is still valid [can be freed] */
1275       return MP_MEM;
1276     }
1277 
1278     /* reallocation succeeded so set a->dp */
1279     a->dp = tmp;
1280 
1281     /* zero excess digits */
1282     i        = a->alloc;
1283     a->alloc = size;
1284     for (; i < a->alloc; i++) {
1285       a->dp[i] = 0;
1286     }
1287   }
1288   return MP_OKAY;
1289 }
1290 
1291 
1292 #ifdef BN_MP_ABS_C
1293 /* b = |a|
1294  *
1295  * Simple function copies the input and fixes the sign to positive
1296  */
mp_abs(mp_int * a,mp_int * b)1297 static int mp_abs (mp_int * a, mp_int * b)
1298 {
1299   int     res;
1300 
1301   /* copy a to b */
1302   if (a != b) {
1303      if ((res = mp_copy (a, b)) != MP_OKAY) {
1304        return res;
1305      }
1306   }
1307 
1308   /* force the sign of b to positive */
1309   b->sign = MP_ZPOS;
1310 
1311   return MP_OKAY;
1312 }
1313 #endif
1314 
1315 
1316 /* set to a digit */
mp_set(mp_int * a,mp_digit b)1317 static void mp_set (mp_int * a, mp_digit b)
1318 {
1319   mp_zero (a);
1320   a->dp[0] = b & MP_MASK;
1321   a->used  = (a->dp[0] != 0) ? 1 : 0;
1322 }
1323 
1324 
1325 #ifndef LTM_NO_NEG_EXP
1326 /* b = a/2 */
mp_div_2(mp_int * a,mp_int * b)1327 static int mp_div_2(mp_int * a, mp_int * b)
1328 {
1329   int     x, res, oldused;
1330 
1331   /* copy */
1332   if (b->alloc < a->used) {
1333     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1334       return res;
1335     }
1336   }
1337 
1338   oldused = b->used;
1339   b->used = a->used;
1340   {
1341     register mp_digit r, rr, *tmpa, *tmpb;
1342 
1343     /* source alias */
1344     tmpa = a->dp + b->used - 1;
1345 
1346     /* dest alias */
1347     tmpb = b->dp + b->used - 1;
1348 
1349     /* carry */
1350     r = 0;
1351     for (x = b->used - 1; x >= 0; x--) {
1352       /* get the carry for the next iteration */
1353       rr = *tmpa & 1;
1354 
1355       /* shift the current digit, add in carry and store */
1356       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
1357 
1358       /* forward carry to next iteration */
1359       r = rr;
1360     }
1361 
1362     /* zero excess digits */
1363     tmpb = b->dp + b->used;
1364     for (x = b->used; x < oldused; x++) {
1365       *tmpb++ = 0;
1366     }
1367   }
1368   b->sign = a->sign;
1369   mp_clamp (b);
1370   return MP_OKAY;
1371 }
1372 #endif /* LTM_NO_NEG_EXP */
1373 
1374 
1375 /* shift left by a certain bit count */
mp_mul_2d(mp_int * a,int b,mp_int * c)1376 static int mp_mul_2d (mp_int * a, int b, mp_int * c)
1377 {
1378   mp_digit d;
1379   int      res;
1380 
1381   /* copy */
1382   if (a != c) {
1383      if ((res = mp_copy (a, c)) != MP_OKAY) {
1384        return res;
1385      }
1386   }
1387 
1388   if (c->alloc < (int)(c->used + b/DIGIT_BIT + 1)) {
1389      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1390        return res;
1391      }
1392   }
1393 
1394   /* shift by as many digits in the bit count */
1395   if (b >= (int)DIGIT_BIT) {
1396     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1397       return res;
1398     }
1399   }
1400 
1401   /* shift any bit count < DIGIT_BIT */
1402   d = (mp_digit) (b % DIGIT_BIT);
1403   if (d != 0) {
1404     register mp_digit *tmpc, shift, mask, r, rr;
1405     register int x;
1406 
1407     /* bitmask for carries */
1408     mask = (((mp_digit)1) << d) - 1;
1409 
1410     /* shift for msbs */
1411     shift = DIGIT_BIT - d;
1412 
1413     /* alias */
1414     tmpc = c->dp;
1415 
1416     /* carry */
1417     r    = 0;
1418     for (x = 0; x < c->used; x++) {
1419       /* get the higher bits of the current word */
1420       rr = (*tmpc >> shift) & mask;
1421 
1422       /* shift the current word and OR in the carry */
1423       *tmpc = ((*tmpc << d) | r) & MP_MASK;
1424       ++tmpc;
1425 
1426       /* set the carry to the carry bits of the current word */
1427       r = rr;
1428     }
1429 
1430     /* set final carry */
1431     if (r != 0) {
1432        c->dp[(c->used)++] = r;
1433     }
1434   }
1435   mp_clamp (c);
1436   return MP_OKAY;
1437 }
1438 
1439 
1440 #ifdef BN_MP_INIT_MULTI_C
mp_init_multi(mp_int * mp,...)1441 static int mp_init_multi(mp_int *mp, ...)
1442 {
1443     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
1444     int n = 0;                 /* Number of ok inits */
1445     mp_int* cur_arg = mp;
1446     va_list args;
1447 
1448     va_start(args, mp);        /* init args to next argument from caller */
1449     while (cur_arg != NULL) {
1450         if (mp_init(cur_arg) != MP_OKAY) {
1451             /* Oops - error! Back-track and mp_clear what we already
1452                succeeded in init-ing, then return error.
1453             */
1454             va_list clean_args;
1455 
1456             /* end the current list */
1457             va_end(args);
1458 
1459             /* now start cleaning up */
1460             cur_arg = mp;
1461             va_start(clean_args, mp);
1462             while (n--) {
1463                 mp_clear(cur_arg);
1464                 cur_arg = va_arg(clean_args, mp_int*);
1465             }
1466             va_end(clean_args);
1467             return MP_MEM;
1468         }
1469         n++;
1470         cur_arg = va_arg(args, mp_int*);
1471     }
1472     va_end(args);
1473     return res;                /* Assumed ok, if error flagged above. */
1474 }
1475 #endif
1476 
1477 
1478 #ifdef BN_MP_CLEAR_MULTI_C
mp_clear_multi(mp_int * mp,...)1479 static void mp_clear_multi(mp_int *mp, ...)
1480 {
1481     mp_int* next_mp = mp;
1482     va_list args;
1483     va_start(args, mp);
1484     while (next_mp != NULL) {
1485         mp_clear(next_mp);
1486         next_mp = va_arg(args, mp_int*);
1487     }
1488     va_end(args);
1489 }
1490 #endif
1491 
1492 
1493 /* shift left a certain amount of digits */
mp_lshd(mp_int * a,int b)1494 static int mp_lshd (mp_int * a, int b)
1495 {
1496   int     x, res;
1497 
1498   /* if its less than zero return */
1499   if (b <= 0) {
1500     return MP_OKAY;
1501   }
1502 
1503   /* grow to fit the new digits */
1504   if (a->alloc < a->used + b) {
1505      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1506        return res;
1507      }
1508   }
1509 
1510   {
1511     register mp_digit *top, *bottom;
1512 
1513     /* increment the used by the shift amount then copy upwards */
1514     a->used += b;
1515 
1516     /* top */
1517     top = a->dp + a->used - 1;
1518 
1519     /* base */
1520     bottom = a->dp + a->used - 1 - b;
1521 
1522     /* much like mp_rshd this is implemented using a sliding window
1523      * except the window goes the otherway around.  Copying from
1524      * the bottom to the top.  see bn_mp_rshd.c for more info.
1525      */
1526     for (x = a->used - 1; x >= b; x--) {
1527       *top-- = *bottom--;
1528     }
1529 
1530     /* zero the lower digits */
1531     top = a->dp;
1532     for (x = 0; x < b; x++) {
1533       *top++ = 0;
1534     }
1535   }
1536   return MP_OKAY;
1537 }
1538 
1539 
1540 /* returns the number of bits in an int */
mp_count_bits(mp_int * a)1541 static int mp_count_bits (mp_int * a)
1542 {
1543   int     r;
1544   mp_digit q;
1545 
1546   /* shortcut */
1547   if (a->used == 0) {
1548     return 0;
1549   }
1550 
1551   /* get number of digits and add that */
1552   r = (a->used - 1) * DIGIT_BIT;
1553 
1554   /* take the last digit and count the bits in it */
1555   q = a->dp[a->used - 1];
1556   while (q > ((mp_digit) 0)) {
1557     ++r;
1558     q >>= ((mp_digit) 1);
1559   }
1560   return r;
1561 }
1562 
1563 
1564 /* calc a value mod 2**b */
mp_mod_2d(mp_int * a,int b,mp_int * c)1565 static int mp_mod_2d (mp_int * a, int b, mp_int * c)
1566 {
1567   int     x, res;
1568 
1569   /* if b is <= 0 then zero the int */
1570   if (b <= 0) {
1571     mp_zero (c);
1572     return MP_OKAY;
1573   }
1574 
1575   /* if the modulus is larger than the value than return */
1576   if (b >= (int) (a->used * DIGIT_BIT)) {
1577     res = mp_copy (a, c);
1578     return res;
1579   }
1580 
1581   /* copy */
1582   if ((res = mp_copy (a, c)) != MP_OKAY) {
1583     return res;
1584   }
1585 
1586   /* zero digits above the last digit of the modulus */
1587   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1588     c->dp[x] = 0;
1589   }
1590   /* clear the digit that is not completely outside/inside the modulus */
1591   c->dp[b / DIGIT_BIT] &=
1592     (mp_digit) ((((mp_digit) 1) << (((mp_digit) b) % DIGIT_BIT)) - ((mp_digit) 1));
1593   mp_clamp (c);
1594   return MP_OKAY;
1595 }
1596 
1597 
1598 #ifdef BN_MP_DIV_SMALL
1599 
1600 /* slower bit-bang division... also smaller */
mp_div(mp_int * a,mp_int * b,mp_int * c,mp_int * d)1601 static int mp_div(mp_int * a, mp_int * b, mp_int * c, mp_int * d)
1602 {
1603    mp_int ta, tb, tq, q;
1604    int    res, n, n2;
1605 
1606   /* is divisor zero ? */
1607   if (mp_iszero (b) == 1) {
1608     return MP_VAL;
1609   }
1610 
1611   /* if a < b then q=0, r = a */
1612   if (mp_cmp_mag (a, b) == MP_LT) {
1613     if (d != NULL) {
1614       res = mp_copy (a, d);
1615     } else {
1616       res = MP_OKAY;
1617     }
1618     if (c != NULL) {
1619       mp_zero (c);
1620     }
1621     return res;
1622   }
1623 
1624   /* init our temps */
1625   if ((res = mp_init_multi(&ta, &tb, &tq, &q, NULL)) != MP_OKAY) {
1626      return res;
1627   }
1628 
1629 
1630   mp_set(&tq, 1);
1631   n = mp_count_bits(a) - mp_count_bits(b);
1632   if (((res = mp_abs(a, &ta)) != MP_OKAY) ||
1633       ((res = mp_abs(b, &tb)) != MP_OKAY) ||
1634       ((res = mp_mul_2d(&tb, n, &tb)) != MP_OKAY) ||
1635       ((res = mp_mul_2d(&tq, n, &tq)) != MP_OKAY)) {
1636       goto LBL_ERR;
1637   }
1638 
1639   while (n-- >= 0) {
1640      if (mp_cmp(&tb, &ta) != MP_GT) {
1641         if (((res = mp_sub(&ta, &tb, &ta)) != MP_OKAY) ||
1642             ((res = mp_add(&q, &tq, &q)) != MP_OKAY)) {
1643            goto LBL_ERR;
1644         }
1645      }
1646      if (((res = mp_div_2d(&tb, 1, &tb, NULL)) != MP_OKAY) ||
1647          ((res = mp_div_2d(&tq, 1, &tq, NULL)) != MP_OKAY)) {
1648            goto LBL_ERR;
1649      }
1650   }
1651 
1652   /* now q == quotient and ta == remainder */
1653   n  = a->sign;
1654   n2 = (a->sign == b->sign ? MP_ZPOS : MP_NEG);
1655   if (c != NULL) {
1656      mp_exch(c, &q);
1657      c->sign  = (mp_iszero(c) == MP_YES) ? MP_ZPOS : n2;
1658   }
1659   if (d != NULL) {
1660      mp_exch(d, &ta);
1661      d->sign = (mp_iszero(d) == MP_YES) ? MP_ZPOS : n;
1662   }
1663 LBL_ERR:
1664    mp_clear_multi(&ta, &tb, &tq, &q, NULL);
1665    return res;
1666 }
1667 
1668 #else
1669 
1670 /* integer signed division.
1671  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1672  * HAC pp.598 Algorithm 14.20
1673  *
1674  * Note that the description in HAC is horribly
1675  * incomplete.  For example, it doesn't consider
1676  * the case where digits are removed from 'x' in
1677  * the inner loop.  It also doesn't consider the
1678  * case that y has fewer than three digits, etc..
1679  *
1680  * The overall algorithm is as described as
1681  * 14.20 from HAC but fixed to treat these cases.
1682 */
mp_div(mp_int * a,mp_int * b,mp_int * c,mp_int * d)1683 static int mp_div (mp_int * a, mp_int * b, mp_int * c, mp_int * d)
1684 {
1685   mp_int  q, x, y, t1, t2;
1686   int     res, n, t, i, norm, neg;
1687 
1688   /* is divisor zero ? */
1689   if (mp_iszero (b) == 1) {
1690     return MP_VAL;
1691   }
1692 
1693   /* if a < b then q=0, r = a */
1694   if (mp_cmp_mag (a, b) == MP_LT) {
1695     if (d != NULL) {
1696       res = mp_copy (a, d);
1697     } else {
1698       res = MP_OKAY;
1699     }
1700     if (c != NULL) {
1701       mp_zero (c);
1702     }
1703     return res;
1704   }
1705 
1706   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1707     return res;
1708   }
1709   q.used = a->used + 2;
1710 
1711   if ((res = mp_init (&t1)) != MP_OKAY) {
1712     goto LBL_Q;
1713   }
1714 
1715   if ((res = mp_init (&t2)) != MP_OKAY) {
1716     goto LBL_T1;
1717   }
1718 
1719   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1720     goto LBL_T2;
1721   }
1722 
1723   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1724     goto LBL_X;
1725   }
1726 
1727   /* fix the sign */
1728   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1729   x.sign = y.sign = MP_ZPOS;
1730 
1731   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1732   norm = mp_count_bits(&y) % DIGIT_BIT;
1733   if (norm < (int)(DIGIT_BIT-1)) {
1734      norm = (DIGIT_BIT-1) - norm;
1735      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1736        goto LBL_Y;
1737      }
1738      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1739        goto LBL_Y;
1740      }
1741   } else {
1742      norm = 0;
1743   }
1744 
1745   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1746   n = x.used - 1;
1747   t = y.used - 1;
1748 
1749   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1750   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1751     goto LBL_Y;
1752   }
1753 
1754   while (mp_cmp (&x, &y) != MP_LT) {
1755     ++(q.dp[n - t]);
1756     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1757       goto LBL_Y;
1758     }
1759   }
1760 
1761   /* reset y by shifting it back down */
1762   mp_rshd (&y, n - t);
1763 
1764   /* step 3. for i from n down to (t + 1) */
1765   for (i = n; i >= (t + 1); i--) {
1766     if (i > x.used) {
1767       continue;
1768     }
1769 
1770     /* step 3.1 if xi == yt then set q{i-t-1} to b-1,
1771      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1772     if (x.dp[i] == y.dp[t]) {
1773       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1774     } else {
1775       mp_word tmp;
1776       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1777       tmp |= ((mp_word) x.dp[i - 1]);
1778       tmp /= ((mp_word) y.dp[t]);
1779       if (tmp > (mp_word) MP_MASK)
1780         tmp = MP_MASK;
1781       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1782     }
1783 
1784     /* while (q{i-t-1} * (yt * b + y{t-1})) >
1785              xi * b**2 + xi-1 * b + xi-2
1786 
1787        do q{i-t-1} -= 1;
1788     */
1789     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1790     do {
1791       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1792 
1793       /* find left hand */
1794       mp_zero (&t1);
1795       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1796       t1.dp[1] = y.dp[t];
1797       t1.used = 2;
1798       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1799         goto LBL_Y;
1800       }
1801 
1802       /* find right hand */
1803       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1804       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1805       t2.dp[2] = x.dp[i];
1806       t2.used = 3;
1807     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1808 
1809     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1810     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1811       goto LBL_Y;
1812     }
1813 
1814     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1815       goto LBL_Y;
1816     }
1817 
1818     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1819       goto LBL_Y;
1820     }
1821 
1822     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1823     if (x.sign == MP_NEG) {
1824       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1825         goto LBL_Y;
1826       }
1827       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1828         goto LBL_Y;
1829       }
1830       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1831         goto LBL_Y;
1832       }
1833 
1834       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1835     }
1836   }
1837 
1838   /* now q is the quotient and x is the remainder
1839    * [which we have to normalize]
1840    */
1841 
1842   /* get sign before writing to c */
1843   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1844 
1845   if (c != NULL) {
1846     mp_clamp (&q);
1847     mp_exch (&q, c);
1848     c->sign = neg;
1849   }
1850 
1851   if (d != NULL) {
1852     mp_div_2d (&x, norm, &x, NULL);
1853     mp_exch (&x, d);
1854   }
1855 
1856   res = MP_OKAY;
1857 
1858 LBL_Y:mp_clear (&y);
1859 LBL_X:mp_clear (&x);
1860 LBL_T2:mp_clear (&t2);
1861 LBL_T1:mp_clear (&t1);
1862 LBL_Q:mp_clear (&q);
1863   return res;
1864 }
1865 
1866 #endif
1867 
1868 
1869 #ifdef MP_LOW_MEM
1870    #define TAB_SIZE 32
1871 #else
1872    #define TAB_SIZE 256
1873 #endif
1874 
s_mp_exptmod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)1875 static int s_mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
1876 {
1877   mp_int  M[TAB_SIZE], res, mu;
1878   mp_digit buf;
1879   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1880   int (*redux)(mp_int*,mp_int*,mp_int*);
1881 
1882   /* find window size */
1883   x = mp_count_bits (X);
1884   if (x <= 7) {
1885     winsize = 2;
1886   } else if (x <= 36) {
1887     winsize = 3;
1888   } else if (x <= 140) {
1889     winsize = 4;
1890   } else if (x <= 450) {
1891     winsize = 5;
1892   } else if (x <= 1303) {
1893     winsize = 6;
1894   } else if (x <= 3529) {
1895     winsize = 7;
1896   } else {
1897     winsize = 8;
1898   }
1899 
1900 #ifdef MP_LOW_MEM
1901     if (winsize > 5) {
1902        winsize = 5;
1903     }
1904 #endif
1905 
1906   /* init M array */
1907   /* init first cell */
1908   if ((err = mp_init(&M[1])) != MP_OKAY) {
1909      return err;
1910   }
1911 
1912   /* now init the second half of the array */
1913   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1914     if ((err = mp_init(&M[x])) != MP_OKAY) {
1915       for (y = 1<<(winsize-1); y < x; y++) {
1916         mp_clear (&M[y]);
1917       }
1918       mp_clear(&M[1]);
1919       return err;
1920     }
1921   }
1922 
1923   /* create mu, used for Barrett reduction */
1924   if ((err = mp_init (&mu)) != MP_OKAY) {
1925     goto LBL_M;
1926   }
1927 
1928   if (redmode == 0) {
1929      if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
1930         goto LBL_MU;
1931      }
1932      redux = mp_reduce;
1933   } else {
1934      if ((err = mp_reduce_2k_setup_l (P, &mu)) != MP_OKAY) {
1935         goto LBL_MU;
1936      }
1937      redux = mp_reduce_2k_l;
1938   }
1939 
1940   /* create M table
1941    *
1942    * The M table contains powers of the base,
1943    * e.g. M[x] = G**x mod P
1944    *
1945    * The first half of the table is not
1946    * computed though accept for M[0] and M[1]
1947    */
1948   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
1949     goto LBL_MU;
1950   }
1951 
1952   /* compute the value at M[1<<(winsize-1)] by squaring
1953    * M[1] (winsize-1) times
1954    */
1955   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
1956     goto LBL_MU;
1957   }
1958 
1959   for (x = 0; x < (winsize - 1); x++) {
1960     /* square it */
1961     if ((err = mp_sqr (&M[1 << (winsize - 1)],
1962                        &M[1 << (winsize - 1)])) != MP_OKAY) {
1963       goto LBL_MU;
1964     }
1965 
1966     /* reduce modulo P */
1967     if ((err = redux (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
1968       goto LBL_MU;
1969     }
1970   }
1971 
1972   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
1973    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
1974    */
1975   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
1976     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
1977       goto LBL_MU;
1978     }
1979     if ((err = redux (&M[x], P, &mu)) != MP_OKAY) {
1980       goto LBL_MU;
1981     }
1982   }
1983 
1984   /* setup result */
1985   if ((err = mp_init (&res)) != MP_OKAY) {
1986     goto LBL_MU;
1987   }
1988   mp_set (&res, 1);
1989 
1990   /* set initial mode and bit cnt */
1991   mode   = 0;
1992   bitcnt = 1;
1993   buf    = 0;
1994   digidx = X->used - 1;
1995   bitcpy = 0;
1996   bitbuf = 0;
1997 
1998   for (;;) {
1999     /* grab next digit as required */
2000     if (--bitcnt == 0) {
2001       /* if digidx == -1 we are out of digits */
2002       if (digidx == -1) {
2003         break;
2004       }
2005       /* read next digit and reset the bitcnt */
2006       buf    = X->dp[digidx--];
2007       bitcnt = (int) DIGIT_BIT;
2008     }
2009 
2010     /* grab the next msb from the exponent */
2011     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
2012     buf <<= (mp_digit)1;
2013 
2014     /* if the bit is zero and mode == 0 then we ignore it
2015      * These represent the leading zero bits before the first 1 bit
2016      * in the exponent.  Technically this opt is not required but it
2017      * does lower the # of trivial squaring/reductions used
2018      */
2019     if (mode == 0 && y == 0) {
2020       continue;
2021     }
2022 
2023     /* if the bit is zero and mode == 1 then we square */
2024     if (mode == 1 && y == 0) {
2025       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2026         goto LBL_RES;
2027       }
2028       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2029         goto LBL_RES;
2030       }
2031       continue;
2032     }
2033 
2034     /* else we add it to the window */
2035     bitbuf |= (y << (winsize - ++bitcpy));
2036     mode    = 2;
2037 
2038     if (bitcpy == winsize) {
2039       /* ok window is filled so square as required and multiply  */
2040       /* square first */
2041       for (x = 0; x < winsize; x++) {
2042         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2043           goto LBL_RES;
2044         }
2045         if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2046           goto LBL_RES;
2047         }
2048       }
2049 
2050       /* then multiply */
2051       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2052         goto LBL_RES;
2053       }
2054       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2055         goto LBL_RES;
2056       }
2057 
2058       /* empty window and reset */
2059       bitcpy = 0;
2060       bitbuf = 0;
2061       mode   = 1;
2062     }
2063   }
2064 
2065   /* if bits remain then square/multiply */
2066   if (mode == 2 && bitcpy > 0) {
2067     /* square then multiply if the bit is set */
2068     for (x = 0; x < bitcpy; x++) {
2069       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2070         goto LBL_RES;
2071       }
2072       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2073         goto LBL_RES;
2074       }
2075 
2076       bitbuf <<= 1;
2077       if ((bitbuf & (1 << winsize)) != 0) {
2078         /* then multiply */
2079         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2080           goto LBL_RES;
2081         }
2082         if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2083           goto LBL_RES;
2084         }
2085       }
2086     }
2087   }
2088 
2089   mp_exch (&res, Y);
2090   err = MP_OKAY;
2091 LBL_RES:mp_clear (&res);
2092 LBL_MU:mp_clear (&mu);
2093 LBL_M:
2094   mp_clear(&M[1]);
2095   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2096     mp_clear (&M[x]);
2097   }
2098   return err;
2099 }
2100 
2101 
2102 /* computes b = a*a */
mp_sqr(mp_int * a,mp_int * b)2103 static int mp_sqr (mp_int * a, mp_int * b)
2104 {
2105   int     res;
2106 
2107 #ifdef BN_MP_TOOM_SQR_C
2108   /* use Toom-Cook? */
2109   if (a->used >= TOOM_SQR_CUTOFF) {
2110     res = mp_toom_sqr(a, b);
2111   /* Karatsuba? */
2112   } else
2113 #endif
2114 #ifdef BN_MP_KARATSUBA_SQR_C
2115 if (a->used >= KARATSUBA_SQR_CUTOFF) {
2116     res = mp_karatsuba_sqr (a, b);
2117   } else
2118 #endif
2119   {
2120 #ifdef BN_FAST_S_MP_SQR_C
2121     /* can we use the fast comba multiplier? */
2122     if ((a->used * 2 + 1) < MP_WARRAY &&
2123          a->used <
2124          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
2125       res = fast_s_mp_sqr (a, b);
2126     } else
2127 #endif
2128 #ifdef BN_S_MP_SQR_C
2129       res = s_mp_sqr (a, b);
2130 #else
2131 #error mp_sqr could fail
2132       res = MP_VAL;
2133 #endif
2134   }
2135   b->sign = MP_ZPOS;
2136   return res;
2137 }
2138 
2139 
2140 /* reduces a modulo n where n is of the form 2**p - d
2141    This differs from reduce_2k since "d" can be larger
2142    than a single digit.
2143 */
mp_reduce_2k_l(mp_int * a,mp_int * n,mp_int * d)2144 static int mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d)
2145 {
2146    mp_int q;
2147    int    p, res;
2148 
2149    if ((res = mp_init(&q)) != MP_OKAY) {
2150       return res;
2151    }
2152 
2153    p = mp_count_bits(n);
2154 top:
2155    /* q = a/2**p, a = a mod 2**p */
2156    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
2157       goto ERR;
2158    }
2159 
2160    /* q = q * d */
2161    if ((res = mp_mul(&q, d, &q)) != MP_OKAY) {
2162       goto ERR;
2163    }
2164 
2165    /* a = a + q */
2166    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
2167       goto ERR;
2168    }
2169 
2170    if (mp_cmp_mag(a, n) != MP_LT) {
2171       s_mp_sub(a, n, a);
2172       goto top;
2173    }
2174 
2175 ERR:
2176    mp_clear(&q);
2177    return res;
2178 }
2179 
2180 
2181 /* determines the setup value */
mp_reduce_2k_setup_l(mp_int * a,mp_int * d)2182 static int mp_reduce_2k_setup_l(mp_int *a, mp_int *d)
2183 {
2184    int    res;
2185    mp_int tmp;
2186 
2187    if ((res = mp_init(&tmp)) != MP_OKAY) {
2188       return res;
2189    }
2190 
2191    if ((res = mp_2expt(&tmp, mp_count_bits(a))) != MP_OKAY) {
2192       goto ERR;
2193    }
2194 
2195    if ((res = s_mp_sub(&tmp, a, d)) != MP_OKAY) {
2196       goto ERR;
2197    }
2198 
2199 ERR:
2200    mp_clear(&tmp);
2201    return res;
2202 }
2203 
2204 
2205 /* computes a = 2**b
2206  *
2207  * Simple algorithm which zeroes the int, grows it then just sets one bit
2208  * as required.
2209  */
mp_2expt(mp_int * a,int b)2210 static int mp_2expt (mp_int * a, int b)
2211 {
2212   int     res;
2213 
2214   /* zero a as per default */
2215   mp_zero (a);
2216 
2217   /* grow a to accommodate the single bit */
2218   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
2219     return res;
2220   }
2221 
2222   /* set the used count of where the bit will go */
2223   a->used = b / DIGIT_BIT + 1;
2224 
2225   /* put the single bit in its place */
2226   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
2227 
2228   return MP_OKAY;
2229 }
2230 
2231 
2232 /* pre-calculate the value required for Barrett reduction
2233  * For a given modulus "b" it calulates the value required in "a"
2234  */
mp_reduce_setup(mp_int * a,mp_int * b)2235 static int mp_reduce_setup (mp_int * a, mp_int * b)
2236 {
2237   int     res;
2238 
2239   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
2240     return res;
2241   }
2242   return mp_div (a, b, a, NULL);
2243 }
2244 
2245 
2246 /* reduces x mod m, assumes 0 < x < m**2, mu is
2247  * precomputed via mp_reduce_setup.
2248  * From HAC pp.604 Algorithm 14.42
2249  */
mp_reduce(mp_int * x,mp_int * m,mp_int * mu)2250 static int mp_reduce (mp_int * x, mp_int * m, mp_int * mu)
2251 {
2252   mp_int  q;
2253   int     res, um = m->used;
2254 
2255   /* q = x */
2256   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
2257     return res;
2258   }
2259 
2260   /* q1 = x / b**(k-1)  */
2261   mp_rshd (&q, um - 1);
2262 
2263   /* according to HAC this optimization is ok */
2264   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
2265     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
2266       goto CLEANUP;
2267     }
2268   } else {
2269 #ifdef BN_S_MP_MUL_HIGH_DIGS_C
2270     if ((res = s_mp_mul_high_digs (&q, mu, &q, um)) != MP_OKAY) {
2271       goto CLEANUP;
2272     }
2273 #elif defined(BN_FAST_S_MP_MUL_HIGH_DIGS_C)
2274     if ((res = fast_s_mp_mul_high_digs (&q, mu, &q, um)) != MP_OKAY) {
2275       goto CLEANUP;
2276     }
2277 #else
2278     {
2279 #error mp_reduce would always fail
2280       res = MP_VAL;
2281       goto CLEANUP;
2282     }
2283 #endif
2284   }
2285 
2286   /* q3 = q2 / b**(k+1) */
2287   mp_rshd (&q, um + 1);
2288 
2289   /* x = x mod b**(k+1), quick (no division) */
2290   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
2291     goto CLEANUP;
2292   }
2293 
2294   /* q = q * m mod b**(k+1), quick (no division) */
2295   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
2296     goto CLEANUP;
2297   }
2298 
2299   /* x = x - q */
2300   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
2301     goto CLEANUP;
2302   }
2303 
2304   /* If x < 0, add b**(k+1) to it */
2305   if (mp_cmp_d (x, 0) == MP_LT) {
2306     mp_set (&q, 1);
2307     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY) {
2308       goto CLEANUP;
2309     }
2310     if ((res = mp_add (x, &q, x)) != MP_OKAY) {
2311       goto CLEANUP;
2312     }
2313   }
2314 
2315   /* Back off if it's too big */
2316   while (mp_cmp (x, m) != MP_LT) {
2317     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
2318       goto CLEANUP;
2319     }
2320   }
2321 
2322 CLEANUP:
2323   mp_clear (&q);
2324 
2325   return res;
2326 }
2327 
2328 
2329 /* multiplies |a| * |b| and only computes up to digs digits of result
2330  * HAC pp. 595, Algorithm 14.12  Modified so you can control how
2331  * many digits of output are created.
2332  */
s_mp_mul_digs(mp_int * a,mp_int * b,mp_int * c,int digs)2333 static int s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
2334 {
2335   mp_int  t;
2336   int     res, pa, pb, ix, iy;
2337   mp_digit u;
2338   mp_word r;
2339   mp_digit tmpx, *tmpt, *tmpy;
2340 
2341 #ifdef BN_FAST_S_MP_MUL_DIGS_C
2342   /* can we use the fast multiplier? */
2343   if (((digs) < MP_WARRAY) &&
2344       MIN (a->used, b->used) <
2345           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2346     return fast_s_mp_mul_digs (a, b, c, digs);
2347   }
2348 #endif
2349 
2350   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
2351     return res;
2352   }
2353   t.used = digs;
2354 
2355   /* compute the digits of the product directly */
2356   pa = a->used;
2357   for (ix = 0; ix < pa; ix++) {
2358     /* set the carry to zero */
2359     u = 0;
2360 
2361     /* limit ourselves to making digs digits of output */
2362     pb = MIN (b->used, digs - ix);
2363 
2364     /* setup some aliases */
2365     /* copy of the digit from a used within the nested loop */
2366     tmpx = a->dp[ix];
2367 
2368     /* an alias for the destination shifted ix places */
2369     tmpt = t.dp + ix;
2370 
2371     /* an alias for the digits of b */
2372     tmpy = b->dp;
2373 
2374     /* compute the columns of the output and propagate the carry */
2375     for (iy = 0; iy < pb; iy++) {
2376       /* compute the column as a mp_word */
2377       r       = ((mp_word)*tmpt) +
2378                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
2379                 ((mp_word) u);
2380 
2381       /* the new column is the lower part of the result */
2382       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2383 
2384       /* get the carry word from the result */
2385       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2386     }
2387     /* set carry if it is placed below digs */
2388     if (ix + iy < digs) {
2389       *tmpt = u;
2390     }
2391   }
2392 
2393   mp_clamp (&t);
2394   mp_exch (&t, c);
2395 
2396   mp_clear (&t);
2397   return MP_OKAY;
2398 }
2399 
2400 
2401 #ifdef BN_FAST_S_MP_MUL_DIGS_C
2402 /* Fast (comba) multiplier
2403  *
2404  * This is the fast column-array [comba] multiplier.  It is
2405  * designed to compute the columns of the product first
2406  * then handle the carries afterwards.  This has the effect
2407  * of making the nested loops that compute the columns very
2408  * simple and schedulable on super-scalar processors.
2409  *
2410  * This has been modified to produce a variable number of
2411  * digits of output so if say only a half-product is required
2412  * you don't have to compute the upper half (a feature
2413  * required for fast Barrett reduction).
2414  *
2415  * Based on Algorithm 14.12 on pp.595 of HAC.
2416  *
2417  */
fast_s_mp_mul_digs(mp_int * a,mp_int * b,mp_int * c,int digs)2418 static int fast_s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
2419 {
2420   int     olduse, res, pa, ix, iz;
2421   mp_digit W[MP_WARRAY];
2422   register mp_word  _W;
2423 
2424   /* grow the destination as required */
2425   if (c->alloc < digs) {
2426     if ((res = mp_grow (c, digs)) != MP_OKAY) {
2427       return res;
2428     }
2429   }
2430 
2431   /* number of output digits to produce */
2432   pa = MIN(digs, a->used + b->used);
2433 
2434   /* clear the carry */
2435   _W = 0;
2436   os_memset(W, 0, sizeof(W));
2437   for (ix = 0; ix < pa; ix++) {
2438       int      tx, ty;
2439       int      iy;
2440       mp_digit *tmpx, *tmpy;
2441 
2442       /* get offsets into the two bignums */
2443       ty = MIN(b->used-1, ix);
2444       tx = ix - ty;
2445 
2446       /* setup temp aliases */
2447       tmpx = a->dp + tx;
2448       tmpy = b->dp + ty;
2449 
2450       /* this is the number of times the loop will iterrate, essentially
2451          while (tx++ < a->used && ty-- >= 0) { ... }
2452        */
2453       iy = MIN(a->used-tx, ty+1);
2454 
2455       /* execute loop */
2456       for (iz = 0; iz < iy; ++iz) {
2457          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
2458 
2459       }
2460 
2461       /* store term */
2462       W[ix] = ((mp_digit)_W) & MP_MASK;
2463 
2464       /* make next carry */
2465       _W = _W >> ((mp_word)DIGIT_BIT);
2466  }
2467 
2468   /* setup dest */
2469   olduse  = c->used;
2470   c->used = pa;
2471 
2472   {
2473     register mp_digit *tmpc;
2474     tmpc = c->dp;
2475     for (ix = 0; ix < pa+1; ix++) {
2476       /* now extract the previous digit [below the carry] */
2477       *tmpc++ = W[ix];
2478     }
2479 
2480     /* clear unused digits [that existed in the old copy of c] */
2481     for (; ix < olduse; ix++) {
2482       *tmpc++ = 0;
2483     }
2484   }
2485   mp_clamp (c);
2486   return MP_OKAY;
2487 }
2488 #endif /* BN_FAST_S_MP_MUL_DIGS_C */
2489 
2490 
2491 /* init an mp_init for a given size */
mp_init_size(mp_int * a,int size)2492 static int mp_init_size (mp_int * a, int size)
2493 {
2494   int x;
2495 
2496   /* pad size so there are always extra digits */
2497   size += (MP_PREC * 2) - (size % MP_PREC);
2498 
2499   /* alloc mem */
2500   a->dp = OPT_CAST(mp_digit) XMALLOC (sizeof (mp_digit) * size);
2501   if (a->dp == NULL) {
2502     return MP_MEM;
2503   }
2504 
2505   /* set the members */
2506   a->used  = 0;
2507   a->alloc = size;
2508   a->sign  = MP_ZPOS;
2509 
2510   /* zero the digits */
2511   for (x = 0; x < size; x++) {
2512       a->dp[x] = 0;
2513   }
2514 
2515   return MP_OKAY;
2516 }
2517 
2518 
2519 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
s_mp_sqr(mp_int * a,mp_int * b)2520 static int s_mp_sqr (mp_int * a, mp_int * b)
2521 {
2522   mp_int  t;
2523   int     res, ix, iy, pa;
2524   mp_word r;
2525   mp_digit u, tmpx, *tmpt;
2526 
2527   pa = a->used;
2528   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
2529     return res;
2530   }
2531 
2532   /* default used is maximum possible size */
2533   t.used = 2*pa + 1;
2534 
2535   for (ix = 0; ix < pa; ix++) {
2536     /* first calculate the digit at 2*ix */
2537     /* calculate double precision result */
2538     r = ((mp_word) t.dp[2*ix]) +
2539         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
2540 
2541     /* store lower part in result */
2542     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
2543 
2544     /* get the carry */
2545     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2546 
2547     /* left hand side of A[ix] * A[iy] */
2548     tmpx        = a->dp[ix];
2549 
2550     /* alias for where to store the results */
2551     tmpt        = t.dp + (2*ix + 1);
2552 
2553     for (iy = ix + 1; iy < pa; iy++) {
2554       /* first calculate the product */
2555       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
2556 
2557       /* now calculate the double precision result, note we use
2558        * addition instead of *2 since it's easier to optimize
2559        */
2560       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
2561 
2562       /* store lower part */
2563       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2564 
2565       /* get carry */
2566       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2567     }
2568     /* propagate upwards */
2569     while (u != ((mp_digit) 0)) {
2570       r       = ((mp_word) *tmpt) + ((mp_word) u);
2571       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2572       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2573     }
2574   }
2575 
2576   mp_clamp (&t);
2577   mp_exch (&t, b);
2578   mp_clear (&t);
2579   return MP_OKAY;
2580 }
2581 
2582 
2583 /* multiplies |a| * |b| and does not compute the lower digs digits
2584  * [meant to get the higher part of the product]
2585  */
s_mp_mul_high_digs(mp_int * a,mp_int * b,mp_int * c,int digs)2586 static int s_mp_mul_high_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
2587 {
2588   mp_int  t;
2589   int     res, pa, pb, ix, iy;
2590   mp_digit u;
2591   mp_word r;
2592   mp_digit tmpx, *tmpt, *tmpy;
2593 
2594   /* can we use the fast multiplier? */
2595 #ifdef BN_FAST_S_MP_MUL_HIGH_DIGS_C
2596   if (((a->used + b->used + 1) < MP_WARRAY)
2597       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2598     return fast_s_mp_mul_high_digs (a, b, c, digs);
2599   }
2600 #endif
2601 
2602   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
2603     return res;
2604   }
2605   t.used = a->used + b->used + 1;
2606 
2607   pa = a->used;
2608   pb = b->used;
2609   for (ix = 0; ix < pa; ix++) {
2610     /* clear the carry */
2611     u = 0;
2612 
2613     /* left hand side of A[ix] * B[iy] */
2614     tmpx = a->dp[ix];
2615 
2616     /* alias to the address of where the digits will be stored */
2617     tmpt = &(t.dp[digs]);
2618 
2619     /* alias for where to read the right hand side from */
2620     tmpy = b->dp + (digs - ix);
2621 
2622     for (iy = digs - ix; iy < pb; iy++) {
2623       /* calculate the double precision result */
2624       r       = ((mp_word)*tmpt) +
2625                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
2626                 ((mp_word) u);
2627 
2628       /* get the lower part */
2629       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2630 
2631       /* carry the carry */
2632       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2633     }
2634     *tmpt = u;
2635   }
2636   mp_clamp (&t);
2637   mp_exch (&t, c);
2638   mp_clear (&t);
2639   return MP_OKAY;
2640 }
2641 
2642 
2643 #ifdef BN_MP_MONTGOMERY_SETUP_C
2644 /* setups the montgomery reduction stuff */
2645 static int
mp_montgomery_setup(mp_int * n,mp_digit * rho)2646 mp_montgomery_setup (mp_int * n, mp_digit * rho)
2647 {
2648   mp_digit x, b;
2649 
2650 /* fast inversion mod 2**k
2651  *
2652  * Based on the fact that
2653  *
2654  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
2655  *                    =>  2*X*A - X*X*A*A = 1
2656  *                    =>  2*(1) - (1)     = 1
2657  */
2658   b = n->dp[0];
2659 
2660   if ((b & 1) == 0) {
2661     return MP_VAL;
2662   }
2663 
2664   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
2665   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
2666 #if !defined(MP_8BIT)
2667   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
2668 #endif
2669 #if defined(MP_64BIT) || !(defined(MP_8BIT) || defined(MP_16BIT))
2670   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
2671 #endif
2672 #ifdef MP_64BIT
2673   x *= 2 - b * x;               /* here x*a==1 mod 2**64 */
2674 #endif
2675 
2676   /* rho = -1/m mod b */
2677   *rho = (unsigned long)(((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
2678 
2679   return MP_OKAY;
2680 }
2681 #endif
2682 
2683 
2684 #ifdef BN_FAST_MP_MONTGOMERY_REDUCE_C
2685 /* computes xR**-1 == x (mod N) via Montgomery Reduction
2686  *
2687  * This is an optimized implementation of montgomery_reduce
2688  * which uses the comba method to quickly calculate the columns of the
2689  * reduction.
2690  *
2691  * Based on Algorithm 14.32 on pp.601 of HAC.
2692 */
fast_mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)2693 static int fast_mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho)
2694 {
2695   int     ix, res, olduse;
2696   mp_word W[MP_WARRAY];
2697 
2698   /* get old used count */
2699   olduse = x->used;
2700 
2701   /* grow a as required */
2702   if (x->alloc < n->used + 1) {
2703     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
2704       return res;
2705     }
2706   }
2707 
2708   /* first we have to get the digits of the input into
2709    * an array of double precision words W[...]
2710    */
2711   {
2712     register mp_word *_W;
2713     register mp_digit *tmpx;
2714 
2715     /* alias for the W[] array */
2716     _W   = W;
2717 
2718     /* alias for the digits of  x*/
2719     tmpx = x->dp;
2720 
2721     /* copy the digits of a into W[0..a->used-1] */
2722     for (ix = 0; ix < x->used; ix++) {
2723       *_W++ = *tmpx++;
2724     }
2725 
2726     /* zero the high words of W[a->used..m->used*2] */
2727     for (; ix < n->used * 2 + 1; ix++) {
2728       *_W++ = 0;
2729     }
2730   }
2731 
2732   /* now we proceed to zero successive digits
2733    * from the least significant upwards
2734    */
2735   for (ix = 0; ix < n->used; ix++) {
2736     /* mu = ai * m' mod b
2737      *
2738      * We avoid a double precision multiplication (which isn't required)
2739      * by casting the value down to a mp_digit.  Note this requires
2740      * that W[ix-1] have  the carry cleared (see after the inner loop)
2741      */
2742     register mp_digit mu;
2743     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
2744 
2745     /* a = a + mu * m * b**i
2746      *
2747      * This is computed in place and on the fly.  The multiplication
2748      * by b**i is handled by offseting which columns the results
2749      * are added to.
2750      *
2751      * Note the comba method normally doesn't handle carries in the
2752      * inner loop In this case we fix the carry from the previous
2753      * column since the Montgomery reduction requires digits of the
2754      * result (so far) [see above] to work.  This is
2755      * handled by fixing up one carry after the inner loop.  The
2756      * carry fixups are done in order so after these loops the
2757      * first m->used words of W[] have the carries fixed
2758      */
2759     {
2760       register int iy;
2761       register mp_digit *tmpn;
2762       register mp_word *_W;
2763 
2764       /* alias for the digits of the modulus */
2765       tmpn = n->dp;
2766 
2767       /* Alias for the columns set by an offset of ix */
2768       _W = W + ix;
2769 
2770       /* inner loop */
2771       for (iy = 0; iy < n->used; iy++) {
2772           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
2773       }
2774     }
2775 
2776     /* now fix carry for next digit, W[ix+1] */
2777     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
2778   }
2779 
2780   /* now we have to propagate the carries and
2781    * shift the words downward [all those least
2782    * significant digits we zeroed].
2783    */
2784   {
2785     register mp_digit *tmpx;
2786     register mp_word *_W, *_W1;
2787 
2788     /* nox fix rest of carries */
2789 
2790     /* alias for current word */
2791     _W1 = W + ix;
2792 
2793     /* alias for next word, where the carry goes */
2794     _W = W + ++ix;
2795 
2796     for (; ix <= n->used * 2 + 1; ix++) {
2797       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
2798     }
2799 
2800     /* copy out, A = A/b**n
2801      *
2802      * The result is A/b**n but instead of converting from an
2803      * array of mp_word to mp_digit than calling mp_rshd
2804      * we just copy them in the right order
2805      */
2806 
2807     /* alias for destination word */
2808     tmpx = x->dp;
2809 
2810     /* alias for shifted double precision result */
2811     _W = W + n->used;
2812 
2813     for (ix = 0; ix < n->used + 1; ix++) {
2814       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
2815     }
2816 
2817     /* zero oldused digits, if the input a was larger than
2818      * m->used+1 we'll have to clear the digits
2819      */
2820     for (; ix < olduse; ix++) {
2821       *tmpx++ = 0;
2822     }
2823   }
2824 
2825   /* set the max used and clamp */
2826   x->used = n->used + 1;
2827   mp_clamp (x);
2828 
2829   /* if A >= m then A = A - m */
2830   if (mp_cmp_mag (x, n) != MP_LT) {
2831     return s_mp_sub (x, n, x);
2832   }
2833   return MP_OKAY;
2834 }
2835 #endif
2836 
2837 
2838 #ifdef BN_MP_MUL_2_C
2839 /* b = a*2 */
mp_mul_2(mp_int * a,mp_int * b)2840 static int mp_mul_2(mp_int * a, mp_int * b)
2841 {
2842   int     x, res, oldused;
2843 
2844   /* grow to accommodate result */
2845   if (b->alloc < a->used + 1) {
2846     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2847       return res;
2848     }
2849   }
2850 
2851   oldused = b->used;
2852   b->used = a->used;
2853 
2854   {
2855     register mp_digit r, rr, *tmpa, *tmpb;
2856 
2857     /* alias for source */
2858     tmpa = a->dp;
2859 
2860     /* alias for dest */
2861     tmpb = b->dp;
2862 
2863     /* carry */
2864     r = 0;
2865     for (x = 0; x < a->used; x++) {
2866 
2867       /* get what will be the *next* carry bit from the
2868        * MSB of the current digit
2869        */
2870       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2871 
2872       /* now shift up this digit, add in the carry [from the previous] */
2873       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2874 
2875       /* copy the carry that would be from the source
2876        * digit into the next iteration
2877        */
2878       r = rr;
2879     }
2880 
2881     /* new leading digit? */
2882     if (r != 0) {
2883       /* add a MSB which is always 1 at this point */
2884       *tmpb = 1;
2885       ++(b->used);
2886     }
2887 
2888     /* now zero any excess digits on the destination
2889      * that we didn't write to
2890      */
2891     tmpb = b->dp + b->used;
2892     for (x = b->used; x < oldused; x++) {
2893       *tmpb++ = 0;
2894     }
2895   }
2896   b->sign = a->sign;
2897   return MP_OKAY;
2898 }
2899 #endif
2900 
2901 
2902 #ifdef BN_MP_MONTGOMERY_CALC_NORMALIZATION_C
2903 /*
2904  * shifts with subtractions when the result is greater than b.
2905  *
2906  * The method is slightly modified to shift B unconditionally up to just under
2907  * the leading bit of b.  This saves a lot of multiple precision shifting.
2908  */
mp_montgomery_calc_normalization(mp_int * a,mp_int * b)2909 static int mp_montgomery_calc_normalization (mp_int * a, mp_int * b)
2910 {
2911   int     x, bits, res;
2912 
2913   /* how many bits of last digit does b use */
2914   bits = mp_count_bits (b) % DIGIT_BIT;
2915 
2916   if (b->used > 1) {
2917      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2918         return res;
2919      }
2920   } else {
2921      mp_set(a, 1);
2922      bits = 1;
2923   }
2924 
2925 
2926   /* now compute C = A * B mod b */
2927   for (x = bits - 1; x < (int)DIGIT_BIT; x++) {
2928     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2929       return res;
2930     }
2931     if (mp_cmp_mag (a, b) != MP_LT) {
2932       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2933         return res;
2934       }
2935     }
2936   }
2937 
2938   return MP_OKAY;
2939 }
2940 #endif
2941 
2942 
2943 #ifdef BN_MP_EXPTMOD_FAST_C
2944 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
2945  *
2946  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
2947  * The value of k changes based on the size of the exponent.
2948  *
2949  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
2950  */
2951 
mp_exptmod_fast(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)2952 static int mp_exptmod_fast (mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
2953 {
2954   mp_int  M[TAB_SIZE], res;
2955   mp_digit buf, mp;
2956   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
2957 
2958   /* use a pointer to the reduction algorithm.  This allows us to use
2959    * one of many reduction algorithms without modding the guts of
2960    * the code with if statements everywhere.
2961    */
2962   int     (*redux)(mp_int*,mp_int*,mp_digit);
2963 
2964   /* find window size */
2965   x = mp_count_bits (X);
2966   if (x <= 7) {
2967     winsize = 2;
2968   } else if (x <= 36) {
2969     winsize = 3;
2970   } else if (x <= 140) {
2971     winsize = 4;
2972   } else if (x <= 450) {
2973     winsize = 5;
2974   } else if (x <= 1303) {
2975     winsize = 6;
2976   } else if (x <= 3529) {
2977     winsize = 7;
2978   } else {
2979     winsize = 8;
2980   }
2981 
2982 #ifdef MP_LOW_MEM
2983   if (winsize > 5) {
2984      winsize = 5;
2985   }
2986 #endif
2987 
2988   /* init M array */
2989   /* init first cell */
2990   if ((err = mp_init(&M[1])) != MP_OKAY) {
2991      return err;
2992   }
2993 
2994   /* now init the second half of the array */
2995   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2996     if ((err = mp_init(&M[x])) != MP_OKAY) {
2997       for (y = 1<<(winsize-1); y < x; y++) {
2998         mp_clear (&M[y]);
2999       }
3000       mp_clear(&M[1]);
3001       return err;
3002     }
3003   }
3004 
3005   /* determine and setup reduction code */
3006   if (redmode == 0) {
3007 #ifdef BN_MP_MONTGOMERY_SETUP_C
3008      /* now setup montgomery  */
3009      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
3010         goto LBL_M;
3011      }
3012 #else
3013      err = MP_VAL;
3014      goto LBL_M;
3015 #endif
3016 
3017      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
3018 #ifdef BN_FAST_MP_MONTGOMERY_REDUCE_C
3019      if (((P->used * 2 + 1) < MP_WARRAY) &&
3020           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3021         redux = fast_mp_montgomery_reduce;
3022      } else
3023 #endif
3024      {
3025 #ifdef BN_MP_MONTGOMERY_REDUCE_C
3026         /* use slower baseline Montgomery method */
3027         redux = mp_montgomery_reduce;
3028 #else
3029         err = MP_VAL;
3030         goto LBL_M;
3031 #endif
3032      }
3033   } else if (redmode == 1) {
3034 #if defined(BN_MP_DR_SETUP_C) && defined(BN_MP_DR_REDUCE_C)
3035      /* setup DR reduction for moduli of the form B**k - b */
3036      mp_dr_setup(P, &mp);
3037      redux = mp_dr_reduce;
3038 #else
3039      err = MP_VAL;
3040      goto LBL_M;
3041 #endif
3042   } else {
3043 #if defined(BN_MP_REDUCE_2K_SETUP_C) && defined(BN_MP_REDUCE_2K_C)
3044      /* setup DR reduction for moduli of the form 2**k - b */
3045      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
3046         goto LBL_M;
3047      }
3048      redux = mp_reduce_2k;
3049 #else
3050      err = MP_VAL;
3051      goto LBL_M;
3052 #endif
3053   }
3054 
3055   /* setup result */
3056   if ((err = mp_init (&res)) != MP_OKAY) {
3057     goto LBL_M;
3058   }
3059 
3060   /* create M table
3061    *
3062 
3063    *
3064    * The first half of the table is not computed though accept for M[0] and M[1]
3065    */
3066 
3067   if (redmode == 0) {
3068 #ifdef BN_MP_MONTGOMERY_CALC_NORMALIZATION_C
3069      /* now we need R mod m */
3070      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
3071        goto LBL_RES;
3072      }
3073 #else
3074      err = MP_VAL;
3075      goto LBL_RES;
3076 #endif
3077 
3078      /* now set M[1] to G * R mod m */
3079      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
3080        goto LBL_RES;
3081      }
3082   } else {
3083      mp_set(&res, 1);
3084      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
3085         goto LBL_RES;
3086      }
3087   }
3088 
3089   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
3090   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
3091     goto LBL_RES;
3092   }
3093 
3094   for (x = 0; x < (winsize - 1); x++) {
3095     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
3096       goto LBL_RES;
3097     }
3098     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
3099       goto LBL_RES;
3100     }
3101   }
3102 
3103   /* create upper table */
3104   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
3105     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
3106       goto LBL_RES;
3107     }
3108     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
3109       goto LBL_RES;
3110     }
3111   }
3112 
3113   /* set initial mode and bit cnt */
3114   mode   = 0;
3115   bitcnt = 1;
3116   buf    = 0;
3117   digidx = X->used - 1;
3118   bitcpy = 0;
3119   bitbuf = 0;
3120 
3121   for (;;) {
3122     /* grab next digit as required */
3123     if (--bitcnt == 0) {
3124       /* if digidx == -1 we are out of digits so break */
3125       if (digidx == -1) {
3126         break;
3127       }
3128       /* read next digit and reset bitcnt */
3129       buf    = X->dp[digidx--];
3130       bitcnt = (int)DIGIT_BIT;
3131     }
3132 
3133     /* grab the next msb from the exponent */
3134     y     = (mp_digit)(buf >> (DIGIT_BIT - 1)) & 1;
3135     buf <<= (mp_digit)1;
3136 
3137     /* if the bit is zero and mode == 0 then we ignore it
3138      * These represent the leading zero bits before the first 1 bit
3139      * in the exponent.  Technically this opt is not required but it
3140      * does lower the # of trivial squaring/reductions used
3141      */
3142     if (mode == 0 && y == 0) {
3143       continue;
3144     }
3145 
3146     /* if the bit is zero and mode == 1 then we square */
3147     if (mode == 1 && y == 0) {
3148       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3149         goto LBL_RES;
3150       }
3151       if ((err = redux (&res, P, mp)) != MP_OKAY) {
3152         goto LBL_RES;
3153       }
3154       continue;
3155     }
3156 
3157     /* else we add it to the window */
3158     bitbuf |= (y << (winsize - ++bitcpy));
3159     mode    = 2;
3160 
3161     if (bitcpy == winsize) {
3162       /* ok window is filled so square as required and multiply  */
3163       /* square first */
3164       for (x = 0; x < winsize; x++) {
3165         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3166           goto LBL_RES;
3167         }
3168         if ((err = redux (&res, P, mp)) != MP_OKAY) {
3169           goto LBL_RES;
3170         }
3171       }
3172 
3173       /* then multiply */
3174       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
3175         goto LBL_RES;
3176       }
3177       if ((err = redux (&res, P, mp)) != MP_OKAY) {
3178         goto LBL_RES;
3179       }
3180 
3181       /* empty window and reset */
3182       bitcpy = 0;
3183       bitbuf = 0;
3184       mode   = 1;
3185     }
3186   }
3187 
3188   /* if bits remain then square/multiply */
3189   if (mode == 2 && bitcpy > 0) {
3190     /* square then multiply if the bit is set */
3191     for (x = 0; x < bitcpy; x++) {
3192       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3193         goto LBL_RES;
3194       }
3195       if ((err = redux (&res, P, mp)) != MP_OKAY) {
3196         goto LBL_RES;
3197       }
3198 
3199       /* get next bit of the window */
3200       bitbuf <<= 1;
3201       if ((bitbuf & (1 << winsize)) != 0) {
3202         /* then multiply */
3203         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
3204           goto LBL_RES;
3205         }
3206         if ((err = redux (&res, P, mp)) != MP_OKAY) {
3207           goto LBL_RES;
3208         }
3209       }
3210     }
3211   }
3212 
3213   if (redmode == 0) {
3214      /* fixup result if Montgomery reduction is used
3215       * recall that any value in a Montgomery system is
3216       * actually multiplied by R mod n.  So we have
3217       * to reduce one more time to cancel out the factor
3218       * of R.
3219       */
3220      if ((err = redux(&res, P, mp)) != MP_OKAY) {
3221        goto LBL_RES;
3222      }
3223   }
3224 
3225   /* swap res with Y */
3226   mp_exch (&res, Y);
3227   err = MP_OKAY;
3228 LBL_RES:mp_clear (&res);
3229 LBL_M:
3230   mp_clear(&M[1]);
3231   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3232     mp_clear (&M[x]);
3233   }
3234   return err;
3235 }
3236 #endif
3237 
3238 
3239 #ifdef BN_FAST_S_MP_SQR_C
3240 /* the jist of squaring...
3241  * you do like mult except the offset of the tmpx [one that
3242  * starts closer to zero] can't equal the offset of tmpy.
3243  * So basically you set up iy like before then you min it with
3244  * (ty-tx) so that it never happens.  You double all those
3245  * you add in the inner loop
3246 
3247 After that loop you do the squares and add them in.
3248 */
3249 
fast_s_mp_sqr(mp_int * a,mp_int * b)3250 static int fast_s_mp_sqr (mp_int * a, mp_int * b)
3251 {
3252   int       olduse, res, pa, ix, iz;
3253   mp_digit   W[MP_WARRAY], *tmpx;
3254   mp_word   W1;
3255 
3256   /* grow the destination as required */
3257   pa = a->used + a->used;
3258   if (b->alloc < pa) {
3259     if ((res = mp_grow (b, pa)) != MP_OKAY) {
3260       return res;
3261     }
3262   }
3263 
3264   /* number of output digits to produce */
3265   W1 = 0;
3266   for (ix = 0; ix < pa; ix++) {
3267       int      tx, ty, iy;
3268       mp_word  _W;
3269       mp_digit *tmpy;
3270 
3271       /* clear counter */
3272       _W = 0;
3273 
3274       /* get offsets into the two bignums */
3275       ty = MIN(a->used-1, ix);
3276       tx = ix - ty;
3277 
3278       /* setup temp aliases */
3279       tmpx = a->dp + tx;
3280       tmpy = a->dp + ty;
3281 
3282       /* this is the number of times the loop will iterrate, essentially
3283          while (tx++ < a->used && ty-- >= 0) { ... }
3284        */
3285       iy = MIN(a->used-tx, ty+1);
3286 
3287       /* now for squaring tx can never equal ty
3288        * we halve the distance since they approach at a rate of 2x
3289        * and we have to round because odd cases need to be executed
3290        */
3291       iy = MIN(iy, (ty-tx+1)>>1);
3292 
3293       /* execute loop */
3294       for (iz = 0; iz < iy; iz++) {
3295          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
3296       }
3297 
3298       /* double the inner product and add carry */
3299       _W = _W + _W + W1;
3300 
3301       /* even columns have the square term in them */
3302       if ((ix&1) == 0) {
3303          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
3304       }
3305 
3306       /* store it */
3307       W[ix] = (mp_digit)(_W & MP_MASK);
3308 
3309       /* make next carry */
3310       W1 = _W >> ((mp_word)DIGIT_BIT);
3311   }
3312 
3313   /* setup dest */
3314   olduse  = b->used;
3315   b->used = a->used+a->used;
3316 
3317   {
3318     mp_digit *tmpb;
3319     tmpb = b->dp;
3320     for (ix = 0; ix < pa; ix++) {
3321       *tmpb++ = W[ix] & MP_MASK;
3322     }
3323 
3324     /* clear unused digits [that existed in the old copy of c] */
3325     for (; ix < olduse; ix++) {
3326       *tmpb++ = 0;
3327     }
3328   }
3329   mp_clamp (b);
3330   return MP_OKAY;
3331 }
3332 #endif
3333 
3334 
3335 #ifdef BN_MP_MUL_D_C
3336 /* multiply by a digit */
3337 static int
mp_mul_d(mp_int * a,mp_digit b,mp_int * c)3338 mp_mul_d (mp_int * a, mp_digit b, mp_int * c)
3339 {
3340   mp_digit u, *tmpa, *tmpc;
3341   mp_word  r;
3342   int      ix, res, olduse;
3343 
3344   /* make sure c is big enough to hold a*b */
3345   if (c->alloc < a->used + 1) {
3346     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
3347       return res;
3348     }
3349   }
3350 
3351   /* get the original destinations used count */
3352   olduse = c->used;
3353 
3354   /* set the sign */
3355   c->sign = a->sign;
3356 
3357   /* alias for a->dp [source] */
3358   tmpa = a->dp;
3359 
3360   /* alias for c->dp [dest] */
3361   tmpc = c->dp;
3362 
3363   /* zero carry */
3364   u = 0;
3365 
3366   /* compute columns */
3367   for (ix = 0; ix < a->used; ix++) {
3368     /* compute product and carry sum for this term */
3369     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
3370 
3371     /* mask off higher bits to get a single digit */
3372     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
3373 
3374     /* send carry into next iteration */
3375     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
3376   }
3377 
3378   /* store final carry [if any] and increment ix offset  */
3379   *tmpc++ = u;
3380   ++ix;
3381 
3382   /* now zero digits above the top */
3383   while (ix++ < olduse) {
3384      *tmpc++ = 0;
3385   }
3386 
3387   /* set used count */
3388   c->used = a->used + 1;
3389   mp_clamp(c);
3390 
3391   return MP_OKAY;
3392 }
3393 #endif
3394