xref: /freebsd/crypto/openssl/crypto/ml_dsa/ml_dsa_key_compress.c (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include "ml_dsa_local.h"
11 
12 /* Key Compression related functions (Rounding & hints) */
13 
14 /**
15  * @brief Decompose r into (r1, r0) such that r == r1 * 2^13 + r0 mod q
16  * See FIPS 204, Algorithm 35, Power2Round()
17  *
18  * Note: that this code is more complex than the FIPS 204 spec since it keeps
19  * r0 as a positive number
20  *
21  * r mod +- 2^13 is defined as having a range of -4095..4096
22  *
23  * i.e for r = 0..4096 r1 = 0 and r0 = 0..4096
24  * at r = 4097..8191 r1 = 1 and r0 = -4095..0
25  * (but since r0 is kept positive it effectively adds q and then reduces by q if needed)
26  * Similarly for the range r = 8192..8192+4096 r1=1 and r0=0..4096
27  * & 12289..16383 r1=2 and r0=-4095..0
28  *
29  * @param r is in the range 0..q-1
30  * @param r1 The returned top 10 MSB (i.e it ranges from 0..1023)
31  * @param r0 The remainder in the range (0..4096 or q-4095..q-1)
32  *           So r0 has an effective range of 8192 (i.e. 13 bits).
33  */
ossl_ml_dsa_key_compress_power2_round(uint32_t r,uint32_t * r1,uint32_t * r0)34 void ossl_ml_dsa_key_compress_power2_round(uint32_t r, uint32_t *r1, uint32_t *r0)
35 {
36     unsigned int mask;
37     uint32_t r0_adjusted, r1_adjusted;
38 
39     *r1 = r >> ML_DSA_D_BITS;         /* top 13 bits */
40     *r0 = r - (*r1 << ML_DSA_D_BITS); /* The remainder mod q */
41 
42     r0_adjusted = mod_sub(*r0, 1 << ML_DSA_D_BITS);
43     r1_adjusted = *r1 + 1;
44 
45     /* Mask is set iff r0 > (2^(dropped_bits))/2. */
46     mask = constant_time_lt((uint32_t)(1 << (ML_DSA_D_BITS - 1)), *r0);
47     /* r0 = mask ? r0_adjusted : r0 */
48     *r0 = constant_time_select_int(mask, r0_adjusted, *r0);
49     /* r1 = mask ? r1_adjusted : r1 */
50     *r1 = constant_time_select_int(mask, r1_adjusted, *r1);
51 }
52 
53 /*
54  * @brief return the r1 component of Decomposing r into (r1, r0) such that
55  * r == r1 * (2 * gamma2) + r0 mod q
56  * See FIPS 204, Algorithm 37, HighBits()
57  *
58  * @param r A value to decompose in the range (0..q-1)
59  * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
60  * @returns r1 (The high order bits)
61  */
ossl_ml_dsa_key_compress_high_bits(uint32_t r,uint32_t gamma2)62 uint32_t ossl_ml_dsa_key_compress_high_bits(uint32_t r, uint32_t gamma2)
63 {
64     int32_t r1 = (r + 127) >> 7;
65 
66     if (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV32) {
67         r1 = (r1 * 1025 + (1 << 21)) >> 22;
68         r1 &= 15; /* mod 16 */
69         return r1;
70     } else {
71         r1 = (r1 * 11275 + (1 << 23)) >> 24;
72         r1 ^= ((43 - r1) >> 31) & r1;
73         return r1;
74     }
75 }
76 
77 /**
78  * @brief Decomposes r into (r1, r0) such that r == r1 * (2*gamma2) + r0 mod q.
79  * See FIPS 204, Algorithm 36, Decompose()
80  *
81  * @param r A value to decompose in the range (0..q-1)
82  * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
83  * @param r1 The returned high order bits
84  * @param r0 The returned low order bits
85  */
ossl_ml_dsa_key_compress_decompose(uint32_t r,uint32_t gamma2,uint32_t * r1,int32_t * r0)86 void ossl_ml_dsa_key_compress_decompose(uint32_t r, uint32_t gamma2,
87                                         uint32_t *r1, int32_t *r0)
88 {
89     *r1 = ossl_ml_dsa_key_compress_high_bits(r, gamma2);
90 
91     *r0 = r - *r1 * 2 * (int32_t)gamma2;
92     *r0 -= (((int32_t)ML_DSA_Q_MINUS1_DIV2 - *r0) >> 31) & (int32_t)ML_DSA_Q;
93 }
94 
95 /**
96  * @brief return the r0 component of Decomposing r into (r1, r0) such that
97  * r == r1 * (2 * gamma2) + r0 mod q
98  * See FIPS 204, Algorithm 38, LowBits()
99  *
100  * @param r A value to decompose in the range (0..q-1)
101  * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
102  * @param r0 The returned low order bits
103  */
ossl_ml_dsa_key_compress_low_bits(uint32_t r,uint32_t gamma2)104 int32_t ossl_ml_dsa_key_compress_low_bits(uint32_t r, uint32_t gamma2)
105 {
106     uint32_t r1;
107     int32_t r0;
108 
109     ossl_ml_dsa_key_compress_decompose(r, gamma2, &r1, &r0);
110     return r0;
111 }
112 
113 /*
114  * @brief Computes hint bit indicating whether adding z to r alters the high
115  * bits of r
116  * See FIPS 204, Algorithm 39, MakeHint().
117  *
118  * In the spec this takes two arguments, z and r, and is called with
119  *   z = -ct0
120  *   r = w - cs2 + ct0
121  *
122  * It then computes HighBits (algorithm 37) of z and z+r.
123  * But z + r is just w - cs2, so this takes three arguments and saves an addition.
124  *
125  * @params ct0 A polynomial c (with coefficients of (-1,0,1)) multiplied by the
126  *             polynomial vector t0 (which encodes the least significant bits of each coefficient of the
127                uncompressed public-key polynomial t)
128  * @params cs2 A polynomial c (with coefficients of (-1,0,1)) multiplied by s2 (a secret polynomial)
129  * @params gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
130  * @params w  (A * y)
131  * @returns The hint bit.
132  */
ossl_ml_dsa_key_compress_make_hint(uint32_t ct0,uint32_t cs2,uint32_t gamma2,uint32_t w)133 int32_t ossl_ml_dsa_key_compress_make_hint(uint32_t ct0, uint32_t cs2,
134                                            uint32_t gamma2, uint32_t w)
135 {
136     uint32_t r_plus_z = mod_sub(w, cs2);
137     uint32_t r = reduce_once(r_plus_z + ct0);
138 
139     return  ossl_ml_dsa_key_compress_high_bits(r, gamma2)
140         !=  ossl_ml_dsa_key_compress_high_bits(r_plus_z, gamma2);
141 }
142 
143 /*
144  * @brief Returns the high bits of |r| adjusted according to hint |h|.
145  * FIPS 204, Algorithm 40, UseHint().
146  * This is not constant time.
147  *
148  * @param hint The hint bit which is either 0 or 1
149  * @param r A value to decompose in the range (0..q-1)
150  * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
151  *
152  * @returns The adjusted high bits or r.
153  */
ossl_ml_dsa_key_compress_use_hint(uint32_t hint,uint32_t r,uint32_t gamma2)154 uint32_t ossl_ml_dsa_key_compress_use_hint(uint32_t hint, uint32_t r,
155                                            uint32_t gamma2)
156 {
157     uint32_t r1;
158     int32_t r0;
159 
160     ossl_ml_dsa_key_compress_decompose(r, gamma2, &r1, &r0);
161 
162     if (hint == 0)
163         return r1;
164 
165     if (gamma2 == ((ML_DSA_Q - 1) / 32)) {
166         /* m = 16, thus |mod m| in the spec turns into |& 15| */
167         return r0 > 0 ? (r1 + 1) & 15 : (r1 - 1) & 15;
168     } else {
169         /* m = 44 if gamma2 = ((q - 1) / 88) */
170         if (r0 > 0)
171             return (r1 == 43) ? 0 : r1 + 1;
172         else
173             return (r1 == 0) ? 43 : r1 - 1;
174     }
175 }
176