1 /* 2 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License 2.0 (the "License"). You may not use 5 * this file except in compliance with the License. You can obtain a copy 6 * in the file LICENSE in the source distribution or at 7 * https://www.openssl.org/source/license.html 8 */ 9 10 #include <openssl/core_dispatch.h> 11 #include <openssl/core_names.h> 12 #include <openssl/params.h> 13 #include <openssl/rand.h> 14 #include "ml_dsa_local.h" 15 #include "ml_dsa_key.h" 16 #include "ml_dsa_matrix.h" 17 #include "ml_dsa_sign.h" 18 #include "ml_dsa_hash.h" 19 20 #define ML_DSA_MAX_LAMBDA 256 /* bit strength for ML-DSA-87 */ 21 22 /* 23 * @brief Initialize a Signature object by pointing all of its objects to 24 * preallocated blocks. The values passed for hint, z and 25 * c_tilde values are not owned/freed by the |sig| object. 26 * 27 * @param sig The ML_DSA_SIG to initialize. 28 * @param hint A preallocated array of |k| polynomial blocks 29 * @param k The number of |hint| polynomials 30 * @param z A preallocated array of |l| polynomial blocks 31 * @param l The number of |z| polynomials 32 * @param c_tilde A preallocated buffer 33 * @param c_tilde_len The size of |c_tilde| 34 */ 35 static void signature_init(ML_DSA_SIG *sig, 36 POLY *hint, uint32_t k, POLY *z, uint32_t l, 37 uint8_t *c_tilde, size_t c_tilde_len) 38 { 39 vector_init(&sig->z, z, l); 40 vector_init(&sig->hint, hint, k); 41 sig->c_tilde = c_tilde; 42 sig->c_tilde_len = c_tilde_len; 43 } 44 45 /* 46 * FIPS 204, Algorithm 7, ML-DSA.Sign_internal() 47 * @returns 1 on success and 0 on failure. 48 */ 49 static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu, 50 const uint8_t *encoded_msg, 51 size_t encoded_msg_len, 52 const uint8_t *rnd, size_t rnd_len, 53 uint8_t *out_sig) 54 { 55 int ret = 0; 56 const ML_DSA_PARAMS *params = priv->params; 57 EVP_MD_CTX *md_ctx = NULL; 58 uint32_t k = params->k, l = params->l; 59 uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2; 60 uint8_t *alloc = NULL, *w1_encoded; 61 size_t alloc_len, w1_encoded_len; 62 size_t num_polys_sig_k = 2 * k; 63 size_t num_polys_k = 5 * k; 64 size_t num_polys_l = 3 * l; 65 size_t num_polys_k_by_l = k * l; 66 POLY *polys = NULL, *p, *c_ntt; 67 VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y; 68 MATRIX a_ntt; 69 ML_DSA_SIG sig; 70 uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu; 71 const size_t mu_len = sizeof(mu); 72 uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES]; 73 uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4]; 74 size_t c_tilde_len = params->bit_strength >> 2; 75 size_t kappa; 76 77 /* 78 * Allocate a single blob for most of the variable size temporary variables. 79 * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K). 80 */ 81 w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128); 82 alloc_len = w1_encoded_len 83 + sizeof(*polys) * (1 + num_polys_k + num_polys_l 84 + num_polys_k_by_l + num_polys_sig_k); 85 alloc = OPENSSL_malloc(alloc_len); 86 if (alloc == NULL) 87 return 0; 88 md_ctx = EVP_MD_CTX_new(); 89 if (md_ctx == NULL) 90 goto err; 91 92 w1_encoded = alloc; 93 /* Init the temp vectors to point to the allocated polys blob */ 94 p = (POLY *)(w1_encoded + w1_encoded_len); 95 c_ntt = p++; 96 matrix_init(&a_ntt, p, k, l); 97 p += num_polys_k_by_l; 98 vector_init(&s2_ntt, p, k); 99 vector_init(&t0_ntt, s2_ntt.poly + k, k); 100 vector_init(&w, t0_ntt.poly + k, k); 101 vector_init(&w1, w.poly + k, k); 102 vector_init(&cs2, w1.poly + k, k); 103 p += num_polys_k; 104 vector_init(&s1_ntt, p, l); 105 vector_init(&y, p + l, l); 106 vector_init(&cs1, p + 2 * l, l); 107 p += num_polys_l; 108 signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len); 109 /* End of the allocated blob setup */ 110 111 if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt)) 112 goto err; 113 if (msg_is_mu) { 114 if (encoded_msg_len != mu_len) 115 goto err; 116 mu_ptr = (uint8_t *)encoded_msg; 117 } else { 118 if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr), 119 encoded_msg, encoded_msg_len, mu_ptr, mu_len)) 120 goto err; 121 } 122 if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K), 123 rnd, rnd_len, mu_ptr, mu_len, 124 rho_prime, sizeof(rho_prime))) 125 goto err; 126 127 vector_copy(&s1_ntt, &priv->s1); 128 vector_ntt(&s1_ntt); 129 vector_copy(&s2_ntt, &priv->s2); 130 vector_ntt(&s2_ntt); 131 vector_copy(&t0_ntt, &priv->t0); 132 vector_ntt(&t0_ntt); 133 134 /* 135 * kappa must not exceed 2^16. But the probability of it 136 * exceeding even 1000 iterations is vanishingly small. 137 */ 138 for (kappa = 0; ; kappa += l) { 139 VECTOR *y_ntt = &cs1; 140 VECTOR *r0 = &w1; 141 VECTOR *ct0 = &w1; 142 uint32_t z_max, r0_max, ct0_max, h_ones; 143 144 vector_expand_mask(&y, rho_prime, sizeof(rho_prime), kappa, 145 gamma1, md_ctx, priv->shake256_md); 146 vector_copy(y_ntt, &y); 147 vector_ntt(y_ntt); 148 149 matrix_mult_vector(&a_ntt, y_ntt, &w); 150 vector_ntt_inverse(&w); 151 152 vector_high_bits(&w, gamma2, &w1); 153 ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len); 154 155 if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len, 156 w1_encoded, w1_encoded_len, c_tilde, c_tilde_len)) 157 break; 158 159 if (!poly_sample_in_ball_ntt(c_ntt, c_tilde, c_tilde_len, 160 md_ctx, priv->shake256_md, params->tau)) 161 break; 162 163 vector_mult_scalar(&s1_ntt, c_ntt, &cs1); 164 vector_ntt_inverse(&cs1); 165 vector_mult_scalar(&s2_ntt, c_ntt, &cs2); 166 vector_ntt_inverse(&cs2); 167 168 vector_add(&y, &cs1, &sig.z); 169 170 /* r0 = lowbits(w - cs2) */ 171 vector_sub(&w, &cs2, r0); 172 vector_low_bits(r0, gamma2, r0); 173 174 /* 175 * Leaking that the signature is rejected is fine as the next attempt at a 176 * signature will be (indistinguishable from) independent of this one. 177 */ 178 z_max = vector_max(&sig.z); 179 r0_max = vector_max_signed(r0); 180 if (value_barrier_32(constant_time_ge(z_max, gamma1 - params->beta) 181 | constant_time_ge(r0_max, gamma2 - params->beta))) 182 continue; 183 184 vector_mult_scalar(&t0_ntt, c_ntt, ct0); 185 vector_ntt_inverse(ct0); 186 vector_make_hint(ct0, &cs2, &w, gamma2, &sig.hint); 187 188 ct0_max = vector_max(ct0); 189 h_ones = vector_count_ones(&sig.hint); 190 /* Same reasoning applies to the leak as above */ 191 if (value_barrier_32(constant_time_ge(ct0_max, gamma2) 192 | constant_time_lt(params->omega, h_ones))) 193 continue; 194 ret = ossl_ml_dsa_sig_encode(&sig, params, out_sig); 195 break; 196 } 197 err: 198 EVP_MD_CTX_free(md_ctx); 199 OPENSSL_clear_free(alloc, alloc_len); 200 OPENSSL_cleanse(rho_prime, sizeof(rho_prime)); 201 return ret; 202 } 203 204 /* 205 * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal(). 206 */ 207 static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu, 208 const uint8_t *msg_enc, size_t msg_enc_len, 209 const uint8_t *sig_enc, size_t sig_enc_len) 210 { 211 int ret = 0; 212 uint8_t *alloc = NULL, *w1_encoded; 213 POLY *polys = NULL, *p, *c_ntt; 214 MATRIX a_ntt; 215 VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx; 216 ML_DSA_SIG sig; 217 const ML_DSA_PARAMS *params = pub->params; 218 uint32_t k = pub->params->k; 219 uint32_t l = pub->params->l; 220 uint32_t gamma2 = params->gamma2; 221 size_t w1_encoded_len; 222 size_t num_polys_sig = k + l; 223 size_t num_polys_k = 2 * k; 224 size_t num_polys_l = 1 * l; 225 size_t num_polys_k_by_l = k * l; 226 uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu; 227 const size_t mu_len = sizeof(mu); 228 uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4]; 229 uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4]; 230 EVP_MD_CTX *md_ctx = NULL; 231 size_t c_tilde_len = params->bit_strength >> 2; 232 uint32_t z_max; 233 234 /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */ 235 w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128); 236 alloc = OPENSSL_malloc(w1_encoded_len 237 + sizeof(*polys) * (1 + num_polys_k 238 + num_polys_l 239 + num_polys_k_by_l 240 + num_polys_sig)); 241 if (alloc == NULL) 242 return 0; 243 md_ctx = EVP_MD_CTX_new(); 244 if (md_ctx == NULL) 245 goto err; 246 247 w1_encoded = alloc; 248 /* Init the temp vectors to point to the allocated polys blob */ 249 p = (POLY *)(w1_encoded + w1_encoded_len); 250 c_ntt = p++; 251 matrix_init(&a_ntt, p, k, l); 252 p += num_polys_k_by_l; 253 signature_init(&sig, p, k, p + k, l, c_tilde_sig, c_tilde_len); 254 p += num_polys_sig; 255 vector_init(&az_ntt, p, k); 256 vector_init(&ct1_ntt, p + k, k); 257 258 if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params) 259 || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt)) 260 goto err; 261 if (msg_is_mu) { 262 if (msg_enc_len != mu_len) 263 goto err; 264 mu_ptr = (uint8_t *)msg_enc; 265 } else { 266 if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr), 267 msg_enc, msg_enc_len, mu_ptr, mu_len)) 268 goto err; 269 } 270 /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */ 271 if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len, 272 md_ctx, pub->shake256_md, params->tau)) 273 goto err; 274 275 /* ct1_ntt = NTT(c) * NTT(t1 * 2^d) */ 276 vector_scale_power2_round_ntt(&pub->t1, &ct1_ntt); 277 vector_mult_scalar(&ct1_ntt, c_ntt, &ct1_ntt); 278 279 /* compute z_max early in order to reuse sig.z */ 280 z_max = vector_max(&sig.z); 281 282 /* w_approx = NTT_inverse(A * NTT(z) - ct1_ntt) */ 283 z_ntt = &sig.z; 284 vector_ntt(z_ntt); 285 matrix_mult_vector(&a_ntt, z_ntt, &az_ntt); 286 w_approx = &az_ntt; 287 vector_sub(&az_ntt, &ct1_ntt, w_approx); 288 vector_ntt_inverse(w_approx); 289 290 /* compute w1_encoded */ 291 w1 = w_approx; 292 vector_use_hint(&sig.hint, w_approx, gamma2, w1); 293 ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len); 294 295 if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len, 296 w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len)) 297 goto err; 298 299 ret = (z_max < (uint32_t)(params->gamma1 - params->beta)) 300 && memcmp(c_tilde, sig.c_tilde, c_tilde_len) == 0; 301 err: 302 OPENSSL_free(alloc); 303 EVP_MD_CTX_free(md_ctx); 304 return ret; 305 } 306 307 /** 308 * @brief Encode a message 309 * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5). 310 * 311 * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg 312 * Where ctx is the empty string by default and ctx_len <= 255. 313 * 314 * Note this code could be shared with SLH_DSA 315 * 316 * @param msg A message to encode 317 * @param msg_len The size of |msg| 318 * @param ctx An optional context to add to the message encoding. 319 * @param ctx_len The size of |ctx|. It must be in the range 0..255 320 * @param encode Use the Pure signature encoding if this is 1, and dont encode 321 * if this value is 0. 322 * @param tmp A small buffer that may be used if the message is small. 323 * @param tmp_len The size of |tmp| 324 * @param out_len The size of the returned encoded buffer. 325 * @returns A buffer containing the encoded message. If the passed in 326 * |tmp| buffer is big enough to hold the encoded message then it returns |tmp| 327 * otherwise it allocates memory which must be freed by the caller. If |encode| 328 * is 0 then it returns |msg|. NULL is returned if there is a failure. 329 */ 330 static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len, 331 const uint8_t *ctx, size_t ctx_len, int encode, 332 uint8_t *tmp, size_t tmp_len, size_t *out_len) 333 { 334 uint8_t *encoded = NULL; 335 size_t encoded_len; 336 337 if (encode == 0) { 338 /* Raw message */ 339 *out_len = msg_len; 340 return (uint8_t *)msg; 341 } 342 if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN) 343 return NULL; 344 345 /* Pure encoding */ 346 encoded_len = 1 + 1 + ctx_len + msg_len; 347 *out_len = encoded_len; 348 if (encoded_len <= tmp_len) { 349 encoded = tmp; 350 } else { 351 encoded = OPENSSL_malloc(encoded_len); 352 if (encoded == NULL) 353 return NULL; 354 } 355 encoded[0] = 0; 356 encoded[1] = (uint8_t)ctx_len; 357 memcpy(&encoded[2], ctx, ctx_len); 358 memcpy(&encoded[2 + ctx_len], msg, msg_len); 359 return encoded; 360 } 361 362 /** 363 * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign() 364 * 365 * @returns 1 on success, or 0 on error. 366 */ 367 int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu, 368 const uint8_t *msg, size_t msg_len, 369 const uint8_t *context, size_t context_len, 370 const uint8_t *rand, size_t rand_len, int encode, 371 unsigned char *sig, size_t *sig_len, size_t sig_size) 372 { 373 int ret = 1; 374 uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL; 375 size_t m_len = 0; 376 377 if (ossl_ml_dsa_key_get_priv(priv) == NULL) 378 return 0; 379 if (sig != NULL) { 380 if (sig_size < priv->params->sig_len) 381 return 0; 382 if (msg_is_mu) { 383 m = (uint8_t *)msg; 384 m_len = msg_len; 385 } else { 386 m = msg_encode(msg, msg_len, context, context_len, encode, 387 m_tmp, sizeof(m_tmp), &m_len); 388 if (m == NULL) 389 return 0; 390 if (m != msg && m != m_tmp) 391 alloced_m = m; 392 } 393 ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig); 394 OPENSSL_free(alloced_m); 395 } 396 if (sig_len != NULL) 397 *sig_len = priv->params->sig_len; 398 return ret; 399 } 400 401 /** 402 * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify() 403 * @returns 1 on success, or 0 on error. 404 */ 405 int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu, 406 const uint8_t *msg, size_t msg_len, 407 const uint8_t *context, size_t context_len, int encode, 408 const uint8_t *sig, size_t sig_len) 409 { 410 uint8_t *m, *alloced_m = NULL; 411 size_t m_len; 412 uint8_t m_tmp[1024]; 413 int ret = 0; 414 415 if (ossl_ml_dsa_key_get_pub(pub) == NULL) 416 return 0; 417 418 if (msg_is_mu) { 419 m = (uint8_t *)msg; 420 m_len = msg_len; 421 } else { 422 m = msg_encode(msg, msg_len, context, context_len, encode, 423 m_tmp, sizeof(m_tmp), &m_len); 424 if (m == NULL) 425 return 0; 426 if (m != msg && m != m_tmp) 427 alloced_m = m; 428 } 429 430 ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len); 431 OPENSSL_free(alloced_m); 432 return ret; 433 } 434