xref: /freebsd/crypto/openssl/crypto/ml_dsa/ml_dsa_encoders.c (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <openssl/byteorder.h>
11 #include <openssl/err.h>
12 #include <openssl/evp.h>
13 #include <openssl/proverr.h>
14 #include "ml_dsa_hash.h"
15 #include "ml_dsa_key.h"
16 #include "ml_dsa_sign.h"
17 #include "internal/packet.h"
18 
19 #define POLY_COEFF_NUM_BYTES(bits)  ((bits) * (ML_DSA_NUM_POLY_COEFFICIENTS / 8))
20 /* Cast mod_sub result in support of left-shifts that create 64-bit values. */
21 #define mod_sub_64(a, b) ((uint64_t) mod_sub(a, b))
22 
23 typedef int (ENCODE_FN)(const POLY *s, WPACKET *pkt);
24 typedef int (DECODE_FN)(POLY *s, PACKET *pkt);
25 
26 static ENCODE_FN poly_encode_signed_2;
27 static ENCODE_FN poly_encode_signed_4;
28 static ENCODE_FN poly_encode_signed_two_to_power_17;
29 static ENCODE_FN poly_encode_signed_two_to_power_19;
30 static DECODE_FN poly_decode_signed_2;
31 static DECODE_FN poly_decode_signed_4;
32 static DECODE_FN poly_decode_signed_two_to_power_17;
33 static DECODE_FN poly_decode_signed_two_to_power_19;
34 
35 /* Bit packing Algorithms */
36 
37 /*
38  * Encodes a polynomial into a byte string, assuming that all coefficients are
39  * in the range 0..15 (4 bits).
40  *
41  * See FIPS 204, Algorithm 16, SimpleBitPack(w, b) where b = 4 bits
42  *
43  * i.e. Use 4 bits from each coefficient and pack them into bytes
44  * So every 2 coefficients fit into 1 byte.
45  *
46  * This is used to encode w1 when signing with ML-DSA-65 and ML-DSA-87
47  *
48  * @param p A polynomial with coefficients all in the range (0..15)
49  * @param pkt A packet object to write 128 bytes to.
50  *
51  * @returns 1 on success, or 0 on error.
52  */
poly_encode_4_bits(const POLY * p,WPACKET * pkt)53 static int poly_encode_4_bits(const POLY *p, WPACKET *pkt)
54 {
55     uint8_t *out;
56     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
57 
58     if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(4), &out))
59         return 0;
60 
61     do {
62         uint32_t z0 = *in++;
63         uint32_t z1 = *in++;
64 
65         *out++ = z0 | (z1 << 4);
66     } while (in < end);
67     return 1;
68 }
69 
70 /*
71  * Encodes a polynomial into a byte string, assuming that all coefficients are
72  * in the range 0..43 (6 bits).
73  *
74  * See FIPS 204, Algorithm 16, SimpleBitPack(w, b) where b = 43
75  *
76  * i.e. Use 6 bits from each coefficient and pack them into bytes
77  * So every 4 coefficients fit into 3 bytes.
78  *
79  *  |c0||c1||c2||c3|
80  *   |  /|  /\  /
81  *  |6 2|4 4|2 6|
82  *
83  * This is used to encode w1 when signing with ML-DSA-44
84  *
85  * @param p A polynomial with coefficients all in the range (0..43)
86  * @param pkt A packet object to write 96 bytes to.
87  *
88  * @returns 1 on success, or 0 on error.
89  */
poly_encode_6_bits(const POLY * p,WPACKET * pkt)90 static int poly_encode_6_bits(const POLY *p, WPACKET *pkt)
91 {
92     uint8_t *out;
93     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
94 
95     if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(6), &out))
96         return 0;
97 
98     do {
99         uint32_t c0 = *in++;
100         uint32_t c1 = *in++;
101         uint32_t c2 = *in++;
102         uint32_t c3 = *in++;
103 
104         *out++ = c0 | (c1 << 6);
105         *out++ = (c1 >> 2) | (c2 << 4);
106         *out++ = (c2 >> 4) | (c3 << 2);
107     } while (in < end);
108     return 1;
109 }
110 
111 /*
112  * Encodes a polynomial into a byte string, assuming that all coefficients are
113  * unsigned 10 bit values.
114  *
115  * See FIPS 204, Algorithm 16, SimpleBitPack(w, b) where b = 10 bits
116  *
117  * i.e. Use 10 bits from each coefficient and pack them into bytes
118  * So every 4 coefficients (c0..c3) fit into 5 bytes.
119  *  |c0||c1||c2||c3|
120  *   |\  |\  |\  |\
121  *  |8|2 6|4 4|6 2|8|
122  *
123  * This is used to save t1 (the high part of public key polynomial t)
124  *
125  * @param p A polynomial with coefficients all in the range (0..1023)
126  * @param pkt A packet object to write 320 bytes to.
127  *
128  * @returns 1 on success, or 0 on error.
129  */
poly_encode_10_bits(const POLY * p,WPACKET * pkt)130 static int poly_encode_10_bits(const POLY *p, WPACKET *pkt)
131 {
132     uint8_t *out;
133     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
134 
135     if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(10), &out))
136         return 0;
137 
138     do {
139         uint32_t c0 = *in++;
140         uint32_t c1 = *in++;
141         uint32_t c2 = *in++;
142         uint32_t c3 = *in++;
143 
144         *out++ = (uint8_t)c0;
145         *out++ = (uint8_t)((c0 >> 8) | (c1 << 2));
146         *out++ = (uint8_t)((c1 >> 6) | (c2 << 4));
147         *out++ = (uint8_t)((c2 >> 4) | (c3 << 6));
148         *out++ = (uint8_t)(c3 >> 2);
149     } while (in < end);
150     return 1;
151 }
152 
153 /*
154  * @brief Reverses the procedure of poly_encode_10_bits().
155  * See FIPS 204, Algorithm 18, SimpleBitUnpack(v, b) where b = 10.
156  *
157  * @param p A polynomial to write coefficients to.
158  * @param pkt A packet object to read 320 bytes from.
159  *
160  * @returns 1 on success, or 0 on error.
161  */
poly_decode_10_bits(POLY * p,PACKET * pkt)162 static int poly_decode_10_bits(POLY *p, PACKET *pkt)
163 {
164     const uint8_t *in = NULL;
165     uint32_t v, w, mask = 0x3ff; /* 10 bits */
166     uint32_t *out = p->coeff, *end = out + ML_DSA_NUM_POLY_COEFFICIENTS;
167 
168     do {
169         if (!PACKET_get_bytes(pkt, &in, 5))
170             return 0;
171 
172         in = OPENSSL_load_u32_le(&v, in);
173         w = *in;
174 
175         *out++ = v & mask;
176         *out++ = (v >> 10) & mask;
177         *out++ = (v >> 20) & mask;
178         *out++ = (v >> 30) | (w << 2);
179     } while (out < end);
180     return 1;
181 }
182 
183 /*
184  * @brief Encodes a polynomial into a byte string, assuming that all
185  * coefficients are in the range -4..4.
186  * See FIPS 204, Algorithm 17, BitPack(w, a, b). (a = 4, b = 4)
187  *
188  * It uses a nibble from each coefficient and packs them into bytes
189  * So every 2 coefficients fit into 1 byte.
190  *
191  * This is used to encode the private key polynomial elements of s1 and s2
192  * for ML-DSA-65 (i.e. eta = 4)
193  *
194  * @param p An array of 256 coefficients all in the range -4..4
195  * @param pkt A packet to write 128 bytes of encoded polynomial coefficients to.
196  *
197  * @returns 1 on success, or 0 on error.
198  */
poly_encode_signed_4(const POLY * p,WPACKET * pkt)199 static int poly_encode_signed_4(const POLY *p, WPACKET *pkt)
200 {
201     uint8_t *out;
202     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
203 
204     if (!WPACKET_allocate_bytes(pkt, 32 * 4, &out))
205         return 0;
206 
207     do {
208         uint32_t z = mod_sub(4, *in++);
209 
210         *out++ = z | (mod_sub(4, *in++) << 4);
211     } while (in < end);
212     return 1;
213 }
214 
215 /*
216  * @brief Reverses the procedure of poly_encode_signed_4().
217  * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = b = 4.
218  *
219  * @param p A polynomial to write coefficients to.
220  * @param pkt A packet object to read 128 bytes from.
221  *
222  * @returns 1 on success, or 0 on error. An error will occur if any of the
223  *          coefficients are not in the correct range.
224  */
poly_decode_signed_4(POLY * p,PACKET * pkt)225 static int poly_decode_signed_4(POLY *p, PACKET *pkt)
226 {
227     int i, ret = 0;
228     uint32_t v, *out = p->coeff;
229     const uint8_t *in;
230     uint32_t msbs, mask;
231 
232     for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
233         if (!PACKET_get_bytes(pkt, &in, 4))
234             goto err;
235         in = OPENSSL_load_u32_le(&v, in);
236 
237         /*
238          * None of the nibbles may be >= 9. So if the MSB of any nibble is set,
239          * none of the other bits may be set. First, select all the MSBs.
240          */
241         msbs = v & 0x88888888u;
242         /* For each nibble where the MSB is set, form a mask of all the other bits. */
243         mask = (msbs >> 1) | (msbs >> 2) | (msbs >> 3);
244         /*
245          * A nibble is only out of range in the case of invalid input, in which case
246          * it is okay to leak the value.
247          */
248         if (value_barrier_32((mask & v) != 0))
249             goto err;
250 
251         *out++ = mod_sub(4, v & 15);
252         *out++ = mod_sub(4, (v >> 4) & 15);
253         *out++ = mod_sub(4, (v >> 8) & 15);
254         *out++ = mod_sub(4, (v >> 12) & 15);
255         *out++ = mod_sub(4, (v >> 16) & 15);
256         *out++ = mod_sub(4, (v >> 20) & 15);
257         *out++ = mod_sub(4, (v >> 24) & 15);
258         *out++ = mod_sub(4, v >> 28);
259     }
260     ret = 1;
261  err:
262     return ret;
263 }
264 
265 /*
266  * @brief Encodes a polynomial into a byte string, assuming that all
267  * coefficients are in the range -2..2.
268  * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = b = 2.
269  *
270  * This is used to encode the private key polynomial elements of s1 and s2
271  * for ML-DSA-44 and ML-DSA-87 (i.e. eta = 2)
272  *
273  * @param pkt A packet to write 128 bytes of encoded polynomial coefficients to.
274  * @param p An array of 256 coefficients all in the range -2..2
275  *
276  * Use 3 bits from each coefficient and pack them into bytes
277  * So every 8 coefficients fit into 3 bytes.
278  *  |c0 c1 c2 c3 c4 c5 c6 c7|
279  *   | /  / | |  / / | |  /
280  *  |3 3 2| 1 3 3 1| 2 3 3|
281  *
282  * @param p An array of 256 coefficients all in the range -2..2
283  * @param pkt A packet to write 64 bytes of encoded polynomial coefficients to.
284  *
285  * @returns 1 on success, or 0 on error.
286  */
poly_encode_signed_2(const POLY * p,WPACKET * pkt)287 static int poly_encode_signed_2(const POLY *p, WPACKET *pkt)
288 {
289     uint8_t *out;
290     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
291 
292     if (!WPACKET_allocate_bytes(pkt, POLY_COEFF_NUM_BYTES(3), &out))
293         return 0;
294 
295     do {
296         uint32_t z;
297 
298         z = mod_sub(2, *in++);
299         z |= mod_sub(2, *in++) << 3;
300         z |= mod_sub(2, *in++) << 6;
301         z |= mod_sub(2, *in++) << 9;
302         z |= mod_sub(2, *in++) << 12;
303         z |= mod_sub(2, *in++) << 15;
304         z |= mod_sub(2, *in++) << 18;
305         z |= mod_sub(2, *in++) << 21;
306 
307         out = OPENSSL_store_u16_le(out, (uint16_t) z);
308         *out++ = (uint8_t) (z >> 16);
309     } while (in < end);
310     return 1;
311 }
312 
313 /*
314  * @brief Reverses the procedure of poly_encode_signed_2().
315  * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = b = 2.
316  *
317  * @param p A polynomial to write coefficients to.
318  * @param pkt A packet object to read 64 encoded bytes from.
319  *
320  * @returns 1 on success, or 0 on error. An error will occur if any of the
321  *          coefficients are not in the correct range.
322  */
poly_decode_signed_2(POLY * p,PACKET * pkt)323 static int poly_decode_signed_2(POLY *p, PACKET *pkt)
324 {
325     int i, ret = 0;
326     uint32_t u = 0, v = 0, *out = p->coeff;
327     uint32_t msbs, mask;
328     const uint8_t *in;
329 
330     for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
331         if (!PACKET_get_bytes(pkt, &in, 3))
332             goto err;
333         memcpy(&u, in, 3);
334         OPENSSL_load_u32_le(&v, (uint8_t *)&u);
335 
336         /*
337          * Each octal value (3 bits) must be <= 4, So if the MSB is set then the
338          * bottom 2 bits must not be set.
339          * First, select all the MSBs (Use octal representation for the mask)
340          */
341         msbs = v & 044444444;
342         /* For each octal value where the MSB is set, form a mask of the 2 other bits. */
343         mask = (msbs >> 1) | (msbs >> 2);
344         /*
345          * A nibble is only out of range in the case of invalid input, in which
346          * case it is okay to leak the value.
347          */
348         if (value_barrier_32((mask & v) != 0))
349             goto err;
350 
351         *out++ = mod_sub(2, v & 7);
352         *out++ = mod_sub(2, (v >> 3) & 7);
353         *out++ = mod_sub(2, (v >> 6) & 7);
354         *out++ = mod_sub(2, (v >> 9) & 7);
355         *out++ = mod_sub(2, (v >> 12) & 7);
356         *out++ = mod_sub(2, (v >> 15) & 7);
357         *out++ = mod_sub(2, (v >> 18) & 7);
358         *out++ = mod_sub(2, (v >> 21) & 7);
359     }
360     ret = 1;
361  err:
362     return ret;
363 }
364 
365 /*
366  * @brief Encodes a polynomial into a byte string, assuming that all
367  * coefficients are in the range (-2^12 + 1)..2^12.
368  * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^12 - 1, b = 2^12.
369  *
370  * This is used to encode the LSB of the public key polynomial elements of t0
371  * (which are encoded as part of the encoded private key).
372  *
373  * Use 13 bits from each coefficient and pack them into bytes
374  *
375  * The code below packs them into 2 64 bits blocks by doing..
376  *  z0 z1 z2 z3  z4  z5 z6  z7 0
377  *  |   |  | |   / \  |  |  |  |
378  * |13 13 13 13 12 |1 13 13 13 24
379  *
380  * @param p An array of 256 coefficients all in the range -2^12+1..2^12
381  * @param pkt A packet to write 416 (13 * 256 / 8) bytes of encoded polynomial
382  *            coefficients to.
383  *
384  * @returns 1 on success, or 0 on error.
385  */
poly_encode_signed_two_to_power_12(const POLY * p,WPACKET * pkt)386 static int poly_encode_signed_two_to_power_12(const POLY *p, WPACKET *pkt)
387 {
388     static const uint32_t range = 1u << 12;
389     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
390 
391     do {
392         uint8_t *out;
393         uint64_t a1, a2;
394 
395         if (!WPACKET_allocate_bytes(pkt, 13, &out))
396             return 0;
397 
398         a1 = mod_sub_64(range, *in++);
399         a1 |= mod_sub_64(range, *in++) << 13;
400         a1 |= mod_sub_64(range, *in++) << 26;
401         a1 |= mod_sub_64(range, *in++) << 39;
402         a1 |= (a2 = mod_sub_64(range, *in++)) << 52;
403         a2 = (a2 >> 12) | (mod_sub_64(range, *in++) << 1);
404         a2 |= mod_sub_64(range, *in++) << 14;
405         a2 |= mod_sub_64(range, *in++) << 27;
406 
407         out = OPENSSL_store_u64_le(out, a1);
408         out = OPENSSL_store_u32_le(out, (uint32_t) a2);
409         *out = (uint8_t) (a2 >> 32);
410     } while (in < end);
411     return 1;
412 }
413 
414 /*
415  * @brief Reverses the procedure of poly_encode_signed_two_to_power_12().
416  * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = 2^12 - 1, b = 2^12.
417  *
418  * @param p A polynomial to write coefficients to.
419  * @param pkt A packet object to read 416 encoded bytes from.
420  *
421  * @returns 1 on success, or 0 on error.
422  */
poly_decode_signed_two_to_power_12(POLY * p,PACKET * pkt)423 static int poly_decode_signed_two_to_power_12(POLY *p, PACKET *pkt)
424 {
425     int i, ret = 0;
426     uint32_t *out = p->coeff;
427     const uint8_t *in;
428     static const uint32_t range = 1u << 12;
429     static const uint32_t mask_13_bits = (1u << 13) - 1;
430 
431     for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
432         uint64_t a1;
433         uint32_t a2, b13;
434 
435         if (!PACKET_get_bytes(pkt, &in, 13))
436             goto err;
437         in = OPENSSL_load_u64_le(&a1, in);
438         in = OPENSSL_load_u32_le(&a2, in);
439         b13 = (uint32_t) *in;
440 
441         *out++ = mod_sub(range, a1 & mask_13_bits);
442         *out++ = mod_sub(range, (a1 >> 13) & mask_13_bits);
443         *out++ = mod_sub(range, (a1 >> 26) & mask_13_bits);
444         *out++ = mod_sub(range, (a1 >> 39) & mask_13_bits);
445         *out++ = mod_sub(range, (a1 >> 52) | ((a2 << 12) & mask_13_bits));
446         *out++ = mod_sub(range, (a2 >> 1) & mask_13_bits);
447         *out++ = mod_sub(range, (a2 >> 14) & mask_13_bits);
448         *out++ = mod_sub(range, (a2 >> 27) | (b13 << 5));
449     }
450     ret = 1;
451  err:
452     return ret;
453 }
454 
455 /*
456  * @brief Encodes a polynomial into a byte string, assuming that all
457  * coefficients are in the range (-2^19 + 1)..2^19.
458  * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^19 - 1, b = 2^19.
459  *
460  * This is used to encode signatures for ML-DSA-65 & ML-DSA-87 (gamma1 = 2^19)
461  *
462  * Use 20 bits from each coefficient and pack them into bytes
463  *
464  * The code below packs every 4 (20 bit) coefficients into 10 bytes
465  *  z0  z1  z2 z3
466  *  |   |\  |  | \
467  * |20 12|8 20 4|16
468  *
469  * @param p An array of 256 coefficients all in the range -2^19+1..2^19
470  * @param pkt A packet to write 640 (20 * 256 / 8) bytes of encoded polynomial
471  *            coefficients to.
472  *
473  * @returns 1 on success, or 0 on error.
474  */
poly_encode_signed_two_to_power_19(const POLY * p,WPACKET * pkt)475 static int poly_encode_signed_two_to_power_19(const POLY *p, WPACKET *pkt)
476 {
477     static const uint32_t range = 1u << 19;
478     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
479 
480     do {
481         uint32_t z0, z1, z2;
482         uint8_t *out;
483 
484         if (!WPACKET_allocate_bytes(pkt, 10, &out))
485             return 0;
486 
487         z0 = mod_sub(range, *in++);
488         z0 |= (z1 = mod_sub(range, *in++)) << 20;
489         z1 = (z1 >> 12) | (mod_sub(range, *in++) << 8);
490         z1 |= (z2 = mod_sub(range, *in++)) << 28;
491 
492         out = OPENSSL_store_u32_le(out, z0);
493         out = OPENSSL_store_u32_le(out, z1);
494         out = OPENSSL_store_u16_le(out, (uint16_t) (z2 >> 4));
495     } while (in < end);
496     return 1;
497 }
498 
499 /*
500  * @brief Reverses the procedure of poly_encode_signed_two_to_power_19().
501  * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = 2^19 - 1, b = 2^19.
502  *
503  * @param p A polynomial to write coefficients to.
504  * @param pkt A packet object to read 640 encoded bytes from.
505  *
506  * @returns 1 on success, or 0 on error.
507  */
poly_decode_signed_two_to_power_19(POLY * p,PACKET * pkt)508 static int poly_decode_signed_two_to_power_19(POLY *p, PACKET *pkt)
509 {
510     int i, ret = 0;
511     uint32_t *out = p->coeff;
512     const uint8_t *in;
513     static const uint32_t range = 1u << 19;
514     static const uint32_t mask_20_bits = (1u << 20) - 1;
515 
516     for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 4); i++) {
517         uint32_t a1, a2;
518         uint16_t a3;
519 
520         if (!PACKET_get_bytes(pkt, &in, 10))
521             goto err;
522         in = OPENSSL_load_u32_le(&a1, in);
523         in = OPENSSL_load_u32_le(&a2, in);
524         in = OPENSSL_load_u16_le(&a3, in);
525 
526         *out++ = mod_sub(range, a1 & mask_20_bits);
527         *out++ = mod_sub(range, (a1 >> 20) | ((a2 & 0xFF) << 12));
528         *out++ = mod_sub(range, (a2 >> 8) & mask_20_bits);
529         *out++ = mod_sub(range, (a2 >> 28) | (a3 << 4));
530     }
531     ret = 1;
532  err:
533     return ret;
534 }
535 
536 /*
537  * @brief Encodes a polynomial into a byte string, assuming that all
538  * coefficients are in the range (-2^17 + 1)..2^17.
539  * See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^17 - 1, b = 2^17.
540  *
541  * This is used to encode signatures for ML-DSA-44 (where gamma1 = 2^17)
542  *
543  * Use 18 bits from each coefficient and pack them into bytes
544  *
545  * The code below packs every 4 (18 bit) coefficients into 9 bytes
546  *  z0  z1  z2 z3
547  *  |   |\  |  | \
548  * |18 14|4 18 10| 8
549  *
550  * @param p An array of 256 coefficients all in the range -2^17+1..2^17
551  * @param pkt A packet to write 576 (18 * 256 / 8) bytes of encoded polynomial
552  *            coefficients to.
553  *
554  * @returns 1 on success, or 0 on error.
555  */
poly_encode_signed_two_to_power_17(const POLY * p,WPACKET * pkt)556 static int poly_encode_signed_two_to_power_17(const POLY *p, WPACKET *pkt)
557 {
558     static const uint32_t range = 1u << 17;
559     const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
560 
561     do {
562         uint8_t *out;
563         uint32_t z0, z1, z2;
564 
565         if (!WPACKET_allocate_bytes(pkt, 9, &out))
566             return 0;
567 
568         z0 = mod_sub(range, *in++);
569         z0 |= (z1 = mod_sub(range, *in++)) << 18;
570         z1 = (z1 >> 14) | (mod_sub(range, *in++) << 4);
571         z1 |= (z2 = mod_sub(range, *in++)) << 22;
572 
573         out = OPENSSL_store_u32_le(out, z0);
574         out = OPENSSL_store_u32_le(out, z1);
575         *out = z2 >> 10;
576     } while (in < end);
577     return 1;
578 }
579 
580 /*
581  * @brief Reverses the procedure of poly_encode_signed_two_to_power_17().
582  * See FIPS 204, Algorithm 19, BitUnpack(v, a, b) where a = 2^17 - 1, b = 2^17.
583  *
584  * @param p A polynomial to write coefficients to.
585  * @param pkt A packet object to read 576 encoded bytes from.
586  *
587  * @returns 1 on success, or 0 on error.
588  */
poly_decode_signed_two_to_power_17(POLY * p,PACKET * pkt)589 static int poly_decode_signed_two_to_power_17(POLY *p, PACKET *pkt)
590 {
591     uint32_t *out = p->coeff;
592     const uint32_t *end = out + ML_DSA_NUM_POLY_COEFFICIENTS;
593     const uint8_t *in;
594     static const uint32_t range = 1u << 17;
595     static const uint32_t mask_18_bits = (1u << 18) - 1;
596 
597     do {
598         uint32_t a1, a2, a3;
599 
600         if (!PACKET_get_bytes(pkt, &in, 9))
601             return 0;
602         in = OPENSSL_load_u32_le(&a1, in);
603         in = OPENSSL_load_u32_le(&a2, in);
604         a3 = (uint32_t) *in;
605 
606         *out++ = mod_sub(range, a1 & mask_18_bits);
607         *out++ = mod_sub(range, (a1 >> 18) | ((a2 & 0xF) << 14));
608         *out++ = mod_sub(range, (a2 >> 4) & mask_18_bits);
609         *out++ = mod_sub(range, (a2 >> 22) | (a3 << 10));
610     } while (out < end);
611     return 1;
612 }
613 
614 /*
615  * @brief Encode the public key as an array of bytes.
616  * See FIPS 204, Algorithm 22, pkEncode().
617  *
618  * @param key A key object containing public key values. The encoded public
619  *            key data is stored in this key.
620  * @returns 1 if the public key was encoded successfully or 0 otherwise.
621  */
ossl_ml_dsa_pk_encode(ML_DSA_KEY * key)622 int ossl_ml_dsa_pk_encode(ML_DSA_KEY *key)
623 {
624     int ret = 0;
625     size_t i, written = 0;
626     const POLY *t1 = key->t1.poly;
627     size_t t1_len = key->t1.num_poly;
628     size_t enc_len = key->params->pk_len;
629     uint8_t *enc = OPENSSL_malloc(enc_len);
630     WPACKET pkt;
631 
632     if (enc == NULL)
633         return 0;
634 
635     if (!WPACKET_init_static_len(&pkt, enc, enc_len, 0)
636             || !WPACKET_memcpy(&pkt, key->rho, sizeof(key->rho)))
637         goto err;
638     for (i = 0; i < t1_len; i++)
639         if (!poly_encode_10_bits(t1 + i, &pkt))
640             goto err;
641     if (!WPACKET_get_total_written(&pkt, &written)
642             || written != enc_len)
643         goto err;
644     OPENSSL_free(key->pub_encoding);
645     key->pub_encoding = enc;
646     ret = 1;
647 err:
648     WPACKET_finish(&pkt);
649     if (ret == 0)
650         OPENSSL_free(enc);
651     return ret;
652 }
653 
654 /*
655  * @brief The reverse of ossl_ml_dsa_pk_encode().
656  * See FIPS 204, Algorithm 23, pkDecode().
657  *
658  * @param in An encoded public key.
659  * @param in_len The size of |in|
660  * @param key A key object to store the decoded public key into.
661  *
662  * @returns 1 if the public key was decoded successfully or 0 otherwise.
663  */
ossl_ml_dsa_pk_decode(ML_DSA_KEY * key,const uint8_t * in,size_t in_len)664 int ossl_ml_dsa_pk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len)
665 {
666     int ret = 0;
667     size_t i;
668     PACKET pkt;
669     EVP_MD_CTX *ctx;
670 
671     if (key->priv_encoding != NULL || key->pub_encoding != NULL)
672         return 0; /* Do not allow key mutation */
673     if (in_len != key->params->pk_len)
674         return 0;
675 
676     if (!ossl_ml_dsa_key_pub_alloc(key))
677         return 0;
678     ctx = EVP_MD_CTX_new();
679     if (ctx == NULL)
680         goto err;
681     if (!PACKET_buf_init(&pkt, in, in_len)
682             || !PACKET_copy_bytes(&pkt, key->rho, sizeof(key->rho)))
683         goto err;
684     for (i = 0; i < key->t1.num_poly; i++)
685         if (!poly_decode_10_bits(key->t1.poly + i, &pkt))
686             goto err;
687 
688     /* cache the hash of the encoded public key */
689     if (!shake_xof(ctx, key->shake256_md, in, in_len, key->tr, sizeof(key->tr)))
690         goto err;
691 
692     key->pub_encoding = OPENSSL_memdup(in, in_len);
693     ret = (key->pub_encoding != NULL);
694 err:
695     EVP_MD_CTX_free(ctx);
696     return ret;
697 }
698 
699 /*
700  * @brief Encode the private key as an array of bytes.
701  * See FIPS 204, Algorithm 24, skEncode().
702  *
703  * @param key A key object containing private key values. The encoded private
704  *            key data is stored in this key.
705  * @returns 1 if the private key was encoded successfully or 0 otherwise.
706  */
ossl_ml_dsa_sk_encode(ML_DSA_KEY * key)707 int ossl_ml_dsa_sk_encode(ML_DSA_KEY *key)
708 {
709     int ret = 0;
710     const ML_DSA_PARAMS *params = key->params;
711     size_t i, written = 0, k = params->k, l = params->l;
712     ENCODE_FN *encode_fn;
713     size_t enc_len = params->sk_len;
714     const POLY *t0 = key->t0.poly;
715     WPACKET pkt;
716     uint8_t *enc = OPENSSL_malloc(enc_len);
717 
718     if (enc == NULL)
719         return 0;
720 
721     /* eta is the range of private key coefficients (-eta...eta) */
722     if (params->eta == ML_DSA_ETA_4)
723         encode_fn = poly_encode_signed_4;
724     else
725         encode_fn = poly_encode_signed_2;
726 
727     if (!WPACKET_init_static_len(&pkt, enc, enc_len, 0)
728             || !WPACKET_memcpy(&pkt, key->rho, sizeof(key->rho))
729             || !WPACKET_memcpy(&pkt, key->K, sizeof(key->K))
730             || !WPACKET_memcpy(&pkt, key->tr, sizeof(key->tr)))
731         goto err;
732     for (i = 0; i < l; ++i)
733         if (!encode_fn(key->s1.poly + i, &pkt))
734             goto err;
735     for (i = 0; i < k; ++i)
736         if (!encode_fn(key->s2.poly + i, &pkt))
737             goto err;
738     for (i = 0; i < k; ++i)
739         if (!poly_encode_signed_two_to_power_12(t0++, &pkt))
740             goto err;
741     if (!WPACKET_get_total_written(&pkt, &written)
742             || written != enc_len)
743         goto err;
744     OPENSSL_clear_free(key->priv_encoding, enc_len);
745     key->priv_encoding = enc;
746     ret = 1;
747 err:
748     WPACKET_finish(&pkt);
749     if (ret == 0)
750         OPENSSL_clear_free(enc, enc_len);
751     return ret;
752 }
753 
754 /*
755  * @brief The reverse of ossl_ml_dsa_sk_encode().
756  * See FIPS 204, Algorithm 24, skDecode().
757  *
758  * @param in An encoded private key.
759  * @param in_len The size of |in|
760  * @param key A key object to store the decoded private key into.
761  *
762  * @returns 1 if the private key was decoded successfully or 0 otherwise.
763  */
ossl_ml_dsa_sk_decode(ML_DSA_KEY * key,const uint8_t * in,size_t in_len)764 int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len)
765 {
766     DECODE_FN *decode_fn;
767     const ML_DSA_PARAMS *params = key->params;
768     size_t i, k = params->k, l = params->l;
769     uint8_t input_tr[ML_DSA_TR_BYTES];
770     PACKET pkt;
771 
772     /* When loading from an explicit key, drop the seed. */
773     OPENSSL_clear_free(key->seed, ML_DSA_SEED_BYTES);
774     key->seed = NULL;
775 
776     /* Allow the key encoding to be already set to the provided pointer */
777     if ((key->priv_encoding != NULL && key->priv_encoding != in)
778         || key->pub_encoding != NULL)
779         return 0; /* Do not allow key mutation */
780     if (in_len != key->params->sk_len)
781         return 0;
782     if (!ossl_ml_dsa_key_priv_alloc(key))
783         return 0;
784 
785     /* eta is the range of private key coefficients (-eta...eta) */
786     if (params->eta == ML_DSA_ETA_4)
787         decode_fn = poly_decode_signed_4;
788     else
789         decode_fn = poly_decode_signed_2;
790 
791     if (!PACKET_buf_init(&pkt, in, in_len)
792             || !PACKET_copy_bytes(&pkt, key->rho, sizeof(key->rho))
793             || !PACKET_copy_bytes(&pkt, key->K, sizeof(key->K))
794             || !PACKET_copy_bytes(&pkt, input_tr, sizeof(input_tr)))
795         return 0;
796 
797     for (i = 0; i < l; ++i)
798         if (!decode_fn(key->s1.poly + i, &pkt))
799             goto err;
800     for (i = 0; i < k; ++i)
801         if (!decode_fn(key->s2.poly + i, &pkt))
802             goto err;
803     for (i = 0; i < k; ++i)
804         if (!poly_decode_signed_two_to_power_12(key->t0.poly + i, &pkt))
805             goto err;
806     if (PACKET_remaining(&pkt) != 0)
807         goto err;
808     if (key->priv_encoding == NULL
809         && (key->priv_encoding = OPENSSL_memdup(in, in_len)) == NULL)
810         goto err;
811     /*
812      * Computing the public key also computes its hash, which must be equal to
813      * the |tr| value in the private key, else the key was corrupted.
814      */
815     if (!ossl_ml_dsa_key_public_from_private(key)
816             || memcmp(input_tr, key->tr, sizeof(input_tr)) != 0) {
817         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
818                        "%s private key does not match its pubkey part",
819                        key->params->alg);
820         ossl_ml_dsa_key_reset(key);
821         goto err;
822     }
823 
824     return 1;
825  err:
826     return 0;
827 }
828 
829 /*
830  * See FIPS 204, Algorithm 20, HintBitPack().
831  * Hint is composed of k polynomials with binary coefficients where only 'omega'
832  * of all the coefficients are set to 1.
833  * This can be encoded as a byte array of 'omega' polynomial coefficient index
834  * positions for the coefficients that are set, followed by
835  * k values of the last coefficient index used in each polynomial.
836  */
hint_bits_encode(const VECTOR * hint,WPACKET * pkt,uint32_t omega)837 static int hint_bits_encode(const VECTOR *hint, WPACKET *pkt, uint32_t omega)
838 {
839     int i, j, k = hint->num_poly;
840     size_t coeff_index = 0;
841     POLY *p = hint->poly;
842     uint8_t *data;
843 
844     if (!WPACKET_allocate_bytes(pkt, omega + k, &data))
845         return 0;
846     memset(data, 0, omega + k);
847 
848     for (i = 0; i < k; i++, p++) {
849         for (j = 0; j < ML_DSA_NUM_POLY_COEFFICIENTS; j++)
850             if (p->coeff[j] != 0)
851                 data[coeff_index++] = j;
852         data[omega + i] = (uint8_t)coeff_index;
853     }
854     return 1;
855 }
856 
857 /*
858  * @brief Reverse the process of hint_bits_encode()
859  * See FIPS 204, Algorithm 21, HintBitUnpack()
860  *
861  * @returns 1 if the hints were successfully unpacked, or 0
862  * if 'pkt' is too small or malformed.
863  */
hint_bits_decode(VECTOR * hint,PACKET * pkt,uint32_t omega)864 static int hint_bits_decode(VECTOR *hint, PACKET *pkt, uint32_t omega)
865 {
866     size_t coeff_index = 0, k = hint->num_poly;
867     const uint8_t *in, *limits;
868     POLY *p = hint->poly, *end = p + k;
869 
870     if (!PACKET_get_bytes(pkt, &in, omega)
871             || !PACKET_get_bytes(pkt, &limits, k))
872         return 0;
873 
874     vector_zero(hint); /* Set all coefficients to zero */
875 
876     do {
877         const uint32_t limit = *limits++;
878         int last = -1;
879 
880         if (limit < coeff_index || limit > omega)
881             return 0;
882 
883         while (coeff_index < limit) {
884             int byte = in[coeff_index++];
885 
886             if (last >= 0 && byte <= last)
887                 return 0;
888             last = byte;
889             p->coeff[byte] = 1;
890         }
891     } while (++p < end);
892 
893     for (; coeff_index < omega; coeff_index++)
894         if (in[coeff_index] != 0)
895             return 0;
896     return 1;
897 }
898 
899 /*
900  * @brief Encode a ML_DSA signature as an array of bytes.
901  * See FIPS 204, Algorithm 26, sigEncode().
902  *
903  * @param
904  * @param
905  * @returns 1 if the signature was encoded successfully or 0 otherwise.
906  */
ossl_ml_dsa_sig_encode(const ML_DSA_SIG * sig,const ML_DSA_PARAMS * params,uint8_t * out)907 int ossl_ml_dsa_sig_encode(const ML_DSA_SIG *sig, const ML_DSA_PARAMS *params,
908                            uint8_t *out)
909 {
910     int ret = 0;
911     size_t i;
912     ENCODE_FN *encode_fn;
913     WPACKET pkt;
914 
915     if (out == NULL)
916         return 0;
917 
918     if (params->gamma1 == ML_DSA_GAMMA1_TWO_POWER_19)
919         encode_fn = poly_encode_signed_two_to_power_19;
920     else
921         encode_fn = poly_encode_signed_two_to_power_17;
922 
923     if (!WPACKET_init_static_len(&pkt, out, params->sig_len, 0)
924             || !WPACKET_memcpy(&pkt, sig->c_tilde, sig->c_tilde_len))
925         goto err;
926 
927     for (i = 0; i < sig->z.num_poly; ++i)
928         if (!encode_fn(sig->z.poly + i, &pkt))
929             goto err;
930     if (!hint_bits_encode(&sig->hint, &pkt, params->omega))
931         goto err;
932     ret = 1;
933 err:
934     WPACKET_finish(&pkt);
935     return ret;
936 }
937 
938 /*
939  * @param sig is a initialized signature object to decode into.
940  * @param in An encoded signature
941  * @param in_len The size of |in|
942  * @param params contains constants for an ML-DSA algorithm (such as gamma1)
943  * @returns 1 if the signature was successfully decoded or 0 otherwise.
944  */
ossl_ml_dsa_sig_decode(ML_DSA_SIG * sig,const uint8_t * in,size_t in_len,const ML_DSA_PARAMS * params)945 int ossl_ml_dsa_sig_decode(ML_DSA_SIG *sig, const uint8_t *in, size_t in_len,
946                            const ML_DSA_PARAMS *params)
947 {
948     int ret = 0;
949     size_t i;
950     DECODE_FN *decode_fn;
951     PACKET pkt;
952 
953     if (params->gamma1 == ML_DSA_GAMMA1_TWO_POWER_19)
954         decode_fn = poly_decode_signed_two_to_power_19;
955     else
956         decode_fn = poly_decode_signed_two_to_power_17;
957 
958     if (!PACKET_buf_init(&pkt, in, in_len)
959             || !PACKET_copy_bytes(&pkt, sig->c_tilde, sig->c_tilde_len))
960         goto err;
961     for (i = 0; i < sig->z.num_poly; ++i)
962         if (!decode_fn(sig->z.poly + i, &pkt))
963             goto err;
964 
965     if (!hint_bits_decode(&sig->hint, &pkt, params->omega)
966             || PACKET_remaining(&pkt) != 0)
967         goto err;
968     ret = 1;
969 err:
970     return ret;
971 }
972 
ossl_ml_dsa_poly_decode_expand_mask(POLY * out,const uint8_t * in,size_t in_len,uint32_t gamma1)973 int ossl_ml_dsa_poly_decode_expand_mask(POLY *out,
974                                         const uint8_t *in, size_t in_len,
975                                         uint32_t gamma1)
976 {
977     PACKET pkt;
978 
979     if (!PACKET_buf_init(&pkt, in, in_len))
980         return 0;
981     if (gamma1 == ML_DSA_GAMMA1_TWO_POWER_19)
982         return poly_decode_signed_two_to_power_19(out, &pkt);
983     else
984         return poly_decode_signed_two_to_power_17(out, &pkt);
985 }
986 
987 /*
988  * @brief Encode a polynomial vector as an array of bytes.
989  * Where the polynomial coefficients have a range of [0..15] or [0..43]
990  * depending on the value of gamma2.
991  *
992  * See FIPS 204, Algorithm 28, w1Encode().
993  *
994  * @param w1 The vector to convert to bytes
995  * @param gamma2 either ML_DSA_GAMMA2_Q_MINUS1_DIV32 or ML_DSA_GAMMA2_Q_MINUS1_DIV88
996  * @returns 1 if the signature was encoded successfully or 0 otherwise.
997  */
ossl_ml_dsa_w1_encode(const VECTOR * w1,uint32_t gamma2,uint8_t * out,size_t out_len)998 int ossl_ml_dsa_w1_encode(const VECTOR *w1, uint32_t gamma2,
999                           uint8_t *out, size_t out_len)
1000 {
1001     WPACKET pkt;
1002     ENCODE_FN *encode_fn;
1003     int ret = 0;
1004     size_t i;
1005 
1006     if (!WPACKET_init_static_len(&pkt, out, out_len, 0))
1007         return 0;
1008     if (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV32)
1009         encode_fn = poly_encode_4_bits;
1010     else
1011         encode_fn = poly_encode_6_bits;
1012     for (i = 0; i < w1->num_poly; ++i)
1013         if (!encode_fn(w1->poly + i, &pkt))
1014             goto err;
1015     ret = 1;
1016 err:
1017     WPACKET_finish(&pkt);
1018     return ret;
1019 }
1020