xref: /freebsd/crypto/openssl/crypto/ml_dsa/ml_dsa_sign.c (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <openssl/core_dispatch.h>
11 #include <openssl/core_names.h>
12 #include <openssl/params.h>
13 #include <openssl/rand.h>
14 #include "ml_dsa_local.h"
15 #include "ml_dsa_key.h"
16 #include "ml_dsa_matrix.h"
17 #include "ml_dsa_sign.h"
18 #include "ml_dsa_hash.h"
19 
20 #define ML_DSA_MAX_LAMBDA 256 /* bit strength for ML-DSA-87 */
21 
22 /*
23  * @brief Initialize a Signature object by pointing all of its objects to
24  * preallocated blocks. The values passed for hint, z and
25  * c_tilde values are not owned/freed by the |sig| object.
26  *
27  * @param sig The ML_DSA_SIG to initialize.
28  * @param hint A preallocated array of |k| polynomial blocks
29  * @param k The number of |hint| polynomials
30  * @param z A preallocated array of |l| polynomial blocks
31  * @param l The number of |z| polynomials
32  * @param c_tilde A preallocated buffer
33  * @param c_tilde_len The size of |c_tilde|
34  */
signature_init(ML_DSA_SIG * sig,POLY * hint,uint32_t k,POLY * z,uint32_t l,uint8_t * c_tilde,size_t c_tilde_len)35 static void signature_init(ML_DSA_SIG *sig,
36                            POLY *hint, uint32_t k, POLY *z, uint32_t l,
37                            uint8_t *c_tilde, size_t c_tilde_len)
38 {
39     vector_init(&sig->z, z, l);
40     vector_init(&sig->hint, hint, k);
41     sig->c_tilde = c_tilde;
42     sig->c_tilde_len = c_tilde_len;
43 }
44 
45 /*
46  * FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
47  * @returns 1 on success and 0 on failure.
48  */
ml_dsa_sign_internal(const ML_DSA_KEY * priv,int msg_is_mu,const uint8_t * encoded_msg,size_t encoded_msg_len,const uint8_t * rnd,size_t rnd_len,uint8_t * out_sig)49 static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
50                                 const uint8_t *encoded_msg,
51                                 size_t encoded_msg_len,
52                                 const uint8_t *rnd, size_t rnd_len,
53                                 uint8_t *out_sig)
54 {
55     int ret = 0;
56     const ML_DSA_PARAMS *params = priv->params;
57     EVP_MD_CTX *md_ctx = NULL;
58     uint32_t k = params->k, l = params->l;
59     uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2;
60     uint8_t *alloc = NULL, *w1_encoded;
61     size_t alloc_len, w1_encoded_len;
62     size_t num_polys_sig_k = 2 * k;
63     size_t num_polys_k = 5 * k;
64     size_t num_polys_l = 3 * l;
65     size_t num_polys_k_by_l = k * l;
66     POLY *polys = NULL, *p, *c_ntt;
67     VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
68     MATRIX a_ntt;
69     ML_DSA_SIG sig;
70     uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
71     const size_t mu_len = sizeof(mu);
72     uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES];
73     uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
74     size_t c_tilde_len = params->bit_strength >> 2;
75     size_t kappa;
76 
77     /*
78      * Allocate a single blob for most of the variable size temporary variables.
79      * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K).
80      */
81     w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
82     alloc_len = w1_encoded_len
83         + sizeof(*polys) * (1 + num_polys_k + num_polys_l
84                             + num_polys_k_by_l + num_polys_sig_k);
85     alloc = OPENSSL_malloc(alloc_len);
86     if (alloc == NULL)
87         return 0;
88     md_ctx = EVP_MD_CTX_new();
89     if (md_ctx == NULL)
90         goto err;
91 
92     w1_encoded = alloc;
93     /* Init the temp vectors to point to the allocated polys blob */
94     p = (POLY *)(w1_encoded + w1_encoded_len);
95     c_ntt = p++;
96     matrix_init(&a_ntt, p, k, l);
97     p += num_polys_k_by_l;
98     vector_init(&s2_ntt, p, k);
99     vector_init(&t0_ntt, s2_ntt.poly + k, k);
100     vector_init(&w, t0_ntt.poly + k, k);
101     vector_init(&w1, w.poly + k, k);
102     vector_init(&cs2, w1.poly + k, k);
103     p += num_polys_k;
104     vector_init(&s1_ntt, p, l);
105     vector_init(&y, p + l, l);
106     vector_init(&cs1, p + 2 * l, l);
107     p += num_polys_l;
108     signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len);
109     /* End of the allocated blob setup */
110 
111     if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt))
112         goto err;
113     if (msg_is_mu) {
114         if (encoded_msg_len != mu_len)
115             goto err;
116         mu_ptr = (uint8_t *)encoded_msg;
117     } else {
118         if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr),
119                          encoded_msg, encoded_msg_len, mu_ptr, mu_len))
120             goto err;
121     }
122     if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
123                      rnd, rnd_len, mu_ptr, mu_len,
124                      rho_prime, sizeof(rho_prime)))
125         goto err;
126 
127     vector_copy(&s1_ntt, &priv->s1);
128     vector_ntt(&s1_ntt);
129     vector_copy(&s2_ntt, &priv->s2);
130     vector_ntt(&s2_ntt);
131     vector_copy(&t0_ntt, &priv->t0);
132     vector_ntt(&t0_ntt);
133 
134     /*
135      * kappa must not exceed 2^16. But the probability of it
136      * exceeding even 1000 iterations is vanishingly small.
137      */
138     for (kappa = 0; ; kappa += l) {
139         VECTOR *y_ntt = &cs1;
140         VECTOR *r0 = &w1;
141         VECTOR *ct0 = &w1;
142         uint32_t z_max, r0_max, ct0_max, h_ones;
143 
144         vector_expand_mask(&y, rho_prime, sizeof(rho_prime), kappa,
145                            gamma1, md_ctx, priv->shake256_md);
146         vector_copy(y_ntt, &y);
147         vector_ntt(y_ntt);
148 
149         matrix_mult_vector(&a_ntt, y_ntt, &w);
150         vector_ntt_inverse(&w);
151 
152         vector_high_bits(&w, gamma2, &w1);
153         ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);
154 
155         if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len,
156                          w1_encoded, w1_encoded_len, c_tilde, c_tilde_len))
157             break;
158 
159         if (!poly_sample_in_ball_ntt(c_ntt, c_tilde, c_tilde_len,
160                                      md_ctx, priv->shake256_md, params->tau))
161             break;
162 
163         vector_mult_scalar(&s1_ntt, c_ntt, &cs1);
164         vector_ntt_inverse(&cs1);
165         vector_mult_scalar(&s2_ntt, c_ntt, &cs2);
166         vector_ntt_inverse(&cs2);
167 
168         vector_add(&y, &cs1, &sig.z);
169 
170         /* r0 = lowbits(w - cs2) */
171         vector_sub(&w, &cs2, r0);
172         vector_low_bits(r0, gamma2, r0);
173 
174         /*
175          * Leaking that the signature is rejected is fine as the next attempt at a
176          * signature will be (indistinguishable from) independent of this one.
177          */
178         z_max = vector_max(&sig.z);
179         r0_max = vector_max_signed(r0);
180         if (value_barrier_32(constant_time_ge(z_max, gamma1 - params->beta)
181                              | constant_time_ge(r0_max, gamma2 - params->beta)))
182             continue;
183 
184         vector_mult_scalar(&t0_ntt, c_ntt, ct0);
185         vector_ntt_inverse(ct0);
186         vector_make_hint(ct0, &cs2, &w, gamma2, &sig.hint);
187 
188         ct0_max = vector_max(ct0);
189         h_ones = vector_count_ones(&sig.hint);
190         /* Same reasoning applies to the leak as above */
191         if (value_barrier_32(constant_time_ge(ct0_max, gamma2)
192                              | constant_time_lt(params->omega, h_ones)))
193             continue;
194         ret = ossl_ml_dsa_sig_encode(&sig, params, out_sig);
195         break;
196     }
197 err:
198     EVP_MD_CTX_free(md_ctx);
199     OPENSSL_clear_free(alloc, alloc_len);
200     OPENSSL_cleanse(rho_prime, sizeof(rho_prime));
201     return ret;
202 }
203 
204 /*
205  * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
206  */
ml_dsa_verify_internal(const ML_DSA_KEY * pub,int msg_is_mu,const uint8_t * msg_enc,size_t msg_enc_len,const uint8_t * sig_enc,size_t sig_enc_len)207 static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
208                                   const uint8_t *msg_enc, size_t msg_enc_len,
209                                   const uint8_t *sig_enc, size_t sig_enc_len)
210 {
211     int ret = 0;
212     uint8_t *alloc = NULL, *w1_encoded;
213     POLY *polys = NULL, *p, *c_ntt;
214     MATRIX a_ntt;
215     VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx;
216     ML_DSA_SIG sig;
217     const ML_DSA_PARAMS *params = pub->params;
218     uint32_t k = pub->params->k;
219     uint32_t l = pub->params->l;
220     uint32_t gamma2 = params->gamma2;
221     size_t w1_encoded_len;
222     size_t num_polys_sig = k + l;
223     size_t num_polys_k = 2 * k;
224     size_t num_polys_l = 1 * l;
225     size_t num_polys_k_by_l = k * l;
226     uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
227     const size_t mu_len = sizeof(mu);
228     uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
229     uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4];
230     EVP_MD_CTX *md_ctx = NULL;
231     size_t c_tilde_len = params->bit_strength >> 2;
232     uint32_t z_max;
233 
234     /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */
235     w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
236     alloc = OPENSSL_malloc(w1_encoded_len
237                            + sizeof(*polys) * (1 + num_polys_k
238                                                + num_polys_l
239                                                + num_polys_k_by_l
240                                                + num_polys_sig));
241     if (alloc == NULL)
242         return 0;
243     md_ctx = EVP_MD_CTX_new();
244     if (md_ctx == NULL)
245         goto err;
246 
247     w1_encoded = alloc;
248     /* Init the temp vectors to point to the allocated polys blob */
249     p = (POLY *)(w1_encoded + w1_encoded_len);
250     c_ntt = p++;
251     matrix_init(&a_ntt, p, k, l);
252     p += num_polys_k_by_l;
253     signature_init(&sig, p, k, p + k, l, c_tilde_sig, c_tilde_len);
254     p += num_polys_sig;
255     vector_init(&az_ntt, p, k);
256     vector_init(&ct1_ntt, p + k, k);
257 
258     if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params)
259             || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt))
260         goto err;
261     if (msg_is_mu) {
262         if (msg_enc_len != mu_len)
263             goto err;
264         mu_ptr = (uint8_t *)msg_enc;
265     } else {
266         if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr),
267                          msg_enc, msg_enc_len, mu_ptr, mu_len))
268             goto err;
269     }
270     /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */
271     if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len,
272                                  md_ctx, pub->shake256_md, params->tau))
273         goto err;
274 
275     /* ct1_ntt = NTT(c) * NTT(t1 * 2^d) */
276     vector_scale_power2_round_ntt(&pub->t1, &ct1_ntt);
277     vector_mult_scalar(&ct1_ntt, c_ntt, &ct1_ntt);
278 
279     /* compute z_max early in order to reuse sig.z */
280     z_max = vector_max(&sig.z);
281 
282     /* w_approx = NTT_inverse(A * NTT(z) - ct1_ntt) */
283     z_ntt = &sig.z;
284     vector_ntt(z_ntt);
285     matrix_mult_vector(&a_ntt, z_ntt, &az_ntt);
286     w_approx = &az_ntt;
287     vector_sub(&az_ntt, &ct1_ntt, w_approx);
288     vector_ntt_inverse(w_approx);
289 
290     /* compute w1_encoded */
291     w1 = w_approx;
292     vector_use_hint(&sig.hint, w_approx, gamma2, w1);
293     ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);
294 
295     if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len,
296                      w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len))
297         goto err;
298 
299     ret = (z_max < (uint32_t)(params->gamma1 - params->beta))
300         && memcmp(c_tilde, sig.c_tilde, c_tilde_len) == 0;
301 err:
302     OPENSSL_free(alloc);
303     EVP_MD_CTX_free(md_ctx);
304     return ret;
305 }
306 
307 /**
308  * @brief Encode a message
309  * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5).
310  *
311  * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg
312  * Where ctx is the empty string by default and ctx_len <= 255.
313  *
314  * Note this code could be shared with SLH_DSA
315  *
316  * @param msg A message to encode
317  * @param msg_len The size of |msg|
318  * @param ctx An optional context to add to the message encoding.
319  * @param ctx_len The size of |ctx|. It must be in the range 0..255
320  * @param encode Use the Pure signature encoding if this is 1, and dont encode
321  *               if this value is 0.
322  * @param tmp A small buffer that may be used if the message is small.
323  * @param tmp_len The size of |tmp|
324  * @param out_len The size of the returned encoded buffer.
325  * @returns A buffer containing the encoded message. If the passed in
326  * |tmp| buffer is big enough to hold the encoded message then it returns |tmp|
327  * otherwise it allocates memory which must be freed by the caller. If |encode|
328  * is 0 then it returns |msg|. NULL is returned if there is a failure.
329  */
msg_encode(const uint8_t * msg,size_t msg_len,const uint8_t * ctx,size_t ctx_len,int encode,uint8_t * tmp,size_t tmp_len,size_t * out_len)330 static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len,
331                            const uint8_t *ctx, size_t ctx_len, int encode,
332                            uint8_t *tmp, size_t tmp_len, size_t *out_len)
333 {
334     uint8_t *encoded = NULL;
335     size_t encoded_len;
336 
337     if (encode == 0) {
338         /* Raw message */
339         *out_len = msg_len;
340         return (uint8_t *)msg;
341     }
342     if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN)
343         return NULL;
344 
345     /* Pure encoding */
346     encoded_len = 1 + 1 + ctx_len + msg_len;
347     *out_len = encoded_len;
348     if (encoded_len <= tmp_len) {
349         encoded = tmp;
350     } else {
351         encoded = OPENSSL_malloc(encoded_len);
352         if (encoded == NULL)
353             return NULL;
354     }
355     encoded[0] = 0;
356     encoded[1] = (uint8_t)ctx_len;
357     memcpy(&encoded[2], ctx, ctx_len);
358     memcpy(&encoded[2 + ctx_len], msg, msg_len);
359     return encoded;
360 }
361 
362 /**
363  * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign()
364  *
365  * @returns 1 on success, or 0 on error.
366  */
ossl_ml_dsa_sign(const ML_DSA_KEY * priv,int msg_is_mu,const uint8_t * msg,size_t msg_len,const uint8_t * context,size_t context_len,const uint8_t * rand,size_t rand_len,int encode,unsigned char * sig,size_t * sig_len,size_t sig_size)367 int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
368                      const uint8_t *msg, size_t msg_len,
369                      const uint8_t *context, size_t context_len,
370                      const uint8_t *rand, size_t rand_len, int encode,
371                      unsigned char *sig, size_t *sig_len, size_t sig_size)
372 {
373     int ret = 1;
374     uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL;
375     size_t m_len = 0;
376 
377     if (ossl_ml_dsa_key_get_priv(priv) == NULL)
378         return 0;
379     if (sig != NULL) {
380         if (sig_size < priv->params->sig_len)
381             return 0;
382         if (msg_is_mu) {
383             m = (uint8_t *)msg;
384             m_len = msg_len;
385         } else {
386             m = msg_encode(msg, msg_len, context, context_len, encode,
387                            m_tmp, sizeof(m_tmp), &m_len);
388             if (m == NULL)
389                 return 0;
390             if (m != msg && m != m_tmp)
391                 alloced_m = m;
392         }
393         ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig);
394         OPENSSL_free(alloced_m);
395     }
396     if (sig_len != NULL)
397         *sig_len = priv->params->sig_len;
398     return ret;
399 }
400 
401 /**
402  * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify()
403  * @returns 1 on success, or 0 on error.
404  */
ossl_ml_dsa_verify(const ML_DSA_KEY * pub,int msg_is_mu,const uint8_t * msg,size_t msg_len,const uint8_t * context,size_t context_len,int encode,const uint8_t * sig,size_t sig_len)405 int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
406                        const uint8_t *msg, size_t msg_len,
407                        const uint8_t *context, size_t context_len, int encode,
408                        const uint8_t *sig, size_t sig_len)
409 {
410     uint8_t *m, *alloced_m = NULL;
411     size_t m_len;
412     uint8_t m_tmp[1024];
413     int ret = 0;
414 
415     if (ossl_ml_dsa_key_get_pub(pub) == NULL)
416         return 0;
417 
418     if (msg_is_mu) {
419         m = (uint8_t *)msg;
420         m_len = msg_len;
421     } else {
422         m = msg_encode(msg, msg_len, context, context_len, encode,
423                        m_tmp, sizeof(m_tmp), &m_len);
424         if (m == NULL)
425             return 0;
426         if (m != msg && m != m_tmp)
427             alloced_m = m;
428     }
429 
430     ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len);
431     OPENSSL_free(alloced_m);
432     return ret;
433 }
434