1 /*
2 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10 #include <openssl/byteorder.h>
11 #include <openssl/rand.h>
12 #include <openssl/proverr.h>
13 #include "crypto/ml_kem.h"
14 #include "internal/common.h"
15 #include "internal/constant_time.h"
16 #include "internal/sha3.h"
17
18 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19 #include <valgrind/memcheck.h>
20 #endif
21
22 #if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23 # error "ML-KEM keygen seed length != shared secret + random bytes length"
24 #endif
25 #if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26 # error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27 #endif
28
29 #if UINT_MAX < UINT32_MAX
30 # error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31 #endif
32
33 /* Handy function-like bit-extraction macros */
34 #define bit0(b) ((b) & 1)
35 #define bitn(n, b) (((b) >> n) & 1)
36
37 /*
38 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40 */
41 #define DEGREE ML_KEM_DEGREE
42 #define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43 #define LOG2PRIME 12
44 #define BARRETT_SHIFT (2 * LOG2PRIME)
45
46 #ifdef SHA3_BLOCKSIZE
47 # define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48 #endif
49
50 /*
51 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52 * in practice! The return value is a mask that is all ones if true, and all
53 * zeros otherwise (twos-complement arithmentic assumed for unsigned values).
54 *
55 * Although this is used in constant-time selects, we omit a value barrier
56 * here. Value barriers impede auto-vectorization (likely because it forces
57 * the value to transit through a general-purpose register). On AArch64, this
58 * is a difference of 2x.
59 *
60 * We usually add value barriers to selects because Clang turns consecutive
61 * selects with the same condition into a branch instead of CMOV/CSEL. This
62 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64 * sign of the input, rather than adding |q| in advance, and using the generic
65 * |reduce_once|. (David Benjamin, Chromium)
66 */
67 #if 0
68 # define constish_time_non_zero(b) (~constant_time_is_zero(b));
69 #else
70 # define constish_time_non_zero(b) (0u - (b))
71 #endif
72
73 /*
74 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75 * is otherwise arbitrary, the preferred block size matches the internal buffer
76 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78 *
79 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80 * fallback.
81 */
82 #if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83 # define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84 #else
85 # define SCALAR_SAMPLING_BUFSIZE 168
86 #endif
87
88 /*
89 * Structure of keys
90 */
91 typedef struct ossl_ml_kem_scalar_st {
92 /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93 uint16_t c[ML_KEM_DEGREE];
94 } scalar;
95
96 /* Key material allocation layout */
97 #define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \
98 struct name##_alloc { \
99 /* Public vector |t| */ \
100 scalar tbuf[(rank)]; \
101 /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102 scalar mbuf[(rank)*(rank)] \
103 /* optional private key data */ \
104 private_sz \
105 }
106
107 /* Declare variant-specific public and private storage */
108 #define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
109 DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
110 DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
111 scalar sbuf[ML_KEM_##bits##_RANK]; \
112 uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
113 DECLARE_ML_KEM_VARIANT_KEYDATA(512);
114 DECLARE_ML_KEM_VARIANT_KEYDATA(768);
115 DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
116 #undef DECLARE_ML_KEM_VARIANT_KEYDATA
117 #undef DECLARE_ML_KEM_KEYDATA
118
119 typedef __owur
120 int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
121 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
122 static void scalar_encode(uint8_t *out, const scalar *s, int bits);
123
124 /*
125 * The wire-form of a losslessly encoded vector uses 12-bits per element.
126 *
127 * The wire-form public key consists of the lossless encoding of the public
128 * vector |t|, followed by the public seed |rho|.
129 *
130 * Our serialised private key concatenates serialisations of the private vector
131 * |s|, the public key, the public key hash, and the failure secret |z|.
132 */
133 #define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
134 #define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
135 #define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
136
137 /*
138 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
139 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
140 * "dv" bits, respectively. This encoding is the ciphertext input for
141 * decapsulation.
142 */
143 #define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
144 #define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
145 #define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
146
147 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
148
149 /*
150 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
151 * of memory as secret. Secret data is tracked as it flows to registers and
152 * other parts of a memory. If secret data is used as a condition for a branch,
153 * or as a memory index, it will trigger warnings in valgrind.
154 */
155 # define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
156
157 /*
158 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
159 * region of memory as public. Public data is not subject to constant-time
160 * rules.
161 */
162 # define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
163
164 #else
165
166 # define CONSTTIME_SECRET(ptr, len)
167 # define CONSTTIME_DECLASSIFY(ptr, len)
168
169 #endif
170
171 /*
172 * Indices of slots in the vinfo tables below
173 */
174 #define ML_KEM_512_VINFO 0
175 #define ML_KEM_768_VINFO 1
176 #define ML_KEM_1024_VINFO 2
177
178 /*
179 * Per-variant fixed parameters
180 */
181 static const ML_KEM_VINFO vinfo_map[3] = {
182 {
183 "ML-KEM-512",
184 PRVKEY_BYTES(512),
185 sizeof(struct prvkey_512_alloc),
186 PUBKEY_BYTES(512),
187 sizeof(struct pubkey_512_alloc),
188 CTEXT_BYTES(512),
189 VECTOR_BYTES(512),
190 U_VECTOR_BYTES(512),
191 EVP_PKEY_ML_KEM_512,
192 ML_KEM_512_BITS,
193 ML_KEM_512_RANK,
194 ML_KEM_512_DU,
195 ML_KEM_512_DV,
196 ML_KEM_512_SECBITS
197 },
198 {
199 "ML-KEM-768",
200 PRVKEY_BYTES(768),
201 sizeof(struct prvkey_768_alloc),
202 PUBKEY_BYTES(768),
203 sizeof(struct pubkey_768_alloc),
204 CTEXT_BYTES(768),
205 VECTOR_BYTES(768),
206 U_VECTOR_BYTES(768),
207 EVP_PKEY_ML_KEM_768,
208 ML_KEM_768_BITS,
209 ML_KEM_768_RANK,
210 ML_KEM_768_DU,
211 ML_KEM_768_DV,
212 ML_KEM_768_SECBITS
213 },
214 {
215 "ML-KEM-1024",
216 PRVKEY_BYTES(1024),
217 sizeof(struct prvkey_1024_alloc),
218 PUBKEY_BYTES(1024),
219 sizeof(struct pubkey_1024_alloc),
220 CTEXT_BYTES(1024),
221 VECTOR_BYTES(1024),
222 U_VECTOR_BYTES(1024),
223 EVP_PKEY_ML_KEM_1024,
224 ML_KEM_1024_BITS,
225 ML_KEM_1024_RANK,
226 ML_KEM_1024_DU,
227 ML_KEM_1024_DV,
228 ML_KEM_1024_SECBITS
229 }
230 };
231
232 /*
233 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
234 * constant time via Barrett reduction, and a final call to reduce_once(),
235 * which reduces inputs that are at most 2*kPrime and is also constant-time.
236 */
237 static const int kPrime = ML_KEM_PRIME;
238 static const unsigned int kBarrettShift = BARRETT_SHIFT;
239 static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
240 static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
241 static const uint16_t kInverseDegree = INVERSE_DEGREE;
242
243 /*
244 * Python helper:
245 *
246 * p = 3329
247 * def bitreverse(i):
248 * ret = 0
249 * for n in range(7):
250 * bit = i & 1
251 * ret <<= 1
252 * ret |= bit
253 * i >>= 1
254 * return ret
255 */
256
257 /*-
258 * First precomputed array from Appendix A of FIPS 203, or else Python:
259 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
260 */
261 static const uint16_t kNTTRoots[128] = {
262 1, 1729, 2580, 3289, 2642, 630, 1897, 848,
263 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
264 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
265 1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
266 289, 331, 3253, 1756, 1197, 2304, 2277, 2055,
267 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
268 2319, 1435, 807, 452, 1438, 2868, 1534, 2402,
269 2647, 2617, 1481, 648, 2474, 3110, 1227, 910,
270 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
271 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
272 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687,
273 939, 2308, 2437, 2388, 733, 2337, 268, 641,
274 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645,
275 1063, 319, 2773, 757, 2099, 561, 2466, 2594,
276 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
277 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
278 };
279
280 /*
281 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
282 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
283 *
284 * 0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
285 */
286 static const uint16_t kInverseNTTRoots[128] = {
287 1, 1175, 2444, 394, 1219, 2300, 1455, 2117,
288 1607, 2443, 554, 1179, 2186, 2303, 2926, 2237,
289 525, 735, 863, 2768, 1230, 2572, 556, 3010,
290 2266, 1684, 1239, 780, 2954, 109, 1292, 1031,
291 1745, 2688, 3061, 992, 2596, 941, 892, 1021,
292 2390, 642, 1868, 2377, 1482, 1540, 540, 1678,
293 1626, 279, 314, 1173, 2573, 3096, 48, 667,
294 1920, 2229, 1041, 2606, 1692, 680, 2746, 568,
295 3312, 2419, 2102, 219, 855, 2681, 1848, 712,
296 682, 927, 1795, 461, 1891, 2877, 2522, 1894,
297 1010, 1414, 2009, 3296, 464, 2697, 816, 1352,
298 2679, 1274, 1052, 1025, 2132, 1573, 76, 2998,
299 3040, 2508, 1355, 450, 936, 447, 2794, 1235,
300 1903, 1996, 1089, 3273, 283, 1853, 1990, 882,
301 3033, 1583, 2760, 69, 543, 2532, 3136, 1410,
302 2267, 2481, 1432, 2699, 687, 40, 749, 1600,
303 };
304
305 /*
306 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
307 * or else Python:
308 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
309 */
310 static const uint16_t kModRoots[128] = {
311 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
312 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
313 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
314 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
315 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
316 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
317 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
318 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
319 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
320 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
321 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
322 };
323
324 /*
325 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
326 * output to |out|. If the |md| specifies a fixed-output function, like
327 * SHA3-256, then |outlen| must be the correct length for that function.
328 */
329 static __owur
single_keccak(uint8_t * out,size_t outlen,const uint8_t * in,size_t inlen,EVP_MD_CTX * mdctx)330 int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
331 EVP_MD_CTX *mdctx)
332 {
333 unsigned int sz = (unsigned int) outlen;
334
335 if (!EVP_DigestUpdate(mdctx, in, inlen))
336 return 0;
337 if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
338 return EVP_DigestFinalXOF(mdctx, out, outlen);
339 return EVP_DigestFinal_ex(mdctx, out, &sz)
340 && ossl_assert((size_t) sz == outlen);
341 }
342
343 /*
344 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
345 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
346 */
347 static __owur
prf(uint8_t * out,size_t len,const uint8_t in[ML_KEM_RANDOM_BYTES+1],EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)348 int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
349 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
350 {
351 return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
352 && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
353 }
354
355 /*
356 * FIPS 203, Section 4.1, equation (4.4): H. SHA3-256 hash of a variable
357 * length input, producing 32 bytes of output.
358 */
359 static __owur
hash_h(uint8_t out[ML_KEM_PKHASH_BYTES],const uint8_t * in,size_t len,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)360 int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
361 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
362 {
363 return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
364 && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
365 }
366
367 /* Incremental hash_h of expanded public key */
368 static int
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],EVP_MD_CTX * mdctx,ML_KEM_KEY * key)369 hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
370 EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
371 {
372 const ML_KEM_VINFO *vinfo = key->vinfo;
373 const scalar *t = key->t, *end = t + vinfo->rank;
374 unsigned int sz;
375
376 if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
377 return 0;
378
379 do {
380 uint8_t buf[3 * DEGREE / 2];
381
382 scalar_encode(buf, t++, 12);
383 if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
384 return 0;
385 } while (t < end);
386
387 if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
388 return 0;
389 return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
390 && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
391 }
392
393 /*
394 * FIPS 203, Section 4.1, equation (4.5): G. SHA3-512 hash of a variable
395 * length input, producing 64 bytes of output, in particular the seeds
396 * (d,z) for key generation.
397 */
398 static __owur
hash_g(uint8_t out[ML_KEM_SEED_BYTES],const uint8_t * in,size_t len,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)399 int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
400 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
401 {
402 return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
403 && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
404 }
405
406 /*
407 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
408 * input to compute a 32-byte implicit rejection shared secret, of the same
409 * length as the expected shared secret. (Computed even on success to avoid
410 * side-channel leaks).
411 */
412 static __owur
kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],const uint8_t z[ML_KEM_RANDOM_BYTES],const uint8_t * ctext,size_t len,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)413 int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
414 const uint8_t z[ML_KEM_RANDOM_BYTES],
415 const uint8_t *ctext, size_t len,
416 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
417 {
418 return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
419 && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
420 && EVP_DigestUpdate(mdctx, ctext, len)
421 && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
422 }
423
424 /*
425 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
426 * are performed by the caller). Rejection-samples a Keccak stream to get
427 * uniformly distributed elements in the range [0,q). This is used for matrix
428 * expansion and only operates on public inputs.
429 */
430 static __owur
sample_scalar(scalar * out,EVP_MD_CTX * mdctx)431 int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
432 {
433 uint16_t *curr = out->c, *endout = curr + DEGREE;
434 uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
435 uint8_t *endin = buf + sizeof(buf);
436 uint16_t d;
437 uint8_t b1, b2, b3;
438
439 do {
440 if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
441 return 0;
442 do {
443 b1 = *in++;
444 b2 = *in++;
445 b3 = *in++;
446
447 if (curr >= endout)
448 break;
449 if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
450 *curr++ = d;
451 if (curr >= endout)
452 break;
453 if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
454 *curr++ = d;
455 } while (in < endin);
456 } while (curr < endout);
457 return 1;
458 }
459
460 /*-
461 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
462 *
463 * Subtract |q| if the input is larger, without exposing a side-channel,
464 * avoiding the "clangover" attack. See |constish_time_non_zero| for a
465 * discussion on why the value barrier is by default omitted.
466 */
reduce_once(uint16_t x)467 static __owur uint16_t reduce_once(uint16_t x)
468 {
469 const uint16_t subtracted = x - kPrime;
470 uint16_t mask = constish_time_non_zero(subtracted >> 15);
471
472 return (mask & x) | (~mask & subtracted);
473 }
474
475 /*
476 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
477 * than kPrime + 2 * kPrime^2. This is sufficient to reduce a product of
478 * two already reduced u_int16 values, in fact it is sufficient for each
479 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
480 */
reduce(uint32_t x)481 static __owur uint16_t reduce(uint32_t x)
482 {
483 uint64_t product = (uint64_t)x * kBarrettMultiplier;
484 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
485 uint32_t remainder = x - quotient * kPrime;
486
487 return reduce_once(remainder);
488 }
489
490 /* Multiply a scalar by a constant. */
scalar_mult_const(scalar * s,uint16_t a)491 static void scalar_mult_const(scalar *s, uint16_t a)
492 {
493 uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
494
495 do {
496 tmp = reduce(*curr * a);
497 *curr++ = tmp;
498 } while (curr < end);
499 }
500
501 /*-
502 * FIPS 203, Section 4.3, Algoritm 9: "NTT".
503 * In-place number theoretic transform of a given scalar. Note that ML-KEM's
504 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
505 * off the last iteration of the usual FFT code, with the 128 relevant roots of
506 * unity being stored in NTTRoots. This means the output should be seen as 128
507 * elements in GF(3329^2), with the coefficients of the elements being
508 * consecutive entries in |s->c|.
509 */
scalar_ntt(scalar * s)510 static void scalar_ntt(scalar *s)
511 {
512 const uint16_t *roots = kNTTRoots;
513 uint16_t *end = s->c + DEGREE;
514 int offset = DEGREE / 2;
515
516 do {
517 uint16_t *curr = s->c, *peer;
518
519 do {
520 uint16_t *pause = curr + offset, even, odd;
521 uint32_t zeta = *++roots;
522
523 peer = pause;
524 do {
525 even = *curr;
526 odd = reduce(*peer * zeta);
527 *peer++ = reduce_once(even - odd + kPrime);
528 *curr++ = reduce_once(odd + even);
529 } while (curr < pause);
530 } while ((curr = peer) < end);
531 } while ((offset >>= 1) >= 2);
532 }
533
534 /*-
535 * FIPS 203, Section 4.3, Algoritm 10: "NTT^(-1)".
536 * In-place inverse number theoretic transform of a given scalar, with pairs of
537 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
538 * the number theoretic transform, this leaves off the first step of the normal
539 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
540 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
541 */
scalar_inverse_ntt(scalar * s)542 static void scalar_inverse_ntt(scalar *s)
543 {
544 const uint16_t *roots = kInverseNTTRoots;
545 uint16_t *end = s->c + DEGREE;
546 int offset = 2;
547
548 do {
549 uint16_t *curr = s->c, *peer;
550
551 do {
552 uint16_t *pause = curr + offset, even, odd;
553 uint32_t zeta = *++roots;
554
555 peer = pause;
556 do {
557 even = *curr;
558 odd = *peer;
559 *peer++ = reduce(zeta * (even - odd + kPrime));
560 *curr++ = reduce_once(odd + even);
561 } while (curr < pause);
562 } while ((curr = peer) < end);
563 } while ((offset <<= 1) < DEGREE);
564 scalar_mult_const(s, kInverseDegree);
565 }
566
567 /* Addition updating the LHS scalar in-place. */
scalar_add(scalar * lhs,const scalar * rhs)568 static void scalar_add(scalar *lhs, const scalar *rhs)
569 {
570 int i;
571
572 for (i = 0; i < DEGREE; i++)
573 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
574 }
575
576 /* Subtraction updating the LHS scalar in-place. */
scalar_sub(scalar * lhs,const scalar * rhs)577 static void scalar_sub(scalar *lhs, const scalar *rhs)
578 {
579 int i;
580
581 for (i = 0; i < DEGREE; i++)
582 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
583 }
584
585 /*
586 * Multiplying two scalars in the number theoretically transformed state. Since
587 * 3329 does not have a 512th root of unity, this means we have to interpret
588 * the 2*ith and (2*i+1)th entries of the scalar as elements of
589 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
590 *
591 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
592 * ModRoots table. Note that our Barrett transform only allows us to multipy
593 * two reduced numbers together, so we need some intermediate reduction steps,
594 * even if an uint64_t could hold 3 multiplied numbers.
595 */
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)596 static void scalar_mult(scalar *out, const scalar *lhs,
597 const scalar *rhs)
598 {
599 uint16_t *curr = out->c, *end = curr + DEGREE;
600 const uint16_t *lc = lhs->c, *rc = rhs->c;
601 const uint16_t *roots = kModRoots;
602
603 do {
604 uint32_t l0 = *lc++, r0 = *rc++;
605 uint32_t l1 = *lc++, r1 = *rc++;
606 uint32_t zetapow = *roots++;
607
608 *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
609 *curr++ = reduce(l0 * r1 + l1 * r0);
610 } while (curr < end);
611 }
612
613 /* Above, but add the result to an existing scalar */
614 static ossl_inline
scalar_mult_add(scalar * out,const scalar * lhs,const scalar * rhs)615 void scalar_mult_add(scalar *out, const scalar *lhs,
616 const scalar *rhs)
617 {
618 uint16_t *curr = out->c, *end = curr + DEGREE;
619 const uint16_t *lc = lhs->c, *rc = rhs->c;
620 const uint16_t *roots = kModRoots;
621
622 do {
623 uint32_t l0 = *lc++, r0 = *rc++;
624 uint32_t l1 = *lc++, r1 = *rc++;
625 uint16_t *c0 = curr++;
626 uint16_t *c1 = curr++;
627 uint32_t zetapow = *roots++;
628
629 *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
630 *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
631 } while (curr < end);
632 }
633
634 /*-
635 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
636 * Here |bits| is |d|. For efficiency, we handle the d=1 case separately.
637 */
scalar_encode(uint8_t * out,const scalar * s,int bits)638 static void scalar_encode(uint8_t *out, const scalar *s, int bits)
639 {
640 const uint16_t *curr = s->c, *end = curr + DEGREE;
641 uint64_t accum = 0, element;
642 int used = 0;
643
644 do {
645 element = *curr++;
646 if (used + bits < 64) {
647 accum |= element << used;
648 used += bits;
649 } else if (used + bits > 64) {
650 out = OPENSSL_store_u64_le(out, accum | (element << used));
651 accum = element >> (64 - used);
652 used = (used + bits) - 64;
653 } else {
654 out = OPENSSL_store_u64_le(out, accum | (element << used));
655 accum = 0;
656 used = 0;
657 }
658 } while (curr < end);
659 }
660
661 /*
662 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
663 */
scalar_encode_1(uint8_t out[DEGREE/8],const scalar * s)664 static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
665 {
666 int i, j;
667 uint8_t out_byte;
668
669 for (i = 0; i < DEGREE; i += 8) {
670 out_byte = 0;
671 for (j = 0; j < 8; j++)
672 out_byte |= bit0(s->c[i + j]) << j;
673 *out = out_byte;
674 out++;
675 }
676 }
677
678 /*-
679 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
680 * Here |bits| is |d|. For efficiency, we handle the d=1 and d=12 cases
681 * separately.
682 *
683 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
684 * |out|.
685 */
scalar_decode(scalar * out,const uint8_t * in,int bits)686 static void scalar_decode(scalar *out, const uint8_t *in, int bits)
687 {
688 uint16_t *curr = out->c, *end = curr + DEGREE;
689 uint64_t accum = 0;
690 int accum_bits = 0, todo = bits;
691 uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
692 uint16_t element = 0;
693
694 do {
695 if (accum_bits == 0) {
696 in = OPENSSL_load_u64_le(&accum, in);
697 accum_bits = 64;
698 }
699 if (todo == bits && accum_bits >= bits) {
700 /* No partial "element", and all the required bits available */
701 *curr++ = ((uint16_t) accum) & mask;
702 accum >>= bits;
703 accum_bits -= bits;
704 } else if (accum_bits >= todo) {
705 /* A partial "element", and all the required bits available */
706 *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
707 accum >>= todo;
708 accum_bits -= todo;
709 element = 0;
710 todo = bits;
711 mask = bitmask;
712 } else {
713 /*
714 * Only some of the requisite bits accumulated, store |accum_bits|
715 * of these in |element|. The accumulated bitcount becomes 0, but
716 * as soon as we have more bits we'll want to merge accum_bits
717 * fewer of them into the final |element|.
718 *
719 * Note that with a 64-bit accumulator and |bits| always 12 or
720 * less, if we're here, the previous iteration had all the
721 * requisite bits, and so there are no kept bits in |element|.
722 */
723 element = ((uint16_t) accum) & mask;
724 todo -= accum_bits;
725 mask = bitmask >> accum_bits;
726 accum_bits = 0;
727 }
728 } while (curr < end);
729 }
730
731 static __owur
scalar_decode_12(scalar * out,const uint8_t in[3* DEGREE/2])732 int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
733 {
734 int i;
735 uint16_t *c = out->c;
736
737 for (i = 0; i < DEGREE / 2; ++i) {
738 uint8_t b1 = *in++;
739 uint8_t b2 = *in++;
740 uint8_t b3 = *in++;
741 int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
742 int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
743
744 if (outOfRange1 | outOfRange2)
745 return 0;
746 }
747 return 1;
748 }
749
750 /*-
751 * scalar_decode_decompress_add is a combination of decoding and decompression
752 * both specialised for |bits| == 1, with the result added (and sum reduced) to
753 * the output scalar.
754 *
755 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
756 * A timing leak in a related function in the reference Kyber implementation
757 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
758 * for ML-KEM-512 in minutes, provided the attacker has access to precise
759 * timing of a CPU performing chosen-ciphertext decap. Admittedly this is only
760 * a risk when private keys are reused (perhaps KEMTLS servers).
761 */
762 static void
scalar_decode_decompress_add(scalar * out,const uint8_t in[DEGREE/8])763 scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
764 {
765 static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
766 uint16_t *curr = out->c, *end = curr + DEGREE;
767 uint16_t mask;
768 uint8_t b;
769
770 /*
771 * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
772 * avoiding the "clangover" attack. See |constish_time_non_zero| for a
773 * discussion on why the value barrier is by default omitted.
774 */
775 #define decode_decompress_add_bit \
776 mask = constish_time_non_zero(bit0(b)); \
777 *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
778 curr++; \
779 b >>= 1
780
781 /* Unrolled to process each byte in one iteration */
782 do {
783 b = *in++;
784 decode_decompress_add_bit;
785 decode_decompress_add_bit;
786 decode_decompress_add_bit;
787 decode_decompress_add_bit;
788
789 decode_decompress_add_bit;
790 decode_decompress_add_bit;
791 decode_decompress_add_bit;
792 decode_decompress_add_bit;
793 } while (curr < end);
794 #undef decode_decompress_add_bit
795 }
796
797 /*
798 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
799 *
800 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
801 * numbers close to each other together. The formula used is
802 * round(2^|bits|/kPrime*x) mod 2^|bits|.
803 * Uses Barrett reduction to achieve constant time. Since we need both the
804 * remainder (for rounding) and the quotient (as the result), we cannot use
805 * |reduce| here, but need to do the Barrett reduction directly.
806 */
compress(uint16_t x,int bits)807 static __owur uint16_t compress(uint16_t x, int bits)
808 {
809 uint32_t shifted = (uint32_t)x << bits;
810 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
811 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
812 uint32_t remainder = shifted - quotient * kPrime;
813
814 /*
815 * Adjust the quotient to round correctly:
816 * 0 <= remainder <= kHalfPrime round to 0
817 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
818 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
819 */
820 quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
821 quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
822 return quotient & ((1 << bits) - 1);
823 }
824
825 /*
826 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
827
828 * Decompresses |x| by using a close equi-distant representative. The formula
829 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
830 * to implement this logic using only bit operations.
831 */
decompress(uint16_t x,int bits)832 static __owur uint16_t decompress(uint16_t x, int bits)
833 {
834 uint32_t product = (uint32_t)x * kPrime;
835 uint32_t power = 1 << bits;
836 /* This is |product| % power, since |power| is a power of 2. */
837 uint32_t remainder = product & (power - 1);
838 /* This is |product| / power, since |power| is a power of 2. */
839 uint32_t lower = product >> bits;
840
841 /*
842 * The rounding logic works since the first half of numbers mod |power|
843 * have a 0 as first bit, and the second half has a 1 as first bit, since
844 * |power| is a power of 2. As a 12 bit number, |remainder| is always
845 * positive, so we will shift in 0s for a right shift.
846 */
847 return lower + (remainder >> (bits - 1));
848 }
849
850 /*-
851 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
852 * In-place lossy rounding of scalars to 2^d bits.
853 */
scalar_compress(scalar * s,int bits)854 static void scalar_compress(scalar *s, int bits)
855 {
856 int i;
857
858 for (i = 0; i < DEGREE; i++)
859 s->c[i] = compress(s->c[i], bits);
860 }
861
862 /*
863 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
864 * In-place approximate recovery of scalars from 2^d bit compression.
865 */
scalar_decompress(scalar * s,int bits)866 static void scalar_decompress(scalar *s, int bits)
867 {
868 int i;
869
870 for (i = 0; i < DEGREE; i++)
871 s->c[i] = decompress(s->c[i], bits);
872 }
873
874 /* Addition updating the LHS vector in-place. */
vector_add(scalar * lhs,const scalar * rhs,int rank)875 static void vector_add(scalar *lhs, const scalar *rhs, int rank)
876 {
877 do {
878 scalar_add(lhs++, rhs++);
879 } while (--rank > 0);
880 }
881
882 /*
883 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
884 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
885 * whole number of bytes, so we do not need to worry about bit packing here.
886 */
vector_encode(uint8_t * out,const scalar * a,int bits,int rank)887 static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
888 {
889 int stride = bits * DEGREE / 8;
890
891 for (; rank-- > 0; out += stride)
892 scalar_encode(out, a++, bits);
893 }
894
895 /*
896 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
897 * if any parsed value is >= |ML_KEM_PRIME|. The resulting scalars are
898 * then decompressed and transformed via the NTT.
899 *
900 * Note: Used only in decrypt_cpa(), which returns void and so does not check
901 * the return value of this function. Side-channels are fine when the input
902 * ciphertext to decap() is simply syntactically invalid.
903 */
904 static void
vector_decode_decompress_ntt(scalar * out,const uint8_t * in,int bits,int rank)905 vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
906 {
907 int stride = bits * DEGREE / 8;
908
909 for (; rank-- > 0; in += stride, ++out) {
910 scalar_decode(out, in, bits);
911 scalar_decompress(out, bits);
912 scalar_ntt(out);
913 }
914 }
915
916 /* vector_decode(), specialised to bits == 12. */
917 static __owur
vector_decode_12(scalar * out,const uint8_t in[3* DEGREE/2],int rank)918 int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
919 {
920 int stride = 3 * DEGREE / 2;
921
922 for (; rank-- > 0; in += stride)
923 if (!scalar_decode_12(out++, in))
924 return 0;
925 return 1;
926 }
927
928 /* In-place compression of each scalar component */
vector_compress(scalar * a,int bits,int rank)929 static void vector_compress(scalar *a, int bits, int rank)
930 {
931 do {
932 scalar_compress(a++, bits);
933 } while (--rank > 0);
934 }
935
936 /* The output scalar must not overlap with the inputs */
inner_product(scalar * out,const scalar * lhs,const scalar * rhs,int rank)937 static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
938 int rank)
939 {
940 scalar_mult(out, lhs, rhs);
941 while (--rank > 0)
942 scalar_mult_add(out, ++lhs, ++rhs);
943 }
944
945 /*
946 * Here, the output vector must not overlap with the inputs, the result is
947 * directly subjected to inverse NTT.
948 */
949 static void
matrix_mult_intt(scalar * out,const scalar * m,const scalar * a,int rank)950 matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
951 {
952 const scalar *ar;
953 int i, j;
954
955 for (i = rank; i-- > 0; ++out) {
956 scalar_mult(out, m++, ar = a);
957 for (j = rank - 1; j > 0; --j)
958 scalar_mult_add(out, m++, ++ar);
959 scalar_inverse_ntt(out);
960 }
961 }
962
963 /* Here, the output vector must not overlap with the inputs */
964 static void
matrix_mult_transpose_add(scalar * out,const scalar * m,const scalar * a,int rank)965 matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
966 {
967 const scalar *mc = m, *mr, *ar;
968 int i, j;
969
970 for (i = rank; i-- > 0; ++out) {
971 scalar_mult_add(out, mr = mc++, ar = a);
972 for (j = rank; --j > 0; )
973 scalar_mult_add(out, (mr += rank), ++ar);
974 }
975 }
976
977 /*-
978 * Expands the matrix from a seed for key generation and for encaps-CPA.
979 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
980 * by appending the (i,j) indices to the seed in the opposite order!
981 *
982 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
983 */
984 static __owur
matrix_expand(EVP_MD_CTX * mdctx,ML_KEM_KEY * key)985 int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
986 {
987 scalar *out = key->m;
988 uint8_t input[ML_KEM_RANDOM_BYTES + 2];
989 int rank = key->vinfo->rank;
990 int i, j;
991
992 memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
993 for (i = 0; i < rank; i++) {
994 for (j = 0; j < rank; j++) {
995 input[ML_KEM_RANDOM_BYTES] = i;
996 input[ML_KEM_RANDOM_BYTES + 1] = j;
997 if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
998 || !EVP_DigestUpdate(mdctx, input, sizeof(input))
999 || !sample_scalar(out++, mdctx))
1000 return 0;
1001 }
1002 }
1003 return 1;
1004 }
1005
1006 /*
1007 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1008 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1009 * and setting the coefficient to the count of the first bits minus the count of
1010 * the second bits, resulting in a centered binomial distribution. Since eta is
1011 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1012 * and 0 with probability 3/8.
1013 */
1014 static __owur
cbd_2(scalar * out,uint8_t in[ML_KEM_RANDOM_BYTES+1],EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1015 int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1016 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1017 {
1018 uint16_t *curr = out->c, *end = curr + DEGREE;
1019 uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1020 uint16_t value, mask;
1021 uint8_t b;
1022
1023 if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1024 return 0;
1025
1026 do {
1027 b = *r++;
1028
1029 /*
1030 * Add |kPrime| if |value| underflowed. See |constish_time_non_zero|
1031 * for a discussion on why the value barrier is by default omitted.
1032 * While this could have been written reduce_once(value + kPrime), this
1033 * is one extra addition and small range of |value| tempts some
1034 * versions of Clang to emit a branch.
1035 */
1036 value = bit0(b) + bitn(1, b);
1037 value -= bitn(2, b) + bitn(3, b);
1038 mask = constish_time_non_zero(value >> 15);
1039 *curr++ = value + (kPrime & mask);
1040
1041 value = bitn(4, b) + bitn(5, b);
1042 value -= bitn(6, b) + bitn(7, b);
1043 mask = constish_time_non_zero(value >> 15);
1044 *curr++ = value + (kPrime & mask);
1045 } while (curr < end);
1046 return 1;
1047 }
1048
1049 /*
1050 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1051 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1052 * and setting the coefficient to the count of the first bits minus the count of
1053 * the second bits, resulting in a centered binomial distribution.
1054 */
1055 static __owur
cbd_3(scalar * out,uint8_t in[ML_KEM_RANDOM_BYTES+1],EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1056 int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1057 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1058 {
1059 uint16_t *curr = out->c, *end = curr + DEGREE;
1060 uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1061 uint8_t b1, b2, b3;
1062 uint16_t value, mask;
1063
1064 if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1065 return 0;
1066
1067 do {
1068 b1 = *r++;
1069 b2 = *r++;
1070 b3 = *r++;
1071
1072 /*
1073 * Add |kPrime| if |value| underflowed. See |constish_time_non_zero|
1074 * for a discussion on why the value barrier is by default omitted.
1075 * While this could have been written reduce_once(value + kPrime), this
1076 * is one extra addition and small range of |value| tempts some
1077 * versions of Clang to emit a branch.
1078 */
1079 value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1080 value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1081 mask = constish_time_non_zero(value >> 15);
1082 *curr++ = value + (kPrime & mask);
1083
1084 value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1085 value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1086 mask = constish_time_non_zero(value >> 15);
1087 *curr++ = value + (kPrime & mask);
1088
1089 value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1090 value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1091 mask = constish_time_non_zero(value >> 15);
1092 *curr++ = value + (kPrime & mask);
1093
1094 value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1095 value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1096 mask = constish_time_non_zero(value >> 15);
1097 *curr++ = value + (kPrime & mask);
1098 } while (curr < end);
1099 return 1;
1100 }
1101
1102 /*
1103 * Generates a secret vector by using |cbd| with the given seed to generate
1104 * scalar elements and incrementing |counter| for each slot of the vector.
1105 */
1106 static __owur
gencbd_vector(scalar * out,CBD_FUNC cbd,uint8_t * counter,const uint8_t seed[ML_KEM_RANDOM_BYTES],int rank,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1107 int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1108 const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1109 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1110 {
1111 uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1112
1113 memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1114 do {
1115 input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1116 if (!cbd(out++, input, mdctx, key))
1117 return 0;
1118 } while (--rank > 0);
1119 return 1;
1120 }
1121
1122 /*
1123 * As above plus NTT transform.
1124 */
1125 static __owur
gencbd_vector_ntt(scalar * out,CBD_FUNC cbd,uint8_t * counter,const uint8_t seed[ML_KEM_RANDOM_BYTES],int rank,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1126 int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1127 const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1128 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1129 {
1130 uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1131
1132 memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1133 do {
1134 input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1135 if (!cbd(out, input, mdctx, key))
1136 return 0;
1137 scalar_ntt(out++);
1138 } while (--rank > 0);
1139 return 1;
1140 }
1141
1142 /* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1143 #define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1144
1145 /*
1146 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1147 *
1148 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1149 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1150 * secure scheme, since lattice schemes are vulnerable to decryption failure
1151 * oracles.
1152 *
1153 * The steps are re-ordered to make more efficient/localised use of storage.
1154 *
1155 * Note also that the input public key is assumed to hold a precomputed matrix
1156 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1157 * coefficient) key->t vector).
1158 *
1159 * Caller passes storage in |tmp| for for two temporary vectors.
1160 */
1161 static __owur
encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],const uint8_t message[DEGREE/8],const uint8_t r[ML_KEM_RANDOM_BYTES],scalar * tmp,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1162 int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1163 const uint8_t message[DEGREE / 8],
1164 const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1165 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1166 {
1167 const ML_KEM_VINFO *vinfo = key->vinfo;
1168 CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1169 int rank = vinfo->rank;
1170 /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1171 scalar *y = &tmp[0], *e1 = y, *e2 = y;
1172 /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1173 scalar *u = &tmp[rank];
1174 scalar v;
1175 uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1176 uint8_t counter = 0;
1177 int du = vinfo->du;
1178 int dv = vinfo->dv;
1179
1180 /* FIPS 203 "y" vector */
1181 if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1182 return 0;
1183 /* FIPS 203 "v" scalar */
1184 inner_product(&v, key->t, y, rank);
1185 scalar_inverse_ntt(&v);
1186 /* FIPS 203 "u" vector */
1187 matrix_mult_intt(u, key->m, y, rank);
1188
1189 /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1190 if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1191 return 0;
1192 vector_add(u, e1, rank);
1193 vector_compress(u, du, rank);
1194 vector_encode(out, u, du, rank);
1195
1196 /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1197 memcpy(input, r, ML_KEM_RANDOM_BYTES);
1198 input[ML_KEM_RANDOM_BYTES] = counter;
1199 if (!cbd_2(e2, input, mdctx, key))
1200 return 0;
1201 scalar_add(&v, e2);
1202
1203 /* Combine message with |v| */
1204 scalar_decode_decompress_add(&v, message);
1205 scalar_compress(&v, dv);
1206 scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1207 return 1;
1208 }
1209
1210 /*
1211 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1212 */
1213 static void
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],const uint8_t * ctext,scalar * u,const ML_KEM_KEY * key)1214 decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1215 const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1216 {
1217 const ML_KEM_VINFO *vinfo = key->vinfo;
1218 scalar v, mask;
1219 int rank = vinfo->rank;
1220 int du = vinfo->du;
1221 int dv = vinfo->dv;
1222
1223 vector_decode_decompress_ntt(u, ctext, du, rank);
1224 scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1225 scalar_decompress(&v, dv);
1226 inner_product(&mask, key->s, u, rank);
1227 scalar_inverse_ntt(&mask);
1228 scalar_sub(&v, &mask);
1229 scalar_compress(&v, 1);
1230 scalar_encode_1(out, &v);
1231 }
1232
1233 /*-
1234 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1235 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1236 *
1237 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1238 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1239 * wire-format of an ML-KEM public key.
1240 */
encode_pubkey(uint8_t * out,const ML_KEM_KEY * key)1241 static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1242 {
1243 const uint8_t *rho = key->rho;
1244 const ML_KEM_VINFO *vinfo = key->vinfo;
1245
1246 vector_encode(out, key->t, 12, vinfo->rank);
1247 memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1248 }
1249
1250 /*-
1251 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1252 *
1253 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1254 * This matches the input format of parse_prvkey() below.
1255 */
encode_prvkey(uint8_t * out,const ML_KEM_KEY * key)1256 static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1257 {
1258 const ML_KEM_VINFO *vinfo = key->vinfo;
1259
1260 vector_encode(out, key->s, 12, vinfo->rank);
1261 out += vinfo->vector_bytes;
1262 encode_pubkey(out, key);
1263 out += vinfo->pubkey_bytes;
1264 memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1265 out += ML_KEM_PKHASH_BYTES;
1266 memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1267 }
1268
1269 /*-
1270 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1271 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1272 *
1273 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1274 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1275 * wire-format of the ML-KEM public key.
1276 */
parse_pubkey(const uint8_t * in,EVP_MD_CTX * mdctx,ML_KEM_KEY * key)1277 static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1278 {
1279 const ML_KEM_VINFO *vinfo = key->vinfo;
1280
1281 /* Decode and check |t| */
1282 if (!vector_decode_12(key->t, in, vinfo->rank)) {
1283 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1284 "%s invalid public 't' vector",
1285 vinfo->algorithm_name);
1286 return 0;
1287 }
1288 /* Save the matrix |m| recovery seed |rho| */
1289 memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1290 /*
1291 * Pre-compute the public key hash, needed for both encap and decap.
1292 * Also pre-compute the matrix expansion, stored with the public key.
1293 */
1294 if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1295 || !matrix_expand(mdctx, key)) {
1296 ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1297 "internal error while parsing %s public key",
1298 vinfo->algorithm_name);
1299 return 0;
1300 }
1301 return 1;
1302 }
1303
1304 /*
1305 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1306 *
1307 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1308 * This matches the output format of encode_prvkey() above.
1309 */
parse_prvkey(const uint8_t * in,EVP_MD_CTX * mdctx,ML_KEM_KEY * key)1310 static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1311 {
1312 const ML_KEM_VINFO *vinfo = key->vinfo;
1313
1314 /* Decode and check |s|. */
1315 if (!vector_decode_12(key->s, in, vinfo->rank)) {
1316 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1317 "%s invalid private 's' vector",
1318 vinfo->algorithm_name);
1319 return 0;
1320 }
1321 in += vinfo->vector_bytes;
1322
1323 if (!parse_pubkey(in, mdctx, key))
1324 return 0;
1325 in += vinfo->pubkey_bytes;
1326
1327 /* Check public key hash. */
1328 if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1329 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1330 "%s public key hash mismatch",
1331 vinfo->algorithm_name);
1332 return 0;
1333 }
1334 in += ML_KEM_PKHASH_BYTES;
1335
1336 memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1337 return 1;
1338 }
1339
1340 /*
1341 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1342 *
1343 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1344 * inlined.
1345 *
1346 * The caller MUST pass a pre-allocated digest context that is not shared with
1347 * any concurrent computation.
1348 *
1349 * This function optionally outputs the serialised wire-form |ek| public key
1350 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1351 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1352 * have preallocated space for these).
1353 *
1354 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1355 * domain separation. These are concatenated and hashed to produce a pair of
1356 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1357 * used to generate the secret vector |s|.
1358 *
1359 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1360 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1361 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1362 * constant time, with no side channel leaks, on all well-formed (valid length,
1363 * and correctly encoded) ciphertext inputs.
1364 */
1365 static __owur
genkey(const uint8_t seed[ML_KEM_SEED_BYTES],EVP_MD_CTX * mdctx,uint8_t * pubenc,ML_KEM_KEY * key)1366 int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1367 EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1368 {
1369 uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1370 const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1371 uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1372 const ML_KEM_VINFO *vinfo = key->vinfo;
1373 CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1374 int rank = vinfo->rank;
1375 uint8_t counter = 0;
1376 int ret = 0;
1377
1378 /*
1379 * Use the "d" seed salted with the rank to derive the public and private
1380 * seeds rho and sigma.
1381 */
1382 memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1383 augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t) rank;
1384 if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1385 goto end;
1386 memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1387 /* The |rho| matrix seed is public */
1388 CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1389
1390 /* FIPS 203 |e| vector is initial value of key->t */
1391 if (!matrix_expand(mdctx, key)
1392 || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1393 || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1394 goto end;
1395
1396 /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1397 matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1398 /* The |t| vector is public */
1399 CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1400
1401 if (pubenc == NULL) {
1402 /* Incremental digest of public key without in-full serialisation. */
1403 if (!hash_h_pubkey(key->pkhash, mdctx, key))
1404 goto end;
1405 } else {
1406 encode_pubkey(pubenc, key);
1407 if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1408 goto end;
1409 }
1410
1411 /* Save |z| portion of seed for "implicit rejection" on failure. */
1412 memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1413
1414 /* Optionally save the |d| portion of the seed */
1415 key->d = key->z + ML_KEM_RANDOM_BYTES;
1416 if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1417 memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1418 } else {
1419 OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1420 key->d = NULL;
1421 }
1422
1423 ret = 1;
1424 end:
1425 OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1426 OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1427 if (ret == 0) {
1428 ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1429 "internal error while generating %s private key",
1430 vinfo->algorithm_name);
1431 }
1432 return ret;
1433 }
1434
1435 /*-
1436 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1437 * This is the deterministic version with randomness supplied externally.
1438 *
1439 * The caller must pass space for two vectors in |tmp|.
1440 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1441 * of the provided key.
1442 */
1443 static
encap(uint8_t * ctext,uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],const uint8_t entropy[ML_KEM_RANDOM_BYTES],scalar * tmp,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1444 int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1445 const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1446 scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1447 {
1448 uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1449 uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1450 uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1451 int ret;
1452
1453 memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1454 memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1455 ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1456 && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1457 OPENSSL_cleanse((void *)input, sizeof(input));
1458
1459 if (ret)
1460 memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1461 else
1462 ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1463 "internal error while performing %s encapsulation",
1464 key->vinfo->algorithm_name);
1465 return ret;
1466 }
1467
1468 /*
1469 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1470 *
1471 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1472 * deterministic, the randomness for the FO transform is extracted during
1473 * private key generation.
1474 *
1475 * The caller must pass space for two vectors in |tmp|.
1476 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1477 * of the key's ML-KEM variant.
1478 */
1479 static
decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],const uint8_t * ctext,uint8_t * tmp_ctext,scalar * tmp,EVP_MD_CTX * mdctx,const ML_KEM_KEY * key)1480 int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1481 const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1482 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1483 {
1484 uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1485 uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1486 uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1487 uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1488 const uint8_t *pkhash = key->pkhash;
1489 const ML_KEM_VINFO *vinfo = key->vinfo;
1490 int i;
1491 uint8_t mask;
1492
1493 /*
1494 * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1495 * any further errors, returning success, and whatever we got for a shared
1496 * secret. The decrypt_cpa() function is just arithmetic on secret data,
1497 * so should not be subject to failure that makes its output predictable.
1498 *
1499 * We guard against "should never happen" catastrophic failure of the
1500 * "pure" function |hash_g| by overwriting the shared secret with the
1501 * content of the failure key and returning early, if nevertheless hash_g
1502 * fails. This is not constant-time, but a failure of |hash_g| already
1503 * implies loss of side-channel resistance.
1504 *
1505 * The same action is taken, if also |encrypt_cpa| should catastrophically
1506 * fail, due to failure of the |PRF| underlying the CBD functions.
1507 */
1508 if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1509 ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1510 "internal error while performing %s decapsulation",
1511 vinfo->algorithm_name);
1512 return 0;
1513 }
1514 decrypt_cpa(decrypted, ctext, tmp, key);
1515 memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1516 if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1517 || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1518 memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1519 OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1520 return 1;
1521 }
1522 mask = constant_time_eq_int_8(0,
1523 CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1524 for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1525 secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1526 OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1527 OPENSSL_cleanse(Kr, sizeof(Kr));
1528 return 1;
1529 }
1530
1531 /*
1532 * After allocating storage for public or private key data, update the key
1533 * component pointers to reference that storage.
1534 */
1535 static __owur
add_storage(scalar * p,int private,ML_KEM_KEY * key)1536 int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1537 {
1538 int rank = key->vinfo->rank;
1539
1540 if (p == NULL)
1541 return 0;
1542
1543 /*
1544 * We're adding key material, the seed buffer will now hold |rho| and
1545 * |pkhash|.
1546 */
1547 memset(key->seedbuf, 0, sizeof(key->seedbuf));
1548 key->rho = key->seedbuf;
1549 key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1550 key->d = key->z = NULL;
1551
1552 /* A public key needs space for |t| and |m| */
1553 key->m = (key->t = p) + rank;
1554
1555 /*
1556 * A private key also needs space for |s| and |z|.
1557 * The |z| buffer always includes additional space for |d|, but a key's |d|
1558 * pointer is left NULL when parsed from the NIST format, which omits that
1559 * information. Only keys generated from a (d, z) seed pair will have a
1560 * non-NULL |d| pointer.
1561 */
1562 if (private)
1563 key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1564 return 1;
1565 }
1566
1567 /*
1568 * After freeing the storage associated with a key that failed to be
1569 * constructed, reset the internal pointers back to NULL.
1570 */
1571 void
ossl_ml_kem_key_reset(ML_KEM_KEY * key)1572 ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1573 {
1574 if (key->t == NULL)
1575 return;
1576 /*-
1577 * Cleanse any sensitive data:
1578 * - The private vector |s| is immediately followed by the FO failure
1579 * secret |z|, and seed |d|, we can cleanse all three in one call.
1580 *
1581 * - Otherwise, when key->d is set, cleanse the stashed seed.
1582 */
1583 if (ossl_ml_kem_have_prvkey(key))
1584 OPENSSL_cleanse(key->s,
1585 key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1586 OPENSSL_free(key->t);
1587 key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1588 }
1589
1590 /*
1591 * ----- API exported to the provider
1592 *
1593 * Parameters with an implicit fixed length in the internal static API of each
1594 * variant have an explicit checked length argument at this layer.
1595 */
1596
1597 /* Retrieve the parameters of one of the ML-KEM variants */
ossl_ml_kem_get_vinfo(int evp_type)1598 const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1599 {
1600 switch (evp_type) {
1601 case EVP_PKEY_ML_KEM_512:
1602 return &vinfo_map[ML_KEM_512_VINFO];
1603 case EVP_PKEY_ML_KEM_768:
1604 return &vinfo_map[ML_KEM_768_VINFO];
1605 case EVP_PKEY_ML_KEM_1024:
1606 return &vinfo_map[ML_KEM_1024_VINFO];
1607 }
1608 return NULL;
1609 }
1610
ossl_ml_kem_key_new(OSSL_LIB_CTX * libctx,const char * properties,int evp_type)1611 ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1612 int evp_type)
1613 {
1614 const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1615 ML_KEM_KEY *key;
1616
1617 if (vinfo == NULL) {
1618 ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1619 "unsupported ML-KEM key type: %d", evp_type);
1620 return NULL;
1621 }
1622
1623 if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1624 return NULL;
1625
1626 key->vinfo = vinfo;
1627 key->libctx = libctx;
1628 key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1629 key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1630 key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1631 key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1632 key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1633 key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1634 key->s = key->m = key->t = NULL;
1635
1636 if (key->shake128_md != NULL
1637 && key->shake256_md != NULL
1638 && key->sha3_256_md != NULL
1639 && key->sha3_512_md != NULL)
1640 return key;
1641
1642 ossl_ml_kem_key_free(key);
1643 ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1644 "missing SHA3 digest algorithms while creating %s key",
1645 vinfo->algorithm_name);
1646 return NULL;
1647 }
1648
ossl_ml_kem_key_dup(const ML_KEM_KEY * key,int selection)1649 ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1650 {
1651 int ok = 0;
1652 ML_KEM_KEY *ret;
1653
1654 /*
1655 * Partially decoded keys, not yet imported or loaded, should never be
1656 * duplicated.
1657 */
1658 if (ossl_ml_kem_decoded_key(key))
1659 return NULL;
1660
1661 if (key == NULL
1662 || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1663 return NULL;
1664 ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1665 ret->s = ret->m = ret->t = NULL;
1666
1667 /* Clear selection bits we can't fulfill */
1668 if (!ossl_ml_kem_have_pubkey(key))
1669 selection = 0;
1670 else if (!ossl_ml_kem_have_prvkey(key))
1671 selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1672
1673 switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1674 case 0:
1675 ok = 1;
1676 break;
1677 case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1678 ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1679 ret->rho = ret->seedbuf;
1680 ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1681 break;
1682 case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1683 ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1684 /* Duplicated keys retain |d|, if available */
1685 if (key->d != NULL)
1686 ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1687 break;
1688 }
1689
1690 if (!ok) {
1691 OPENSSL_free(ret);
1692 return NULL;
1693 }
1694
1695 EVP_MD_up_ref(ret->shake128_md);
1696 EVP_MD_up_ref(ret->shake256_md);
1697 EVP_MD_up_ref(ret->sha3_256_md);
1698 EVP_MD_up_ref(ret->sha3_512_md);
1699
1700 return ret;
1701 }
1702
ossl_ml_kem_key_free(ML_KEM_KEY * key)1703 void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1704 {
1705 if (key == NULL)
1706 return;
1707
1708 EVP_MD_free(key->shake128_md);
1709 EVP_MD_free(key->shake256_md);
1710 EVP_MD_free(key->sha3_256_md);
1711 EVP_MD_free(key->sha3_512_md);
1712
1713 if (ossl_ml_kem_decoded_key(key)) {
1714 OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
1715 if (ossl_ml_kem_have_dkenc(key)) {
1716 OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
1717 OPENSSL_free(key->encoded_dk);
1718 }
1719 }
1720 ossl_ml_kem_key_reset(key);
1721 OPENSSL_free(key);
1722 }
1723
1724 /* Serialise the public component of an ML-KEM key */
ossl_ml_kem_encode_public_key(uint8_t * out,size_t len,const ML_KEM_KEY * key)1725 int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1726 const ML_KEM_KEY *key)
1727 {
1728 if (!ossl_ml_kem_have_pubkey(key)
1729 || len != key->vinfo->pubkey_bytes)
1730 return 0;
1731 encode_pubkey(out, key);
1732 return 1;
1733 }
1734
1735 /* Serialise an ML-KEM private key */
ossl_ml_kem_encode_private_key(uint8_t * out,size_t len,const ML_KEM_KEY * key)1736 int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1737 const ML_KEM_KEY *key)
1738 {
1739 if (!ossl_ml_kem_have_prvkey(key)
1740 || len != key->vinfo->prvkey_bytes)
1741 return 0;
1742 encode_prvkey(out, key);
1743 return 1;
1744 }
1745
ossl_ml_kem_encode_seed(uint8_t * out,size_t len,const ML_KEM_KEY * key)1746 int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1747 const ML_KEM_KEY *key)
1748 {
1749 if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1750 return 0;
1751 /*
1752 * Both in the seed buffer, and in the allocated storage, the |d| component
1753 * of the seed is stored last, so we must copy each separately.
1754 */
1755 memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1756 out += ML_KEM_RANDOM_BYTES;
1757 memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1758 return 1;
1759 }
1760
1761 /*
1762 * Stash the seed without (yet) performing a keygen, used during decoding, to
1763 * avoid an extra keygen if we're only going to export the key again to load
1764 * into another provider.
1765 */
ossl_ml_kem_set_seed(const uint8_t * seed,size_t seedlen,ML_KEM_KEY * key)1766 ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1767 {
1768 if (key == NULL
1769 || ossl_ml_kem_have_pubkey(key)
1770 || ossl_ml_kem_have_seed(key)
1771 || seedlen != ML_KEM_SEED_BYTES)
1772 return NULL;
1773 /*
1774 * With no public or private key material on hand, we can use the seed
1775 * buffer for |z| and |d|, in that order.
1776 */
1777 key->z = key->seedbuf;
1778 key->d = key->z + ML_KEM_RANDOM_BYTES;
1779 memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1780 seed += ML_KEM_RANDOM_BYTES;
1781 memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1782 return key;
1783 }
1784
1785 /* Parse input as a public key */
ossl_ml_kem_parse_public_key(const uint8_t * in,size_t len,ML_KEM_KEY * key)1786 int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1787 {
1788 EVP_MD_CTX *mdctx = NULL;
1789 const ML_KEM_VINFO *vinfo;
1790 int ret = 0;
1791
1792 /* Keys with key material are immutable */
1793 if (key == NULL
1794 || ossl_ml_kem_have_pubkey(key)
1795 || ossl_ml_kem_have_dkenc(key))
1796 return 0;
1797 vinfo = key->vinfo;
1798
1799 if (len != vinfo->pubkey_bytes
1800 || (mdctx = EVP_MD_CTX_new()) == NULL)
1801 return 0;
1802
1803 if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
1804 ret = parse_pubkey(in, mdctx, key);
1805
1806 if (!ret)
1807 ossl_ml_kem_key_reset(key);
1808 EVP_MD_CTX_free(mdctx);
1809 return ret;
1810 }
1811
1812 /* Parse input as a new private key */
ossl_ml_kem_parse_private_key(const uint8_t * in,size_t len,ML_KEM_KEY * key)1813 int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1814 ML_KEM_KEY *key)
1815 {
1816 EVP_MD_CTX *mdctx = NULL;
1817 const ML_KEM_VINFO *vinfo;
1818 int ret = 0;
1819
1820 /* Keys with key material are immutable */
1821 if (key == NULL
1822 || ossl_ml_kem_have_pubkey(key)
1823 || ossl_ml_kem_have_dkenc(key))
1824 return 0;
1825 vinfo = key->vinfo;
1826
1827 if (len != vinfo->prvkey_bytes
1828 || (mdctx = EVP_MD_CTX_new()) == NULL)
1829 return 0;
1830
1831 if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1832 ret = parse_prvkey(in, mdctx, key);
1833
1834 if (!ret)
1835 ossl_ml_kem_key_reset(key);
1836 EVP_MD_CTX_free(mdctx);
1837 return ret;
1838 }
1839
1840 /*
1841 * Generate a new keypair, either from the saved seed (when non-null), or from
1842 * the RNG.
1843 */
ossl_ml_kem_genkey(uint8_t * pubenc,size_t publen,ML_KEM_KEY * key)1844 int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1845 {
1846 uint8_t seed[ML_KEM_SEED_BYTES];
1847 EVP_MD_CTX *mdctx = NULL;
1848 const ML_KEM_VINFO *vinfo;
1849 int ret = 0;
1850
1851 if (key == NULL
1852 || ossl_ml_kem_have_pubkey(key)
1853 || ossl_ml_kem_have_dkenc(key))
1854 return 0;
1855 vinfo = key->vinfo;
1856
1857 if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1858 return 0;
1859
1860 if (ossl_ml_kem_have_seed(key)) {
1861 if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1862 return 0;
1863 key->d = key->z = NULL;
1864 } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1865 key->vinfo->secbits) <= 0) {
1866 return 0;
1867 }
1868
1869 if ((mdctx = EVP_MD_CTX_new()) == NULL)
1870 return 0;
1871
1872 /*
1873 * Data derived from (d, z) defaults secret, and to avoid side-channel
1874 * leaks should not influence control flow.
1875 */
1876 CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1877
1878 if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1879 ret = genkey(seed, mdctx, pubenc, key);
1880 OPENSSL_cleanse(seed, sizeof(seed));
1881
1882 /* Declassify secret inputs and derived outputs before returning control */
1883 CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1884
1885 EVP_MD_CTX_free(mdctx);
1886 if (!ret) {
1887 ossl_ml_kem_key_reset(key);
1888 return 0;
1889 }
1890
1891 /* The public components are already declassified */
1892 CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1893 CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1894 return 1;
1895 }
1896
1897 /*
1898 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1899 * This is the deterministic version with randomness supplied externally.
1900 */
ossl_ml_kem_encap_seed(uint8_t * ctext,size_t clen,uint8_t * shared_secret,size_t slen,const uint8_t * entropy,size_t elen,const ML_KEM_KEY * key)1901 int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1902 uint8_t *shared_secret, size_t slen,
1903 const uint8_t *entropy, size_t elen,
1904 const ML_KEM_KEY *key)
1905 {
1906 const ML_KEM_VINFO *vinfo;
1907 EVP_MD_CTX *mdctx;
1908 int ret = 0;
1909
1910 if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1911 return 0;
1912 vinfo = key->vinfo;
1913
1914 if (ctext == NULL || clen != vinfo->ctext_bytes
1915 || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1916 || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1917 || (mdctx = EVP_MD_CTX_new()) == NULL)
1918 return 0;
1919 /*
1920 * Data derived from the encap entropy defaults secret, and to avoid
1921 * side-channel leaks should not influence control flow.
1922 */
1923 CONSTTIME_SECRET(entropy, elen);
1924
1925 /*-
1926 * This avoids the need to handle allocation failures for two (max 2KB
1927 * each) vectors, that are never retained on return from this function.
1928 * We stack-allocate these.
1929 */
1930 # define case_encap_seed(bits) \
1931 case EVP_PKEY_ML_KEM_##bits: \
1932 { \
1933 scalar tmp[2 * ML_KEM_##bits##_RANK]; \
1934 \
1935 ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
1936 OPENSSL_cleanse((void *)tmp, sizeof(tmp)); \
1937 break; \
1938 }
1939 switch (vinfo->evp_type) {
1940 case_encap_seed(512);
1941 case_encap_seed(768);
1942 case_encap_seed(1024);
1943 }
1944 # undef case_encap_seed
1945
1946 /* Declassify secret inputs and derived outputs before returning control */
1947 CONSTTIME_DECLASSIFY(entropy, elen);
1948 CONSTTIME_DECLASSIFY(ctext, clen);
1949 CONSTTIME_DECLASSIFY(shared_secret, slen);
1950
1951 EVP_MD_CTX_free(mdctx);
1952 return ret;
1953 }
1954
ossl_ml_kem_encap_rand(uint8_t * ctext,size_t clen,uint8_t * shared_secret,size_t slen,const ML_KEM_KEY * key)1955 int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1956 uint8_t *shared_secret, size_t slen,
1957 const ML_KEM_KEY *key)
1958 {
1959 uint8_t r[ML_KEM_RANDOM_BYTES];
1960
1961 if (key == NULL)
1962 return 0;
1963
1964 if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1965 key->vinfo->secbits) < 1)
1966 return 0;
1967
1968 return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
1969 r, sizeof(r), key);
1970 }
1971
ossl_ml_kem_decap(uint8_t * shared_secret,size_t slen,const uint8_t * ctext,size_t clen,const ML_KEM_KEY * key)1972 int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
1973 const uint8_t *ctext, size_t clen,
1974 const ML_KEM_KEY *key)
1975 {
1976 const ML_KEM_VINFO *vinfo;
1977 EVP_MD_CTX *mdctx;
1978 int ret = 0;
1979 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1980 int classify_bytes;
1981 #endif
1982
1983 /* Need a private key here */
1984 if (!ossl_ml_kem_have_prvkey(key))
1985 return 0;
1986 vinfo = key->vinfo;
1987
1988 if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1989 || ctext == NULL || clen != vinfo->ctext_bytes
1990 || (mdctx = EVP_MD_CTX_new()) == NULL) {
1991 (void)RAND_bytes_ex(key->libctx, shared_secret,
1992 ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
1993 return 0;
1994 }
1995 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1996 /*
1997 * Data derived from |s| and |z| defaults secret, and to avoid side-channel
1998 * leaks should not influence control flow.
1999 */
2000 classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2001 CONSTTIME_SECRET(key->s, classify_bytes);
2002 #endif
2003
2004 /*-
2005 * This avoids the need to handle allocation failures for two (max 2KB
2006 * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2007 * retained on return from this function.
2008 * We stack-allocate these.
2009 */
2010 # define case_decap(bits) \
2011 case EVP_PKEY_ML_KEM_##bits: \
2012 { \
2013 uint8_t cbuf[CTEXT_BYTES(bits)]; \
2014 scalar tmp[2 * ML_KEM_##bits##_RANK]; \
2015 \
2016 ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2017 OPENSSL_cleanse((void *)tmp, sizeof(tmp)); \
2018 break; \
2019 }
2020 switch (vinfo->evp_type) {
2021 case_decap(512);
2022 case_decap(768);
2023 case_decap(1024);
2024 }
2025
2026 /* Declassify secret inputs and derived outputs before returning control */
2027 CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2028 CONSTTIME_DECLASSIFY(shared_secret, slen);
2029 EVP_MD_CTX_free(mdctx);
2030
2031 return ret;
2032 # undef case_decap
2033 }
2034
ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY * key1,const ML_KEM_KEY * key2)2035 int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2036 {
2037 /*
2038 * This handles any unexpected differences in the ML-KEM variant rank,
2039 * giving different key component structures, barring SHA3-256 hash
2040 * collisions, the keys are the same size.
2041 */
2042 if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2043 return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2044
2045 /*
2046 * No match if just one of the public keys is not available, otherwise both
2047 * are unavailable, and for now such keys are considered equal.
2048 */
2049 return (ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2));
2050 }
2051