1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * Support for verifying ML-DSA signatures 4 * 5 * Copyright 2025 Google LLC 6 */ 7 8 #include <crypto/mldsa.h> 9 #include <crypto/sha3.h> 10 #include <kunit/visibility.h> 11 #include <linux/export.h> 12 #include <linux/module.h> 13 #include <linux/slab.h> 14 #include <linux/string.h> 15 #include <linux/unaligned.h> 16 17 #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */ 18 #define QINV_MOD_2_32 58728449 /* Multiplicative inverse of q mod 2^32 */ 19 #define N 256 /* Number of components per ring element */ 20 #define D 13 /* Number of bits dropped from the public key vector t */ 21 #define RHO_LEN 32 /* Length of the public random seed in bytes */ 22 #define MAX_W1_ENCODED_LEN 192 /* Max encoded length of one element of w'_1 */ 23 24 /* 25 * The zetas array in Montgomery form, i.e. with extra factor of 2^32. 26 * Reference: FIPS 204 Section 7.5 "NTT and NTT^-1" 27 * Generated by the following Python code: 28 * q=8380417; [a%q - q*(a%q > q//2) for a in [1753**(int(f'{i:08b}'[::-1], 2)) << 32 for i in range(256)]] 29 */ 30 static const s32 zetas_times_2_32[N] = { 31 -4186625, 25847, -2608894, -518909, 237124, -777960, -876248, 32 466468, 1826347, 2353451, -359251, -2091905, 3119733, -2884855, 33 3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488, 34 -1119584, 2619752, -2108549, -2118186, -3859737, -1399561, -3277672, 35 1757237, -19422, 4010497, 280005, 2706023, 95776, 3077325, 36 3530437, -1661693, -3592148, -2537516, 3915439, -3861115, -3043716, 37 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267, 38 -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596, 39 811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892, 40 -2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950, 41 2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144, 42 -3157330, -3632928, 126922, 3412210, -983419, 2147896, 2715295, 43 -2967645, -3693493, -411027, -2477047, -671102, -1228525, -22981, 44 -1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944, 45 508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342, 46 -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856, 47 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 48 1341330, 1285669, -1584928, -812732, -1439742, -3019102, -3881060, 49 -3628969, 3839961, 2091667, 3407706, 2316500, 3817976, -3342478, 50 2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181, 51 -3520352, -3759364, -1197226, -3193378, 900702, 1859098, 909542, 52 819034, 495491, -1613174, -43260, -522500, -655327, -3122442, 53 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297, 54 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, 55 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 56 1595974, -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 57 1903435, -1050970, -1333058, 1237275, -3318210, -1430225, -451100, 58 1312455, 3306115, -1962642, -1279661, 1917081, -2546312, -1374803, 59 1500165, 777191, 2235880, 3406031, -542412, -2831860, -1671176, 60 -1846953, -2584293, -3724270, 594136, -3776993, -2013608, 2432395, 61 2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426, 62 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, 63 -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 64 472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333, 65 -260646, -3833893, -2939036, -2235985, -420899, -2286327, 183443, 66 -976891, 1612842, -3545687, -554416, 3919660, -48306, -1362209, 67 3937738, 1400424, -846154, 1976782 68 }; 69 70 /* Reference: FIPS 204 Section 4 "Parameter Sets" */ 71 static const struct mldsa_parameter_set { 72 u8 k; /* num rows in the matrix A */ 73 u8 l; /* num columns in the matrix A */ 74 u8 ctilde_len; /* length of commitment hash ctilde in bytes; lambda/4 */ 75 u8 omega; /* max num of 1's in the hint vector h */ 76 u8 tau; /* num of +-1's in challenge c */ 77 u8 beta; /* tau times eta */ 78 u16 pk_len; /* length of public keys in bytes */ 79 u16 sig_len; /* length of signatures in bytes */ 80 s32 gamma1; /* coefficient range of y */ 81 } mldsa_parameter_sets[] = { 82 [MLDSA44] = { 83 .k = 4, 84 .l = 4, 85 .ctilde_len = 32, 86 .omega = 80, 87 .tau = 39, 88 .beta = 78, 89 .pk_len = MLDSA44_PUBLIC_KEY_SIZE, 90 .sig_len = MLDSA44_SIGNATURE_SIZE, 91 .gamma1 = 1 << 17, 92 }, 93 [MLDSA65] = { 94 .k = 6, 95 .l = 5, 96 .ctilde_len = 48, 97 .omega = 55, 98 .tau = 49, 99 .beta = 196, 100 .pk_len = MLDSA65_PUBLIC_KEY_SIZE, 101 .sig_len = MLDSA65_SIGNATURE_SIZE, 102 .gamma1 = 1 << 19, 103 }, 104 [MLDSA87] = { 105 .k = 8, 106 .l = 7, 107 .ctilde_len = 64, 108 .omega = 75, 109 .tau = 60, 110 .beta = 120, 111 .pk_len = MLDSA87_PUBLIC_KEY_SIZE, 112 .sig_len = MLDSA87_SIGNATURE_SIZE, 113 .gamma1 = 1 << 19, 114 }, 115 }; 116 117 /* 118 * An element of the ring R_q (normal form) or the ring T_q (NTT form). It 119 * consists of N integers mod q: either the polynomial coefficients of the R_q 120 * element or the components of the T_q element. In either case, whether they 121 * are fully reduced to [0, q - 1] varies in the different parts of the code. 122 */ 123 struct mldsa_ring_elem { 124 s32 x[N]; 125 }; 126 127 struct mldsa_verification_workspace { 128 /* SHAKE context for computing c, mu, and ctildeprime */ 129 struct shake_ctx shake; 130 /* The fields in this union are used in their order of declaration. */ 131 union { 132 /* The hash of the public key */ 133 u8 tr[64]; 134 /* The message representative mu */ 135 u8 mu[64]; 136 /* Temporary space for rej_ntt_poly() */ 137 u8 block[SHAKE128_BLOCK_SIZE + 1]; 138 /* Encoded element of w'_1 */ 139 u8 w1_encoded[MAX_W1_ENCODED_LEN]; 140 /* The commitment hash. Real length is params->ctilde_len */ 141 u8 ctildeprime[64]; 142 }; 143 /* SHAKE context for generating elements of the matrix A */ 144 struct shake_ctx a_shake; 145 /* 146 * An element of the matrix A generated from the public seed, or an 147 * element of the vector t_1 decoded from the public key and pre-scaled 148 * by 2^d. Both are in NTT form. To reduce memory usage, we generate 149 * or decode these elements only as needed. 150 */ 151 union { 152 struct mldsa_ring_elem a; 153 struct mldsa_ring_elem t1_scaled; 154 }; 155 /* The challenge c, generated from ctilde */ 156 struct mldsa_ring_elem c; 157 /* A temporary element used during calculations */ 158 struct mldsa_ring_elem tmp; 159 160 /* The following fields are variable-length: */ 161 162 /* The signer's response vector */ 163 struct mldsa_ring_elem z[/* l */]; 164 165 /* The signer's hint vector */ 166 /* u8 h[k * N]; */ 167 }; 168 169 /* 170 * Compute a * b * 2^-32 mod q. a * b must be in the range [-2^31 * q, 2^31 * q 171 * - 1] before reduction. The return value is in the range [-q + 1, q - 1]. 172 * 173 * To reduce mod q efficiently, this uses Montgomery reduction with R=2^32. 174 * That's where the factor of 2^-32 comes from. The caller must include a 175 * factor of 2^32 at some point to compensate for that. 176 * 177 * To keep the input and output ranges very close to symmetric, this 178 * specifically does a "signed" Montgomery reduction. That is, when computing 179 * d = c * q^-1 mod 2^32, this chooses a representative in [S32_MIN, S32_MAX] 180 * rather than [0, U32_MAX], i.e. s32 rather than u32. This matters in the 181 * wider multiplication d * Q when d keeps its value via sign extension. 182 * 183 * Reference: FIPS 204 Appendix A "Montgomery Multiplication". But, it doesn't 184 * explain it properly: it has an off-by-one error in the upper end of the input 185 * range, it doesn't clarify that the signed version should be used, and it 186 * gives an unnecessarily large output range. A better citation is perhaps the 187 * Dilithium reference code, which functionally matches the below code and 188 * merely has the (benign) off-by-one error in its documentation. 189 */ 190 static inline s32 Zq_mult(s32 a, s32 b) 191 { 192 /* Compute the unreduced product c. */ 193 s64 c = (s64)a * b; 194 195 /* 196 * Compute d = c * q^-1 mod 2^32. Generate a signed result, as 197 * explained above, but do the actual multiplication using an unsigned 198 * type to avoid signed integer overflow which is undefined behavior. 199 */ 200 s32 d = (u32)c * QINV_MOD_2_32; 201 202 /* 203 * Compute e = c - d * q. This makes the low 32 bits zero, since 204 * c - (c * q^-1) * q mod 2^32 205 * = c - c * (q^-1 * q) mod 2^32 206 * = c - c * 1 mod 2^32 207 * = c - c mod 2^32 208 * = 0 mod 2^32 209 */ 210 s64 e = c - (s64)d * Q; 211 212 /* Finally, return e * 2^-32. */ 213 return e >> 32; 214 } 215 216 /* 217 * Convert @w to its number-theoretically-transformed representation in-place. 218 * Reference: FIPS 204 Algorithm 41, NTT 219 * 220 * To prevent intermediate overflows, all input coefficients must have absolute 221 * value < q. All output components have absolute value < 9*q. 222 */ 223 static void ntt(struct mldsa_ring_elem *w) 224 { 225 int m = 0; /* index in zetas_times_2_32 */ 226 227 for (int len = 128; len >= 1; len /= 2) { 228 for (int start = 0; start < 256; start += 2 * len) { 229 const s32 z = zetas_times_2_32[++m]; 230 231 for (int j = start; j < start + len; j++) { 232 s32 t = Zq_mult(z, w->x[j + len]); 233 234 w->x[j + len] = w->x[j] - t; 235 w->x[j] += t; 236 } 237 } 238 } 239 } 240 241 /* 242 * Convert @w from its number-theoretically-transformed representation in-place. 243 * Reference: FIPS 204 Algorithm 42, NTT^-1 244 * 245 * This also multiplies the coefficients by 2^32, undoing an extra factor of 246 * 2^-32 introduced earlier, and reduces the coefficients to [0, q - 1]. 247 */ 248 static void invntt_and_mul_2_32(struct mldsa_ring_elem *w) 249 { 250 int m = 256; /* index in zetas_times_2_32 */ 251 252 /* Prevent intermediate overflows. */ 253 for (int j = 0; j < 256; j++) 254 w->x[j] %= Q; 255 256 for (int len = 1; len < 256; len *= 2) { 257 for (int start = 0; start < 256; start += 2 * len) { 258 const s32 z = -zetas_times_2_32[--m]; 259 260 for (int j = start; j < start + len; j++) { 261 s32 t = w->x[j]; 262 263 w->x[j] = t + w->x[j + len]; 264 w->x[j + len] = Zq_mult(z, t - w->x[j + len]); 265 } 266 } 267 } 268 /* 269 * Multiply by 2^32 * 256^-1. 2^32 cancels the factor of 2^-32 from 270 * earlier Montgomery multiplications. 256^-1 is for NTT^-1. This 271 * itself uses Montgomery multiplication, so *another* 2^32 is needed. 272 * Thus the actual multiplicand is 2^32 * 2^32 * 256^-1 mod q = 41978. 273 * 274 * Finally, also reduce from [-q + 1, q - 1] to [0, q - 1]. 275 */ 276 for (int j = 0; j < 256; j++) { 277 w->x[j] = Zq_mult(w->x[j], 41978); 278 w->x[j] += (w->x[j] >> 31) & Q; 279 } 280 } 281 282 /* 283 * Decode an element of t_1, i.e. the high d bits of t = A*s_1 + s_2. 284 * Reference: FIPS 204 Algorithm 23, pkDecode. 285 * Also multiply it by 2^d and convert it to NTT form. 286 */ 287 static const u8 *decode_t1_elem(struct mldsa_ring_elem *out, 288 const u8 *t1_encoded) 289 { 290 for (int j = 0; j < N; j += 4, t1_encoded += 5) { 291 u32 v = get_unaligned_le32(t1_encoded); 292 293 out->x[j + 0] = ((v >> 0) & 0x3ff) << D; 294 out->x[j + 1] = ((v >> 10) & 0x3ff) << D; 295 out->x[j + 2] = ((v >> 20) & 0x3ff) << D; 296 out->x[j + 3] = ((v >> 30) | (t1_encoded[4] << 2)) << D; 297 static_assert(0x3ff << D < Q); /* All coefficients < q. */ 298 } 299 ntt(out); 300 return t1_encoded; /* Return updated pointer. */ 301 } 302 303 /* 304 * Decode the signer's response vector 'z' from the signature. 305 * Reference: FIPS 204 Algorithm 27, sigDecode. 306 * 307 * This also validates that the coefficients of z are in range, corresponding 308 * the infinity norm check at the end of Algorithm 8, ML-DSA.Verify_internal. 309 * 310 * Finally, this also converts z to NTT form. 311 */ 312 static bool decode_z(struct mldsa_ring_elem z[/* l */], int l, s32 gamma1, 313 int beta, const u8 **sig_ptr) 314 { 315 const u8 *sig = *sig_ptr; 316 317 for (int i = 0; i < l; i++) { 318 if (l == 4) { /* ML-DSA-44? */ 319 /* 18-bit coefficients: decode 4 from 9 bytes. */ 320 for (int j = 0; j < N; j += 4, sig += 9) { 321 u64 v = get_unaligned_le64(sig); 322 323 z[i].x[j + 0] = (v >> 0) & 0x3ffff; 324 z[i].x[j + 1] = (v >> 18) & 0x3ffff; 325 z[i].x[j + 2] = (v >> 36) & 0x3ffff; 326 z[i].x[j + 3] = (v >> 54) | (sig[8] << 10); 327 } 328 } else { 329 /* 20-bit coefficients: decode 4 from 10 bytes. */ 330 for (int j = 0; j < N; j += 4, sig += 10) { 331 u64 v = get_unaligned_le64(sig); 332 333 z[i].x[j + 0] = (v >> 0) & 0xfffff; 334 z[i].x[j + 1] = (v >> 20) & 0xfffff; 335 z[i].x[j + 2] = (v >> 40) & 0xfffff; 336 z[i].x[j + 3] = 337 (v >> 60) | 338 (get_unaligned_le16(&sig[8]) << 4); 339 } 340 } 341 for (int j = 0; j < N; j++) { 342 z[i].x[j] = gamma1 - z[i].x[j]; 343 if (z[i].x[j] <= -(gamma1 - beta) || 344 z[i].x[j] >= gamma1 - beta) 345 return false; 346 } 347 ntt(&z[i]); 348 } 349 *sig_ptr = sig; /* Return updated pointer. */ 350 return true; 351 } 352 353 /* 354 * Decode the signer's hint vector 'h' from the signature. 355 * Reference: FIPS 204 Algorithm 21, HintBitUnpack 356 * 357 * Note that there are several ways in which the hint vector can be malformed. 358 */ 359 static bool decode_hint_vector(u8 h[/* k * N */], int k, int omega, const u8 *y) 360 { 361 int index = 0; 362 363 memset(h, 0, k * N); 364 for (int i = 0; i < k; i++) { 365 int count = y[omega + i]; /* num 1's in elems 0 through i */ 366 int prev = -1; 367 368 /* Cumulative count mustn't decrease or exceed omega. */ 369 if (count < index || count > omega) 370 return false; 371 for (; index < count; index++) { 372 if (prev >= y[index]) /* Coefficients out of order? */ 373 return false; 374 prev = y[index]; 375 h[i * N + y[index]] = 1; 376 } 377 } 378 return mem_is_zero(&y[index], omega - index); 379 } 380 381 /* 382 * Expand @seed into an element of R_q @c with coefficients in {-1, 0, 1}, 383 * exactly @tau of them nonzero. Reference: FIPS 204 Algorithm 29, SampleInBall 384 */ 385 static void sample_in_ball(struct mldsa_ring_elem *c, const u8 *seed, 386 size_t seed_len, int tau, struct shake_ctx *shake) 387 { 388 u64 signs; 389 u8 j; 390 391 shake256_init(shake); 392 shake_update(shake, seed, seed_len); 393 shake_squeeze(shake, (u8 *)&signs, sizeof(signs)); 394 le64_to_cpus(&signs); 395 *c = (struct mldsa_ring_elem){}; 396 for (int i = N - tau; i < N; i++, signs >>= 1) { 397 do { 398 shake_squeeze(shake, &j, 1); 399 } while (j > i); 400 c->x[i] = c->x[j]; 401 c->x[j] = 1 - 2 * (s32)(signs & 1); 402 } 403 } 404 405 /* 406 * Expand the public seed @rho and @row_and_column into an element of T_q @out. 407 * Reference: FIPS 204 Algorithm 30, RejNTTPoly 408 * 409 * @shake and @block are temporary space used by the expansion. @block has 410 * space for one SHAKE128 block, plus an extra byte to allow reading a u32 from 411 * the final 3-byte group without reading out-of-bounds. 412 */ 413 static void rej_ntt_poly(struct mldsa_ring_elem *out, const u8 rho[RHO_LEN], 414 __le16 row_and_column, struct shake_ctx *shake, 415 u8 block[SHAKE128_BLOCK_SIZE + 1]) 416 { 417 shake128_init(shake); 418 shake_update(shake, rho, RHO_LEN); 419 shake_update(shake, (u8 *)&row_and_column, sizeof(row_and_column)); 420 for (int i = 0; i < N;) { 421 shake_squeeze(shake, block, SHAKE128_BLOCK_SIZE); 422 block[SHAKE128_BLOCK_SIZE] = 0; /* for KMSAN */ 423 static_assert(SHAKE128_BLOCK_SIZE % 3 == 0); 424 for (int j = 0; j < SHAKE128_BLOCK_SIZE && i < N; j += 3) { 425 u32 x = get_unaligned_le32(&block[j]) & 0x7fffff; 426 427 if (x < Q) /* Ignore values >= q. */ 428 out->x[i++] = x; 429 } 430 } 431 } 432 433 /* 434 * Return the HighBits of r adjusted according to hint h 435 * Reference: FIPS 204 Algorithm 40, UseHint 436 * 437 * This is needed because of the public key compression in ML-DSA. 438 * 439 * h is either 0 or 1, r is in [0, q - 1], and gamma2 is either (q - 1) / 88 or 440 * (q - 1) / 32. Except when invoked via the unit test interface, gamma2 is a 441 * compile-time constant, so compilers will optimize the code accordingly. 442 */ 443 static __always_inline s32 use_hint(u8 h, s32 r, const s32 gamma2) 444 { 445 const s32 m = (Q - 1) / (2 * gamma2); /* 44 or 16, compile-time const */ 446 s32 r1; 447 448 /* 449 * Handle the special case where r - (r mod+- (2 * gamma2)) == q - 1, 450 * i.e. r >= q - gamma2. This is also exactly where the computation of 451 * r1 below would produce 'm' and would need a correction. 452 */ 453 if (r >= Q - gamma2) 454 return h == 0 ? 0 : m - 1; 455 456 /* 457 * Compute the (non-hint-adjusted) HighBits r1 as: 458 * 459 * r1 = (r - (r mod+- (2 * gamma2))) / (2 * gamma2) 460 * = floor((r + gamma2 - 1) / (2 * gamma2)) 461 * 462 * Note that when '2 * gamma2' is a compile-time constant, compilers 463 * optimize the division to a reciprocal multiplication and shift. 464 */ 465 r1 = (u32)(r + gamma2 - 1) / (2 * gamma2); 466 467 /* 468 * Return the HighBits r1: 469 * + 0 if the hint is 0; 470 * + 1 (mod m) if the hint is 1 and the LowBits are positive; 471 * - 1 (mod m) if the hint is 1 and the LowBits are negative or 0. 472 * 473 * r1 is in (and remains in) [0, m - 1]. Note that when 'm' is a 474 * compile-time constant, compilers optimize the '% m' accordingly. 475 */ 476 if (h == 0) 477 return r1; 478 if (r > r1 * (2 * gamma2)) 479 return (u32)(r1 + 1) % m; 480 return (u32)(r1 + m - 1) % m; 481 } 482 483 static __always_inline void use_hint_elem(struct mldsa_ring_elem *w, 484 const u8 h[N], const s32 gamma2) 485 { 486 for (int j = 0; j < N; j++) 487 w->x[j] = use_hint(h[j], w->x[j], gamma2); 488 } 489 490 #if IS_ENABLED(CONFIG_CRYPTO_LIB_MLDSA_KUNIT_TEST) 491 /* Allow the __always_inline function use_hint() to be unit-tested. */ 492 s32 mldsa_use_hint(u8 h, s32 r, s32 gamma2) 493 { 494 return use_hint(h, r, gamma2); 495 } 496 EXPORT_SYMBOL_IF_KUNIT(mldsa_use_hint); 497 #endif 498 499 /* 500 * Encode one element of the commitment vector w'_1 into a byte string. 501 * Reference: FIPS 204 Algorithm 28, w1Encode. 502 * Return the number of bytes used: 192 for ML-DSA-44 and 128 for the others. 503 */ 504 static size_t encode_w1(u8 out[MAX_W1_ENCODED_LEN], 505 const struct mldsa_ring_elem *w1, int k) 506 { 507 size_t pos = 0; 508 509 static_assert(N * 6 / 8 == MAX_W1_ENCODED_LEN); 510 if (k == 4) { /* ML-DSA-44? */ 511 /* 6 bits per coefficient. Pack 4 at a time. */ 512 for (int j = 0; j < N; j += 4) { 513 u32 v = (w1->x[j + 0] << 0) | (w1->x[j + 1] << 6) | 514 (w1->x[j + 2] << 12) | (w1->x[j + 3] << 18); 515 out[pos++] = v >> 0; 516 out[pos++] = v >> 8; 517 out[pos++] = v >> 16; 518 } 519 } else { 520 /* 4 bits per coefficient. Pack 2 at a time. */ 521 for (int j = 0; j < N; j += 2) 522 out[pos++] = w1->x[j] | (w1->x[j + 1] << 4); 523 } 524 return pos; 525 } 526 527 /* Reference: FIPS 204 Section 6.3 "ML-DSA Verifying (Internal)" */ 528 int mldsa_verify(enum mldsa_alg alg, const u8 *sig, size_t sig_len, 529 const u8 *msg, size_t msg_len, const u8 *pk, size_t pk_len) 530 { 531 const struct mldsa_parameter_set *params = &mldsa_parameter_sets[alg]; 532 const int k = params->k, l = params->l; 533 /* For now this just does pure ML-DSA with an empty context string. */ 534 static const u8 msg_prefix[2] = { /* dom_sep= */ 0, /* ctx_len= */ 0 }; 535 const u8 *ctilde; /* The signer's commitment hash */ 536 const u8 *t1_encoded = &pk[RHO_LEN]; /* Next encoded element of t_1 */ 537 u8 *h; /* The signer's hint vector, length k * N */ 538 size_t w1_enc_len; 539 540 /* Validate the public key and signature lengths. */ 541 if (pk_len != params->pk_len || sig_len != params->sig_len) 542 return -EBADMSG; 543 544 /* 545 * Allocate the workspace, including variable-length fields. Its size 546 * depends only on the ML-DSA parameter set, not the other inputs. 547 * 548 * For freeing it, use kfree_sensitive() rather than kfree(). This is 549 * mainly to comply with FIPS 204 Section 3.6.3 "Intermediate Values". 550 * In reality it's a bit gratuitous, as this is a public key operation. 551 */ 552 struct mldsa_verification_workspace *ws __free(kfree_sensitive) = 553 kmalloc(sizeof(*ws) + (l * sizeof(ws->z[0])) + (k * N), 554 GFP_KERNEL); 555 if (!ws) 556 return -ENOMEM; 557 h = (u8 *)&ws->z[l]; 558 559 /* Decode the signature. Reference: FIPS 204 Algorithm 27, sigDecode */ 560 ctilde = sig; 561 sig += params->ctilde_len; 562 if (!decode_z(ws->z, l, params->gamma1, params->beta, &sig)) 563 return -EBADMSG; 564 if (!decode_hint_vector(h, k, params->omega, sig)) 565 return -EBADMSG; 566 567 /* Recreate the challenge c from the signer's commitment hash. */ 568 sample_in_ball(&ws->c, ctilde, params->ctilde_len, params->tau, 569 &ws->shake); 570 ntt(&ws->c); 571 572 /* Compute the message representative mu. */ 573 shake256(pk, pk_len, ws->tr, sizeof(ws->tr)); 574 shake256_init(&ws->shake); 575 shake_update(&ws->shake, ws->tr, sizeof(ws->tr)); 576 shake_update(&ws->shake, msg_prefix, sizeof(msg_prefix)); 577 shake_update(&ws->shake, msg, msg_len); 578 shake_squeeze(&ws->shake, ws->mu, sizeof(ws->mu)); 579 580 /* Start computing ctildeprime = H(mu || w1Encode(w'_1)). */ 581 shake256_init(&ws->shake); 582 shake_update(&ws->shake, ws->mu, sizeof(ws->mu)); 583 584 /* 585 * Compute the commitment w'_1 from A, z, c, t_1, and h. 586 * 587 * The computation is the same for each of the k rows. Just do each row 588 * before moving on to the next, resulting in only one loop over k. 589 */ 590 for (int i = 0; i < k; i++) { 591 /* 592 * tmp = NTT(A) * NTT(z) * 2^-32 593 * To reduce memory use, generate each element of NTT(A) 594 * on-demand. Note that each element is used only once. 595 */ 596 ws->tmp = (struct mldsa_ring_elem){}; 597 for (int j = 0; j < l; j++) { 598 rej_ntt_poly(&ws->a, pk /* rho is first field of pk */, 599 cpu_to_le16((i << 8) | j), &ws->a_shake, 600 ws->block); 601 for (int n = 0; n < N; n++) 602 ws->tmp.x[n] += 603 Zq_mult(ws->a.x[n], ws->z[j].x[n]); 604 } 605 /* All components of tmp now have abs value < l*q. */ 606 607 /* Decode the next element of t_1. */ 608 t1_encoded = decode_t1_elem(&ws->t1_scaled, t1_encoded); 609 610 /* 611 * tmp -= NTT(c) * NTT(t_1 * 2^d) * 2^-32 612 * 613 * Taking a conservative bound for the output of ntt(), the 614 * multiplicands can have absolute value up to 9*q. That 615 * corresponds to a product with absolute value 81*q^2. That is 616 * within the limits of Zq_mult() which needs < ~256*q^2. 617 */ 618 for (int j = 0; j < N; j++) 619 ws->tmp.x[j] -= Zq_mult(ws->c.x[j], ws->t1_scaled.x[j]); 620 /* All components of tmp now have abs value < (l+1)*q. */ 621 622 /* tmp = w'_Approx = NTT^-1(tmp) * 2^32 */ 623 invntt_and_mul_2_32(&ws->tmp); 624 /* All coefficients of tmp are now in [0, q - 1]. */ 625 626 /* 627 * tmp = w'_1 = UseHint(h, w'_Approx) 628 * For efficiency, set gamma2 to a compile-time constant. 629 */ 630 if (k == 4) 631 use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 88); 632 else 633 use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 32); 634 635 /* Encode and hash the next element of w'_1. */ 636 w1_enc_len = encode_w1(ws->w1_encoded, &ws->tmp, k); 637 shake_update(&ws->shake, ws->w1_encoded, w1_enc_len); 638 } 639 640 /* Finish computing ctildeprime. */ 641 shake_squeeze(&ws->shake, ws->ctildeprime, params->ctilde_len); 642 643 /* Verify that ctilde == ctildeprime. */ 644 if (memcmp(ws->ctildeprime, ctilde, params->ctilde_len) != 0) 645 return -EKEYREJECTED; 646 /* ||z||_infinity < gamma1 - beta was already checked in decode_z(). */ 647 return 0; 648 } 649 EXPORT_SYMBOL_GPL(mldsa_verify); 650 651 MODULE_DESCRIPTION("ML-DSA signature verification"); 652 MODULE_LICENSE("GPL"); 653