1 /*
2 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10 #include <openssl/core_dispatch.h>
11 #include <openssl/core_names.h>
12 #include <openssl/params.h>
13 #include <openssl/rand.h>
14 #include "ml_dsa_local.h"
15 #include "ml_dsa_key.h"
16 #include "ml_dsa_matrix.h"
17 #include "ml_dsa_sign.h"
18 #include "ml_dsa_hash.h"
19
20 #define ML_DSA_MAX_LAMBDA 256 /* bit strength for ML-DSA-87 */
21
22 /*
23 * @brief Initialize a Signature object by pointing all of its objects to
24 * preallocated blocks. The values passed for hint, z and
25 * c_tilde values are not owned/freed by the |sig| object.
26 *
27 * @param sig The ML_DSA_SIG to initialize.
28 * @param hint A preallocated array of |k| polynomial blocks
29 * @param k The number of |hint| polynomials
30 * @param z A preallocated array of |l| polynomial blocks
31 * @param l The number of |z| polynomials
32 * @param c_tilde A preallocated buffer
33 * @param c_tilde_len The size of |c_tilde|
34 */
signature_init(ML_DSA_SIG * sig,POLY * hint,uint32_t k,POLY * z,uint32_t l,uint8_t * c_tilde,size_t c_tilde_len)35 static void signature_init(ML_DSA_SIG *sig,
36 POLY *hint, uint32_t k, POLY *z, uint32_t l,
37 uint8_t *c_tilde, size_t c_tilde_len)
38 {
39 vector_init(&sig->z, z, l);
40 vector_init(&sig->hint, hint, k);
41 sig->c_tilde = c_tilde;
42 sig->c_tilde_len = c_tilde_len;
43 }
44
45 /*
46 * FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
47 * @returns 1 on success and 0 on failure.
48 */
ml_dsa_sign_internal(const ML_DSA_KEY * priv,int msg_is_mu,const uint8_t * encoded_msg,size_t encoded_msg_len,const uint8_t * rnd,size_t rnd_len,uint8_t * out_sig)49 static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
50 const uint8_t *encoded_msg,
51 size_t encoded_msg_len,
52 const uint8_t *rnd, size_t rnd_len,
53 uint8_t *out_sig)
54 {
55 int ret = 0;
56 const ML_DSA_PARAMS *params = priv->params;
57 EVP_MD_CTX *md_ctx = NULL;
58 uint32_t k = params->k, l = params->l;
59 uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2;
60 uint8_t *alloc = NULL, *w1_encoded;
61 size_t alloc_len, w1_encoded_len;
62 size_t num_polys_sig_k = 2 * k;
63 size_t num_polys_k = 5 * k;
64 size_t num_polys_l = 3 * l;
65 size_t num_polys_k_by_l = k * l;
66 POLY *polys = NULL, *p, *c_ntt;
67 VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
68 MATRIX a_ntt;
69 ML_DSA_SIG sig;
70 uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
71 const size_t mu_len = sizeof(mu);
72 uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES];
73 uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
74 size_t c_tilde_len = params->bit_strength >> 2;
75 size_t kappa;
76
77 /*
78 * Allocate a single blob for most of the variable size temporary variables.
79 * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K).
80 */
81 w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
82 alloc_len = w1_encoded_len
83 + sizeof(*polys) * (1 + num_polys_k + num_polys_l + num_polys_k_by_l + num_polys_sig_k);
84 alloc = OPENSSL_malloc(alloc_len);
85 if (alloc == NULL)
86 return 0;
87 md_ctx = EVP_MD_CTX_new();
88 if (md_ctx == NULL)
89 goto err;
90
91 w1_encoded = alloc;
92 /* Init the temp vectors to point to the allocated polys blob */
93 p = (POLY *)(w1_encoded + w1_encoded_len);
94 c_ntt = p++;
95 matrix_init(&a_ntt, p, k, l);
96 p += num_polys_k_by_l;
97 vector_init(&s2_ntt, p, k);
98 vector_init(&t0_ntt, s2_ntt.poly + k, k);
99 vector_init(&w, t0_ntt.poly + k, k);
100 vector_init(&w1, w.poly + k, k);
101 vector_init(&cs2, w1.poly + k, k);
102 p += num_polys_k;
103 vector_init(&s1_ntt, p, l);
104 vector_init(&y, p + l, l);
105 vector_init(&cs1, p + 2 * l, l);
106 p += num_polys_l;
107 signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len);
108 /* End of the allocated blob setup */
109
110 if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt))
111 goto err;
112 if (msg_is_mu) {
113 if (encoded_msg_len != mu_len)
114 goto err;
115 mu_ptr = (uint8_t *)encoded_msg;
116 } else {
117 if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr),
118 encoded_msg, encoded_msg_len, mu_ptr, mu_len))
119 goto err;
120 }
121 if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
122 rnd, rnd_len, mu_ptr, mu_len,
123 rho_prime, sizeof(rho_prime)))
124 goto err;
125
126 vector_copy(&s1_ntt, &priv->s1);
127 vector_ntt(&s1_ntt);
128 vector_copy(&s2_ntt, &priv->s2);
129 vector_ntt(&s2_ntt);
130 vector_copy(&t0_ntt, &priv->t0);
131 vector_ntt(&t0_ntt);
132
133 /*
134 * kappa must not exceed 2^16. But the probability of it
135 * exceeding even 1000 iterations is vanishingly small.
136 */
137 for (kappa = 0;; kappa += l) {
138 VECTOR *y_ntt = &cs1;
139 VECTOR *r0 = &w1;
140 VECTOR *ct0 = &w1;
141 uint32_t z_max, r0_max, ct0_max, h_ones;
142
143 vector_expand_mask(&y, rho_prime, sizeof(rho_prime), kappa,
144 gamma1, md_ctx, priv->shake256_md);
145 vector_copy(y_ntt, &y);
146 vector_ntt(y_ntt);
147
148 matrix_mult_vector(&a_ntt, y_ntt, &w);
149 vector_ntt_inverse(&w);
150
151 vector_high_bits(&w, gamma2, &w1);
152 ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);
153
154 if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len,
155 w1_encoded, w1_encoded_len, c_tilde, c_tilde_len))
156 break;
157
158 if (!poly_sample_in_ball_ntt(c_ntt, c_tilde, c_tilde_len,
159 md_ctx, priv->shake256_md, params->tau))
160 break;
161
162 vector_mult_scalar(&s1_ntt, c_ntt, &cs1);
163 vector_ntt_inverse(&cs1);
164 vector_mult_scalar(&s2_ntt, c_ntt, &cs2);
165 vector_ntt_inverse(&cs2);
166
167 vector_add(&y, &cs1, &sig.z);
168
169 /* r0 = lowbits(w - cs2) */
170 vector_sub(&w, &cs2, r0);
171 vector_low_bits(r0, gamma2, r0);
172
173 /*
174 * Leaking that the signature is rejected is fine as the next attempt at a
175 * signature will be (indistinguishable from) independent of this one.
176 */
177 z_max = vector_max(&sig.z);
178 r0_max = vector_max_signed(r0);
179 if (value_barrier_32(constant_time_ge(z_max, gamma1 - params->beta)
180 | constant_time_ge(r0_max, gamma2 - params->beta)))
181 continue;
182
183 vector_mult_scalar(&t0_ntt, c_ntt, ct0);
184 vector_ntt_inverse(ct0);
185 vector_make_hint(ct0, &cs2, &w, gamma2, &sig.hint);
186
187 ct0_max = vector_max(ct0);
188 h_ones = vector_count_ones(&sig.hint);
189 /* Same reasoning applies to the leak as above */
190 if (value_barrier_32(constant_time_ge(ct0_max, gamma2)
191 | constant_time_lt(params->omega, h_ones)))
192 continue;
193 ret = ossl_ml_dsa_sig_encode(&sig, params, out_sig);
194 break;
195 }
196 err:
197 EVP_MD_CTX_free(md_ctx);
198 OPENSSL_clear_free(alloc, alloc_len);
199 OPENSSL_cleanse(rho_prime, sizeof(rho_prime));
200 return ret;
201 }
202
203 /*
204 * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
205 */
ml_dsa_verify_internal(const ML_DSA_KEY * pub,int msg_is_mu,const uint8_t * msg_enc,size_t msg_enc_len,const uint8_t * sig_enc,size_t sig_enc_len)206 static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
207 const uint8_t *msg_enc, size_t msg_enc_len,
208 const uint8_t *sig_enc, size_t sig_enc_len)
209 {
210 int ret = 0;
211 uint8_t *alloc = NULL, *w1_encoded;
212 POLY *polys = NULL, *p, *c_ntt;
213 MATRIX a_ntt;
214 VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx;
215 ML_DSA_SIG sig;
216 const ML_DSA_PARAMS *params = pub->params;
217 uint32_t k = pub->params->k;
218 uint32_t l = pub->params->l;
219 uint32_t gamma2 = params->gamma2;
220 size_t w1_encoded_len;
221 size_t num_polys_sig = k + l;
222 size_t num_polys_k = 2 * k;
223 size_t num_polys_l = 1 * l;
224 size_t num_polys_k_by_l = k * l;
225 uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
226 const size_t mu_len = sizeof(mu);
227 uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
228 uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4];
229 EVP_MD_CTX *md_ctx = NULL;
230 size_t c_tilde_len = params->bit_strength >> 2;
231 uint32_t z_max;
232
233 /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */
234 w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
235 alloc = OPENSSL_malloc(w1_encoded_len
236 + sizeof(*polys) * (1 + num_polys_k + num_polys_l + num_polys_k_by_l + num_polys_sig));
237 if (alloc == NULL)
238 return 0;
239 md_ctx = EVP_MD_CTX_new();
240 if (md_ctx == NULL)
241 goto err;
242
243 w1_encoded = alloc;
244 /* Init the temp vectors to point to the allocated polys blob */
245 p = (POLY *)(w1_encoded + w1_encoded_len);
246 c_ntt = p++;
247 matrix_init(&a_ntt, p, k, l);
248 p += num_polys_k_by_l;
249 signature_init(&sig, p, k, p + k, l, c_tilde_sig, c_tilde_len);
250 p += num_polys_sig;
251 vector_init(&az_ntt, p, k);
252 vector_init(&ct1_ntt, p + k, k);
253
254 if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params)
255 || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt))
256 goto err;
257 if (msg_is_mu) {
258 if (msg_enc_len != mu_len)
259 goto err;
260 mu_ptr = (uint8_t *)msg_enc;
261 } else {
262 if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr),
263 msg_enc, msg_enc_len, mu_ptr, mu_len))
264 goto err;
265 }
266 /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */
267 if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len,
268 md_ctx, pub->shake256_md, params->tau))
269 goto err;
270
271 /* ct1_ntt = NTT(c) * NTT(t1 * 2^d) */
272 vector_scale_power2_round_ntt(&pub->t1, &ct1_ntt);
273 vector_mult_scalar(&ct1_ntt, c_ntt, &ct1_ntt);
274
275 /* compute z_max early in order to reuse sig.z */
276 z_max = vector_max(&sig.z);
277
278 /* w_approx = NTT_inverse(A * NTT(z) - ct1_ntt) */
279 z_ntt = &sig.z;
280 vector_ntt(z_ntt);
281 matrix_mult_vector(&a_ntt, z_ntt, &az_ntt);
282 w_approx = &az_ntt;
283 vector_sub(&az_ntt, &ct1_ntt, w_approx);
284 vector_ntt_inverse(w_approx);
285
286 /* compute w1_encoded */
287 w1 = w_approx;
288 vector_use_hint(&sig.hint, w_approx, gamma2, w1);
289 ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);
290
291 if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len,
292 w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len))
293 goto err;
294
295 ret = (z_max < (uint32_t)(params->gamma1 - params->beta))
296 && memcmp(c_tilde, sig.c_tilde, c_tilde_len) == 0;
297 err:
298 OPENSSL_free(alloc);
299 EVP_MD_CTX_free(md_ctx);
300 return ret;
301 }
302
303 /**
304 * @brief Encode a message
305 * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5).
306 *
307 * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg
308 * Where ctx is the empty string by default and ctx_len <= 255.
309 *
310 * Note this code could be shared with SLH_DSA
311 *
312 * @param msg A message to encode
313 * @param msg_len The size of |msg|
314 * @param ctx An optional context to add to the message encoding.
315 * @param ctx_len The size of |ctx|. It must be in the range 0..255
316 * @param encode Use the Pure signature encoding if this is 1, and dont encode
317 * if this value is 0.
318 * @param tmp A small buffer that may be used if the message is small.
319 * @param tmp_len The size of |tmp|
320 * @param out_len The size of the returned encoded buffer.
321 * @returns A buffer containing the encoded message. If the passed in
322 * |tmp| buffer is big enough to hold the encoded message then it returns |tmp|
323 * otherwise it allocates memory which must be freed by the caller. If |encode|
324 * is 0 then it returns |msg|. NULL is returned if there is a failure.
325 */
msg_encode(const uint8_t * msg,size_t msg_len,const uint8_t * ctx,size_t ctx_len,int encode,uint8_t * tmp,size_t tmp_len,size_t * out_len)326 static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len,
327 const uint8_t *ctx, size_t ctx_len, int encode,
328 uint8_t *tmp, size_t tmp_len, size_t *out_len)
329 {
330 uint8_t *encoded = NULL;
331 size_t encoded_len;
332
333 if (encode == 0) {
334 /* Raw message */
335 *out_len = msg_len;
336 return (uint8_t *)msg;
337 }
338 if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN)
339 return NULL;
340
341 /* Pure encoding */
342 encoded_len = 1 + 1 + ctx_len + msg_len;
343 *out_len = encoded_len;
344 if (encoded_len <= tmp_len) {
345 encoded = tmp;
346 } else {
347 encoded = OPENSSL_malloc(encoded_len);
348 if (encoded == NULL)
349 return NULL;
350 }
351 encoded[0] = 0;
352 encoded[1] = (uint8_t)ctx_len;
353 memcpy(&encoded[2], ctx, ctx_len);
354 memcpy(&encoded[2 + ctx_len], msg, msg_len);
355 return encoded;
356 }
357
358 /**
359 * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign()
360 *
361 * @returns 1 on success, or 0 on error.
362 */
ossl_ml_dsa_sign(const ML_DSA_KEY * priv,int msg_is_mu,const uint8_t * msg,size_t msg_len,const uint8_t * context,size_t context_len,const uint8_t * rand,size_t rand_len,int encode,unsigned char * sig,size_t * sig_len,size_t sig_size)363 int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
364 const uint8_t *msg, size_t msg_len,
365 const uint8_t *context, size_t context_len,
366 const uint8_t *rand, size_t rand_len, int encode,
367 unsigned char *sig, size_t *sig_len, size_t sig_size)
368 {
369 int ret = 1;
370 uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL;
371 size_t m_len = 0;
372
373 if (ossl_ml_dsa_key_get_priv(priv) == NULL)
374 return 0;
375 if (sig != NULL) {
376 if (sig_size < priv->params->sig_len)
377 return 0;
378 if (msg_is_mu) {
379 m = (uint8_t *)msg;
380 m_len = msg_len;
381 } else {
382 m = msg_encode(msg, msg_len, context, context_len, encode,
383 m_tmp, sizeof(m_tmp), &m_len);
384 if (m == NULL)
385 return 0;
386 if (m != msg && m != m_tmp)
387 alloced_m = m;
388 }
389 ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig);
390 OPENSSL_free(alloced_m);
391 }
392 if (sig_len != NULL)
393 *sig_len = priv->params->sig_len;
394 return ret;
395 }
396
397 /**
398 * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify()
399 * @returns 1 on success, or 0 on error.
400 */
ossl_ml_dsa_verify(const ML_DSA_KEY * pub,int msg_is_mu,const uint8_t * msg,size_t msg_len,const uint8_t * context,size_t context_len,int encode,const uint8_t * sig,size_t sig_len)401 int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
402 const uint8_t *msg, size_t msg_len,
403 const uint8_t *context, size_t context_len, int encode,
404 const uint8_t *sig, size_t sig_len)
405 {
406 uint8_t *m, *alloced_m = NULL;
407 size_t m_len;
408 uint8_t m_tmp[1024];
409 int ret = 0;
410
411 if (ossl_ml_dsa_key_get_pub(pub) == NULL)
412 return 0;
413
414 if (msg_is_mu) {
415 m = (uint8_t *)msg;
416 m_len = msg_len;
417 } else {
418 m = msg_encode(msg, msg_len, context, context_len, encode,
419 m_tmp, sizeof(m_tmp), &m_len);
420 if (m == NULL)
421 return 0;
422 if (m != msg && m != m_tmp)
423 alloced_m = m;
424 }
425
426 ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len);
427 OPENSSL_free(alloced_m);
428 return ret;
429 }
430