xref: /freebsd/crypto/openssh/sntrup761.c (revision 7899f917b1c0ea178f1d2be0cfb452086d079d23)
1 /*  $OpenBSD: sntrup761.c,v 1.6 2023/01/11 02:13:52 djm Exp $ */
2 
3 /*
4  * Public Domain, Authors:
5  * - Daniel J. Bernstein
6  * - Chitchanok Chuengsatiansup
7  * - Tanja Lange
8  * - Christine van Vredendaal
9  */
10 
11 #include "includes.h"
12 
13 #ifdef USE_SNTRUP761X25519
14 
15 #include <string.h>
16 #include "crypto_api.h"
17 
18 #define int8 crypto_int8
19 #define uint8 crypto_uint8
20 #define int16 crypto_int16
21 #define uint16 crypto_uint16
22 #define int32 crypto_int32
23 #define uint32 crypto_uint32
24 #define int64 crypto_int64
25 #define uint64 crypto_uint64
26 
27 /* from supercop-20201130/crypto_sort/int32/portable4/int32_minmax.inc */
28 #define int32_MINMAX(a,b) \
29 do { \
30   int64_t ab = (int64_t)b ^ (int64_t)a; \
31   int64_t c = (int64_t)b - (int64_t)a; \
32   c ^= ab & (c ^ b); \
33   c >>= 31; \
34   c &= ab; \
35   a ^= c; \
36   b ^= c; \
37 } while(0)
38 
39 /* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
40 
41 
42 static void crypto_sort_int32(void *array,long long n)
43 {
44   long long top,p,q,r,i,j;
45   int32 *x = array;
46 
47   if (n < 2) return;
48   top = 1;
49   while (top < n - top) top += top;
50 
51   for (p = top;p >= 1;p >>= 1) {
52     i = 0;
53     while (i + 2 * p <= n) {
54       for (j = i;j < i + p;++j)
55         int32_MINMAX(x[j],x[j+p]);
56       i += 2 * p;
57     }
58     for (j = i;j < n - p;++j)
59       int32_MINMAX(x[j],x[j+p]);
60 
61     i = 0;
62     j = 0;
63     for (q = top;q > p;q >>= 1) {
64       if (j != i) for (;;) {
65         if (j == n - q) goto done;
66         int32 a = x[j + p];
67         for (r = q;r > p;r >>= 1)
68           int32_MINMAX(a,x[j + r]);
69         x[j + p] = a;
70         ++j;
71         if (j == i + p) {
72           i += 2 * p;
73           break;
74         }
75       }
76       while (i + p <= n - q) {
77         for (j = i;j < i + p;++j) {
78           int32 a = x[j + p];
79           for (r = q;r > p;r >>= 1)
80             int32_MINMAX(a,x[j+r]);
81           x[j + p] = a;
82         }
83         i += 2 * p;
84       }
85       /* now i + p > n - q */
86       j = i;
87       while (j < n - q) {
88         int32 a = x[j + p];
89         for (r = q;r > p;r >>= 1)
90           int32_MINMAX(a,x[j+r]);
91         x[j + p] = a;
92         ++j;
93       }
94 
95       done: ;
96     }
97   }
98 }
99 
100 /* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
101 
102 /* can save time by vectorizing xor loops */
103 /* can save time by integrating xor loops with int32_sort */
104 
105 static void crypto_sort_uint32(void *array,long long n)
106 {
107   crypto_uint32 *x = array;
108   long long j;
109   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
110   crypto_sort_int32(array,n);
111   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
112 }
113 
114 /* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
115 
116 /*
117 CPU division instruction typically takes time depending on x.
118 This software is designed to take time independent of x.
119 Time still varies depending on m; user must ensure that m is constant.
120 Time also varies on CPUs where multiplication is variable-time.
121 There could be more CPU issues.
122 There could also be compiler issues.
123 */
124 
125 static void uint32_divmod_uint14(uint32 *q,uint16 *r,uint32 x,uint16 m)
126 {
127   uint32 v = 0x80000000;
128   uint32 qpart;
129   uint32 mask;
130 
131   v /= m;
132 
133   /* caller guarantees m > 0 */
134   /* caller guarantees m < 16384 */
135   /* vm <= 2^31 <= vm+m-1 */
136   /* xvm <= 2^31 x <= xvm+x(m-1) */
137 
138   *q = 0;
139 
140   qpart = (x*(uint64)v)>>31;
141   /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
142   /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
143   /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
144   /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
145   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
146   /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
147 
148   x -= qpart*m; *q += qpart;
149   /* x <= 49146 */
150 
151   qpart = (x*(uint64)v)>>31;
152   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
153   /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
154   /* 0 <= newx <= m + 0.4 */
155   /* 0 <= newx <= m */
156 
157   x -= qpart*m; *q += qpart;
158   /* x <= m */
159 
160   x -= m; *q += 1;
161   mask = -(x>>31);
162   x += mask&(uint32)m; *q += mask;
163   /* x < m */
164 
165   *r = x;
166 }
167 
168 
169 static uint16 uint32_mod_uint14(uint32 x,uint16 m)
170 {
171   uint32 q;
172   uint16 r;
173   uint32_divmod_uint14(&q,&r,x,m);
174   return r;
175 }
176 
177 /* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
178 
179 static void int32_divmod_uint14(int32 *q,uint16 *r,int32 x,uint16 m)
180 {
181   uint32 uq,uq2;
182   uint16 ur,ur2;
183   uint32 mask;
184 
185   uint32_divmod_uint14(&uq,&ur,0x80000000+(uint32)x,m);
186   uint32_divmod_uint14(&uq2,&ur2,0x80000000,m);
187   ur -= ur2; uq -= uq2;
188   mask = -(uint32)(ur>>15);
189   ur += mask&m; uq += mask;
190   *r = ur; *q = uq;
191 }
192 
193 
194 static uint16 int32_mod_uint14(int32 x,uint16 m)
195 {
196   int32 q;
197   uint16 r;
198   int32_divmod_uint14(&q,&r,x,m);
199   return r;
200 }
201 
202 /* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
203 /* pick one of these three: */
204 #define SIZE761
205 #undef SIZE653
206 #undef SIZE857
207 
208 /* pick one of these two: */
209 #define SNTRUP /* Streamlined NTRU Prime */
210 #undef LPR /* NTRU LPRime */
211 
212 /* from supercop-20201130/crypto_kem/sntrup761/ref/params.h */
213 #ifndef params_H
214 #define params_H
215 
216 /* menu of parameter choices: */
217 
218 
219 /* what the menu means: */
220 
221 #if defined(SIZE761)
222 #define p 761
223 #define q 4591
224 #define Rounded_bytes 1007
225 #ifndef LPR
226 #define Rq_bytes 1158
227 #define w 286
228 #else
229 #define w 250
230 #define tau0 2156
231 #define tau1 114
232 #define tau2 2007
233 #define tau3 287
234 #endif
235 
236 #elif defined(SIZE653)
237 #define p 653
238 #define q 4621
239 #define Rounded_bytes 865
240 #ifndef LPR
241 #define Rq_bytes 994
242 #define w 288
243 #else
244 #define w 252
245 #define tau0 2175
246 #define tau1 113
247 #define tau2 2031
248 #define tau3 290
249 #endif
250 
251 #elif defined(SIZE857)
252 #define p 857
253 #define q 5167
254 #define Rounded_bytes 1152
255 #ifndef LPR
256 #define Rq_bytes 1322
257 #define w 322
258 #else
259 #define w 281
260 #define tau0 2433
261 #define tau1 101
262 #define tau2 2265
263 #define tau3 324
264 #endif
265 
266 #else
267 #error "no parameter set defined"
268 #endif
269 
270 #ifdef LPR
271 #define I 256
272 #endif
273 
274 #endif
275 
276 /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
277 #ifndef Decode_H
278 #define Decode_H
279 
280 
281 /* Decode(R,s,M,len) */
282 /* assumes 0 < M[i] < 16384 */
283 /* produces 0 <= R[i] < M[i] */
284 
285 #endif
286 
287 /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
288 
289 static void Decode(uint16 *out,const unsigned char *S,const uint16 *M,long long len)
290 {
291   if (len == 1) {
292     if (M[0] == 1)
293       *out = 0;
294     else if (M[0] <= 256)
295       *out = uint32_mod_uint14(S[0],M[0]);
296     else
297       *out = uint32_mod_uint14(S[0]+(((uint16)S[1])<<8),M[0]);
298   }
299   if (len > 1) {
300     uint16 R2[(len+1)/2];
301     uint16 M2[(len+1)/2];
302     uint16 bottomr[len/2];
303     uint32 bottomt[len/2];
304     long long i;
305     for (i = 0;i < len-1;i += 2) {
306       uint32 m = M[i]*(uint32) M[i+1];
307       if (m > 256*16383) {
308         bottomt[i/2] = 256*256;
309         bottomr[i/2] = S[0]+256*S[1];
310         S += 2;
311         M2[i/2] = (((m+255)>>8)+255)>>8;
312       } else if (m >= 16384) {
313         bottomt[i/2] = 256;
314         bottomr[i/2] = S[0];
315         S += 1;
316         M2[i/2] = (m+255)>>8;
317       } else {
318         bottomt[i/2] = 1;
319         bottomr[i/2] = 0;
320         M2[i/2] = m;
321       }
322     }
323     if (i < len)
324       M2[i/2] = M[i];
325     Decode(R2,S,M2,(len+1)/2);
326     for (i = 0;i < len-1;i += 2) {
327       uint32 r = bottomr[i/2];
328       uint32 r1;
329       uint16 r0;
330       r += bottomt[i/2]*R2[i/2];
331       uint32_divmod_uint14(&r1,&r0,r,M[i]);
332       r1 = uint32_mod_uint14(r1,M[i+1]); /* only needed for invalid inputs */
333       *out++ = r0;
334       *out++ = r1;
335     }
336     if (i < len)
337       *out++ = R2[i/2];
338   }
339 }
340 
341 /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
342 #ifndef Encode_H
343 #define Encode_H
344 
345 
346 /* Encode(s,R,M,len) */
347 /* assumes 0 <= R[i] < M[i] < 16384 */
348 
349 #endif
350 
351 /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
352 
353 /* 0 <= R[i] < M[i] < 16384 */
354 static void Encode(unsigned char *out,const uint16 *R,const uint16 *M,long long len)
355 {
356   if (len == 1) {
357     uint16 r = R[0];
358     uint16 m = M[0];
359     while (m > 1) {
360       *out++ = r;
361       r >>= 8;
362       m = (m+255)>>8;
363     }
364   }
365   if (len > 1) {
366     uint16 R2[(len+1)/2];
367     uint16 M2[(len+1)/2];
368     long long i;
369     for (i = 0;i < len-1;i += 2) {
370       uint32 m0 = M[i];
371       uint32 r = R[i]+R[i+1]*m0;
372       uint32 m = M[i+1]*m0;
373       while (m >= 16384) {
374         *out++ = r;
375         r >>= 8;
376         m = (m+255)>>8;
377       }
378       R2[i/2] = r;
379       M2[i/2] = m;
380     }
381     if (i < len) {
382       R2[i/2] = R[i];
383       M2[i/2] = M[i];
384     }
385     Encode(out,R2,M2,(len+1)/2);
386   }
387 }
388 
389 /* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
390 
391 #ifdef LPR
392 #endif
393 
394 
395 /* ----- masks */
396 
397 #ifndef LPR
398 
399 /* return -1 if x!=0; else return 0 */
400 static int int16_nonzero_mask(int16 x)
401 {
402   uint16 u = x; /* 0, else 1...65535 */
403   uint32 v = u; /* 0, else 1...65535 */
404   v = -v; /* 0, else 2^32-65535...2^32-1 */
405   v >>= 31; /* 0, else 1 */
406   return -v; /* 0, else -1 */
407 }
408 
409 #endif
410 
411 /* return -1 if x<0; otherwise return 0 */
412 static int int16_negative_mask(int16 x)
413 {
414   uint16 u = x;
415   u >>= 15;
416   return -(int) u;
417   /* alternative with gcc -fwrapv: */
418   /* x>>15 compiles to CPU's arithmetic right shift */
419 }
420 
421 /* ----- arithmetic mod 3 */
422 
423 typedef int8 small;
424 
425 /* F3 is always represented as -1,0,1 */
426 /* so ZZ_fromF3 is a no-op */
427 
428 /* x must not be close to top int16 */
429 static small F3_freeze(int16 x)
430 {
431   return int32_mod_uint14(x+1,3)-1;
432 }
433 
434 /* ----- arithmetic mod q */
435 
436 #define q12 ((q-1)/2)
437 typedef int16 Fq;
438 /* always represented as -q12...q12 */
439 /* so ZZ_fromFq is a no-op */
440 
441 /* x must not be close to top int32 */
442 static Fq Fq_freeze(int32 x)
443 {
444   return int32_mod_uint14(x+q12,q)-q12;
445 }
446 
447 #ifndef LPR
448 
449 static Fq Fq_recip(Fq a1)
450 {
451   int i = 1;
452   Fq ai = a1;
453 
454   while (i < q-2) {
455     ai = Fq_freeze(a1*(int32)ai);
456     i += 1;
457   }
458   return ai;
459 }
460 
461 #endif
462 
463 /* ----- Top and Right */
464 
465 #ifdef LPR
466 #define tau 16
467 
468 static int8 Top(Fq C)
469 {
470   return (tau1*(int32)(C+tau0)+16384)>>15;
471 }
472 
473 static Fq Right(int8 T)
474 {
475   return Fq_freeze(tau3*(int32)T-tau2);
476 }
477 #endif
478 
479 /* ----- small polynomials */
480 
481 #ifndef LPR
482 
483 /* 0 if Weightw_is(r), else -1 */
484 static int Weightw_mask(small *r)
485 {
486   int weight = 0;
487   int i;
488 
489   for (i = 0;i < p;++i) weight += r[i]&1;
490   return int16_nonzero_mask(weight-w);
491 }
492 
493 /* R3_fromR(R_fromRq(r)) */
494 static void R3_fromRq(small *out,const Fq *r)
495 {
496   int i;
497   for (i = 0;i < p;++i) out[i] = F3_freeze(r[i]);
498 }
499 
500 /* h = f*g in the ring R3 */
501 static void R3_mult(small *h,const small *f,const small *g)
502 {
503   small fg[p+p-1];
504   small result;
505   int i,j;
506 
507   for (i = 0;i < p;++i) {
508     result = 0;
509     for (j = 0;j <= i;++j) result = F3_freeze(result+f[j]*g[i-j]);
510     fg[i] = result;
511   }
512   for (i = p;i < p+p-1;++i) {
513     result = 0;
514     for (j = i-p+1;j < p;++j) result = F3_freeze(result+f[j]*g[i-j]);
515     fg[i] = result;
516   }
517 
518   for (i = p+p-2;i >= p;--i) {
519     fg[i-p] = F3_freeze(fg[i-p]+fg[i]);
520     fg[i-p+1] = F3_freeze(fg[i-p+1]+fg[i]);
521   }
522 
523   for (i = 0;i < p;++i) h[i] = fg[i];
524 }
525 
526 /* returns 0 if recip succeeded; else -1 */
527 static int R3_recip(small *out,const small *in)
528 {
529   small f[p+1],g[p+1],v[p+1],r[p+1];
530   int i,loop,delta;
531   int sign,swap,t;
532 
533   for (i = 0;i < p+1;++i) v[i] = 0;
534   for (i = 0;i < p+1;++i) r[i] = 0;
535   r[0] = 1;
536   for (i = 0;i < p;++i) f[i] = 0;
537   f[0] = 1; f[p-1] = f[p] = -1;
538   for (i = 0;i < p;++i) g[p-1-i] = in[i];
539   g[p] = 0;
540 
541   delta = 1;
542 
543   for (loop = 0;loop < 2*p-1;++loop) {
544     for (i = p;i > 0;--i) v[i] = v[i-1];
545     v[0] = 0;
546 
547     sign = -g[0]*f[0];
548     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
549     delta ^= swap&(delta^-delta);
550     delta += 1;
551 
552     for (i = 0;i < p+1;++i) {
553       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
554       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
555     }
556 
557     for (i = 0;i < p+1;++i) g[i] = F3_freeze(g[i]+sign*f[i]);
558     for (i = 0;i < p+1;++i) r[i] = F3_freeze(r[i]+sign*v[i]);
559 
560     for (i = 0;i < p;++i) g[i] = g[i+1];
561     g[p] = 0;
562   }
563 
564   sign = f[0];
565   for (i = 0;i < p;++i) out[i] = sign*v[p-1-i];
566 
567   return int16_nonzero_mask(delta);
568 }
569 
570 #endif
571 
572 /* ----- polynomials mod q */
573 
574 /* h = f*g in the ring Rq */
575 static void Rq_mult_small(Fq *h,const Fq *f,const small *g)
576 {
577   Fq fg[p+p-1];
578   Fq result;
579   int i,j;
580 
581   for (i = 0;i < p;++i) {
582     result = 0;
583     for (j = 0;j <= i;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
584     fg[i] = result;
585   }
586   for (i = p;i < p+p-1;++i) {
587     result = 0;
588     for (j = i-p+1;j < p;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
589     fg[i] = result;
590   }
591 
592   for (i = p+p-2;i >= p;--i) {
593     fg[i-p] = Fq_freeze(fg[i-p]+fg[i]);
594     fg[i-p+1] = Fq_freeze(fg[i-p+1]+fg[i]);
595   }
596 
597   for (i = 0;i < p;++i) h[i] = fg[i];
598 }
599 
600 #ifndef LPR
601 
602 /* h = 3f in Rq */
603 static void Rq_mult3(Fq *h,const Fq *f)
604 {
605   int i;
606 
607   for (i = 0;i < p;++i) h[i] = Fq_freeze(3*f[i]);
608 }
609 
610 /* out = 1/(3*in) in Rq */
611 /* returns 0 if recip succeeded; else -1 */
612 static int Rq_recip3(Fq *out,const small *in)
613 {
614   Fq f[p+1],g[p+1],v[p+1],r[p+1];
615   int i,loop,delta;
616   int swap,t;
617   int32 f0,g0;
618   Fq scale;
619 
620   for (i = 0;i < p+1;++i) v[i] = 0;
621   for (i = 0;i < p+1;++i) r[i] = 0;
622   r[0] = Fq_recip(3);
623   for (i = 0;i < p;++i) f[i] = 0;
624   f[0] = 1; f[p-1] = f[p] = -1;
625   for (i = 0;i < p;++i) g[p-1-i] = in[i];
626   g[p] = 0;
627 
628   delta = 1;
629 
630   for (loop = 0;loop < 2*p-1;++loop) {
631     for (i = p;i > 0;--i) v[i] = v[i-1];
632     v[0] = 0;
633 
634     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
635     delta ^= swap&(delta^-delta);
636     delta += 1;
637 
638     for (i = 0;i < p+1;++i) {
639       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
640       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
641     }
642 
643     f0 = f[0];
644     g0 = g[0];
645     for (i = 0;i < p+1;++i) g[i] = Fq_freeze(f0*g[i]-g0*f[i]);
646     for (i = 0;i < p+1;++i) r[i] = Fq_freeze(f0*r[i]-g0*v[i]);
647 
648     for (i = 0;i < p;++i) g[i] = g[i+1];
649     g[p] = 0;
650   }
651 
652   scale = Fq_recip(f[0]);
653   for (i = 0;i < p;++i) out[i] = Fq_freeze(scale*(int32)v[p-1-i]);
654 
655   return int16_nonzero_mask(delta);
656 }
657 
658 #endif
659 
660 /* ----- rounded polynomials mod q */
661 
662 static void Round(Fq *out,const Fq *a)
663 {
664   int i;
665   for (i = 0;i < p;++i) out[i] = a[i]-F3_freeze(a[i]);
666 }
667 
668 /* ----- sorting to generate short polynomial */
669 
670 static void Short_fromlist(small *out,const uint32 *in)
671 {
672   uint32 L[p];
673   int i;
674 
675   for (i = 0;i < w;++i) L[i] = in[i]&(uint32)-2;
676   for (i = w;i < p;++i) L[i] = (in[i]&(uint32)-3)|1;
677   crypto_sort_uint32(L,p);
678   for (i = 0;i < p;++i) out[i] = (L[i]&3)-1;
679 }
680 
681 /* ----- underlying hash function */
682 
683 #define Hash_bytes 32
684 
685 /* e.g., b = 0 means out = Hash0(in) */
686 static void Hash_prefix(unsigned char *out,int b,const unsigned char *in,int inlen)
687 {
688   unsigned char x[inlen+1];
689   unsigned char h[64];
690   int i;
691 
692   x[0] = b;
693   for (i = 0;i < inlen;++i) x[i+1] = in[i];
694   crypto_hash_sha512(h,x,inlen+1);
695   for (i = 0;i < 32;++i) out[i] = h[i];
696 }
697 
698 /* ----- higher-level randomness */
699 
700 static uint32 urandom32(void)
701 {
702   unsigned char c[4];
703   uint32 out[4];
704 
705   randombytes(c,4);
706   out[0] = (uint32)c[0];
707   out[1] = ((uint32)c[1])<<8;
708   out[2] = ((uint32)c[2])<<16;
709   out[3] = ((uint32)c[3])<<24;
710   return out[0]+out[1]+out[2]+out[3];
711 }
712 
713 static void Short_random(small *out)
714 {
715   uint32 L[p];
716   int i;
717 
718   for (i = 0;i < p;++i) L[i] = urandom32();
719   Short_fromlist(out,L);
720 }
721 
722 #ifndef LPR
723 
724 static void Small_random(small *out)
725 {
726   int i;
727 
728   for (i = 0;i < p;++i) out[i] = (((urandom32()&0x3fffffff)*3)>>30)-1;
729 }
730 
731 #endif
732 
733 /* ----- Streamlined NTRU Prime Core */
734 
735 #ifndef LPR
736 
737 /* h,(f,ginv) = KeyGen() */
738 static void KeyGen(Fq *h,small *f,small *ginv)
739 {
740   small g[p];
741   Fq finv[p];
742 
743   for (;;) {
744     Small_random(g);
745     if (R3_recip(ginv,g) == 0) break;
746   }
747   Short_random(f);
748   Rq_recip3(finv,f); /* always works */
749   Rq_mult_small(h,finv,g);
750 }
751 
752 /* c = Encrypt(r,h) */
753 static void Encrypt(Fq *c,const small *r,const Fq *h)
754 {
755   Fq hr[p];
756 
757   Rq_mult_small(hr,h,r);
758   Round(c,hr);
759 }
760 
761 /* r = Decrypt(c,(f,ginv)) */
762 static void Decrypt(small *r,const Fq *c,const small *f,const small *ginv)
763 {
764   Fq cf[p];
765   Fq cf3[p];
766   small e[p];
767   small ev[p];
768   int mask;
769   int i;
770 
771   Rq_mult_small(cf,c,f);
772   Rq_mult3(cf3,cf);
773   R3_fromRq(e,cf3);
774   R3_mult(ev,e,ginv);
775 
776   mask = Weightw_mask(ev); /* 0 if weight w, else -1 */
777   for (i = 0;i < w;++i) r[i] = ((ev[i]^1)&~mask)^1;
778   for (i = w;i < p;++i) r[i] = ev[i]&~mask;
779 }
780 
781 #endif
782 
783 /* ----- NTRU LPRime Core */
784 
785 #ifdef LPR
786 
787 /* (G,A),a = KeyGen(G); leaves G unchanged */
788 static void KeyGen(Fq *A,small *a,const Fq *G)
789 {
790   Fq aG[p];
791 
792   Short_random(a);
793   Rq_mult_small(aG,G,a);
794   Round(A,aG);
795 }
796 
797 /* B,T = Encrypt(r,(G,A),b) */
798 static void Encrypt(Fq *B,int8 *T,const int8 *r,const Fq *G,const Fq *A,const small *b)
799 {
800   Fq bG[p];
801   Fq bA[p];
802   int i;
803 
804   Rq_mult_small(bG,G,b);
805   Round(B,bG);
806   Rq_mult_small(bA,A,b);
807   for (i = 0;i < I;++i) T[i] = Top(Fq_freeze(bA[i]+r[i]*q12));
808 }
809 
810 /* r = Decrypt((B,T),a) */
811 static void Decrypt(int8 *r,const Fq *B,const int8 *T,const small *a)
812 {
813   Fq aB[p];
814   int i;
815 
816   Rq_mult_small(aB,B,a);
817   for (i = 0;i < I;++i)
818     r[i] = -int16_negative_mask(Fq_freeze(Right(T[i])-aB[i]+4*w+1));
819 }
820 
821 #endif
822 
823 /* ----- encoding I-bit inputs */
824 
825 #ifdef LPR
826 
827 #define Inputs_bytes (I/8)
828 typedef int8 Inputs[I]; /* passed by reference */
829 
830 static void Inputs_encode(unsigned char *s,const Inputs r)
831 {
832   int i;
833   for (i = 0;i < Inputs_bytes;++i) s[i] = 0;
834   for (i = 0;i < I;++i) s[i>>3] |= r[i]<<(i&7);
835 }
836 
837 #endif
838 
839 /* ----- Expand */
840 
841 #ifdef LPR
842 
843 static const unsigned char aes_nonce[16] = {0};
844 
845 static void Expand(uint32 *L,const unsigned char *k)
846 {
847   int i;
848   crypto_stream_aes256ctr((unsigned char *) L,4*p,aes_nonce,k);
849   for (i = 0;i < p;++i) {
850     uint32 L0 = ((unsigned char *) L)[4*i];
851     uint32 L1 = ((unsigned char *) L)[4*i+1];
852     uint32 L2 = ((unsigned char *) L)[4*i+2];
853     uint32 L3 = ((unsigned char *) L)[4*i+3];
854     L[i] = L0+(L1<<8)+(L2<<16)+(L3<<24);
855   }
856 }
857 
858 #endif
859 
860 /* ----- Seeds */
861 
862 #ifdef LPR
863 
864 #define Seeds_bytes 32
865 
866 static void Seeds_random(unsigned char *s)
867 {
868   randombytes(s,Seeds_bytes);
869 }
870 
871 #endif
872 
873 /* ----- Generator, HashShort */
874 
875 #ifdef LPR
876 
877 /* G = Generator(k) */
878 static void Generator(Fq *G,const unsigned char *k)
879 {
880   uint32 L[p];
881   int i;
882 
883   Expand(L,k);
884   for (i = 0;i < p;++i) G[i] = uint32_mod_uint14(L[i],q)-q12;
885 }
886 
887 /* out = HashShort(r) */
888 static void HashShort(small *out,const Inputs r)
889 {
890   unsigned char s[Inputs_bytes];
891   unsigned char h[Hash_bytes];
892   uint32 L[p];
893 
894   Inputs_encode(s,r);
895   Hash_prefix(h,5,s,sizeof s);
896   Expand(L,h);
897   Short_fromlist(out,L);
898 }
899 
900 #endif
901 
902 /* ----- NTRU LPRime Expand */
903 
904 #ifdef LPR
905 
906 /* (S,A),a = XKeyGen() */
907 static void XKeyGen(unsigned char *S,Fq *A,small *a)
908 {
909   Fq G[p];
910 
911   Seeds_random(S);
912   Generator(G,S);
913   KeyGen(A,a,G);
914 }
915 
916 /* B,T = XEncrypt(r,(S,A)) */
917 static void XEncrypt(Fq *B,int8 *T,const int8 *r,const unsigned char *S,const Fq *A)
918 {
919   Fq G[p];
920   small b[p];
921 
922   Generator(G,S);
923   HashShort(b,r);
924   Encrypt(B,T,r,G,A,b);
925 }
926 
927 #define XDecrypt Decrypt
928 
929 #endif
930 
931 /* ----- encoding small polynomials (including short polynomials) */
932 
933 #define Small_bytes ((p+3)/4)
934 
935 /* these are the only functions that rely on p mod 4 = 1 */
936 
937 static void Small_encode(unsigned char *s,const small *f)
938 {
939   small x;
940   int i;
941 
942   for (i = 0;i < p/4;++i) {
943     x = *f++ + 1;
944     x += (*f++ + 1)<<2;
945     x += (*f++ + 1)<<4;
946     x += (*f++ + 1)<<6;
947     *s++ = x;
948   }
949   x = *f++ + 1;
950   *s++ = x;
951 }
952 
953 static void Small_decode(small *f,const unsigned char *s)
954 {
955   unsigned char x;
956   int i;
957 
958   for (i = 0;i < p/4;++i) {
959     x = *s++;
960     *f++ = ((small)(x&3))-1; x >>= 2;
961     *f++ = ((small)(x&3))-1; x >>= 2;
962     *f++ = ((small)(x&3))-1; x >>= 2;
963     *f++ = ((small)(x&3))-1;
964   }
965   x = *s++;
966   *f++ = ((small)(x&3))-1;
967 }
968 
969 /* ----- encoding general polynomials */
970 
971 #ifndef LPR
972 
973 static void Rq_encode(unsigned char *s,const Fq *r)
974 {
975   uint16 R[p],M[p];
976   int i;
977 
978   for (i = 0;i < p;++i) R[i] = r[i]+q12;
979   for (i = 0;i < p;++i) M[i] = q;
980   Encode(s,R,M,p);
981 }
982 
983 static void Rq_decode(Fq *r,const unsigned char *s)
984 {
985   uint16 R[p],M[p];
986   int i;
987 
988   for (i = 0;i < p;++i) M[i] = q;
989   Decode(R,s,M,p);
990   for (i = 0;i < p;++i) r[i] = ((Fq)R[i])-q12;
991 }
992 
993 #endif
994 
995 /* ----- encoding rounded polynomials */
996 
997 static void Rounded_encode(unsigned char *s,const Fq *r)
998 {
999   uint16 R[p],M[p];
1000   int i;
1001 
1002   for (i = 0;i < p;++i) R[i] = ((r[i]+q12)*10923)>>15;
1003   for (i = 0;i < p;++i) M[i] = (q+2)/3;
1004   Encode(s,R,M,p);
1005 }
1006 
1007 static void Rounded_decode(Fq *r,const unsigned char *s)
1008 {
1009   uint16 R[p],M[p];
1010   int i;
1011 
1012   for (i = 0;i < p;++i) M[i] = (q+2)/3;
1013   Decode(R,s,M,p);
1014   for (i = 0;i < p;++i) r[i] = R[i]*3-q12;
1015 }
1016 
1017 /* ----- encoding top polynomials */
1018 
1019 #ifdef LPR
1020 
1021 #define Top_bytes (I/2)
1022 
1023 static void Top_encode(unsigned char *s,const int8 *T)
1024 {
1025   int i;
1026   for (i = 0;i < Top_bytes;++i)
1027     s[i] = T[2*i]+(T[2*i+1]<<4);
1028 }
1029 
1030 static void Top_decode(int8 *T,const unsigned char *s)
1031 {
1032   int i;
1033   for (i = 0;i < Top_bytes;++i) {
1034     T[2*i] = s[i]&15;
1035     T[2*i+1] = s[i]>>4;
1036   }
1037 }
1038 
1039 #endif
1040 
1041 /* ----- Streamlined NTRU Prime Core plus encoding */
1042 
1043 #ifndef LPR
1044 
1045 typedef small Inputs[p]; /* passed by reference */
1046 #define Inputs_random Short_random
1047 #define Inputs_encode Small_encode
1048 #define Inputs_bytes Small_bytes
1049 
1050 #define Ciphertexts_bytes Rounded_bytes
1051 #define SecretKeys_bytes (2*Small_bytes)
1052 #define PublicKeys_bytes Rq_bytes
1053 
1054 /* pk,sk = ZKeyGen() */
1055 static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1056 {
1057   Fq h[p];
1058   small f[p],v[p];
1059 
1060   KeyGen(h,f,v);
1061   Rq_encode(pk,h);
1062   Small_encode(sk,f); sk += Small_bytes;
1063   Small_encode(sk,v);
1064 }
1065 
1066 /* C = ZEncrypt(r,pk) */
1067 static void ZEncrypt(unsigned char *C,const Inputs r,const unsigned char *pk)
1068 {
1069   Fq h[p];
1070   Fq c[p];
1071   Rq_decode(h,pk);
1072   Encrypt(c,r,h);
1073   Rounded_encode(C,c);
1074 }
1075 
1076 /* r = ZDecrypt(C,sk) */
1077 static void ZDecrypt(Inputs r,const unsigned char *C,const unsigned char *sk)
1078 {
1079   small f[p],v[p];
1080   Fq c[p];
1081 
1082   Small_decode(f,sk); sk += Small_bytes;
1083   Small_decode(v,sk);
1084   Rounded_decode(c,C);
1085   Decrypt(r,c,f,v);
1086 }
1087 
1088 #endif
1089 
1090 /* ----- NTRU LPRime Expand plus encoding */
1091 
1092 #ifdef LPR
1093 
1094 #define Ciphertexts_bytes (Rounded_bytes+Top_bytes)
1095 #define SecretKeys_bytes Small_bytes
1096 #define PublicKeys_bytes (Seeds_bytes+Rounded_bytes)
1097 
1098 static void Inputs_random(Inputs r)
1099 {
1100   unsigned char s[Inputs_bytes];
1101   int i;
1102 
1103   randombytes(s,sizeof s);
1104   for (i = 0;i < I;++i) r[i] = 1&(s[i>>3]>>(i&7));
1105 }
1106 
1107 /* pk,sk = ZKeyGen() */
1108 static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1109 {
1110   Fq A[p];
1111   small a[p];
1112 
1113   XKeyGen(pk,A,a); pk += Seeds_bytes;
1114   Rounded_encode(pk,A);
1115   Small_encode(sk,a);
1116 }
1117 
1118 /* c = ZEncrypt(r,pk) */
1119 static void ZEncrypt(unsigned char *c,const Inputs r,const unsigned char *pk)
1120 {
1121   Fq A[p];
1122   Fq B[p];
1123   int8 T[I];
1124 
1125   Rounded_decode(A,pk+Seeds_bytes);
1126   XEncrypt(B,T,r,pk,A);
1127   Rounded_encode(c,B); c += Rounded_bytes;
1128   Top_encode(c,T);
1129 }
1130 
1131 /* r = ZDecrypt(C,sk) */
1132 static void ZDecrypt(Inputs r,const unsigned char *c,const unsigned char *sk)
1133 {
1134   small a[p];
1135   Fq B[p];
1136   int8 T[I];
1137 
1138   Small_decode(a,sk);
1139   Rounded_decode(B,c);
1140   Top_decode(T,c+Rounded_bytes);
1141   XDecrypt(r,B,T,a);
1142 }
1143 
1144 #endif
1145 
1146 /* ----- confirmation hash */
1147 
1148 #define Confirm_bytes 32
1149 
1150 /* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
1151 static void HashConfirm(unsigned char *h,const unsigned char *r,const unsigned char *pk,const unsigned char *cache)
1152 {
1153 #ifndef LPR
1154   unsigned char x[Hash_bytes*2];
1155   int i;
1156 
1157   Hash_prefix(x,3,r,Inputs_bytes);
1158   for (i = 0;i < Hash_bytes;++i) x[Hash_bytes+i] = cache[i];
1159 #else
1160   unsigned char x[Inputs_bytes+Hash_bytes];
1161   int i;
1162 
1163   for (i = 0;i < Inputs_bytes;++i) x[i] = r[i];
1164   for (i = 0;i < Hash_bytes;++i) x[Inputs_bytes+i] = cache[i];
1165 #endif
1166   Hash_prefix(h,2,x,sizeof x);
1167 }
1168 
1169 /* ----- session-key hash */
1170 
1171 /* k = HashSession(b,y,z) */
1172 static void HashSession(unsigned char *k,int b,const unsigned char *y,const unsigned char *z)
1173 {
1174 #ifndef LPR
1175   unsigned char x[Hash_bytes+Ciphertexts_bytes+Confirm_bytes];
1176   int i;
1177 
1178   Hash_prefix(x,3,y,Inputs_bytes);
1179   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Hash_bytes+i] = z[i];
1180 #else
1181   unsigned char x[Inputs_bytes+Ciphertexts_bytes+Confirm_bytes];
1182   int i;
1183 
1184   for (i = 0;i < Inputs_bytes;++i) x[i] = y[i];
1185   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Inputs_bytes+i] = z[i];
1186 #endif
1187   Hash_prefix(k,b,x,sizeof x);
1188 }
1189 
1190 /* ----- Streamlined NTRU Prime and NTRU LPRime */
1191 
1192 /* pk,sk = KEM_KeyGen() */
1193 static void KEM_KeyGen(unsigned char *pk,unsigned char *sk)
1194 {
1195   int i;
1196 
1197   ZKeyGen(pk,sk); sk += SecretKeys_bytes;
1198   for (i = 0;i < PublicKeys_bytes;++i) *sk++ = pk[i];
1199   randombytes(sk,Inputs_bytes); sk += Inputs_bytes;
1200   Hash_prefix(sk,4,pk,PublicKeys_bytes);
1201 }
1202 
1203 /* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
1204 static void Hide(unsigned char *c,unsigned char *r_enc,const Inputs r,const unsigned char *pk,const unsigned char *cache)
1205 {
1206   Inputs_encode(r_enc,r);
1207   ZEncrypt(c,r,pk); c += Ciphertexts_bytes;
1208   HashConfirm(c,r_enc,pk,cache);
1209 }
1210 
1211 /* c,k = Encap(pk) */
1212 static void Encap(unsigned char *c,unsigned char *k,const unsigned char *pk)
1213 {
1214   Inputs r;
1215   unsigned char r_enc[Inputs_bytes];
1216   unsigned char cache[Hash_bytes];
1217 
1218   Hash_prefix(cache,4,pk,PublicKeys_bytes);
1219   Inputs_random(r);
1220   Hide(c,r_enc,r,pk,cache);
1221   HashSession(k,1,r_enc,c);
1222 }
1223 
1224 /* 0 if matching ciphertext+confirm, else -1 */
1225 static int Ciphertexts_diff_mask(const unsigned char *c,const unsigned char *c2)
1226 {
1227   uint16 differentbits = 0;
1228   int len = Ciphertexts_bytes+Confirm_bytes;
1229 
1230   while (len-- > 0) differentbits |= (*c++)^(*c2++);
1231   return (1&((differentbits-1)>>8))-1;
1232 }
1233 
1234 /* k = Decap(c,sk) */
1235 static void Decap(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1236 {
1237   const unsigned char *pk = sk + SecretKeys_bytes;
1238   const unsigned char *rho = pk + PublicKeys_bytes;
1239   const unsigned char *cache = rho + Inputs_bytes;
1240   Inputs r;
1241   unsigned char r_enc[Inputs_bytes];
1242   unsigned char cnew[Ciphertexts_bytes+Confirm_bytes];
1243   int mask;
1244   int i;
1245 
1246   ZDecrypt(r,c,sk);
1247   Hide(cnew,r_enc,r,pk,cache);
1248   mask = Ciphertexts_diff_mask(c,cnew);
1249   for (i = 0;i < Inputs_bytes;++i) r_enc[i] ^= mask&(r_enc[i]^rho[i]);
1250   HashSession(k,1+mask,r_enc,c);
1251 }
1252 
1253 /* ----- crypto_kem API */
1254 
1255 
1256 int crypto_kem_sntrup761_keypair(unsigned char *pk,unsigned char *sk)
1257 {
1258   KEM_KeyGen(pk,sk);
1259   return 0;
1260 }
1261 
1262 int crypto_kem_sntrup761_enc(unsigned char *c,unsigned char *k,const unsigned char *pk)
1263 {
1264   Encap(c,k,pk);
1265   return 0;
1266 }
1267 
1268 int crypto_kem_sntrup761_dec(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1269 {
1270   Decap(k,c,sk);
1271   return 0;
1272 }
1273 #endif /* USE_SNTRUP761X25519 */
1274