xref: /freebsd/crypto/openssl/crypto/ec/curve448/scalar.c (revision f25b8c9fb4f58cf61adb47d7570abe7caa6d385d)
1 /*
2  * Copyright 2017-2021 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright 2015-2016 Cryptography Research, Inc.
4  *
5  * Licensed under the Apache License 2.0 (the "License").  You may not use
6  * this file except in compliance with the License.  You can obtain a copy
7  * in the file LICENSE in the source distribution or at
8  * https://www.openssl.org/source/license.html
9  *
10  * Originally written by Mike Hamburg
11  */
12 #include <openssl/crypto.h>
13 
14 #include "word.h"
15 #include "point_448.h"
16 
17 static const c448_word_t MONTGOMERY_FACTOR = (c448_word_t)0x3bd440fae918bc5ULL;
18 static const curve448_scalar_t sc_p = {
19     { { SC_LIMB(0x2378c292ab5844f3ULL), SC_LIMB(0x216cc2728dc58f55ULL),
20         SC_LIMB(0xc44edb49aed63690ULL), SC_LIMB(0xffffffff7cca23e9ULL),
21         SC_LIMB(0xffffffffffffffffULL), SC_LIMB(0xffffffffffffffffULL),
22         SC_LIMB(0x3fffffffffffffffULL) } }
23 },
24                                sc_r2 = { { {
25 
26                                    SC_LIMB(0xe3539257049b9b60ULL), SC_LIMB(0x7af32c4bc1b195d9ULL), SC_LIMB(0x0d66de2388ea1859ULL), SC_LIMB(0xae17cf725ee4d838ULL), SC_LIMB(0x1a9cc14ba3c47c44ULL), SC_LIMB(0x2052bcb7e4d070afULL), SC_LIMB(0x3402a939f823b729ULL) } } };
27 
28 #define WBITS C448_WORD_BITS /* NB this may be different from ARCH_WORD_BITS */
29 
30 const curve448_scalar_t ossl_curve448_scalar_one = { { { 1 } } };
31 const curve448_scalar_t ossl_curve448_scalar_zero = { { { 0 } } };
32 
33 /*
34  * {extra,accum} - sub +? p
35  * Must have extra <= 1
36  */
sc_subx(curve448_scalar_t out,const c448_word_t accum[C448_SCALAR_LIMBS],const curve448_scalar_t sub,const curve448_scalar_t p,c448_word_t extra)37 static void sc_subx(curve448_scalar_t out,
38     const c448_word_t accum[C448_SCALAR_LIMBS],
39     const curve448_scalar_t sub,
40     const curve448_scalar_t p, c448_word_t extra)
41 {
42     c448_dsword_t chain = 0;
43     unsigned int i;
44     c448_word_t borrow;
45 
46     for (i = 0; i < C448_SCALAR_LIMBS; i++) {
47         chain = (chain + accum[i]) - sub->limb[i];
48         out->limb[i] = (c448_word_t)chain;
49         chain >>= WBITS;
50     }
51     borrow = (c448_word_t)chain + extra; /* = 0 or -1 */
52 
53     chain = 0;
54     for (i = 0; i < C448_SCALAR_LIMBS; i++) {
55         chain = (chain + out->limb[i]) + (p->limb[i] & borrow);
56         out->limb[i] = (c448_word_t)chain;
57         chain >>= WBITS;
58     }
59 }
60 
sc_montmul(curve448_scalar_t out,const curve448_scalar_t a,const curve448_scalar_t b)61 static void sc_montmul(curve448_scalar_t out, const curve448_scalar_t a,
62     const curve448_scalar_t b)
63 {
64     unsigned int i, j;
65     c448_word_t accum[C448_SCALAR_LIMBS + 1] = { 0 };
66     c448_word_t hi_carry = 0;
67 
68     for (i = 0; i < C448_SCALAR_LIMBS; i++) {
69         c448_word_t mand = a->limb[i];
70         const c448_word_t *mier = b->limb;
71 
72         c448_dword_t chain = 0;
73         for (j = 0; j < C448_SCALAR_LIMBS; j++) {
74             chain += ((c448_dword_t)mand) * mier[j] + accum[j];
75             accum[j] = (c448_word_t)chain;
76             chain >>= WBITS;
77         }
78         accum[j] = (c448_word_t)chain;
79 
80         mand = accum[0] * MONTGOMERY_FACTOR;
81         chain = 0;
82         mier = sc_p->limb;
83         for (j = 0; j < C448_SCALAR_LIMBS; j++) {
84             chain += (c448_dword_t)mand * mier[j] + accum[j];
85             if (j)
86                 accum[j - 1] = (c448_word_t)chain;
87             chain >>= WBITS;
88         }
89         chain += accum[j];
90         chain += hi_carry;
91         accum[j - 1] = (c448_word_t)chain;
92         hi_carry = chain >> WBITS;
93     }
94 
95     sc_subx(out, accum, sc_p, sc_p, hi_carry);
96 }
97 
ossl_curve448_scalar_mul(curve448_scalar_t out,const curve448_scalar_t a,const curve448_scalar_t b)98 void ossl_curve448_scalar_mul(curve448_scalar_t out, const curve448_scalar_t a,
99     const curve448_scalar_t b)
100 {
101     sc_montmul(out, a, b);
102     sc_montmul(out, out, sc_r2);
103 }
104 
ossl_curve448_scalar_sub(curve448_scalar_t out,const curve448_scalar_t a,const curve448_scalar_t b)105 void ossl_curve448_scalar_sub(curve448_scalar_t out, const curve448_scalar_t a,
106     const curve448_scalar_t b)
107 {
108     sc_subx(out, a->limb, b, sc_p, 0);
109 }
110 
ossl_curve448_scalar_add(curve448_scalar_t out,const curve448_scalar_t a,const curve448_scalar_t b)111 void ossl_curve448_scalar_add(curve448_scalar_t out, const curve448_scalar_t a,
112     const curve448_scalar_t b)
113 {
114     c448_dword_t chain = 0;
115     unsigned int i;
116 
117     for (i = 0; i < C448_SCALAR_LIMBS; i++) {
118         chain = (chain + a->limb[i]) + b->limb[i];
119         out->limb[i] = (c448_word_t)chain;
120         chain >>= WBITS;
121     }
122     sc_subx(out, out->limb, sc_p, sc_p, (c448_word_t)chain);
123 }
124 
scalar_decode_short(curve448_scalar_t s,const unsigned char * ser,size_t nbytes)125 static ossl_inline void scalar_decode_short(curve448_scalar_t s,
126     const unsigned char *ser,
127     size_t nbytes)
128 {
129     size_t i, j, k = 0;
130 
131     for (i = 0; i < C448_SCALAR_LIMBS; i++) {
132         c448_word_t out = 0;
133 
134         for (j = 0; j < sizeof(c448_word_t) && k < nbytes; j++, k++)
135             out |= ((c448_word_t)ser[k]) << (8 * j);
136         s->limb[i] = out;
137     }
138 }
139 
140 c448_error_t
ossl_curve448_scalar_decode(curve448_scalar_t s,const unsigned char ser[C448_SCALAR_BYTES])141 ossl_curve448_scalar_decode(curve448_scalar_t s,
142     const unsigned char ser[C448_SCALAR_BYTES])
143 {
144     unsigned int i;
145     c448_dsword_t accum = 0;
146 
147     scalar_decode_short(s, ser, C448_SCALAR_BYTES);
148     for (i = 0; i < C448_SCALAR_LIMBS; i++)
149         accum = (accum + s->limb[i] - sc_p->limb[i]) >> WBITS;
150     /* Here accum == 0 or -1 */
151 
152     ossl_curve448_scalar_mul(s, s, ossl_curve448_scalar_one); /* ham-handed reduce */
153 
154     return c448_succeed_if(~word_is_zero((uint32_t)accum));
155 }
156 
ossl_curve448_scalar_destroy(curve448_scalar_t scalar)157 void ossl_curve448_scalar_destroy(curve448_scalar_t scalar)
158 {
159     OPENSSL_cleanse(scalar, sizeof(curve448_scalar_t));
160 }
161 
ossl_curve448_scalar_decode_long(curve448_scalar_t s,const unsigned char * ser,size_t ser_len)162 void ossl_curve448_scalar_decode_long(curve448_scalar_t s,
163     const unsigned char *ser, size_t ser_len)
164 {
165     size_t i;
166     curve448_scalar_t t1, t2;
167 
168     if (ser_len == 0) {
169         curve448_scalar_copy(s, ossl_curve448_scalar_zero);
170         return;
171     }
172 
173     i = ser_len - (ser_len % C448_SCALAR_BYTES);
174     if (i == ser_len)
175         i -= C448_SCALAR_BYTES;
176 
177     scalar_decode_short(t1, &ser[i], ser_len - i);
178 
179     if (ser_len == sizeof(curve448_scalar_t)) {
180         assert(i == 0);
181         /* ham-handed reduce */
182         ossl_curve448_scalar_mul(s, t1, ossl_curve448_scalar_one);
183         ossl_curve448_scalar_destroy(t1);
184         return;
185     }
186 
187     while (i) {
188         i -= C448_SCALAR_BYTES;
189         sc_montmul(t1, t1, sc_r2);
190         (void)ossl_curve448_scalar_decode(t2, ser + i);
191         ossl_curve448_scalar_add(t1, t1, t2);
192     }
193 
194     curve448_scalar_copy(s, t1);
195     ossl_curve448_scalar_destroy(t1);
196     ossl_curve448_scalar_destroy(t2);
197 }
198 
ossl_curve448_scalar_encode(unsigned char ser[C448_SCALAR_BYTES],const curve448_scalar_t s)199 void ossl_curve448_scalar_encode(unsigned char ser[C448_SCALAR_BYTES],
200     const curve448_scalar_t s)
201 {
202     unsigned int i, j, k = 0;
203 
204     for (i = 0; i < C448_SCALAR_LIMBS; i++) {
205         for (j = 0; j < sizeof(c448_word_t); j++, k++)
206             ser[k] = s->limb[i] >> (8 * j);
207     }
208 }
209 
ossl_curve448_scalar_halve(curve448_scalar_t out,const curve448_scalar_t a)210 void ossl_curve448_scalar_halve(curve448_scalar_t out, const curve448_scalar_t a)
211 {
212     c448_word_t mask = 0 - (a->limb[0] & 1);
213     c448_dword_t chain = 0;
214     unsigned int i;
215 
216     for (i = 0; i < C448_SCALAR_LIMBS; i++) {
217         chain = (chain + a->limb[i]) + (sc_p->limb[i] & mask);
218         out->limb[i] = (c448_word_t)chain;
219         chain >>= C448_WORD_BITS;
220     }
221     for (i = 0; i < C448_SCALAR_LIMBS - 1; i++)
222         out->limb[i] = out->limb[i] >> 1 | out->limb[i + 1] << (WBITS - 1);
223     out->limb[i] = out->limb[i] >> 1 | (c448_word_t)(chain << (WBITS - 1));
224 }
225