1 /*
2 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10 #include <openssl/byteorder.h>
11 #include <openssl/err.h>
12 #include <openssl/evp.h>
13 #include <openssl/proverr.h>
14 #include "ml_dsa_hash.h"
15 #include "ml_dsa_key.h"
16 #include "ml_dsa_sign.h"
17 #include "internal/packet.h"
18
19 #define POLY_COEFF_NUM_BYTES(bits) ((bits) * (ML_DSA_NUM_POLY_COEFFICIENTS / 8))
20 /* Cast mod_sub result in support of left-shifts that create 64-bit values. */
21 #define mod_sub_64(a, b) ((uint64_t) mod_sub(a, b))
22
23 typedef int (ENCODE_FN)(const POLY *s, WPACKET *pkt);
24 typedef int (DECODE_FN)(POLY *s, PACKET *pkt);
25
26 static ENCODE_FN poly_encode_signed_2;
27 static ENCODE_FN poly_encode_signed_4;
28 static ENCODE_FN poly_encode_signed_two_to_power_17;
29 static ENCODE_FN poly_encode_signed_two_to_power_19;
30 static DECODE_FN poly_decode_signed_2;
31 static DECODE_FN poly_decode_signed_4;
32 static DECODE_FN poly_decode_signed_two_to_power_17;
33 static DECODE_FN poly_decode_signed_two_to_power_19;
34
35 /* Bit packing Algorithms */
36
37 /*
38 * Encodes a polynomial into a byte string, assuming that all coefficients are
39 * in the range 0..15 (4 bits).
40 *
41 * See FIPS 204, Algorithm 16, SimpleBitPack(w, b) where b = 4 bits
42 *
43 * i.e. Use 4 bits from each coefficient and pack them into bytes
44 * So every 2 coefficients fit into 1 byte.
45 *
46 * This is used to encode w1 when signing with ML-DSA-65 and ML-DSA-87
47 *
48 * @param p A polynomial with coefficients all in the range (0..15)
49 * @param pkt A packet object to write 128 bytes to.
50 *
51 * @returns 1 on success, or 0 on error.
52 */
poly_encode_4_bits(const POLY * p,WPACKET * pkt)53 static int poly_encode_4_bits(const POLY *p, WPACKET *pkt)
54 {
55 uint8_t *out;
56 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
57
58 if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(4), &out))
59 return 0;
60
61 do {
62 uint32_t z0 = *in++;
63 uint32_t z1 = *in++;
64
65 *out++ = z0 | (z1 << 4);
66 } while (in < end);
67 return 1;
68 }
69
70 /*
71 * Encodes a polynomial into a byte string, assuming that all coefficients are
72 * in the range 0..43 (6 bits).
73 *
74 * See FIPS 204, Algorithm 16, SimpleBitPack(w, b) where b = 43
75 *
76 * i.e. Use 6 bits from each coefficient and pack them into bytes
77 * So every 4 coefficients fit into 3 bytes.
78 *
79 * |c0||c1||c2||c3|
80 * | /| /\ /
81 * |6 2|4 4|2 6|
82 *
83 * This is used to encode w1 when signing with ML-DSA-44
84 *
85 * @param p A polynomial with coefficients all in the range (0..43)
86 * @param pkt A packet object to write 96 bytes to.
87 *
88 * @returns 1 on success, or 0 on error.
89 */
poly_encode_6_bits(const POLY * p,WPACKET * pkt)90 static int poly_encode_6_bits(const POLY *p, WPACKET *pkt)
91 {
92 uint8_t *out;
93 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
94
95 if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(6), &out))
96 return 0;
97
98 do {
99 uint32_t c0 = *in++;
100 uint32_t c1 = *in++;
101 uint32_t c2 = *in++;
102 uint32_t c3 = *in++;
103
104 *out++ = c0 | (c1 << 6);
105 *out++ = (c1 >> 2) | (c2 << 4);
106 *out++ = (c2 >> 4) | (c3 << 2);
107 } while (in < end);
108 return 1;
109 }
110
111 /*
112 * Encodes a polynomial into a byte string, assuming that all coefficients are
113 * unsigned 10 bit values.
114 *
115 * See FIPS 204, Algorithm 16, SimpleBitPack(w, b) where b = 10 bits
116 *
117 * i.e. Use 10 bits from each coefficient and pack them into bytes
118 * So every 4 coefficients (c0..c3) fit into 5 bytes.
119 * |c0||c1||c2||c3|
120 * |\ |\ |\ |\
121 * |8|2 6|4 4|6 2|8|
122 *
123 * This is used to save t1 (the high part of public key polynomial t)
124 *
125 * @param p A polynomial with coefficients all in the range (0..1023)
126 * @param pkt A packet object to write 320 bytes to.
127 *
128 * @returns 1 on success, or 0 on error.
129 */
poly_encode_10_bits(const POLY * p,WPACKET * pkt)130 static int poly_encode_10_bits(const POLY *p, WPACKET *pkt)
131 {
132 uint8_t *out;
133 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
134
135 if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(10), &out))
136 return 0;
137
138 do {
139 uint32_t c0 = *in++;
140 uint32_t c1 = *in++;
141 uint32_t c2 = *in++;
142 uint32_t c3 = *in++;
143
144 *out++ = (uint8_t)c0;
145 *out++ = (uint8_t)((c0 >> 8) | (c1 << 2));
146 *out++ = (uint8_t)((c1 >> 6) | (c2 << 4));
147 *out++ = (uint8_t)((c2 >> 4) | (c3 << 6));
148 *out++ = (uint8_t)(c3 >> 2);
149 } while (in < end);
150 return 1;
151 }
152
153 /*
154 * @brief Reverses the procedure of poly_encode_10_bits().
155 * See FIPS 204, Algorithm 18, SimpleBitUnpack(v, b) where b = 10.
156 *
157 * @param p A polynomial to write coefficients to.
158 * @param pkt A packet object to read 320 bytes from.
159 *
160 * @returns 1 on success, or 0 on error.
161 */
poly_decode_10_bits(POLY * p,PACKET * pkt)162 static int poly_decode_10_bits(POLY *p, PACKET *pkt)
163 {
164 const uint8_t *in = NULL;
165 uint32_t v, w, mask = 0x3ff; /* 10 bits */
166 uint32_t *out = p->coeff, *end = out + ML_DSA_NUM_POLY_COEFFICIENTS;
167
168 do {
169 if (!PACKET_get_bytes(pkt, &in, 5))
170 return 0;
171
172 in = OPENSSL_load_u32_le(&v, in);
173 w = *in;
174
175 *out++ = v & mask;
176 *out++ = (v >> 10) & mask;
177 *out++ = (v >> 20) & mask;
178 *out++ = (v >> 30) | (w << 2);
179 } while (out < end);
180 return 1;
181 }
182
183 /*
184 * @brief Encodes a polynomial into a byte string, assuming that all
185 * coefficients are in the range -4..4.
186 * See FIPS 204, Algorithm 17, BitPack(w, a, b). (a = 4, b = 4)
187 *
188 * It uses a nibble from each coefficient and packs them into bytes
189 * So every 2 coefficients fit into 1 byte.
190 *
191 * This is used to encode the private key polynomial elements of s1 and s2
192 * for ML-DSA-65 (i.e. eta = 4)
193 *
194 * @param p An array of 256 coefficients all in the range -4..4
195 * @param pkt A packet to write 128 bytes of encoded polynomial coefficients to.
196 *
197 * @returns 1 on success, or 0 on error.
198 */
poly_encode_signed_4(const POLY * p,WPACKET * pkt)199 static int poly_encode_signed_4(const POLY *p, WPACKET *pkt)
200 {
201 uint8_t *out;
202 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
203
204 if (!WPACKET_allocate_bytes(pkt, 32 * 4, &out))
205 return 0;
206
207 do {
208 uint32_t z = mod_sub(4, *in++);
209
210 *out++ = z | (mod_sub(4, *in++) << 4);
211 } while (in < end);
212 return 1;
213 }
214
215 /*
216 * @brief Reverses the procedure of poly_encode_signed_4().
217 * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = b = 4.
218 *
219 * @param p A polynomial to write coefficients to.
220 * @param pkt A packet object to read 128 bytes from.
221 *
222 * @returns 1 on success, or 0 on error. An error will occur if any of the
223 * coefficients are not in the correct range.
224 */
poly_decode_signed_4(POLY * p,PACKET * pkt)225 static int poly_decode_signed_4(POLY *p, PACKET *pkt)
226 {
227 int i, ret = 0;
228 uint32_t v, *out = p->coeff;
229 const uint8_t *in;
230 uint32_t msbs, mask;
231
232 for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
233 if (!PACKET_get_bytes(pkt, &in, 4))
234 goto err;
235 in = OPENSSL_load_u32_le(&v, in);
236
237 /*
238 * None of the nibbles may be >= 9. So if the MSB of any nibble is set,
239 * none of the other bits may be set. First, select all the MSBs.
240 */
241 msbs = v & 0x88888888u;
242 /* For each nibble where the MSB is set, form a mask of all the other bits. */
243 mask = (msbs >> 1) | (msbs >> 2) | (msbs >> 3);
244 /*
245 * A nibble is only out of range in the case of invalid input, in which case
246 * it is okay to leak the value.
247 */
248 if (value_barrier_32((mask & v) != 0))
249 goto err;
250
251 *out++ = mod_sub(4, v & 15);
252 *out++ = mod_sub(4, (v >> 4) & 15);
253 *out++ = mod_sub(4, (v >> 8) & 15);
254 *out++ = mod_sub(4, (v >> 12) & 15);
255 *out++ = mod_sub(4, (v >> 16) & 15);
256 *out++ = mod_sub(4, (v >> 20) & 15);
257 *out++ = mod_sub(4, (v >> 24) & 15);
258 *out++ = mod_sub(4, v >> 28);
259 }
260 ret = 1;
261 err:
262 return ret;
263 }
264
265 /*
266 * @brief Encodes a polynomial into a byte string, assuming that all
267 * coefficients are in the range -2..2.
268 * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = b = 2.
269 *
270 * This is used to encode the private key polynomial elements of s1 and s2
271 * for ML-DSA-44 and ML-DSA-87 (i.e. eta = 2)
272 *
273 * @param pkt A packet to write 128 bytes of encoded polynomial coefficients to.
274 * @param p An array of 256 coefficients all in the range -2..2
275 *
276 * Use 3 bits from each coefficient and pack them into bytes
277 * So every 8 coefficients fit into 3 bytes.
278 * |c0 c1 c2 c3 c4 c5 c6 c7|
279 * | / / | | / / | | /
280 * |3 3 2| 1 3 3 1| 2 3 3|
281 *
282 * @param p An array of 256 coefficients all in the range -2..2
283 * @param pkt A packet to write 64 bytes of encoded polynomial coefficients to.
284 *
285 * @returns 1 on success, or 0 on error.
286 */
poly_encode_signed_2(const POLY * p,WPACKET * pkt)287 static int poly_encode_signed_2(const POLY *p, WPACKET *pkt)
288 {
289 uint8_t *out;
290 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
291
292 if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(3), &out))
293 return 0;
294
295 do {
296 uint32_t z;
297
298 z = mod_sub(2, *in++);
299 z |= mod_sub(2, *in++) << 3;
300 z |= mod_sub(2, *in++) << 6;
301 z |= mod_sub(2, *in++) << 9;
302 z |= mod_sub(2, *in++) << 12;
303 z |= mod_sub(2, *in++) << 15;
304 z |= mod_sub(2, *in++) << 18;
305 z |= mod_sub(2, *in++) << 21;
306
307 out = OPENSSL_store_u16_le(out, (uint16_t) z);
308 *out++ = (uint8_t) (z >> 16);
309 } while (in < end);
310 return 1;
311 }
312
313 /*
314 * @brief Reverses the procedure of poly_encode_signed_2().
315 * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = b = 2.
316 *
317 * @param p A polynomial to write coefficients to.
318 * @param pkt A packet object to read 64 encoded bytes from.
319 *
320 * @returns 1 on success, or 0 on error. An error will occur if any of the
321 * coefficients are not in the correct range.
322 */
poly_decode_signed_2(POLY * p,PACKET * pkt)323 static int poly_decode_signed_2(POLY *p, PACKET *pkt)
324 {
325 int i, ret = 0;
326 uint32_t u = 0, v = 0, *out = p->coeff;
327 uint32_t msbs, mask;
328 const uint8_t *in;
329
330 for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
331 if (!PACKET_get_bytes(pkt, &in, 3))
332 goto err;
333 memcpy(&u, in, 3);
334 OPENSSL_load_u32_le(&v, (uint8_t *)&u);
335
336 /*
337 * Each octal value (3 bits) must be <= 4, So if the MSB is set then the
338 * bottom 2 bits must not be set.
339 * First, select all the MSBs (Use octal representation for the mask)
340 */
341 msbs = v & 044444444;
342 /* For each octal value where the MSB is set, form a mask of the 2 other bits. */
343 mask = (msbs >> 1) | (msbs >> 2);
344 /*
345 * A nibble is only out of range in the case of invalid input, in which
346 * case it is okay to leak the value.
347 */
348 if (value_barrier_32((mask & v) != 0))
349 goto err;
350
351 *out++ = mod_sub(2, v & 7);
352 *out++ = mod_sub(2, (v >> 3) & 7);
353 *out++ = mod_sub(2, (v >> 6) & 7);
354 *out++ = mod_sub(2, (v >> 9) & 7);
355 *out++ = mod_sub(2, (v >> 12) & 7);
356 *out++ = mod_sub(2, (v >> 15) & 7);
357 *out++ = mod_sub(2, (v >> 18) & 7);
358 *out++ = mod_sub(2, (v >> 21) & 7);
359 }
360 ret = 1;
361 err:
362 return ret;
363 }
364
365 /*
366 * @brief Encodes a polynomial into a byte string, assuming that all
367 * coefficients are in the range (-2^12 + 1)..2^12.
368 * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^12 - 1, b = 2^12.
369 *
370 * This is used to encode the LSB of the public key polynomial elements of t0
371 * (which are encoded as part of the encoded private key).
372 *
373 * Use 13 bits from each coefficient and pack them into bytes
374 *
375 * The code below packs them into 2 64 bits blocks by doing..
376 * z0 z1 z2 z3 z4 z5 z6 z7 0
377 * | | | | / \ | | | |
378 * |13 13 13 13 12 |1 13 13 13 24
379 *
380 * @param p An array of 256 coefficients all in the range -2^12+1..2^12
381 * @param pkt A packet to write 416 (13 * 256 / 8) bytes of encoded polynomial
382 * coefficients to.
383 *
384 * @returns 1 on success, or 0 on error.
385 */
poly_encode_signed_two_to_power_12(const POLY * p,WPACKET * pkt)386 static int poly_encode_signed_two_to_power_12(const POLY *p, WPACKET *pkt)
387 {
388 static const uint32_t range = 1u << 12;
389 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
390
391 do {
392 uint8_t *out;
393 uint64_t a1, a2;
394
395 if (!WPACKET_allocate_bytes(pkt, 13, &out))
396 return 0;
397
398 a1 = mod_sub_64(range, *in++);
399 a1 |= mod_sub_64(range, *in++) << 13;
400 a1 |= mod_sub_64(range, *in++) << 26;
401 a1 |= mod_sub_64(range, *in++) << 39;
402 a1 |= (a2 = mod_sub_64(range, *in++)) << 52;
403 a2 = (a2 >> 12) | (mod_sub_64(range, *in++) << 1);
404 a2 |= mod_sub_64(range, *in++) << 14;
405 a2 |= mod_sub_64(range, *in++) << 27;
406
407 out = OPENSSL_store_u64_le(out, a1);
408 out = OPENSSL_store_u32_le(out, (uint32_t) a2);
409 *out = (uint8_t) (a2 >> 32);
410 } while (in < end);
411 return 1;
412 }
413
414 /*
415 * @brief Reverses the procedure of poly_encode_signed_two_to_power_12().
416 * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = 2^12 - 1, b = 2^12.
417 *
418 * @param p A polynomial to write coefficients to.
419 * @param pkt A packet object to read 416 encoded bytes from.
420 *
421 * @returns 1 on success, or 0 on error.
422 */
poly_decode_signed_two_to_power_12(POLY * p,PACKET * pkt)423 static int poly_decode_signed_two_to_power_12(POLY *p, PACKET *pkt)
424 {
425 int i, ret = 0;
426 uint32_t *out = p->coeff;
427 const uint8_t *in;
428 static const uint32_t range = 1u << 12;
429 static const uint32_t mask_13_bits = (1u << 13) - 1;
430
431 for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
432 uint64_t a1;
433 uint32_t a2, b13;
434
435 if (!PACKET_get_bytes(pkt, &in, 13))
436 goto err;
437 in = OPENSSL_load_u64_le(&a1, in);
438 in = OPENSSL_load_u32_le(&a2, in);
439 b13 = (uint32_t) *in;
440
441 *out++ = mod_sub(range, a1 & mask_13_bits);
442 *out++ = mod_sub(range, (a1 >> 13) & mask_13_bits);
443 *out++ = mod_sub(range, (a1 >> 26) & mask_13_bits);
444 *out++ = mod_sub(range, (a1 >> 39) & mask_13_bits);
445 *out++ = mod_sub(range, (a1 >> 52) | ((a2 << 12) & mask_13_bits));
446 *out++ = mod_sub(range, (a2 >> 1) & mask_13_bits);
447 *out++ = mod_sub(range, (a2 >> 14) & mask_13_bits);
448 *out++ = mod_sub(range, (a2 >> 27) | (b13 << 5));
449 }
450 ret = 1;
451 err:
452 return ret;
453 }
454
455 /*
456 * @brief Encodes a polynomial into a byte string, assuming that all
457 * coefficients are in the range (-2^19 + 1)..2^19.
458 * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^19 - 1, b = 2^19.
459 *
460 * This is used to encode signatures for ML-DSA-65 & ML-DSA-87 (gamma1 = 2^19)
461 *
462 * Use 20 bits from each coefficient and pack them into bytes
463 *
464 * The code below packs every 4 (20 bit) coefficients into 10 bytes
465 * z0 z1 z2 z3
466 * | |\ | | \
467 * |20 12|8 20 4|16
468 *
469 * @param p An array of 256 coefficients all in the range -2^19+1..2^19
470 * @param pkt A packet to write 640 (20 * 256 / 8) bytes of encoded polynomial
471 * coefficients to.
472 *
473 * @returns 1 on success, or 0 on error.
474 */
poly_encode_signed_two_to_power_19(const POLY * p,WPACKET * pkt)475 static int poly_encode_signed_two_to_power_19(const POLY *p, WPACKET *pkt)
476 {
477 static const uint32_t range = 1u << 19;
478 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
479
480 do {
481 uint32_t z0, z1, z2;
482 uint8_t *out;
483
484 if (!WPACKET_allocate_bytes(pkt, 10, &out))
485 return 0;
486
487 z0 = mod_sub(range, *in++);
488 z0 |= (z1 = mod_sub(range, *in++)) << 20;
489 z1 = (z1 >> 12) | (mod_sub(range, *in++) << 8);
490 z1 |= (z2 = mod_sub(range, *in++)) << 28;
491
492 out = OPENSSL_store_u32_le(out, z0);
493 out = OPENSSL_store_u32_le(out, z1);
494 out = OPENSSL_store_u16_le(out, (uint16_t) (z2 >> 4));
495 } while (in < end);
496 return 1;
497 }
498
499 /*
500 * @brief Reverses the procedure of poly_encode_signed_two_to_power_19().
501 * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = 2^19 - 1, b = 2^19.
502 *
503 * @param p A polynomial to write coefficients to.
504 * @param pkt A packet object to read 640 encoded bytes from.
505 *
506 * @returns 1 on success, or 0 on error.
507 */
poly_decode_signed_two_to_power_19(POLY * p,PACKET * pkt)508 static int poly_decode_signed_two_to_power_19(POLY *p, PACKET *pkt)
509 {
510 int i, ret = 0;
511 uint32_t *out = p->coeff;
512 const uint8_t *in;
513 static const uint32_t range = 1u << 19;
514 static const uint32_t mask_20_bits = (1u << 20) - 1;
515
516 for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 4); i++) {
517 uint32_t a1, a2;
518 uint16_t a3;
519
520 if (!PACKET_get_bytes(pkt, &in, 10))
521 goto err;
522 in = OPENSSL_load_u32_le(&a1, in);
523 in = OPENSSL_load_u32_le(&a2, in);
524 in = OPENSSL_load_u16_le(&a3, in);
525
526 *out++ = mod_sub(range, a1 & mask_20_bits);
527 *out++ = mod_sub(range, (a1 >> 20) | ((a2 & 0xFF) << 12));
528 *out++ = mod_sub(range, (a2 >> 8) & mask_20_bits);
529 *out++ = mod_sub(range, (a2 >> 28) | (a3 << 4));
530 }
531 ret = 1;
532 err:
533 return ret;
534 }
535
536 /*
537 * @brief Encodes a polynomial into a byte string, assuming that all
538 * coefficients are in the range (-2^17 + 1)..2^17.
539 * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^17 - 1, b = 2^17.
540 *
541 * This is used to encode signatures for ML-DSA-44 (where gamma1 = 2^17)
542 *
543 * Use 18 bits from each coefficient and pack them into bytes
544 *
545 * The code below packs every 4 (18 bit) coefficients into 9 bytes
546 * z0 z1 z2 z3
547 * | |\ | | \
548 * |18 14|4 18 10| 8
549 *
550 * @param p An array of 256 coefficients all in the range -2^17+1..2^17
551 * @param pkt A packet to write 576 (18 * 256 / 8) bytes of encoded polynomial
552 * coefficients to.
553 *
554 * @returns 1 on success, or 0 on error.
555 */
poly_encode_signed_two_to_power_17(const POLY * p,WPACKET * pkt)556 static int poly_encode_signed_two_to_power_17(const POLY *p, WPACKET *pkt)
557 {
558 static const uint32_t range = 1u << 17;
559 const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
560
561 do {
562 uint8_t *out;
563 uint32_t z0, z1, z2;
564
565 if (!WPACKET_allocate_bytes(pkt, 9, &out))
566 return 0;
567
568 z0 = mod_sub(range, *in++);
569 z0 |= (z1 = mod_sub(range, *in++)) << 18;
570 z1 = (z1 >> 14) | (mod_sub(range, *in++) << 4);
571 z1 |= (z2 = mod_sub(range, *in++)) << 22;
572
573 out = OPENSSL_store_u32_le(out, z0);
574 out = OPENSSL_store_u32_le(out, z1);
575 *out = z2 >> 10;
576 } while (in < end);
577 return 1;
578 }
579
580 /*
581 * @brief Reverses the procedure of poly_encode_signed_two_to_power_17().
582 * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = 2^17 - 1, b = 2^17.
583 *
584 * @param p A polynomial to write coefficients to.
585 * @param pkt A packet object to read 576 encoded bytes from.
586 *
587 * @returns 1 on success, or 0 on error.
588 */
poly_decode_signed_two_to_power_17(POLY * p,PACKET * pkt)589 static int poly_decode_signed_two_to_power_17(POLY *p, PACKET *pkt)
590 {
591 uint32_t *out = p->coeff;
592 const uint32_t *end = out + ML_DSA_NUM_POLY_COEFFICIENTS;
593 const uint8_t *in;
594 static const uint32_t range = 1u << 17;
595 static const uint32_t mask_18_bits = (1u << 18) - 1;
596
597 do {
598 uint32_t a1, a2, a3;
599
600 if (!PACKET_get_bytes(pkt, &in, 9))
601 return 0;
602 in = OPENSSL_load_u32_le(&a1, in);
603 in = OPENSSL_load_u32_le(&a2, in);
604 a3 = (uint32_t) *in;
605
606 *out++ = mod_sub(range, a1 & mask_18_bits);
607 *out++ = mod_sub(range, (a1 >> 18) | ((a2 & 0xF) << 14));
608 *out++ = mod_sub(range, (a2 >> 4) & mask_18_bits);
609 *out++ = mod_sub(range, (a2 >> 22) | (a3 << 10));
610 } while (out < end);
611 return 1;
612 }
613
614 /*
615 * @brief Encode the public key as an array of bytes.
616 * See FIPS 204, Algorithm 22, pkEncode().
617 *
618 * @param key A key object containing public key values. The encoded public
619 * key data is stored in this key.
620 * @returns 1 if the public key was encoded successfully or 0 otherwise.
621 */
ossl_ml_dsa_pk_encode(ML_DSA_KEY * key)622 int ossl_ml_dsa_pk_encode(ML_DSA_KEY *key)
623 {
624 int ret = 0;
625 size_t i, written = 0;
626 const POLY *t1 = key->t1.poly;
627 size_t t1_len = key->t1.num_poly;
628 size_t enc_len = key->params->pk_len;
629 uint8_t *enc = OPENSSL_malloc(enc_len);
630 WPACKET pkt;
631
632 if (enc == NULL)
633 return 0;
634
635 if (!WPACKET_init_static_len(&pkt, enc, enc_len, 0)
636 || !WPACKET_memcpy(&pkt, key->rho, sizeof(key->rho)))
637 goto err;
638 for (i = 0; i < t1_len; i++)
639 if (!poly_encode_10_bits(t1 + i, &pkt))
640 goto err;
641 if (!WPACKET_get_total_written(&pkt, &written)
642 || written != enc_len)
643 goto err;
644 OPENSSL_free(key->pub_encoding);
645 key->pub_encoding = enc;
646 ret = 1;
647 err:
648 WPACKET_finish(&pkt);
649 if (ret == 0)
650 OPENSSL_free(enc);
651 return ret;
652 }
653
654 /*
655 * @brief The reverse of ossl_ml_dsa_pk_encode().
656 * See FIPS 204, Algorithm 23, pkDecode().
657 *
658 * @param in An encoded public key.
659 * @param in_len The size of |in|
660 * @param key A key object to store the decoded public key into.
661 *
662 * @returns 1 if the public key was decoded successfully or 0 otherwise.
663 */
ossl_ml_dsa_pk_decode(ML_DSA_KEY * key,const uint8_t * in,size_t in_len)664 int ossl_ml_dsa_pk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len)
665 {
666 int ret = 0;
667 size_t i;
668 PACKET pkt;
669 EVP_MD_CTX *ctx;
670
671 if (key->priv_encoding != NULL || key->pub_encoding != NULL)
672 return 0; /* Do not allow key mutation */
673 if (in_len != key->params->pk_len)
674 return 0;
675
676 if (!ossl_ml_dsa_key_pub_alloc(key))
677 return 0;
678 ctx = EVP_MD_CTX_new();
679 if (ctx == NULL)
680 goto err;
681 if (!PACKET_buf_init(&pkt, in, in_len)
682 || !PACKET_copy_bytes(&pkt, key->rho, sizeof(key->rho)))
683 goto err;
684 for (i = 0; i < key->t1.num_poly; i++)
685 if (!poly_decode_10_bits(key->t1.poly + i, &pkt))
686 goto err;
687
688 /* cache the hash of the encoded public key */
689 if (!shake_xof(ctx, key->shake256_md, in, in_len, key->tr, sizeof(key->tr)))
690 goto err;
691
692 key->pub_encoding = OPENSSL_memdup(in, in_len);
693 ret = (key->pub_encoding != NULL);
694 err:
695 EVP_MD_CTX_free(ctx);
696 return ret;
697 }
698
699 /*
700 * @brief Encode the private key as an array of bytes.
701 * See FIPS 204, Algorithm 24, skEncode().
702 *
703 * @param key A key object containing private key values. The encoded private
704 * key data is stored in this key.
705 * @returns 1 if the private key was encoded successfully or 0 otherwise.
706 */
ossl_ml_dsa_sk_encode(ML_DSA_KEY * key)707 int ossl_ml_dsa_sk_encode(ML_DSA_KEY *key)
708 {
709 int ret = 0;
710 const ML_DSA_PARAMS *params = key->params;
711 size_t i, written = 0, k = params->k, l = params->l;
712 ENCODE_FN *encode_fn;
713 size_t enc_len = params->sk_len;
714 const POLY *t0 = key->t0.poly;
715 WPACKET pkt;
716 uint8_t *enc = OPENSSL_malloc(enc_len);
717
718 if (enc == NULL)
719 return 0;
720
721 /* eta is the range of private key coefficients (-eta...eta) */
722 if (params->eta == ML_DSA_ETA_4)
723 encode_fn = poly_encode_signed_4;
724 else
725 encode_fn = poly_encode_signed_2;
726
727 if (!WPACKET_init_static_len(&pkt, enc, enc_len, 0)
728 || !WPACKET_memcpy(&pkt, key->rho, sizeof(key->rho))
729 || !WPACKET_memcpy(&pkt, key->K, sizeof(key->K))
730 || !WPACKET_memcpy(&pkt, key->tr, sizeof(key->tr)))
731 goto err;
732 for (i = 0; i < l; ++i)
733 if (!encode_fn(key->s1.poly + i, &pkt))
734 goto err;
735 for (i = 0; i < k; ++i)
736 if (!encode_fn(key->s2.poly + i, &pkt))
737 goto err;
738 for (i = 0; i < k; ++i)
739 if (!poly_encode_signed_two_to_power_12(t0++, &pkt))
740 goto err;
741 if (!WPACKET_get_total_written(&pkt, &written)
742 || written != enc_len)
743 goto err;
744 OPENSSL_clear_free(key->priv_encoding, enc_len);
745 key->priv_encoding = enc;
746 ret = 1;
747 err:
748 WPACKET_finish(&pkt);
749 if (ret == 0)
750 OPENSSL_clear_free(enc, enc_len);
751 return ret;
752 }
753
754 /*
755 * @brief The reverse of ossl_ml_dsa_sk_encode().
756 * See FIPS 204, Algorithm 24, skDecode().
757 *
758 * @param in An encoded private key.
759 * @param in_len The size of |in|
760 * @param key A key object to store the decoded private key into.
761 *
762 * @returns 1 if the private key was decoded successfully or 0 otherwise.
763 */
ossl_ml_dsa_sk_decode(ML_DSA_KEY * key,const uint8_t * in,size_t in_len)764 int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len)
765 {
766 DECODE_FN *decode_fn;
767 const ML_DSA_PARAMS *params = key->params;
768 size_t i, k = params->k, l = params->l;
769 uint8_t input_tr[ML_DSA_TR_BYTES];
770 PACKET pkt;
771
772 /* When loading from an explicit key, drop the seed. */
773 OPENSSL_clear_free(key->seed, ML_DSA_SEED_BYTES);
774 key->seed = NULL;
775
776 /* Allow the key encoding to be already set to the provided pointer */
777 if ((key->priv_encoding != NULL && key->priv_encoding != in)
778 || key->pub_encoding != NULL)
779 return 0; /* Do not allow key mutation */
780 if (in_len != key->params->sk_len)
781 return 0;
782 if (!ossl_ml_dsa_key_priv_alloc(key))
783 return 0;
784
785 /* eta is the range of private key coefficients (-eta...eta) */
786 if (params->eta == ML_DSA_ETA_4)
787 decode_fn = poly_decode_signed_4;
788 else
789 decode_fn = poly_decode_signed_2;
790
791 if (!PACKET_buf_init(&pkt, in, in_len)
792 || !PACKET_copy_bytes(&pkt, key->rho, sizeof(key->rho))
793 || !PACKET_copy_bytes(&pkt, key->K, sizeof(key->K))
794 || !PACKET_copy_bytes(&pkt, input_tr, sizeof(input_tr)))
795 return 0;
796
797 for (i = 0; i < l; ++i)
798 if (!decode_fn(key->s1.poly + i, &pkt))
799 goto err;
800 for (i = 0; i < k; ++i)
801 if (!decode_fn(key->s2.poly + i, &pkt))
802 goto err;
803 for (i = 0; i < k; ++i)
804 if (!poly_decode_signed_two_to_power_12(key->t0.poly + i, &pkt))
805 goto err;
806 if (PACKET_remaining(&pkt) != 0)
807 goto err;
808 if (key->priv_encoding == NULL
809 && (key->priv_encoding = OPENSSL_memdup(in, in_len)) == NULL)
810 goto err;
811 /*
812 * Computing the public key also computes its hash, which must be equal to
813 * the |tr| value in the private key, else the key was corrupted.
814 */
815 if (!ossl_ml_dsa_key_public_from_private(key)
816 || memcmp(input_tr, key->tr, sizeof(input_tr)) != 0) {
817 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
818 "%s private key does not match its pubkey part",
819 key->params->alg);
820 ossl_ml_dsa_key_reset(key);
821 goto err;
822 }
823
824 return 1;
825 err:
826 return 0;
827 }
828
829 /*
830 * See FIPS 204, Algorithm 20, HintBitPack().
831 * Hint is composed of k polynomials with binary coefficients where only 'omega'
832 * of all the coefficients are set to 1.
833 * This can be encoded as a byte array of 'omega' polynomial coefficient index
834 * positions for the coefficients that are set, followed by
835 * k values of the last coefficient index used in each polynomial.
836 */
hint_bits_encode(const VECTOR * hint,WPACKET * pkt,uint32_t omega)837 static int hint_bits_encode(const VECTOR *hint, WPACKET *pkt, uint32_t omega)
838 {
839 int i, j, k = hint->num_poly;
840 size_t coeff_index = 0;
841 POLY *p = hint->poly;
842 uint8_t *data;
843
844 if (!WPACKET_allocate_bytes(pkt, omega + k, &data))
845 return 0;
846 memset(data, 0, omega + k);
847
848 for (i = 0; i < k; i++, p++) {
849 for (j = 0; j < ML_DSA_NUM_POLY_COEFFICIENTS; j++)
850 if (p->coeff[j] != 0)
851 data[coeff_index++] = j;
852 data[omega + i] = (uint8_t)coeff_index;
853 }
854 return 1;
855 }
856
857 /*
858 * @brief Reverse the process of hint_bits_encode()
859 * See FIPS 204, Algorithm 21, HintBitUnpack()
860 *
861 * @returns 1 if the hints were successfully unpacked, or 0
862 * if 'pkt' is too small or malformed.
863 */
hint_bits_decode(VECTOR * hint,PACKET * pkt,uint32_t omega)864 static int hint_bits_decode(VECTOR *hint, PACKET *pkt, uint32_t omega)
865 {
866 size_t coeff_index = 0, k = hint->num_poly;
867 const uint8_t *in, *limits;
868 POLY *p = hint->poly, *end = p + k;
869
870 if (!PACKET_get_bytes(pkt, &in, omega)
871 || !PACKET_get_bytes(pkt, &limits, k))
872 return 0;
873
874 vector_zero(hint); /* Set all coefficients to zero */
875
876 do {
877 const uint32_t limit = *limits++;
878 int last = -1;
879
880 if (limit < coeff_index || limit > omega)
881 return 0;
882
883 while (coeff_index < limit) {
884 int byte = in[coeff_index++];
885
886 if (last >= 0 && byte <= last)
887 return 0;
888 last = byte;
889 p->coeff[byte] = 1;
890 }
891 } while (++p < end);
892
893 for (; coeff_index < omega; coeff_index++)
894 if (in[coeff_index] != 0)
895 return 0;
896 return 1;
897 }
898
899 /*
900 * @brief Encode a ML_DSA signature as an array of bytes.
901 * See FIPS 204, Algorithm 26, sigEncode().
902 *
903 * @param
904 * @param
905 * @returns 1 if the signature was encoded successfully or 0 otherwise.
906 */
ossl_ml_dsa_sig_encode(const ML_DSA_SIG * sig,const ML_DSA_PARAMS * params,uint8_t * out)907 int ossl_ml_dsa_sig_encode(const ML_DSA_SIG *sig, const ML_DSA_PARAMS *params,
908 uint8_t *out)
909 {
910 int ret = 0;
911 size_t i;
912 ENCODE_FN *encode_fn;
913 WPACKET pkt;
914
915 if (out == NULL)
916 return 0;
917
918 if (params->gamma1 == ML_DSA_GAMMA1_TWO_POWER_19)
919 encode_fn = poly_encode_signed_two_to_power_19;
920 else
921 encode_fn = poly_encode_signed_two_to_power_17;
922
923 if (!WPACKET_init_static_len(&pkt, out, params->sig_len, 0)
924 || !WPACKET_memcpy(&pkt, sig->c_tilde, sig->c_tilde_len))
925 goto err;
926
927 for (i = 0; i < sig->z.num_poly; ++i)
928 if (!encode_fn(sig->z.poly + i, &pkt))
929 goto err;
930 if (!hint_bits_encode(&sig->hint, &pkt, params->omega))
931 goto err;
932 ret = 1;
933 err:
934 WPACKET_finish(&pkt);
935 return ret;
936 }
937
938 /*
939 * @param sig is a initialized signature object to decode into.
940 * @param in An encoded signature
941 * @param in_len The size of |in|
942 * @param params contains constants for an ML-DSA algorithm (such as gamma1)
943 * @returns 1 if the signature was successfully decoded or 0 otherwise.
944 */
ossl_ml_dsa_sig_decode(ML_DSA_SIG * sig,const uint8_t * in,size_t in_len,const ML_DSA_PARAMS * params)945 int ossl_ml_dsa_sig_decode(ML_DSA_SIG *sig, const uint8_t *in, size_t in_len,
946 const ML_DSA_PARAMS *params)
947 {
948 int ret = 0;
949 size_t i;
950 DECODE_FN *decode_fn;
951 PACKET pkt;
952
953 if (params->gamma1 == ML_DSA_GAMMA1_TWO_POWER_19)
954 decode_fn = poly_decode_signed_two_to_power_19;
955 else
956 decode_fn = poly_decode_signed_two_to_power_17;
957
958 if (!PACKET_buf_init(&pkt, in, in_len)
959 || !PACKET_copy_bytes(&pkt, sig->c_tilde, sig->c_tilde_len))
960 goto err;
961 for (i = 0; i < sig->z.num_poly; ++i)
962 if (!decode_fn(sig->z.poly + i, &pkt))
963 goto err;
964
965 if (!hint_bits_decode(&sig->hint, &pkt, params->omega)
966 || PACKET_remaining(&pkt) != 0)
967 goto err;
968 ret = 1;
969 err:
970 return ret;
971 }
972
ossl_ml_dsa_poly_decode_expand_mask(POLY * out,const uint8_t * in,size_t in_len,uint32_t gamma1)973 int ossl_ml_dsa_poly_decode_expand_mask(POLY *out,
974 const uint8_t *in, size_t in_len,
975 uint32_t gamma1)
976 {
977 PACKET pkt;
978
979 if (!PACKET_buf_init(&pkt, in, in_len))
980 return 0;
981 if (gamma1 == ML_DSA_GAMMA1_TWO_POWER_19)
982 return poly_decode_signed_two_to_power_19(out, &pkt);
983 else
984 return poly_decode_signed_two_to_power_17(out, &pkt);
985 }
986
987 /*
988 * @brief Encode a polynomial vector as an array of bytes.
989 * Where the polynomial coefficients have a range of [0..15] or [0..43]
990 * depending on the value of gamma2.
991 *
992 * See FIPS 204, Algorithm 28, w1Encode().
993 *
994 * @param w1 The vector to convert to bytes
995 * @param gamma2 either ML_DSA_GAMMA2_Q_MINUS1_DIV32 or ML_DSA_GAMMA2_Q_MINUS1_DIV88
996 * @returns 1 if the signature was encoded successfully or 0 otherwise.
997 */
ossl_ml_dsa_w1_encode(const VECTOR * w1,uint32_t gamma2,uint8_t * out,size_t out_len)998 int ossl_ml_dsa_w1_encode(const VECTOR *w1, uint32_t gamma2,
999 uint8_t *out, size_t out_len)
1000 {
1001 WPACKET pkt;
1002 ENCODE_FN *encode_fn;
1003 int ret = 0;
1004 size_t i;
1005
1006 if (!WPACKET_init_static_len(&pkt, out, out_len, 0))
1007 return 0;
1008 if (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV32)
1009 encode_fn = poly_encode_4_bits;
1010 else
1011 encode_fn = poly_encode_6_bits;
1012 for (i = 0; i < w1->num_poly; ++i)
1013 if (!encode_fn(w1->poly + i, &pkt))
1014 goto err;
1015 ret = 1;
1016 err:
1017 WPACKET_finish(&pkt);
1018 return ret;
1019 }
1020