xref: /freebsd/crypto/openssl/crypto/ml_kem/ml_kem.c (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <openssl/byteorder.h>
11 #include <openssl/rand.h>
12 #include <openssl/proverr.h>
13 #include "crypto/ml_kem.h"
14 #include "internal/common.h"
15 #include "internal/constant_time.h"
16 #include "internal/sha3.h"
17 
18 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19 #include <valgrind/memcheck.h>
20 #endif
21 
22 #if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23 # error "ML-KEM keygen seed length != shared secret + random bytes length"
24 #endif
25 #if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26 # error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27 #endif
28 
29 #if UINT_MAX < UINT32_MAX
30 # error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31 #endif
32 
33 /* Handy function-like bit-extraction macros */
34 #define bit0(b)     ((b) & 1)
35 #define bitn(n, b)  (((b) >> n) & 1)
36 
37 /*
38  * 12 bits are sufficient to losslessly represent values in [0, q-1].
39  * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40  */
41 #define DEGREE          ML_KEM_DEGREE
42 #define INVERSE_DEGREE  (ML_KEM_PRIME - 2 * 13)
43 #define LOG2PRIME       12
44 #define BARRETT_SHIFT   (2 * LOG2PRIME)
45 
46 #ifdef SHA3_BLOCKSIZE
47 # define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48 #endif
49 
50 /*
51  * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52  * in practice!  The return value is a mask that is all ones if true, and all
53  * zeros otherwise (twos-complement arithmentic assumed for unsigned values).
54  *
55  * Although this is used in constant-time selects, we omit a value barrier
56  * here.  Value barriers impede auto-vectorization (likely because it forces
57  * the value to transit through a general-purpose register). On AArch64, this
58  * is a difference of 2x.
59  *
60  * We usually add value barriers to selects because Clang turns consecutive
61  * selects with the same condition into a branch instead of CMOV/CSEL. This
62  * condition does not occur in Kyber, so omitting it seems to be safe so far,
63  * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64  * sign of the input, rather than adding |q| in advance, and using the generic
65  * |reduce_once|.  (David Benjamin, Chromium)
66  */
67 #if 0
68 # define constish_time_non_zero(b) (~constant_time_is_zero(b));
69 #else
70 # define constish_time_non_zero(b) (0u - (b))
71 #endif
72 
73 /*
74  * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75  * is otherwise arbitrary, the preferred block size matches the internal buffer
76  * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77  * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78  *
79  * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80  * fallback.
81  */
82 #if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83 # define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84 #else
85 # define SCALAR_SAMPLING_BUFSIZE 168
86 #endif
87 
88 /*
89  * Structure of keys
90  */
91 typedef struct ossl_ml_kem_scalar_st {
92     /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93     uint16_t c[ML_KEM_DEGREE];
94 } scalar;
95 
96 /* Key material allocation layout */
97 #define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \
98     struct name##_alloc { \
99         /* Public vector |t| */ \
100         scalar tbuf[(rank)]; \
101         /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102         scalar mbuf[(rank)*(rank)] \
103         /* optional private key data */ \
104         private_sz \
105     }
106 
107 /* Declare variant-specific public and private storage */
108 #define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
109     DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
110     DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
111         scalar sbuf[ML_KEM_##bits##_RANK]; \
112         uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
113 DECLARE_ML_KEM_VARIANT_KEYDATA(512);
114 DECLARE_ML_KEM_VARIANT_KEYDATA(768);
115 DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
116 #undef DECLARE_ML_KEM_VARIANT_KEYDATA
117 #undef DECLARE_ML_KEM_KEYDATA
118 
119 typedef __owur
120 int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
121                 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
122 static void scalar_encode(uint8_t *out, const scalar *s, int bits);
123 
124 /*
125  * The wire-form of a losslessly encoded vector uses 12-bits per element.
126  *
127  * The wire-form public key consists of the lossless encoding of the public
128  * vector |t|, followed by the public seed |rho|.
129  *
130  * Our serialised private key concatenates serialisations of the private vector
131  * |s|, the public key, the public key hash, and the failure secret |z|.
132  */
133 #define VECTOR_BYTES(b)     ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
134 #define PUBKEY_BYTES(b)     (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
135 #define PRVKEY_BYTES(b)     (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
136 
137 /*
138  * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
139  * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
140  * "dv" bits, respectively.  This encoding is the ciphertext input for
141  * decapsulation.
142  */
143 #define U_VECTOR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
144 #define V_SCALAR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DV)
145 #define CTEXT_BYTES(b)      (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
146 
147 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
148 
149 /*
150  * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
151  * of memory as secret. Secret data is tracked as it flows to registers and
152  * other parts of a memory. If secret data is used as a condition for a branch,
153  * or as a memory index, it will trigger warnings in valgrind.
154  */
155 # define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
156 
157 /*
158  * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
159  * region of memory as public. Public data is not subject to constant-time
160  * rules.
161  */
162 # define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
163 
164 #else
165 
166 # define CONSTTIME_SECRET(ptr, len)
167 # define CONSTTIME_DECLASSIFY(ptr, len)
168 
169 #endif
170 
171 /*
172  * Indices of slots in the vinfo tables below
173  */
174 #define ML_KEM_512_VINFO    0
175 #define ML_KEM_768_VINFO    1
176 #define ML_KEM_1024_VINFO   2
177 
178 /*
179  * Per-variant fixed parameters
180  */
181 static const ML_KEM_VINFO vinfo_map[3] = {
182     {
183         "ML-KEM-512",
184         PRVKEY_BYTES(512),
185         sizeof(struct prvkey_512_alloc),
186         PUBKEY_BYTES(512),
187         sizeof(struct pubkey_512_alloc),
188         CTEXT_BYTES(512),
189         VECTOR_BYTES(512),
190         U_VECTOR_BYTES(512),
191         EVP_PKEY_ML_KEM_512,
192         ML_KEM_512_BITS,
193         ML_KEM_512_RANK,
194         ML_KEM_512_DU,
195         ML_KEM_512_DV,
196         ML_KEM_512_SECBITS
197     },
198     {
199         "ML-KEM-768",
200         PRVKEY_BYTES(768),
201         sizeof(struct prvkey_768_alloc),
202         PUBKEY_BYTES(768),
203         sizeof(struct pubkey_768_alloc),
204         CTEXT_BYTES(768),
205         VECTOR_BYTES(768),
206         U_VECTOR_BYTES(768),
207         EVP_PKEY_ML_KEM_768,
208         ML_KEM_768_BITS,
209         ML_KEM_768_RANK,
210         ML_KEM_768_DU,
211         ML_KEM_768_DV,
212         ML_KEM_768_SECBITS
213     },
214     {
215         "ML-KEM-1024",
216         PRVKEY_BYTES(1024),
217         sizeof(struct prvkey_1024_alloc),
218         PUBKEY_BYTES(1024),
219         sizeof(struct pubkey_1024_alloc),
220         CTEXT_BYTES(1024),
221         VECTOR_BYTES(1024),
222         U_VECTOR_BYTES(1024),
223         EVP_PKEY_ML_KEM_1024,
224         ML_KEM_1024_BITS,
225         ML_KEM_1024_RANK,
226         ML_KEM_1024_DU,
227         ML_KEM_1024_DV,
228         ML_KEM_1024_SECBITS
229     }
230 };
231 
232 /*
233  * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
234  * constant time via Barrett reduction, and a final call to reduce_once(),
235  * which reduces inputs that are at most 2*kPrime and is also constant-time.
236  */
237 static const int kPrime = ML_KEM_PRIME;
238 static const unsigned int kBarrettShift = BARRETT_SHIFT;
239 static const size_t   kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
240 static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
241 static const uint16_t kInverseDegree = INVERSE_DEGREE;
242 
243 /*
244  * Python helper:
245  *
246  * p = 3329
247  * def bitreverse(i):
248  *     ret = 0
249  *     for n in range(7):
250  *         bit = i & 1
251  *         ret <<= 1
252  *         ret |= bit
253  *         i >>= 1
254  *     return ret
255  */
256 
257 /*-
258  * First precomputed array from Appendix A of FIPS 203, or else Python:
259  * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
260  */
261 static const uint16_t kNTTRoots[128] = {
262     1,    1729, 2580, 3289, 2642, 630,  1897, 848,
263     1062, 1919, 193,  797,  2786, 3260, 569,  1746,
264     296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
265     1426, 2094, 535,  2882, 2393, 2879, 1974, 821,
266     289,  331,  3253, 1756, 1197, 2304, 2277, 2055,
267     650,  1977, 2513, 632,  2865, 33,   1320, 1915,
268     2319, 1435, 807,  452,  1438, 2868, 1534, 2402,
269     2647, 2617, 1481, 648,  2474, 3110, 1227, 910,
270     17,   2761, 583,  2649, 1637, 723,  2288, 1100,
271     1409, 2662, 3281, 233,  756,  2156, 3015, 3050,
272     1703, 1651, 2789, 1789, 1847, 952,  1461, 2687,
273     939,  2308, 2437, 2388, 733,  2337, 268,  641,
274     1584, 2298, 2037, 3220, 375,  2549, 2090, 1645,
275     1063, 319,  2773, 757,  2099, 561,  2466, 2594,
276     2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
277     1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
278 };
279 
280 /*
281  * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
282  * Listed in order of use in the inverse NTT loop (index 0 is skipped):
283  *
284  *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
285  */
286 static const uint16_t kInverseNTTRoots[128] = {
287     1,    1175, 2444, 394,  1219, 2300, 1455, 2117,
288     1607, 2443, 554,  1179, 2186, 2303, 2926, 2237,
289     525,  735,  863,  2768, 1230, 2572, 556,  3010,
290     2266, 1684, 1239, 780,  2954, 109,  1292, 1031,
291     1745, 2688, 3061, 992,  2596, 941,  892,  1021,
292     2390, 642,  1868, 2377, 1482, 1540, 540,  1678,
293     1626, 279,  314,  1173, 2573, 3096, 48,   667,
294     1920, 2229, 1041, 2606, 1692, 680,  2746, 568,
295     3312, 2419, 2102, 219,  855,  2681, 1848, 712,
296     682,  927,  1795, 461,  1891, 2877, 2522, 1894,
297     1010, 1414, 2009, 3296, 464,  2697, 816,  1352,
298     2679, 1274, 1052, 1025, 2132, 1573, 76,   2998,
299     3040, 2508, 1355, 450,  936,  447,  2794, 1235,
300     1903, 1996, 1089, 3273, 283,  1853, 1990, 882,
301     3033, 1583, 2760, 69,   543,  2532, 3136, 1410,
302     2267, 2481, 1432, 2699, 687,  40,   749,  1600,
303 };
304 
305 /*
306  * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
307  * or else Python:
308  * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
309  */
310 static const uint16_t kModRoots[128] = {
311     17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
312     2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
313     756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
314     2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
315     939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
316     268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
317     375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
318     2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
319     2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
320     2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
321     2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
322 };
323 
324 /*
325  * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
326  * output to |out|. If the |md| specifies a fixed-output function, like
327  * SHA3-256, then |outlen| must be the correct length for that function.
328  */
329 static __owur
single_keccak(uint8_t * out,size_t outlen,const uint8_t * in,size_t inlen,EVP_MD_CTX * mdctx)330 int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
331                   EVP_MD_CTX *mdctx)
332 {
333     unsigned int sz = (unsigned int) outlen;
334 
335     if (!EVP_DigestUpdate(mdctx, in, inlen))
336         return 0;
337     if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
338         return EVP_DigestFinalXOF(mdctx, out, outlen);
339     return EVP_DigestFinal_ex(mdctx, out, &sz)
340         && ossl_assert((size_t) sz == outlen);
341 }
342 
343 /*
344  * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
345  * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
346  */
347 static __owur
prf(uint8_t * out,size_t len,const uint8_t in[ML_KEM_RANDOM_BYTES+1],EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)348 int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
349         EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
350 {
351     return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
352         && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
353 }
354 
355 /*
356  * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
357  * length input, producing 32 bytes of output.
358  */
359 static __owur
hash_h(uint8_t out[ML_KEM_PKHASH_BYTES],const uint8_t * in,size_t len,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)360 int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
361            EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
362 {
363     return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
364         && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
365 }
366 
367 /* Incremental hash_h of expanded public key */
368 static int
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],EVP_MD_CTX * mdctx,ML_KEM_KEY * key)369 hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
370               EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
371 {
372     const ML_KEM_VINFO *vinfo = key->vinfo;
373     const scalar *t = key->t, *end = t + vinfo->rank;
374     unsigned int sz;
375 
376     if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
377         return 0;
378 
379     do {
380         uint8_t buf[3 * DEGREE / 2];
381 
382         scalar_encode(buf, t++, 12);
383         if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
384             return 0;
385     } while (t < end);
386 
387     if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
388         return 0;
389     return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
390         && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
391 }
392 
393 /*
394  * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
395  * length input, producing 64 bytes of output, in particular the seeds
396  * (d,z) for key generation.
397  */
398 static __owur
hash_g(uint8_t out[ML_KEM_SEED_BYTES],const uint8_t * in,size_t len,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)399 int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
400            EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
401 {
402     return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
403         && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
404 }
405 
406 /*
407  * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
408  * input to compute a 32-byte implicit rejection shared secret, of the same
409  * length as the expected shared secret.  (Computed even on success to avoid
410  * side-channel leaks).
411  */
412 static __owur
kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],const uint8_t z[ML_KEM_RANDOM_BYTES],const uint8_t * ctext,size_t len,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)413 int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
414         const uint8_t z[ML_KEM_RANDOM_BYTES],
415         const uint8_t *ctext, size_t len,
416         EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
417 {
418     return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
419         && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
420         && EVP_DigestUpdate(mdctx, ctext, len)
421         && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
422 }
423 
424 /*
425  * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
426  * are performed by the caller). Rejection-samples a Keccak stream to get
427  * uniformly distributed elements in the range [0,q). This is used for matrix
428  * expansion and only operates on public inputs.
429  */
430 static __owur
sample_scalar(scalar * out,EVP_MD_CTX * mdctx)431 int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
432 {
433     uint16_t *curr = out->c, *endout = curr + DEGREE;
434     uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
435     uint8_t *endin = buf + sizeof(buf);
436     uint16_t d;
437     uint8_t b1, b2, b3;
438 
439     do {
440         if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
441             return 0;
442         do {
443             b1 = *in++;
444             b2 = *in++;
445             b3 = *in++;
446 
447             if (curr >= endout)
448                 break;
449             if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
450                 *curr++ = d;
451             if (curr >= endout)
452                 break;
453             if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
454                 *curr++ = d;
455         } while (in < endin);
456     } while (curr < endout);
457     return 1;
458 }
459 
460 /*-
461  * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
462  *
463  * Subtract |q| if the input is larger, without exposing a side-channel,
464  * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
465  * discussion on why the value barrier is by default omitted.
466  */
reduce_once(uint16_t x)467 static __owur uint16_t reduce_once(uint16_t x)
468 {
469     const uint16_t subtracted = x - kPrime;
470     uint16_t mask = constish_time_non_zero(subtracted >> 15);
471 
472     return (mask & x) | (~mask & subtracted);
473 }
474 
475 /*
476  * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
477  * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
478  * two already reduced u_int16 values, in fact it is sufficient for each
479  * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
480  */
reduce(uint32_t x)481 static __owur uint16_t reduce(uint32_t x)
482 {
483     uint64_t product = (uint64_t)x * kBarrettMultiplier;
484     uint32_t quotient = (uint32_t)(product >> kBarrettShift);
485     uint32_t remainder = x - quotient * kPrime;
486 
487     return reduce_once(remainder);
488 }
489 
490 /* Multiply a scalar by a constant. */
scalar_mult_const(scalar * s,uint16_t a)491 static void scalar_mult_const(scalar *s, uint16_t a)
492 {
493     uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
494 
495     do {
496         tmp = reduce(*curr * a);
497         *curr++ = tmp;
498     } while (curr < end);
499 }
500 
501 /*-
502  * FIPS 203, Section 4.3, Algoritm 9: "NTT".
503  * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
504  * kPrime 3329 does not have a 512th root of unity, so this transform leaves
505  * off the last iteration of the usual FFT code, with the 128 relevant roots of
506  * unity being stored in NTTRoots.  This means the output should be seen as 128
507  * elements in GF(3329^2), with the coefficients of the elements being
508  * consecutive entries in |s->c|.
509  */
scalar_ntt(scalar * s)510 static void scalar_ntt(scalar *s)
511 {
512     const uint16_t *roots = kNTTRoots;
513     uint16_t *end = s->c + DEGREE;
514     int offset = DEGREE / 2;
515 
516     do {
517         uint16_t *curr = s->c, *peer;
518 
519         do {
520             uint16_t *pause = curr + offset, even, odd;
521             uint32_t zeta = *++roots;
522 
523             peer = pause;
524             do {
525                 even = *curr;
526                 odd = reduce(*peer * zeta);
527                 *peer++ = reduce_once(even - odd + kPrime);
528                 *curr++ = reduce_once(odd + even);
529             } while (curr < pause);
530         } while ((curr = peer) < end);
531     } while ((offset >>= 1) >= 2);
532 }
533 
534 /*-
535  * FIPS 203, Section 4.3, Algoritm 10: "NTT^(-1)".
536  * In-place inverse number theoretic transform of a given scalar, with pairs of
537  * entries of s->v being interpreted as elements of GF(3329^2). Just as with
538  * the number theoretic transform, this leaves off the first step of the normal
539  * iFFT to account for the fact that 3329 does not have a 512th root of unity,
540  * using the precomputed 128 roots of unity stored in InverseNTTRoots.
541  */
scalar_inverse_ntt(scalar * s)542 static void scalar_inverse_ntt(scalar *s)
543 {
544     const uint16_t *roots = kInverseNTTRoots;
545     uint16_t *end = s->c + DEGREE;
546     int offset = 2;
547 
548     do {
549         uint16_t *curr = s->c, *peer;
550 
551         do {
552             uint16_t *pause = curr + offset, even, odd;
553             uint32_t zeta = *++roots;
554 
555             peer = pause;
556             do {
557                 even = *curr;
558                 odd = *peer;
559                 *peer++ = reduce(zeta * (even - odd + kPrime));
560                 *curr++ = reduce_once(odd + even);
561             } while (curr < pause);
562         } while ((curr = peer) < end);
563     } while ((offset <<= 1) < DEGREE);
564     scalar_mult_const(s, kInverseDegree);
565 }
566 
567 /* Addition updating the LHS scalar in-place. */
scalar_add(scalar * lhs,const scalar * rhs)568 static void scalar_add(scalar *lhs, const scalar *rhs)
569 {
570     int i;
571 
572     for (i = 0; i < DEGREE; i++)
573         lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
574 }
575 
576 /* Subtraction updating the LHS scalar in-place. */
scalar_sub(scalar * lhs,const scalar * rhs)577 static void scalar_sub(scalar *lhs, const scalar *rhs)
578 {
579     int i;
580 
581     for (i = 0; i < DEGREE; i++)
582         lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
583 }
584 
585 /*
586  * Multiplying two scalars in the number theoretically transformed state. Since
587  * 3329 does not have a 512th root of unity, this means we have to interpret
588  * the 2*ith and (2*i+1)th entries of the scalar as elements of
589  * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
590  *
591  * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
592  * ModRoots table. Note that our Barrett transform only allows us to multipy
593  * two reduced numbers together, so we need some intermediate reduction steps,
594  * even if an uint64_t could hold 3 multiplied numbers.
595  */
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)596 static void scalar_mult(scalar *out, const scalar *lhs,
597                         const scalar *rhs)
598 {
599     uint16_t *curr = out->c, *end = curr + DEGREE;
600     const uint16_t *lc = lhs->c, *rc = rhs->c;
601     const uint16_t *roots = kModRoots;
602 
603     do {
604         uint32_t l0 = *lc++, r0 = *rc++;
605         uint32_t l1 = *lc++, r1 = *rc++;
606         uint32_t zetapow = *roots++;
607 
608         *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
609         *curr++ = reduce(l0 * r1 + l1 * r0);
610     } while (curr < end);
611 }
612 
613 /* Above, but add the result to an existing scalar */
614 static ossl_inline
scalar_mult_add(scalar * out,const scalar * lhs,const scalar * rhs)615 void scalar_mult_add(scalar *out, const scalar *lhs,
616                      const scalar *rhs)
617 {
618     uint16_t *curr = out->c, *end = curr + DEGREE;
619     const uint16_t *lc = lhs->c, *rc = rhs->c;
620     const uint16_t *roots = kModRoots;
621 
622     do {
623         uint32_t l0 = *lc++, r0 = *rc++;
624         uint32_t l1 = *lc++, r1 = *rc++;
625         uint16_t *c0 = curr++;
626         uint16_t *c1 = curr++;
627         uint32_t zetapow = *roots++;
628 
629         *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
630         *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
631     } while (curr < end);
632 }
633 
634 /*-
635  * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
636  * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
637  */
scalar_encode(uint8_t * out,const scalar * s,int bits)638 static void scalar_encode(uint8_t *out, const scalar *s, int bits)
639 {
640     const uint16_t *curr = s->c, *end = curr + DEGREE;
641     uint64_t accum = 0, element;
642     int used = 0;
643 
644     do {
645         element = *curr++;
646         if (used + bits < 64) {
647             accum |= element << used;
648             used += bits;
649         } else if (used + bits > 64) {
650             out = OPENSSL_store_u64_le(out, accum | (element << used));
651             accum = element >> (64 - used);
652             used = (used + bits) - 64;
653         } else {
654             out = OPENSSL_store_u64_le(out, accum | (element << used));
655             accum = 0;
656             used = 0;
657         }
658     } while (curr < end);
659 }
660 
661 /*
662  * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
663  */
scalar_encode_1(uint8_t out[DEGREE/8],const scalar * s)664 static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
665 {
666     int i, j;
667     uint8_t out_byte;
668 
669     for (i = 0; i < DEGREE; i += 8) {
670         out_byte = 0;
671         for (j = 0; j < 8; j++)
672             out_byte |= bit0(s->c[i + j]) << j;
673         *out = out_byte;
674         out++;
675     }
676 }
677 
678 /*-
679  * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
680  * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
681  * separately.
682  *
683  * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
684  * |out|.
685  */
scalar_decode(scalar * out,const uint8_t * in,int bits)686 static void scalar_decode(scalar *out, const uint8_t *in, int bits)
687 {
688     uint16_t *curr = out->c, *end = curr + DEGREE;
689     uint64_t accum = 0;
690     int accum_bits = 0, todo = bits;
691     uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
692     uint16_t element = 0;
693 
694     do {
695         if (accum_bits == 0) {
696             in = OPENSSL_load_u64_le(&accum, in);
697             accum_bits = 64;
698         }
699         if (todo == bits && accum_bits >= bits) {
700             /* No partial "element", and all the required bits available */
701             *curr++ = ((uint16_t) accum) & mask;
702             accum >>= bits;
703             accum_bits -= bits;
704         } else if (accum_bits >= todo) {
705             /* A partial "element", and all the required bits available */
706             *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
707             accum >>= todo;
708             accum_bits -= todo;
709             element = 0;
710             todo = bits;
711             mask = bitmask;
712         } else {
713             /*
714              * Only some of the requisite bits accumulated, store |accum_bits|
715              * of these in |element|.  The accumulated bitcount becomes 0, but
716              * as soon as we have more bits we'll want to merge accum_bits
717              * fewer of them into the final |element|.
718              *
719              * Note that with a 64-bit accumulator and |bits| always 12 or
720              * less, if we're here, the previous iteration had all the
721              * requisite bits, and so there are no kept bits in |element|.
722              */
723             element = ((uint16_t) accum) & mask;
724             todo -= accum_bits;
725             mask = bitmask >> accum_bits;
726             accum_bits = 0;
727         }
728     } while (curr < end);
729 }
730 
731 static __owur
scalar_decode_12(scalar * out,const uint8_t in[3* DEGREE/2])732 int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
733 {
734     int i;
735     uint16_t *c = out->c;
736 
737     for (i = 0; i < DEGREE / 2; ++i) {
738         uint8_t b1 = *in++;
739         uint8_t b2 = *in++;
740         uint8_t b3 = *in++;
741         int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
742         int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
743 
744         if (outOfRange1 | outOfRange2)
745             return 0;
746     }
747     return 1;
748 }
749 
750 /*-
751  * scalar_decode_decompress_add is a combination of decoding and decompression
752  * both specialised for |bits| == 1, with the result added (and sum reduced) to
753  * the output scalar.
754  *
755  * NOTE: this function MUST not leak an input-data-depedennt timing signal.
756  * A timing leak in a related function in the reference Kyber implementation
757  * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
758  * for ML-KEM-512 in minutes, provided the attacker has access to precise
759  * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
760  * a risk when private keys are reused (perhaps KEMTLS servers).
761  */
762 static void
scalar_decode_decompress_add(scalar * out,const uint8_t in[DEGREE/8])763 scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
764 {
765     static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
766     uint16_t *curr = out->c, *end = curr + DEGREE;
767     uint16_t mask;
768     uint8_t b;
769 
770     /*
771      * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
772      * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
773      * discussion on why the value barrier is by default omitted.
774      */
775 #define decode_decompress_add_bit                               \
776         mask = constish_time_non_zero(bit0(b));                 \
777         *curr = reduce_once(*curr + (mask & half_q_plus_1));    \
778         curr++;                                                 \
779         b >>= 1
780 
781     /* Unrolled to process each byte in one iteration */
782     do {
783         b = *in++;
784         decode_decompress_add_bit;
785         decode_decompress_add_bit;
786         decode_decompress_add_bit;
787         decode_decompress_add_bit;
788 
789         decode_decompress_add_bit;
790         decode_decompress_add_bit;
791         decode_decompress_add_bit;
792         decode_decompress_add_bit;
793     } while (curr < end);
794 #undef decode_decompress_add_bit
795 }
796 
797 /*
798  * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
799  *
800  * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
801  * numbers close to each other together. The formula used is
802  * round(2^|bits|/kPrime*x) mod 2^|bits|.
803  * Uses Barrett reduction to achieve constant time. Since we need both the
804  * remainder (for rounding) and the quotient (as the result), we cannot use
805  * |reduce| here, but need to do the Barrett reduction directly.
806  */
compress(uint16_t x,int bits)807 static __owur uint16_t compress(uint16_t x, int bits)
808 {
809     uint32_t shifted = (uint32_t)x << bits;
810     uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
811     uint32_t quotient = (uint32_t)(product >> kBarrettShift);
812     uint32_t remainder = shifted - quotient * kPrime;
813 
814     /*
815      * Adjust the quotient to round correctly:
816      *   0 <= remainder <= kHalfPrime round to 0
817      *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
818      *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
819      */
820     quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
821     quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
822     return quotient & ((1 << bits) - 1);
823 }
824 
825 /*
826  * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
827 
828  * Decompresses |x| by using a close equi-distant representative. The formula
829  * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
830  * to implement this logic using only bit operations.
831  */
decompress(uint16_t x,int bits)832 static __owur uint16_t decompress(uint16_t x, int bits)
833 {
834     uint32_t product = (uint32_t)x * kPrime;
835     uint32_t power = 1 << bits;
836     /* This is |product| % power, since |power| is a power of 2. */
837     uint32_t remainder = product & (power - 1);
838     /* This is |product| / power, since |power| is a power of 2. */
839     uint32_t lower = product >> bits;
840 
841     /*
842      * The rounding logic works since the first half of numbers mod |power|
843      * have a 0 as first bit, and the second half has a 1 as first bit, since
844      * |power| is a power of 2. As a 12 bit number, |remainder| is always
845      * positive, so we will shift in 0s for a right shift.
846      */
847     return lower + (remainder >> (bits - 1));
848 }
849 
850 /*-
851  * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
852  * In-place lossy rounding of scalars to 2^d bits.
853  */
scalar_compress(scalar * s,int bits)854 static void scalar_compress(scalar *s, int bits)
855 {
856     int i;
857 
858     for (i = 0; i < DEGREE; i++)
859         s->c[i] = compress(s->c[i], bits);
860 }
861 
862 /*
863  * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
864  * In-place approximate recovery of scalars from 2^d bit compression.
865  */
scalar_decompress(scalar * s,int bits)866 static void scalar_decompress(scalar *s, int bits)
867 {
868     int i;
869 
870     for (i = 0; i < DEGREE; i++)
871         s->c[i] = decompress(s->c[i], bits);
872 }
873 
874 /* Addition updating the LHS vector in-place. */
vector_add(scalar * lhs,const scalar * rhs,int rank)875 static void vector_add(scalar *lhs, const scalar *rhs, int rank)
876 {
877     do {
878         scalar_add(lhs++, rhs++);
879     } while (--rank > 0);
880 }
881 
882 /*
883  * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
884  * (DEGREE) is divisible by 8, the individual vector entries will always fill a
885  * whole number of bytes, so we do not need to worry about bit packing here.
886  */
vector_encode(uint8_t * out,const scalar * a,int bits,int rank)887 static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
888 {
889     int stride = bits * DEGREE / 8;
890 
891     for (; rank-- > 0; out += stride)
892         scalar_encode(out, a++, bits);
893 }
894 
895 /*
896  * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
897  * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
898  * then decompressed and transformed via the NTT.
899  *
900  * Note: Used only in decrypt_cpa(), which returns void and so does not check
901  * the return value of this function.  Side-channels are fine when the input
902  * ciphertext to decap() is simply syntactically invalid.
903  */
904 static void
vector_decode_decompress_ntt(scalar * out,const uint8_t * in,int bits,int rank)905 vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
906 {
907     int stride = bits * DEGREE / 8;
908 
909     for (; rank-- > 0; in += stride, ++out) {
910         scalar_decode(out, in, bits);
911         scalar_decompress(out, bits);
912         scalar_ntt(out);
913     }
914 }
915 
916 /* vector_decode(), specialised to bits == 12. */
917 static __owur
vector_decode_12(scalar * out,const uint8_t in[3* DEGREE/2],int rank)918 int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
919 {
920     int stride = 3 * DEGREE / 2;
921 
922     for (; rank-- > 0; in += stride)
923         if (!scalar_decode_12(out++, in))
924             return 0;
925     return 1;
926 }
927 
928 /* In-place compression of each scalar component */
vector_compress(scalar * a,int bits,int rank)929 static void vector_compress(scalar *a, int bits, int rank)
930 {
931     do {
932         scalar_compress(a++, bits);
933     } while (--rank > 0);
934 }
935 
936 /* The output scalar must not overlap with the inputs */
inner_product(scalar * out,const scalar * lhs,const scalar * rhs,int rank)937 static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
938                           int rank)
939 {
940     scalar_mult(out, lhs, rhs);
941     while (--rank > 0)
942         scalar_mult_add(out, ++lhs, ++rhs);
943 }
944 
945 /*
946  * Here, the output vector must not overlap with the inputs, the result is
947  * directly subjected to inverse NTT.
948  */
949 static void
matrix_mult_intt(scalar * out,const scalar * m,const scalar * a,int rank)950 matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
951 {
952     const scalar *ar;
953     int i, j;
954 
955     for (i = rank; i-- > 0; ++out) {
956         scalar_mult(out, m++, ar = a);
957         for (j = rank - 1; j > 0; --j)
958             scalar_mult_add(out, m++, ++ar);
959         scalar_inverse_ntt(out);
960     }
961 }
962 
963 /* Here, the output vector must not overlap with the inputs */
964 static void
matrix_mult_transpose_add(scalar * out,const scalar * m,const scalar * a,int rank)965 matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
966 {
967     const scalar *mc = m, *mr, *ar;
968     int i, j;
969 
970     for (i = rank; i-- > 0; ++out) {
971         scalar_mult_add(out, mr = mc++, ar = a);
972         for (j = rank; --j > 0; )
973             scalar_mult_add(out, (mr += rank), ++ar);
974     }
975 }
976 
977 /*-
978  * Expands the matrix from a seed for key generation and for encaps-CPA.
979  * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
980  * by appending the (i,j) indices to the seed in the opposite order!
981  *
982  * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
983  */
984 static __owur
matrix_expand(EVP_MD_CTX * mdctx,ML_KEM_KEY * key)985 int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
986 {
987     scalar *out = key->m;
988     uint8_t input[ML_KEM_RANDOM_BYTES + 2];
989     int rank = key->vinfo->rank;
990     int i, j;
991 
992     memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
993     for (i = 0; i < rank; i++) {
994         for (j = 0; j < rank; j++) {
995             input[ML_KEM_RANDOM_BYTES] = i;
996             input[ML_KEM_RANDOM_BYTES + 1] = j;
997             if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
998                 || !EVP_DigestUpdate(mdctx, input, sizeof(input))
999                 || !sample_scalar(out++, mdctx))
1000                 return 0;
1001         }
1002     }
1003     return 1;
1004 }
1005 
1006 /*
1007  * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1008  * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1009  * and setting the coefficient to the count of the first bits minus the count of
1010  * the second bits, resulting in a centered binomial distribution. Since eta is
1011  * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1012  * and 0 with probability 3/8.
1013  */
1014 static __owur
cbd_2(scalar * out,uint8_t in[ML_KEM_RANDOM_BYTES+1],EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1015 int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1016           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1017 {
1018     uint16_t *curr = out->c, *end = curr + DEGREE;
1019     uint8_t randbuf[4 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1020     uint16_t value, mask;
1021     uint8_t b;
1022 
1023     if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1024         return 0;
1025 
1026     do {
1027         b = *r++;
1028 
1029         /*
1030          * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1031          * for a discussion on why the value barrier is by default omitted.
1032          * While this could have been written reduce_once(value + kPrime), this
1033          * is one extra addition and small range of |value| tempts some
1034          * versions of Clang to emit a branch.
1035          */
1036         value = bit0(b) + bitn(1, b);
1037         value -= bitn(2, b) + bitn(3, b);
1038         mask = constish_time_non_zero(value >> 15);
1039         *curr++ = value + (kPrime & mask);
1040 
1041         value = bitn(4, b) + bitn(5, b);
1042         value -= bitn(6, b) + bitn(7, b);
1043         mask = constish_time_non_zero(value >> 15);
1044         *curr++ = value + (kPrime & mask);
1045     } while (curr < end);
1046     return 1;
1047 }
1048 
1049 /*
1050  * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1051  * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1052  * and setting the coefficient to the count of the first bits minus the count of
1053  * the second bits, resulting in a centered binomial distribution.
1054  */
1055 static __owur
cbd_3(scalar * out,uint8_t in[ML_KEM_RANDOM_BYTES+1],EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1056 int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1057           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1058 {
1059     uint16_t *curr = out->c, *end = curr + DEGREE;
1060     uint8_t randbuf[6 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1061     uint8_t b1, b2, b3;
1062     uint16_t value, mask;
1063 
1064     if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1065         return 0;
1066 
1067     do {
1068         b1 = *r++;
1069         b2 = *r++;
1070         b3 = *r++;
1071 
1072         /*
1073          * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1074          * for a discussion on why the value barrier is by default omitted.
1075          * While this could have been written reduce_once(value + kPrime), this
1076          * is one extra addition and small range of |value| tempts some
1077          * versions of Clang to emit a branch.
1078          */
1079         value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1080         value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);
1081         mask = constish_time_non_zero(value >> 15);
1082         *curr++ = value + (kPrime & mask);
1083 
1084         value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1085         value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1086         mask = constish_time_non_zero(value >> 15);
1087         *curr++ = value + (kPrime & mask);
1088 
1089         value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1090         value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1091         mask = constish_time_non_zero(value >> 15);
1092         *curr++ = value + (kPrime & mask);
1093 
1094         value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1095         value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1096         mask = constish_time_non_zero(value >> 15);
1097         *curr++ = value + (kPrime & mask);
1098     } while (curr < end);
1099     return 1;
1100 }
1101 
1102 /*
1103  * Generates a secret vector by using |cbd| with the given seed to generate
1104  * scalar elements and incrementing |counter| for each slot of the vector.
1105  */
1106 static __owur
gencbd_vector(scalar * out,CBD_FUNC cbd,uint8_t * counter,const uint8_t seed[ML_KEM_RANDOM_BYTES],int rank,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1107 int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1108                   const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1109                   EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1110 {
1111     uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1112 
1113     memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1114     do {
1115         input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1116         if (!cbd(out++, input, mdctx, key))
1117             return 0;
1118     } while (--rank > 0);
1119     return 1;
1120 }
1121 
1122 /*
1123  * As above plus NTT transform.
1124  */
1125 static __owur
gencbd_vector_ntt(scalar * out,CBD_FUNC cbd,uint8_t * counter,const uint8_t seed[ML_KEM_RANDOM_BYTES],int rank,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1126 int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1127                       const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1128                       EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1129 {
1130     uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1131 
1132     memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1133     do {
1134         input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1135         if (!cbd(out, input, mdctx, key))
1136             return 0;
1137         scalar_ntt(out++);
1138     } while (--rank > 0);
1139     return 1;
1140 }
1141 
1142 /* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1143 #define CBD1(evp_type)  ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1144 
1145 /*
1146  * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1147  *
1148  * Encrypts a message with given randomness to the ciphertext in |out|. Without
1149  * applying the Fujisaki-Okamoto transform this would not result in a CCA
1150  * secure scheme, since lattice schemes are vulnerable to decryption failure
1151  * oracles.
1152  *
1153  * The steps are re-ordered to make more efficient/localised use of storage.
1154  *
1155  * Note also that the input public key is assumed to hold a precomputed matrix
1156  * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1157  * coefficient) key->t vector).
1158  *
1159  * Caller passes storage in |tmp| for for two temporary vectors.
1160  */
1161 static __owur
encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],const uint8_t message[DEGREE/8],const uint8_t r[ML_KEM_RANDOM_BYTES],scalar * tmp,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1162 int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1163                 const uint8_t message[DEGREE / 8],
1164                 const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1165                 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1166 {
1167     const ML_KEM_VINFO *vinfo = key->vinfo;
1168     CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1169     int rank = vinfo->rank;
1170     /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1171     scalar *y = &tmp[0], *e1 = y, *e2 = y;
1172     /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1173     scalar *u = &tmp[rank];
1174     scalar v;
1175     uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1176     uint8_t counter = 0;
1177     int du = vinfo->du;
1178     int dv = vinfo->dv;
1179 
1180     /* FIPS 203 "y" vector */
1181     if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1182         return 0;
1183     /* FIPS 203 "v" scalar */
1184     inner_product(&v, key->t, y, rank);
1185     scalar_inverse_ntt(&v);
1186     /* FIPS 203 "u" vector */
1187     matrix_mult_intt(u, key->m, y, rank);
1188 
1189     /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1190     if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1191         return 0;
1192     vector_add(u, e1, rank);
1193     vector_compress(u, du, rank);
1194     vector_encode(out, u, du, rank);
1195 
1196     /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1197     memcpy(input, r, ML_KEM_RANDOM_BYTES);
1198     input[ML_KEM_RANDOM_BYTES] = counter;
1199     if (!cbd_2(e2, input, mdctx, key))
1200         return 0;
1201     scalar_add(&v, e2);
1202 
1203     /* Combine message with |v| */
1204     scalar_decode_decompress_add(&v, message);
1205     scalar_compress(&v, dv);
1206     scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1207     return 1;
1208 }
1209 
1210 /*
1211  * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1212  */
1213 static void
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],const uint8_t * ctext,scalar * u,const ML_KEM_KEY * key)1214 decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1215             const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1216 {
1217     const ML_KEM_VINFO *vinfo = key->vinfo;
1218     scalar v, mask;
1219     int rank = vinfo->rank;
1220     int du = vinfo->du;
1221     int dv = vinfo->dv;
1222 
1223     vector_decode_decompress_ntt(u, ctext, du, rank);
1224     scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1225     scalar_decompress(&v, dv);
1226     inner_product(&mask, key->s, u, rank);
1227     scalar_inverse_ntt(&mask);
1228     scalar_sub(&v, &mask);
1229     scalar_compress(&v, 1);
1230     scalar_encode_1(out, &v);
1231 }
1232 
1233 /*-
1234  * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1235  * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1236  *
1237  * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1238  * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1239  * wire-format of an ML-KEM public key.
1240  */
encode_pubkey(uint8_t * out,const ML_KEM_KEY * key)1241 static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1242 {
1243     const uint8_t *rho = key->rho;
1244     const ML_KEM_VINFO *vinfo = key->vinfo;
1245 
1246     vector_encode(out, key->t, 12, vinfo->rank);
1247     memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1248 }
1249 
1250 /*-
1251  * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1252  *
1253  * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1254  * This matches the input format of parse_prvkey() below.
1255  */
encode_prvkey(uint8_t * out,const ML_KEM_KEY * key)1256 static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1257 {
1258     const ML_KEM_VINFO *vinfo = key->vinfo;
1259 
1260     vector_encode(out, key->s, 12, vinfo->rank);
1261     out += vinfo->vector_bytes;
1262     encode_pubkey(out, key);
1263     out += vinfo->pubkey_bytes;
1264     memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1265     out += ML_KEM_PKHASH_BYTES;
1266     memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1267 }
1268 
1269 /*-
1270  * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1271  * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1272  *
1273  * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1274  * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1275  * wire-format of the ML-KEM public key.
1276  */
parse_pubkey(const uint8_t * in,EVP_MD_CTX * mdctx,ML_KEM_KEY * key)1277 static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1278 {
1279     const ML_KEM_VINFO *vinfo = key->vinfo;
1280 
1281     /* Decode and check |t| */
1282     if (!vector_decode_12(key->t, in, vinfo->rank)) {
1283         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1284                        "%s invalid public 't' vector",
1285                        vinfo->algorithm_name);
1286         return 0;
1287     }
1288     /* Save the matrix |m| recovery seed |rho| */
1289     memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1290     /*
1291      * Pre-compute the public key hash, needed for both encap and decap.
1292      * Also pre-compute the matrix expansion, stored with the public key.
1293      */
1294     if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1295         || !matrix_expand(mdctx, key)) {
1296         ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1297                        "internal error while parsing %s public key",
1298                        vinfo->algorithm_name);
1299         return 0;
1300     }
1301     return 1;
1302 }
1303 
1304 /*
1305  * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1306  *
1307  * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1308  * This matches the output format of encode_prvkey() above.
1309  */
parse_prvkey(const uint8_t * in,EVP_MD_CTX * mdctx,ML_KEM_KEY * key)1310 static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1311 {
1312     const ML_KEM_VINFO *vinfo = key->vinfo;
1313 
1314     /* Decode and check |s|. */
1315     if (!vector_decode_12(key->s, in, vinfo->rank)) {
1316         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1317                        "%s invalid private 's' vector",
1318                        vinfo->algorithm_name);
1319         return 0;
1320     }
1321     in += vinfo->vector_bytes;
1322 
1323     if (!parse_pubkey(in, mdctx, key))
1324         return 0;
1325     in += vinfo->pubkey_bytes;
1326 
1327     /* Check public key hash. */
1328     if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1329         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1330                        "%s public key hash mismatch",
1331                        vinfo->algorithm_name);
1332         return 0;
1333     }
1334     in += ML_KEM_PKHASH_BYTES;
1335 
1336     memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1337     return 1;
1338 }
1339 
1340 /*
1341  * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1342  *
1343  * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1344  * inlined.
1345  *
1346  * The caller MUST pass a pre-allocated digest context that is not shared with
1347  * any concurrent computation.
1348  *
1349  * This function optionally outputs the serialised wire-form |ek| public key
1350  * into the provided |pubenc| buffer, and generates the content of the |rho|,
1351  * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1352  * have preallocated space for these).
1353  *
1354  * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1355  * domain separation.  These are concatenated and hashed to produce a pair of
1356  * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1357  * used to generate the secret vector |s|.
1358  *
1359  * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1360  * (FO) transform "implicit-rejection" secret (the |z| component of the private
1361  * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1362  * constant time, with no side channel leaks, on all well-formed (valid length,
1363  * and correctly encoded) ciphertext inputs.
1364  */
1365 static __owur
genkey(const uint8_t seed[ML_KEM_SEED_BYTES],EVP_MD_CTX * mdctx,uint8_t * pubenc,ML_KEM_KEY * key)1366 int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1367            EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1368 {
1369     uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1370     const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1371     uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1372     const ML_KEM_VINFO *vinfo = key->vinfo;
1373     CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1374     int rank = vinfo->rank;
1375     uint8_t counter = 0;
1376     int ret = 0;
1377 
1378     /*
1379      * Use the "d" seed salted with the rank to derive the public and private
1380      * seeds rho and sigma.
1381      */
1382     memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1383     augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t) rank;
1384     if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1385         goto end;
1386     memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1387     /* The |rho| matrix seed is public */
1388     CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1389 
1390     /* FIPS 203 |e| vector is initial value of key->t */
1391     if (!matrix_expand(mdctx, key)
1392         || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1393         || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1394         goto end;
1395 
1396     /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1397     matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1398     /* The |t| vector is public */
1399     CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1400 
1401     if (pubenc == NULL) {
1402         /* Incremental digest of public key without in-full serialisation. */
1403         if (!hash_h_pubkey(key->pkhash, mdctx, key))
1404             goto end;
1405     } else {
1406         encode_pubkey(pubenc, key);
1407         if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1408             goto end;
1409     }
1410 
1411     /* Save |z| portion of seed for "implicit rejection" on failure. */
1412     memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1413 
1414     /* Optionally save the |d| portion of the seed */
1415     key->d = key->z + ML_KEM_RANDOM_BYTES;
1416     if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1417         memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1418     } else {
1419         OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1420         key->d = NULL;
1421     }
1422 
1423     ret = 1;
1424  end:
1425     OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1426     OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1427     if (ret == 0) {
1428         ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1429                        "internal error while generating %s private key",
1430                        vinfo->algorithm_name);
1431     }
1432     return ret;
1433 }
1434 
1435 /*-
1436  * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1437  * This is the deterministic version with randomness supplied externally.
1438  *
1439  * The caller must pass space for two vectors in |tmp|.
1440  * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1441  * of the provided key.
1442  */
1443 static
encap(uint8_t * ctext,uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],const uint8_t entropy[ML_KEM_RANDOM_BYTES],scalar * tmp,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1444 int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1445           const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1446           scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1447 {
1448     uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1449     uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1450     uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1451     int ret;
1452 
1453     memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1454     memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1455     ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1456         && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1457     OPENSSL_cleanse((void *)input, sizeof(input));
1458 
1459     if (ret)
1460         memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1461     else
1462         ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1463                        "internal error while performing %s encapsulation",
1464                        key->vinfo->algorithm_name);
1465     return ret;
1466 }
1467 
1468 /*
1469  * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1470  *
1471  * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1472  * deterministic, the randomness for the FO transform is extracted during
1473  * private key generation.
1474  *
1475  * The caller must pass space for two vectors in |tmp|.
1476  * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1477  * of the key's ML-KEM variant.
1478  */
1479 static
decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],const uint8_t * ctext,uint8_t * tmp_ctext,scalar * tmp,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1480 int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1481           const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1482           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1483 {
1484     uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1485     uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1486     uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1487     uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1488     const uint8_t *pkhash = key->pkhash;
1489     const ML_KEM_VINFO *vinfo = key->vinfo;
1490     int i;
1491     uint8_t mask;
1492 
1493     /*
1494      * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1495      * any further errors, returning success, and whatever we got for a shared
1496      * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1497      * so should not be subject to failure that makes its output predictable.
1498      *
1499      * We guard against "should never happen" catastrophic failure of the
1500      * "pure" function |hash_g| by overwriting the shared secret with the
1501      * content of the failure key and returning early, if nevertheless hash_g
1502      * fails.  This is not constant-time, but a failure of |hash_g| already
1503      * implies loss of side-channel resistance.
1504      *
1505      * The same action is taken, if also |encrypt_cpa| should catastrophically
1506      * fail, due to failure of the |PRF| underlying the CBD functions.
1507      */
1508     if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1509         ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1510                        "internal error while performing %s decapsulation",
1511                        vinfo->algorithm_name);
1512         return 0;
1513     }
1514     decrypt_cpa(decrypted, ctext, tmp, key);
1515     memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1516     if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1517         || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1518         memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1519         OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1520         return 1;
1521     }
1522     mask = constant_time_eq_int_8(0,
1523         CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1524     for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1525         secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1526     OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1527     OPENSSL_cleanse(Kr, sizeof(Kr));
1528     return 1;
1529 }
1530 
1531 /*
1532  * After allocating storage for public or private key data, update the key
1533  * component pointers to reference that storage.
1534  */
1535 static __owur
add_storage(scalar * p,int private,ML_KEM_KEY * key)1536 int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1537 {
1538     int rank = key->vinfo->rank;
1539 
1540     if (p == NULL)
1541         return 0;
1542 
1543     /*
1544      * We're adding key material, the seed buffer will now hold |rho| and
1545      * |pkhash|.
1546      */
1547     memset(key->seedbuf, 0, sizeof(key->seedbuf));
1548     key->rho = key->seedbuf;
1549     key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1550     key->d = key->z = NULL;
1551 
1552     /* A public key needs space for |t| and |m| */
1553     key->m = (key->t = p) + rank;
1554 
1555     /*
1556      * A private key also needs space for |s| and |z|.
1557      * The |z| buffer always includes additional space for |d|, but a key's |d|
1558      * pointer is left NULL when parsed from the NIST format, which omits that
1559      * information.  Only keys generated from a (d, z) seed pair will have a
1560      * non-NULL |d| pointer.
1561      */
1562     if (private)
1563         key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1564     return 1;
1565 }
1566 
1567 /*
1568  * After freeing the storage associated with a key that failed to be
1569  * constructed, reset the internal pointers back to NULL.
1570  */
1571 void
ossl_ml_kem_key_reset(ML_KEM_KEY * key)1572 ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1573 {
1574     if (key->t == NULL)
1575         return;
1576     /*-
1577      * Cleanse any sensitive data:
1578      * - The private vector |s| is immediately followed by the FO failure
1579      *   secret |z|, and seed |d|, we can cleanse all three in one call.
1580      *
1581      * - Otherwise, when key->d is set, cleanse the stashed seed.
1582      */
1583     if (ossl_ml_kem_have_prvkey(key))
1584         OPENSSL_cleanse(key->s,
1585                         key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1586     OPENSSL_free(key->t);
1587     key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1588 }
1589 
1590 /*
1591  * ----- API exported to the provider
1592  *
1593  * Parameters with an implicit fixed length in the internal static API of each
1594  * variant have an explicit checked length argument at this layer.
1595  */
1596 
1597 /* Retrieve the parameters of one of the ML-KEM variants */
ossl_ml_kem_get_vinfo(int evp_type)1598 const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1599 {
1600     switch (evp_type) {
1601     case EVP_PKEY_ML_KEM_512:
1602         return &vinfo_map[ML_KEM_512_VINFO];
1603     case EVP_PKEY_ML_KEM_768:
1604         return &vinfo_map[ML_KEM_768_VINFO];
1605     case EVP_PKEY_ML_KEM_1024:
1606         return &vinfo_map[ML_KEM_1024_VINFO];
1607     }
1608     return NULL;
1609 }
1610 
ossl_ml_kem_key_new(OSSL_LIB_CTX * libctx,const char * properties,int evp_type)1611 ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1612                                 int evp_type)
1613 {
1614     const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1615     ML_KEM_KEY *key;
1616 
1617     if (vinfo == NULL) {
1618         ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1619                        "unsupported ML-KEM key type: %d", evp_type);
1620         return NULL;
1621     }
1622 
1623     if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1624         return NULL;
1625 
1626     key->vinfo = vinfo;
1627     key->libctx = libctx;
1628     key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1629     key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1630     key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1631     key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1632     key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1633     key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1634     key->s = key->m = key->t = NULL;
1635 
1636     if (key->shake128_md != NULL
1637         && key->shake256_md != NULL
1638         && key->sha3_256_md != NULL
1639         && key->sha3_512_md != NULL)
1640         return key;
1641 
1642     ossl_ml_kem_key_free(key);
1643     ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1644                    "missing SHA3 digest algorithms while creating %s key",
1645                    vinfo->algorithm_name);
1646     return NULL;
1647 }
1648 
ossl_ml_kem_key_dup(const ML_KEM_KEY * key,int selection)1649 ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1650 {
1651     int ok = 0;
1652     ML_KEM_KEY *ret;
1653 
1654     /*
1655      * Partially decoded keys, not yet imported or loaded, should never be
1656      * duplicated.
1657      */
1658     if (ossl_ml_kem_decoded_key(key))
1659         return NULL;
1660 
1661     if (key == NULL
1662         || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1663         return NULL;
1664     ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1665     ret->s = ret->m = ret->t = NULL;
1666 
1667     /* Clear selection bits we can't fulfill */
1668     if (!ossl_ml_kem_have_pubkey(key))
1669         selection = 0;
1670     else if (!ossl_ml_kem_have_prvkey(key))
1671         selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1672 
1673     switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1674     case 0:
1675         ok = 1;
1676         break;
1677     case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1678         ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1679         ret->rho = ret->seedbuf;
1680         ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1681         break;
1682     case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1683         ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1684         /* Duplicated keys retain |d|, if available */
1685         if (key->d != NULL)
1686             ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1687         break;
1688     }
1689 
1690     if (!ok) {
1691         OPENSSL_free(ret);
1692         return NULL;
1693     }
1694 
1695     EVP_MD_up_ref(ret->shake128_md);
1696     EVP_MD_up_ref(ret->shake256_md);
1697     EVP_MD_up_ref(ret->sha3_256_md);
1698     EVP_MD_up_ref(ret->sha3_512_md);
1699 
1700     return ret;
1701 }
1702 
ossl_ml_kem_key_free(ML_KEM_KEY * key)1703 void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1704 {
1705     if (key == NULL)
1706         return;
1707 
1708     EVP_MD_free(key->shake128_md);
1709     EVP_MD_free(key->shake256_md);
1710     EVP_MD_free(key->sha3_256_md);
1711     EVP_MD_free(key->sha3_512_md);
1712 
1713     if (ossl_ml_kem_decoded_key(key)) {
1714         OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
1715         if (ossl_ml_kem_have_dkenc(key)) {
1716             OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
1717             OPENSSL_free(key->encoded_dk);
1718         }
1719     }
1720     ossl_ml_kem_key_reset(key);
1721     OPENSSL_free(key);
1722 }
1723 
1724 /* Serialise the public component of an ML-KEM key */
ossl_ml_kem_encode_public_key(uint8_t * out,size_t len,const ML_KEM_KEY * key)1725 int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1726                                   const ML_KEM_KEY *key)
1727 {
1728     if (!ossl_ml_kem_have_pubkey(key)
1729         || len != key->vinfo->pubkey_bytes)
1730         return 0;
1731     encode_pubkey(out, key);
1732     return 1;
1733 }
1734 
1735 /* Serialise an ML-KEM private key */
ossl_ml_kem_encode_private_key(uint8_t * out,size_t len,const ML_KEM_KEY * key)1736 int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1737                                    const ML_KEM_KEY *key)
1738 {
1739     if (!ossl_ml_kem_have_prvkey(key)
1740         || len != key->vinfo->prvkey_bytes)
1741         return 0;
1742     encode_prvkey(out, key);
1743     return 1;
1744 }
1745 
ossl_ml_kem_encode_seed(uint8_t * out,size_t len,const ML_KEM_KEY * key)1746 int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1747                             const ML_KEM_KEY *key)
1748 {
1749     if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1750         return 0;
1751     /*
1752      * Both in the seed buffer, and in the allocated storage, the |d| component
1753      * of the seed is stored last, so we must copy each separately.
1754      */
1755     memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1756     out += ML_KEM_RANDOM_BYTES;
1757     memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1758     return 1;
1759 }
1760 
1761 /*
1762  * Stash the seed without (yet) performing a keygen, used during decoding, to
1763  * avoid an extra keygen if we're only going to export the key again to load
1764  * into another provider.
1765  */
ossl_ml_kem_set_seed(const uint8_t * seed,size_t seedlen,ML_KEM_KEY * key)1766 ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1767 {
1768     if (key == NULL
1769         || ossl_ml_kem_have_pubkey(key)
1770         || ossl_ml_kem_have_seed(key)
1771         || seedlen != ML_KEM_SEED_BYTES)
1772         return NULL;
1773     /*
1774      * With no public or private key material on hand, we can use the seed
1775      * buffer for |z| and |d|, in that order.
1776      */
1777     key->z = key->seedbuf;
1778     key->d = key->z + ML_KEM_RANDOM_BYTES;
1779     memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1780     seed += ML_KEM_RANDOM_BYTES;
1781     memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1782     return key;
1783 }
1784 
1785 /* Parse input as a public key */
ossl_ml_kem_parse_public_key(const uint8_t * in,size_t len,ML_KEM_KEY * key)1786 int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1787 {
1788     EVP_MD_CTX *mdctx = NULL;
1789     const ML_KEM_VINFO *vinfo;
1790     int ret = 0;
1791 
1792     /* Keys with key material are immutable */
1793     if (key == NULL
1794         || ossl_ml_kem_have_pubkey(key)
1795         || ossl_ml_kem_have_dkenc(key))
1796         return 0;
1797     vinfo = key->vinfo;
1798 
1799     if (len != vinfo->pubkey_bytes
1800         || (mdctx = EVP_MD_CTX_new()) == NULL)
1801         return 0;
1802 
1803     if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
1804         ret = parse_pubkey(in, mdctx, key);
1805 
1806     if (!ret)
1807         ossl_ml_kem_key_reset(key);
1808     EVP_MD_CTX_free(mdctx);
1809     return ret;
1810 }
1811 
1812 /* Parse input as a new private key */
ossl_ml_kem_parse_private_key(const uint8_t * in,size_t len,ML_KEM_KEY * key)1813 int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1814                                   ML_KEM_KEY *key)
1815 {
1816     EVP_MD_CTX *mdctx = NULL;
1817     const ML_KEM_VINFO *vinfo;
1818     int ret = 0;
1819 
1820     /* Keys with key material are immutable */
1821     if (key == NULL
1822         || ossl_ml_kem_have_pubkey(key)
1823         || ossl_ml_kem_have_dkenc(key))
1824         return 0;
1825     vinfo = key->vinfo;
1826 
1827     if (len != vinfo->prvkey_bytes
1828         || (mdctx = EVP_MD_CTX_new()) == NULL)
1829         return 0;
1830 
1831     if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1832         ret = parse_prvkey(in, mdctx, key);
1833 
1834     if (!ret)
1835         ossl_ml_kem_key_reset(key);
1836     EVP_MD_CTX_free(mdctx);
1837     return ret;
1838 }
1839 
1840 /*
1841  * Generate a new keypair, either from the saved seed (when non-null), or from
1842  * the RNG.
1843  */
ossl_ml_kem_genkey(uint8_t * pubenc,size_t publen,ML_KEM_KEY * key)1844 int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1845 {
1846     uint8_t seed[ML_KEM_SEED_BYTES];
1847     EVP_MD_CTX *mdctx = NULL;
1848     const ML_KEM_VINFO *vinfo;
1849     int ret = 0;
1850 
1851     if (key == NULL
1852         || ossl_ml_kem_have_pubkey(key)
1853         || ossl_ml_kem_have_dkenc(key))
1854         return 0;
1855     vinfo = key->vinfo;
1856 
1857     if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1858         return 0;
1859 
1860     if (ossl_ml_kem_have_seed(key)) {
1861         if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1862             return 0;
1863         key->d = key->z = NULL;
1864     } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1865                                   key->vinfo->secbits) <= 0) {
1866         return 0;
1867     }
1868 
1869     if ((mdctx = EVP_MD_CTX_new()) == NULL)
1870         return 0;
1871 
1872     /*
1873      * Data derived from (d, z) defaults secret, and to avoid side-channel
1874      * leaks should not influence control flow.
1875      */
1876     CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1877 
1878     if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1879         ret = genkey(seed, mdctx, pubenc, key);
1880     OPENSSL_cleanse(seed, sizeof(seed));
1881 
1882     /* Declassify secret inputs and derived outputs before returning control */
1883     CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1884 
1885     EVP_MD_CTX_free(mdctx);
1886     if (!ret) {
1887         ossl_ml_kem_key_reset(key);
1888         return 0;
1889     }
1890 
1891     /* The public components are already declassified */
1892     CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1893     CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1894     return 1;
1895 }
1896 
1897 /*
1898  * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1899  * This is the deterministic version with randomness supplied externally.
1900  */
ossl_ml_kem_encap_seed(uint8_t * ctext,size_t clen,uint8_t * shared_secret,size_t slen,const uint8_t * entropy,size_t elen,const ML_KEM_KEY * key)1901 int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1902                            uint8_t *shared_secret, size_t slen,
1903                            const uint8_t *entropy, size_t elen,
1904                            const ML_KEM_KEY *key)
1905 {
1906     const ML_KEM_VINFO *vinfo;
1907     EVP_MD_CTX *mdctx;
1908     int ret = 0;
1909 
1910     if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1911         return 0;
1912     vinfo = key->vinfo;
1913 
1914     if (ctext == NULL || clen != vinfo->ctext_bytes
1915         || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1916         || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1917         || (mdctx = EVP_MD_CTX_new()) == NULL)
1918         return 0;
1919     /*
1920      * Data derived from the encap entropy defaults secret, and to avoid
1921      * side-channel leaks should not influence control flow.
1922      */
1923     CONSTTIME_SECRET(entropy, elen);
1924 
1925     /*-
1926      * This avoids the need to handle allocation failures for two (max 2KB
1927      * each) vectors, that are never retained on return from this function.
1928      * We stack-allocate these.
1929      */
1930 #   define case_encap_seed(bits)                                            \
1931     case EVP_PKEY_ML_KEM_##bits:                                            \
1932         {                                                                   \
1933             scalar tmp[2 * ML_KEM_##bits##_RANK];                           \
1934                                                                             \
1935             ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key);    \
1936             OPENSSL_cleanse((void *)tmp, sizeof(tmp));                      \
1937             break;                                                          \
1938         }
1939     switch (vinfo->evp_type) {
1940     case_encap_seed(512);
1941     case_encap_seed(768);
1942     case_encap_seed(1024);
1943     }
1944 #   undef case_encap_seed
1945 
1946     /* Declassify secret inputs and derived outputs before returning control */
1947     CONSTTIME_DECLASSIFY(entropy, elen);
1948     CONSTTIME_DECLASSIFY(ctext, clen);
1949     CONSTTIME_DECLASSIFY(shared_secret, slen);
1950 
1951     EVP_MD_CTX_free(mdctx);
1952     return ret;
1953 }
1954 
ossl_ml_kem_encap_rand(uint8_t * ctext,size_t clen,uint8_t * shared_secret,size_t slen,const ML_KEM_KEY * key)1955 int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1956                            uint8_t *shared_secret, size_t slen,
1957                            const ML_KEM_KEY *key)
1958 {
1959     uint8_t r[ML_KEM_RANDOM_BYTES];
1960 
1961     if (key == NULL)
1962         return 0;
1963 
1964     if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1965                       key->vinfo->secbits) < 1)
1966         return 0;
1967 
1968     return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
1969                                   r, sizeof(r), key);
1970 }
1971 
ossl_ml_kem_decap(uint8_t * shared_secret,size_t slen,const uint8_t * ctext,size_t clen,const ML_KEM_KEY * key)1972 int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
1973                       const uint8_t *ctext, size_t clen,
1974                       const ML_KEM_KEY *key)
1975 {
1976     const ML_KEM_VINFO *vinfo;
1977     EVP_MD_CTX *mdctx;
1978     int ret = 0;
1979 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1980     int classify_bytes;
1981 #endif
1982 
1983     /* Need a private key here */
1984     if (!ossl_ml_kem_have_prvkey(key))
1985         return 0;
1986     vinfo = key->vinfo;
1987 
1988     if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1989         || ctext == NULL || clen != vinfo->ctext_bytes
1990         || (mdctx = EVP_MD_CTX_new()) == NULL) {
1991         (void)RAND_bytes_ex(key->libctx, shared_secret,
1992                             ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
1993         return 0;
1994     }
1995 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1996     /*
1997      * Data derived from |s| and |z| defaults secret, and to avoid side-channel
1998      * leaks should not influence control flow.
1999      */
2000     classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2001     CONSTTIME_SECRET(key->s, classify_bytes);
2002 #endif
2003 
2004     /*-
2005      * This avoids the need to handle allocation failures for two (max 2KB
2006      * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2007      * retained on return from this function.
2008      * We stack-allocate these.
2009      */
2010 #   define case_decap(bits)                                             \
2011     case EVP_PKEY_ML_KEM_##bits:                                        \
2012         {                                                               \
2013             uint8_t cbuf[CTEXT_BYTES(bits)];                            \
2014             scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
2015                                                                         \
2016             ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key);   \
2017             OPENSSL_cleanse((void *)tmp, sizeof(tmp));                  \
2018             break;                                                      \
2019         }
2020     switch (vinfo->evp_type) {
2021     case_decap(512);
2022     case_decap(768);
2023     case_decap(1024);
2024     }
2025 
2026     /* Declassify secret inputs and derived outputs before returning control */
2027     CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2028     CONSTTIME_DECLASSIFY(shared_secret, slen);
2029     EVP_MD_CTX_free(mdctx);
2030 
2031     return ret;
2032 #   undef case_decap
2033 }
2034 
ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY * key1,const ML_KEM_KEY * key2)2035 int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2036 {
2037     /*
2038      * This handles any unexpected differences in the ML-KEM variant rank,
2039      * giving different key component structures, barring SHA3-256 hash
2040      * collisions, the keys are the same size.
2041      */
2042     if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2043         return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2044 
2045     /*
2046      * No match if just one of the public keys is not available, otherwise both
2047      * are unavailable, and for now such keys are considered equal.
2048      */
2049     return (ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2));
2050 }
2051