164edcceaSEric Biggers // SPDX-License-Identifier: GPL-2.0-or-later 264edcceaSEric Biggers /* 364edcceaSEric Biggers * Support for verifying ML-DSA signatures 464edcceaSEric Biggers * 564edcceaSEric Biggers * Copyright 2025 Google LLC 664edcceaSEric Biggers */ 764edcceaSEric Biggers 864edcceaSEric Biggers #include <crypto/mldsa.h> 964edcceaSEric Biggers #include <crypto/sha3.h> 1064edcceaSEric Biggers #include <kunit/visibility.h> 1164edcceaSEric Biggers #include <linux/export.h> 1264edcceaSEric Biggers #include <linux/module.h> 1364edcceaSEric Biggers #include <linux/slab.h> 1464edcceaSEric Biggers #include <linux/string.h> 1564edcceaSEric Biggers #include <linux/unaligned.h> 16*959a634eSEric Biggers #include "fips-mldsa.h" 1764edcceaSEric Biggers 1864edcceaSEric Biggers #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */ 1964edcceaSEric Biggers #define QINV_MOD_2_32 58728449 /* Multiplicative inverse of q mod 2^32 */ 2064edcceaSEric Biggers #define N 256 /* Number of components per ring element */ 2164edcceaSEric Biggers #define D 13 /* Number of bits dropped from the public key vector t */ 2264edcceaSEric Biggers #define RHO_LEN 32 /* Length of the public random seed in bytes */ 2364edcceaSEric Biggers #define MAX_W1_ENCODED_LEN 192 /* Max encoded length of one element of w'_1 */ 2464edcceaSEric Biggers 2564edcceaSEric Biggers /* 2664edcceaSEric Biggers * The zetas array in Montgomery form, i.e. with extra factor of 2^32. 2764edcceaSEric Biggers * Reference: FIPS 204 Section 7.5 "NTT and NTT^-1" 2864edcceaSEric Biggers * Generated by the following Python code: 2964edcceaSEric Biggers * q=8380417; [a%q - q*(a%q > q//2) for a in [1753**(int(f'{i:08b}'[::-1], 2)) << 32 for i in range(256)]] 3064edcceaSEric Biggers */ 3164edcceaSEric Biggers static const s32 zetas_times_2_32[N] = { 3264edcceaSEric Biggers -4186625, 25847, -2608894, -518909, 237124, -777960, -876248, 3364edcceaSEric Biggers 466468, 1826347, 2353451, -359251, -2091905, 3119733, -2884855, 3464edcceaSEric Biggers 3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488, 3564edcceaSEric Biggers -1119584, 2619752, -2108549, -2118186, -3859737, -1399561, -3277672, 3664edcceaSEric Biggers 1757237, -19422, 4010497, 280005, 2706023, 95776, 3077325, 3764edcceaSEric Biggers 3530437, -1661693, -3592148, -2537516, 3915439, -3861115, -3043716, 3864edcceaSEric Biggers 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267, 3964edcceaSEric Biggers -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596, 4064edcceaSEric Biggers 811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892, 4164edcceaSEric Biggers -2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950, 4264edcceaSEric Biggers 2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144, 4364edcceaSEric Biggers -3157330, -3632928, 126922, 3412210, -983419, 2147896, 2715295, 4464edcceaSEric Biggers -2967645, -3693493, -411027, -2477047, -671102, -1228525, -22981, 4564edcceaSEric Biggers -1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944, 4664edcceaSEric Biggers 508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342, 4764edcceaSEric Biggers -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856, 4864edcceaSEric Biggers 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 4964edcceaSEric Biggers 1341330, 1285669, -1584928, -812732, -1439742, -3019102, -3881060, 5064edcceaSEric Biggers -3628969, 3839961, 2091667, 3407706, 2316500, 3817976, -3342478, 5164edcceaSEric Biggers 2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181, 5264edcceaSEric Biggers -3520352, -3759364, -1197226, -3193378, 900702, 1859098, 909542, 5364edcceaSEric Biggers 819034, 495491, -1613174, -43260, -522500, -655327, -3122442, 5464edcceaSEric Biggers 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297, 5564edcceaSEric Biggers 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, 5664edcceaSEric Biggers 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 5764edcceaSEric Biggers 1595974, -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 5864edcceaSEric Biggers 1903435, -1050970, -1333058, 1237275, -3318210, -1430225, -451100, 5964edcceaSEric Biggers 1312455, 3306115, -1962642, -1279661, 1917081, -2546312, -1374803, 6064edcceaSEric Biggers 1500165, 777191, 2235880, 3406031, -542412, -2831860, -1671176, 6164edcceaSEric Biggers -1846953, -2584293, -3724270, 594136, -3776993, -2013608, 2432395, 6264edcceaSEric Biggers 2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426, 6364edcceaSEric Biggers 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, 6464edcceaSEric Biggers -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 6564edcceaSEric Biggers 472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333, 6664edcceaSEric Biggers -260646, -3833893, -2939036, -2235985, -420899, -2286327, 183443, 6764edcceaSEric Biggers -976891, 1612842, -3545687, -554416, 3919660, -48306, -1362209, 6864edcceaSEric Biggers 3937738, 1400424, -846154, 1976782 6964edcceaSEric Biggers }; 7064edcceaSEric Biggers 7164edcceaSEric Biggers /* Reference: FIPS 204 Section 4 "Parameter Sets" */ 7264edcceaSEric Biggers static const struct mldsa_parameter_set { 7364edcceaSEric Biggers u8 k; /* num rows in the matrix A */ 7464edcceaSEric Biggers u8 l; /* num columns in the matrix A */ 7564edcceaSEric Biggers u8 ctilde_len; /* length of commitment hash ctilde in bytes; lambda/4 */ 7664edcceaSEric Biggers u8 omega; /* max num of 1's in the hint vector h */ 7764edcceaSEric Biggers u8 tau; /* num of +-1's in challenge c */ 7864edcceaSEric Biggers u8 beta; /* tau times eta */ 7964edcceaSEric Biggers u16 pk_len; /* length of public keys in bytes */ 8064edcceaSEric Biggers u16 sig_len; /* length of signatures in bytes */ 8164edcceaSEric Biggers s32 gamma1; /* coefficient range of y */ 8264edcceaSEric Biggers } mldsa_parameter_sets[] = { 8364edcceaSEric Biggers [MLDSA44] = { 8464edcceaSEric Biggers .k = 4, 8564edcceaSEric Biggers .l = 4, 8664edcceaSEric Biggers .ctilde_len = 32, 8764edcceaSEric Biggers .omega = 80, 8864edcceaSEric Biggers .tau = 39, 8964edcceaSEric Biggers .beta = 78, 9064edcceaSEric Biggers .pk_len = MLDSA44_PUBLIC_KEY_SIZE, 9164edcceaSEric Biggers .sig_len = MLDSA44_SIGNATURE_SIZE, 9264edcceaSEric Biggers .gamma1 = 1 << 17, 9364edcceaSEric Biggers }, 9464edcceaSEric Biggers [MLDSA65] = { 9564edcceaSEric Biggers .k = 6, 9664edcceaSEric Biggers .l = 5, 9764edcceaSEric Biggers .ctilde_len = 48, 9864edcceaSEric Biggers .omega = 55, 9964edcceaSEric Biggers .tau = 49, 10064edcceaSEric Biggers .beta = 196, 10164edcceaSEric Biggers .pk_len = MLDSA65_PUBLIC_KEY_SIZE, 10264edcceaSEric Biggers .sig_len = MLDSA65_SIGNATURE_SIZE, 10364edcceaSEric Biggers .gamma1 = 1 << 19, 10464edcceaSEric Biggers }, 10564edcceaSEric Biggers [MLDSA87] = { 10664edcceaSEric Biggers .k = 8, 10764edcceaSEric Biggers .l = 7, 10864edcceaSEric Biggers .ctilde_len = 64, 10964edcceaSEric Biggers .omega = 75, 11064edcceaSEric Biggers .tau = 60, 11164edcceaSEric Biggers .beta = 120, 11264edcceaSEric Biggers .pk_len = MLDSA87_PUBLIC_KEY_SIZE, 11364edcceaSEric Biggers .sig_len = MLDSA87_SIGNATURE_SIZE, 11464edcceaSEric Biggers .gamma1 = 1 << 19, 11564edcceaSEric Biggers }, 11664edcceaSEric Biggers }; 11764edcceaSEric Biggers 11864edcceaSEric Biggers /* 11964edcceaSEric Biggers * An element of the ring R_q (normal form) or the ring T_q (NTT form). It 12064edcceaSEric Biggers * consists of N integers mod q: either the polynomial coefficients of the R_q 12164edcceaSEric Biggers * element or the components of the T_q element. In either case, whether they 12264edcceaSEric Biggers * are fully reduced to [0, q - 1] varies in the different parts of the code. 12364edcceaSEric Biggers */ 12464edcceaSEric Biggers struct mldsa_ring_elem { 12564edcceaSEric Biggers s32 x[N]; 12664edcceaSEric Biggers }; 12764edcceaSEric Biggers 12864edcceaSEric Biggers struct mldsa_verification_workspace { 12964edcceaSEric Biggers /* SHAKE context for computing c, mu, and ctildeprime */ 13064edcceaSEric Biggers struct shake_ctx shake; 13164edcceaSEric Biggers /* The fields in this union are used in their order of declaration. */ 13264edcceaSEric Biggers union { 13364edcceaSEric Biggers /* The hash of the public key */ 13464edcceaSEric Biggers u8 tr[64]; 13564edcceaSEric Biggers /* The message representative mu */ 13664edcceaSEric Biggers u8 mu[64]; 13764edcceaSEric Biggers /* Temporary space for rej_ntt_poly() */ 13864edcceaSEric Biggers u8 block[SHAKE128_BLOCK_SIZE + 1]; 13964edcceaSEric Biggers /* Encoded element of w'_1 */ 14064edcceaSEric Biggers u8 w1_encoded[MAX_W1_ENCODED_LEN]; 14164edcceaSEric Biggers /* The commitment hash. Real length is params->ctilde_len */ 14264edcceaSEric Biggers u8 ctildeprime[64]; 14364edcceaSEric Biggers }; 14464edcceaSEric Biggers /* SHAKE context for generating elements of the matrix A */ 14564edcceaSEric Biggers struct shake_ctx a_shake; 14664edcceaSEric Biggers /* 14764edcceaSEric Biggers * An element of the matrix A generated from the public seed, or an 14864edcceaSEric Biggers * element of the vector t_1 decoded from the public key and pre-scaled 14964edcceaSEric Biggers * by 2^d. Both are in NTT form. To reduce memory usage, we generate 15064edcceaSEric Biggers * or decode these elements only as needed. 15164edcceaSEric Biggers */ 15264edcceaSEric Biggers union { 15364edcceaSEric Biggers struct mldsa_ring_elem a; 15464edcceaSEric Biggers struct mldsa_ring_elem t1_scaled; 15564edcceaSEric Biggers }; 15664edcceaSEric Biggers /* The challenge c, generated from ctilde */ 15764edcceaSEric Biggers struct mldsa_ring_elem c; 15864edcceaSEric Biggers /* A temporary element used during calculations */ 15964edcceaSEric Biggers struct mldsa_ring_elem tmp; 16064edcceaSEric Biggers 16164edcceaSEric Biggers /* The following fields are variable-length: */ 16264edcceaSEric Biggers 16364edcceaSEric Biggers /* The signer's response vector */ 16464edcceaSEric Biggers struct mldsa_ring_elem z[/* l */]; 16564edcceaSEric Biggers 16664edcceaSEric Biggers /* The signer's hint vector */ 16764edcceaSEric Biggers /* u8 h[k * N]; */ 16864edcceaSEric Biggers }; 16964edcceaSEric Biggers 17064edcceaSEric Biggers /* 17164edcceaSEric Biggers * Compute a * b * 2^-32 mod q. a * b must be in the range [-2^31 * q, 2^31 * q 17264edcceaSEric Biggers * - 1] before reduction. The return value is in the range [-q + 1, q - 1]. 17364edcceaSEric Biggers * 17464edcceaSEric Biggers * To reduce mod q efficiently, this uses Montgomery reduction with R=2^32. 17564edcceaSEric Biggers * That's where the factor of 2^-32 comes from. The caller must include a 17664edcceaSEric Biggers * factor of 2^32 at some point to compensate for that. 17764edcceaSEric Biggers * 17864edcceaSEric Biggers * To keep the input and output ranges very close to symmetric, this 17964edcceaSEric Biggers * specifically does a "signed" Montgomery reduction. That is, when computing 18064edcceaSEric Biggers * d = c * q^-1 mod 2^32, this chooses a representative in [S32_MIN, S32_MAX] 18164edcceaSEric Biggers * rather than [0, U32_MAX], i.e. s32 rather than u32. This matters in the 18264edcceaSEric Biggers * wider multiplication d * Q when d keeps its value via sign extension. 18364edcceaSEric Biggers * 18464edcceaSEric Biggers * Reference: FIPS 204 Appendix A "Montgomery Multiplication". But, it doesn't 18564edcceaSEric Biggers * explain it properly: it has an off-by-one error in the upper end of the input 18664edcceaSEric Biggers * range, it doesn't clarify that the signed version should be used, and it 18764edcceaSEric Biggers * gives an unnecessarily large output range. A better citation is perhaps the 18864edcceaSEric Biggers * Dilithium reference code, which functionally matches the below code and 18964edcceaSEric Biggers * merely has the (benign) off-by-one error in its documentation. 19064edcceaSEric Biggers */ 19164edcceaSEric Biggers static inline s32 Zq_mult(s32 a, s32 b) 19264edcceaSEric Biggers { 19364edcceaSEric Biggers /* Compute the unreduced product c. */ 19464edcceaSEric Biggers s64 c = (s64)a * b; 19564edcceaSEric Biggers 19664edcceaSEric Biggers /* 19764edcceaSEric Biggers * Compute d = c * q^-1 mod 2^32. Generate a signed result, as 19864edcceaSEric Biggers * explained above, but do the actual multiplication using an unsigned 19964edcceaSEric Biggers * type to avoid signed integer overflow which is undefined behavior. 20064edcceaSEric Biggers */ 20164edcceaSEric Biggers s32 d = (u32)c * QINV_MOD_2_32; 20264edcceaSEric Biggers 20364edcceaSEric Biggers /* 20464edcceaSEric Biggers * Compute e = c - d * q. This makes the low 32 bits zero, since 20564edcceaSEric Biggers * c - (c * q^-1) * q mod 2^32 20664edcceaSEric Biggers * = c - c * (q^-1 * q) mod 2^32 20764edcceaSEric Biggers * = c - c * 1 mod 2^32 20864edcceaSEric Biggers * = c - c mod 2^32 20964edcceaSEric Biggers * = 0 mod 2^32 21064edcceaSEric Biggers */ 21164edcceaSEric Biggers s64 e = c - (s64)d * Q; 21264edcceaSEric Biggers 21364edcceaSEric Biggers /* Finally, return e * 2^-32. */ 21464edcceaSEric Biggers return e >> 32; 21564edcceaSEric Biggers } 21664edcceaSEric Biggers 21764edcceaSEric Biggers /* 21864edcceaSEric Biggers * Convert @w to its number-theoretically-transformed representation in-place. 21964edcceaSEric Biggers * Reference: FIPS 204 Algorithm 41, NTT 22064edcceaSEric Biggers * 22164edcceaSEric Biggers * To prevent intermediate overflows, all input coefficients must have absolute 22264edcceaSEric Biggers * value < q. All output components have absolute value < 9*q. 22364edcceaSEric Biggers */ 22464edcceaSEric Biggers static void ntt(struct mldsa_ring_elem *w) 22564edcceaSEric Biggers { 22664edcceaSEric Biggers int m = 0; /* index in zetas_times_2_32 */ 22764edcceaSEric Biggers 22864edcceaSEric Biggers for (int len = 128; len >= 1; len /= 2) { 22964edcceaSEric Biggers for (int start = 0; start < 256; start += 2 * len) { 23064edcceaSEric Biggers const s32 z = zetas_times_2_32[++m]; 23164edcceaSEric Biggers 23264edcceaSEric Biggers for (int j = start; j < start + len; j++) { 23364edcceaSEric Biggers s32 t = Zq_mult(z, w->x[j + len]); 23464edcceaSEric Biggers 23564edcceaSEric Biggers w->x[j + len] = w->x[j] - t; 23664edcceaSEric Biggers w->x[j] += t; 23764edcceaSEric Biggers } 23864edcceaSEric Biggers } 23964edcceaSEric Biggers } 24064edcceaSEric Biggers } 24164edcceaSEric Biggers 24264edcceaSEric Biggers /* 24364edcceaSEric Biggers * Convert @w from its number-theoretically-transformed representation in-place. 24464edcceaSEric Biggers * Reference: FIPS 204 Algorithm 42, NTT^-1 24564edcceaSEric Biggers * 24664edcceaSEric Biggers * This also multiplies the coefficients by 2^32, undoing an extra factor of 24764edcceaSEric Biggers * 2^-32 introduced earlier, and reduces the coefficients to [0, q - 1]. 24864edcceaSEric Biggers */ 24964edcceaSEric Biggers static void invntt_and_mul_2_32(struct mldsa_ring_elem *w) 25064edcceaSEric Biggers { 25164edcceaSEric Biggers int m = 256; /* index in zetas_times_2_32 */ 25264edcceaSEric Biggers 25364edcceaSEric Biggers /* Prevent intermediate overflows. */ 25464edcceaSEric Biggers for (int j = 0; j < 256; j++) 25564edcceaSEric Biggers w->x[j] %= Q; 25664edcceaSEric Biggers 25764edcceaSEric Biggers for (int len = 1; len < 256; len *= 2) { 25864edcceaSEric Biggers for (int start = 0; start < 256; start += 2 * len) { 25964edcceaSEric Biggers const s32 z = -zetas_times_2_32[--m]; 26064edcceaSEric Biggers 26164edcceaSEric Biggers for (int j = start; j < start + len; j++) { 26264edcceaSEric Biggers s32 t = w->x[j]; 26364edcceaSEric Biggers 26464edcceaSEric Biggers w->x[j] = t + w->x[j + len]; 26564edcceaSEric Biggers w->x[j + len] = Zq_mult(z, t - w->x[j + len]); 26664edcceaSEric Biggers } 26764edcceaSEric Biggers } 26864edcceaSEric Biggers } 26964edcceaSEric Biggers /* 27064edcceaSEric Biggers * Multiply by 2^32 * 256^-1. 2^32 cancels the factor of 2^-32 from 27164edcceaSEric Biggers * earlier Montgomery multiplications. 256^-1 is for NTT^-1. This 27264edcceaSEric Biggers * itself uses Montgomery multiplication, so *another* 2^32 is needed. 27364edcceaSEric Biggers * Thus the actual multiplicand is 2^32 * 2^32 * 256^-1 mod q = 41978. 27464edcceaSEric Biggers * 27564edcceaSEric Biggers * Finally, also reduce from [-q + 1, q - 1] to [0, q - 1]. 27664edcceaSEric Biggers */ 27764edcceaSEric Biggers for (int j = 0; j < 256; j++) { 27864edcceaSEric Biggers w->x[j] = Zq_mult(w->x[j], 41978); 27964edcceaSEric Biggers w->x[j] += (w->x[j] >> 31) & Q; 28064edcceaSEric Biggers } 28164edcceaSEric Biggers } 28264edcceaSEric Biggers 28364edcceaSEric Biggers /* 28464edcceaSEric Biggers * Decode an element of t_1, i.e. the high d bits of t = A*s_1 + s_2. 28564edcceaSEric Biggers * Reference: FIPS 204 Algorithm 23, pkDecode. 28664edcceaSEric Biggers * Also multiply it by 2^d and convert it to NTT form. 28764edcceaSEric Biggers */ 28864edcceaSEric Biggers static const u8 *decode_t1_elem(struct mldsa_ring_elem *out, 28964edcceaSEric Biggers const u8 *t1_encoded) 29064edcceaSEric Biggers { 29164edcceaSEric Biggers for (int j = 0; j < N; j += 4, t1_encoded += 5) { 29264edcceaSEric Biggers u32 v = get_unaligned_le32(t1_encoded); 29364edcceaSEric Biggers 29464edcceaSEric Biggers out->x[j + 0] = ((v >> 0) & 0x3ff) << D; 29564edcceaSEric Biggers out->x[j + 1] = ((v >> 10) & 0x3ff) << D; 29664edcceaSEric Biggers out->x[j + 2] = ((v >> 20) & 0x3ff) << D; 29764edcceaSEric Biggers out->x[j + 3] = ((v >> 30) | (t1_encoded[4] << 2)) << D; 29864edcceaSEric Biggers static_assert(0x3ff << D < Q); /* All coefficients < q. */ 29964edcceaSEric Biggers } 30064edcceaSEric Biggers ntt(out); 30164edcceaSEric Biggers return t1_encoded; /* Return updated pointer. */ 30264edcceaSEric Biggers } 30364edcceaSEric Biggers 30464edcceaSEric Biggers /* 30564edcceaSEric Biggers * Decode the signer's response vector 'z' from the signature. 30664edcceaSEric Biggers * Reference: FIPS 204 Algorithm 27, sigDecode. 30764edcceaSEric Biggers * 30864edcceaSEric Biggers * This also validates that the coefficients of z are in range, corresponding 30964edcceaSEric Biggers * the infinity norm check at the end of Algorithm 8, ML-DSA.Verify_internal. 31064edcceaSEric Biggers * 31164edcceaSEric Biggers * Finally, this also converts z to NTT form. 31264edcceaSEric Biggers */ 31364edcceaSEric Biggers static bool decode_z(struct mldsa_ring_elem z[/* l */], int l, s32 gamma1, 31464edcceaSEric Biggers int beta, const u8 **sig_ptr) 31564edcceaSEric Biggers { 31664edcceaSEric Biggers const u8 *sig = *sig_ptr; 31764edcceaSEric Biggers 31864edcceaSEric Biggers for (int i = 0; i < l; i++) { 31964edcceaSEric Biggers if (l == 4) { /* ML-DSA-44? */ 32064edcceaSEric Biggers /* 18-bit coefficients: decode 4 from 9 bytes. */ 32164edcceaSEric Biggers for (int j = 0; j < N; j += 4, sig += 9) { 32264edcceaSEric Biggers u64 v = get_unaligned_le64(sig); 32364edcceaSEric Biggers 32464edcceaSEric Biggers z[i].x[j + 0] = (v >> 0) & 0x3ffff; 32564edcceaSEric Biggers z[i].x[j + 1] = (v >> 18) & 0x3ffff; 32664edcceaSEric Biggers z[i].x[j + 2] = (v >> 36) & 0x3ffff; 32764edcceaSEric Biggers z[i].x[j + 3] = (v >> 54) | (sig[8] << 10); 32864edcceaSEric Biggers } 32964edcceaSEric Biggers } else { 33064edcceaSEric Biggers /* 20-bit coefficients: decode 4 from 10 bytes. */ 33164edcceaSEric Biggers for (int j = 0; j < N; j += 4, sig += 10) { 33264edcceaSEric Biggers u64 v = get_unaligned_le64(sig); 33364edcceaSEric Biggers 33464edcceaSEric Biggers z[i].x[j + 0] = (v >> 0) & 0xfffff; 33564edcceaSEric Biggers z[i].x[j + 1] = (v >> 20) & 0xfffff; 33664edcceaSEric Biggers z[i].x[j + 2] = (v >> 40) & 0xfffff; 33764edcceaSEric Biggers z[i].x[j + 3] = 33864edcceaSEric Biggers (v >> 60) | 33964edcceaSEric Biggers (get_unaligned_le16(&sig[8]) << 4); 34064edcceaSEric Biggers } 34164edcceaSEric Biggers } 34264edcceaSEric Biggers for (int j = 0; j < N; j++) { 34364edcceaSEric Biggers z[i].x[j] = gamma1 - z[i].x[j]; 34464edcceaSEric Biggers if (z[i].x[j] <= -(gamma1 - beta) || 34564edcceaSEric Biggers z[i].x[j] >= gamma1 - beta) 34664edcceaSEric Biggers return false; 34764edcceaSEric Biggers } 34864edcceaSEric Biggers ntt(&z[i]); 34964edcceaSEric Biggers } 35064edcceaSEric Biggers *sig_ptr = sig; /* Return updated pointer. */ 35164edcceaSEric Biggers return true; 35264edcceaSEric Biggers } 35364edcceaSEric Biggers 35464edcceaSEric Biggers /* 35564edcceaSEric Biggers * Decode the signer's hint vector 'h' from the signature. 35664edcceaSEric Biggers * Reference: FIPS 204 Algorithm 21, HintBitUnpack 35764edcceaSEric Biggers * 35864edcceaSEric Biggers * Note that there are several ways in which the hint vector can be malformed. 35964edcceaSEric Biggers */ 36064edcceaSEric Biggers static bool decode_hint_vector(u8 h[/* k * N */], int k, int omega, const u8 *y) 36164edcceaSEric Biggers { 36264edcceaSEric Biggers int index = 0; 36364edcceaSEric Biggers 36464edcceaSEric Biggers memset(h, 0, k * N); 36564edcceaSEric Biggers for (int i = 0; i < k; i++) { 36664edcceaSEric Biggers int count = y[omega + i]; /* num 1's in elems 0 through i */ 36764edcceaSEric Biggers int prev = -1; 36864edcceaSEric Biggers 36964edcceaSEric Biggers /* Cumulative count mustn't decrease or exceed omega. */ 37064edcceaSEric Biggers if (count < index || count > omega) 37164edcceaSEric Biggers return false; 37264edcceaSEric Biggers for (; index < count; index++) { 37364edcceaSEric Biggers if (prev >= y[index]) /* Coefficients out of order? */ 37464edcceaSEric Biggers return false; 37564edcceaSEric Biggers prev = y[index]; 37664edcceaSEric Biggers h[i * N + y[index]] = 1; 37764edcceaSEric Biggers } 37864edcceaSEric Biggers } 37964edcceaSEric Biggers return mem_is_zero(&y[index], omega - index); 38064edcceaSEric Biggers } 38164edcceaSEric Biggers 38264edcceaSEric Biggers /* 38364edcceaSEric Biggers * Expand @seed into an element of R_q @c with coefficients in {-1, 0, 1}, 38464edcceaSEric Biggers * exactly @tau of them nonzero. Reference: FIPS 204 Algorithm 29, SampleInBall 38564edcceaSEric Biggers */ 38664edcceaSEric Biggers static void sample_in_ball(struct mldsa_ring_elem *c, const u8 *seed, 38764edcceaSEric Biggers size_t seed_len, int tau, struct shake_ctx *shake) 38864edcceaSEric Biggers { 38964edcceaSEric Biggers u64 signs; 39064edcceaSEric Biggers u8 j; 39164edcceaSEric Biggers 39264edcceaSEric Biggers shake256_init(shake); 39364edcceaSEric Biggers shake_update(shake, seed, seed_len); 39464edcceaSEric Biggers shake_squeeze(shake, (u8 *)&signs, sizeof(signs)); 39564edcceaSEric Biggers le64_to_cpus(&signs); 39664edcceaSEric Biggers *c = (struct mldsa_ring_elem){}; 39764edcceaSEric Biggers for (int i = N - tau; i < N; i++, signs >>= 1) { 39864edcceaSEric Biggers do { 39964edcceaSEric Biggers shake_squeeze(shake, &j, 1); 40064edcceaSEric Biggers } while (j > i); 40164edcceaSEric Biggers c->x[i] = c->x[j]; 40264edcceaSEric Biggers c->x[j] = 1 - 2 * (s32)(signs & 1); 40364edcceaSEric Biggers } 40464edcceaSEric Biggers } 40564edcceaSEric Biggers 40664edcceaSEric Biggers /* 40764edcceaSEric Biggers * Expand the public seed @rho and @row_and_column into an element of T_q @out. 40864edcceaSEric Biggers * Reference: FIPS 204 Algorithm 30, RejNTTPoly 40964edcceaSEric Biggers * 41064edcceaSEric Biggers * @shake and @block are temporary space used by the expansion. @block has 41164edcceaSEric Biggers * space for one SHAKE128 block, plus an extra byte to allow reading a u32 from 41264edcceaSEric Biggers * the final 3-byte group without reading out-of-bounds. 41364edcceaSEric Biggers */ 41464edcceaSEric Biggers static void rej_ntt_poly(struct mldsa_ring_elem *out, const u8 rho[RHO_LEN], 41564edcceaSEric Biggers __le16 row_and_column, struct shake_ctx *shake, 41664edcceaSEric Biggers u8 block[SHAKE128_BLOCK_SIZE + 1]) 41764edcceaSEric Biggers { 41864edcceaSEric Biggers shake128_init(shake); 41964edcceaSEric Biggers shake_update(shake, rho, RHO_LEN); 42064edcceaSEric Biggers shake_update(shake, (u8 *)&row_and_column, sizeof(row_and_column)); 42164edcceaSEric Biggers for (int i = 0; i < N;) { 42264edcceaSEric Biggers shake_squeeze(shake, block, SHAKE128_BLOCK_SIZE); 42364edcceaSEric Biggers block[SHAKE128_BLOCK_SIZE] = 0; /* for KMSAN */ 42464edcceaSEric Biggers static_assert(SHAKE128_BLOCK_SIZE % 3 == 0); 42564edcceaSEric Biggers for (int j = 0; j < SHAKE128_BLOCK_SIZE && i < N; j += 3) { 42664edcceaSEric Biggers u32 x = get_unaligned_le32(&block[j]) & 0x7fffff; 42764edcceaSEric Biggers 42864edcceaSEric Biggers if (x < Q) /* Ignore values >= q. */ 42964edcceaSEric Biggers out->x[i++] = x; 43064edcceaSEric Biggers } 43164edcceaSEric Biggers } 43264edcceaSEric Biggers } 43364edcceaSEric Biggers 43464edcceaSEric Biggers /* 43564edcceaSEric Biggers * Return the HighBits of r adjusted according to hint h 43664edcceaSEric Biggers * Reference: FIPS 204 Algorithm 40, UseHint 43764edcceaSEric Biggers * 43864edcceaSEric Biggers * This is needed because of the public key compression in ML-DSA. 43964edcceaSEric Biggers * 44064edcceaSEric Biggers * h is either 0 or 1, r is in [0, q - 1], and gamma2 is either (q - 1) / 88 or 44164edcceaSEric Biggers * (q - 1) / 32. Except when invoked via the unit test interface, gamma2 is a 44264edcceaSEric Biggers * compile-time constant, so compilers will optimize the code accordingly. 44364edcceaSEric Biggers */ 44464edcceaSEric Biggers static __always_inline s32 use_hint(u8 h, s32 r, const s32 gamma2) 44564edcceaSEric Biggers { 44664edcceaSEric Biggers const s32 m = (Q - 1) / (2 * gamma2); /* 44 or 16, compile-time const */ 44764edcceaSEric Biggers s32 r1; 44864edcceaSEric Biggers 44964edcceaSEric Biggers /* 45064edcceaSEric Biggers * Handle the special case where r - (r mod+- (2 * gamma2)) == q - 1, 45164edcceaSEric Biggers * i.e. r >= q - gamma2. This is also exactly where the computation of 45264edcceaSEric Biggers * r1 below would produce 'm' and would need a correction. 45364edcceaSEric Biggers */ 45464edcceaSEric Biggers if (r >= Q - gamma2) 45564edcceaSEric Biggers return h == 0 ? 0 : m - 1; 45664edcceaSEric Biggers 45764edcceaSEric Biggers /* 45864edcceaSEric Biggers * Compute the (non-hint-adjusted) HighBits r1 as: 45964edcceaSEric Biggers * 46064edcceaSEric Biggers * r1 = (r - (r mod+- (2 * gamma2))) / (2 * gamma2) 46164edcceaSEric Biggers * = floor((r + gamma2 - 1) / (2 * gamma2)) 46264edcceaSEric Biggers * 46364edcceaSEric Biggers * Note that when '2 * gamma2' is a compile-time constant, compilers 46464edcceaSEric Biggers * optimize the division to a reciprocal multiplication and shift. 46564edcceaSEric Biggers */ 46664edcceaSEric Biggers r1 = (u32)(r + gamma2 - 1) / (2 * gamma2); 46764edcceaSEric Biggers 46864edcceaSEric Biggers /* 46964edcceaSEric Biggers * Return the HighBits r1: 47064edcceaSEric Biggers * + 0 if the hint is 0; 47164edcceaSEric Biggers * + 1 (mod m) if the hint is 1 and the LowBits are positive; 47264edcceaSEric Biggers * - 1 (mod m) if the hint is 1 and the LowBits are negative or 0. 47364edcceaSEric Biggers * 47464edcceaSEric Biggers * r1 is in (and remains in) [0, m - 1]. Note that when 'm' is a 47564edcceaSEric Biggers * compile-time constant, compilers optimize the '% m' accordingly. 47664edcceaSEric Biggers */ 47764edcceaSEric Biggers if (h == 0) 47864edcceaSEric Biggers return r1; 47964edcceaSEric Biggers if (r > r1 * (2 * gamma2)) 48064edcceaSEric Biggers return (u32)(r1 + 1) % m; 48164edcceaSEric Biggers return (u32)(r1 + m - 1) % m; 48264edcceaSEric Biggers } 48364edcceaSEric Biggers 48464edcceaSEric Biggers static __always_inline void use_hint_elem(struct mldsa_ring_elem *w, 48564edcceaSEric Biggers const u8 h[N], const s32 gamma2) 48664edcceaSEric Biggers { 48764edcceaSEric Biggers for (int j = 0; j < N; j++) 48864edcceaSEric Biggers w->x[j] = use_hint(h[j], w->x[j], gamma2); 48964edcceaSEric Biggers } 49064edcceaSEric Biggers 49164edcceaSEric Biggers #if IS_ENABLED(CONFIG_CRYPTO_LIB_MLDSA_KUNIT_TEST) 49264edcceaSEric Biggers /* Allow the __always_inline function use_hint() to be unit-tested. */ 49364edcceaSEric Biggers s32 mldsa_use_hint(u8 h, s32 r, s32 gamma2) 49464edcceaSEric Biggers { 49564edcceaSEric Biggers return use_hint(h, r, gamma2); 49664edcceaSEric Biggers } 49764edcceaSEric Biggers EXPORT_SYMBOL_IF_KUNIT(mldsa_use_hint); 49864edcceaSEric Biggers #endif 49964edcceaSEric Biggers 50064edcceaSEric Biggers /* 50164edcceaSEric Biggers * Encode one element of the commitment vector w'_1 into a byte string. 50264edcceaSEric Biggers * Reference: FIPS 204 Algorithm 28, w1Encode. 50364edcceaSEric Biggers * Return the number of bytes used: 192 for ML-DSA-44 and 128 for the others. 50464edcceaSEric Biggers */ 50564edcceaSEric Biggers static size_t encode_w1(u8 out[MAX_W1_ENCODED_LEN], 50664edcceaSEric Biggers const struct mldsa_ring_elem *w1, int k) 50764edcceaSEric Biggers { 50864edcceaSEric Biggers size_t pos = 0; 50964edcceaSEric Biggers 51064edcceaSEric Biggers static_assert(N * 6 / 8 == MAX_W1_ENCODED_LEN); 51164edcceaSEric Biggers if (k == 4) { /* ML-DSA-44? */ 51264edcceaSEric Biggers /* 6 bits per coefficient. Pack 4 at a time. */ 51364edcceaSEric Biggers for (int j = 0; j < N; j += 4) { 51464edcceaSEric Biggers u32 v = (w1->x[j + 0] << 0) | (w1->x[j + 1] << 6) | 51564edcceaSEric Biggers (w1->x[j + 2] << 12) | (w1->x[j + 3] << 18); 51664edcceaSEric Biggers out[pos++] = v >> 0; 51764edcceaSEric Biggers out[pos++] = v >> 8; 51864edcceaSEric Biggers out[pos++] = v >> 16; 51964edcceaSEric Biggers } 52064edcceaSEric Biggers } else { 52164edcceaSEric Biggers /* 4 bits per coefficient. Pack 2 at a time. */ 52264edcceaSEric Biggers for (int j = 0; j < N; j += 2) 52364edcceaSEric Biggers out[pos++] = w1->x[j] | (w1->x[j + 1] << 4); 52464edcceaSEric Biggers } 52564edcceaSEric Biggers return pos; 52664edcceaSEric Biggers } 52764edcceaSEric Biggers 52864edcceaSEric Biggers int mldsa_verify(enum mldsa_alg alg, const u8 *sig, size_t sig_len, 52964edcceaSEric Biggers const u8 *msg, size_t msg_len, const u8 *pk, size_t pk_len) 53064edcceaSEric Biggers { 53164edcceaSEric Biggers const struct mldsa_parameter_set *params = &mldsa_parameter_sets[alg]; 53264edcceaSEric Biggers const int k = params->k, l = params->l; 53364edcceaSEric Biggers /* For now this just does pure ML-DSA with an empty context string. */ 53464edcceaSEric Biggers static const u8 msg_prefix[2] = { /* dom_sep= */ 0, /* ctx_len= */ 0 }; 53564edcceaSEric Biggers const u8 *ctilde; /* The signer's commitment hash */ 53664edcceaSEric Biggers const u8 *t1_encoded = &pk[RHO_LEN]; /* Next encoded element of t_1 */ 53764edcceaSEric Biggers u8 *h; /* The signer's hint vector, length k * N */ 53864edcceaSEric Biggers size_t w1_enc_len; 53964edcceaSEric Biggers 54064edcceaSEric Biggers /* Validate the public key and signature lengths. */ 54164edcceaSEric Biggers if (pk_len != params->pk_len || sig_len != params->sig_len) 54264edcceaSEric Biggers return -EBADMSG; 54364edcceaSEric Biggers 54464edcceaSEric Biggers /* 54564edcceaSEric Biggers * Allocate the workspace, including variable-length fields. Its size 54664edcceaSEric Biggers * depends only on the ML-DSA parameter set, not the other inputs. 54764edcceaSEric Biggers * 54864edcceaSEric Biggers * For freeing it, use kfree_sensitive() rather than kfree(). This is 54964edcceaSEric Biggers * mainly to comply with FIPS 204 Section 3.6.3 "Intermediate Values". 55064edcceaSEric Biggers * In reality it's a bit gratuitous, as this is a public key operation. 55164edcceaSEric Biggers */ 55264edcceaSEric Biggers struct mldsa_verification_workspace *ws __free(kfree_sensitive) = 55364edcceaSEric Biggers kmalloc(sizeof(*ws) + (l * sizeof(ws->z[0])) + (k * N), 55464edcceaSEric Biggers GFP_KERNEL); 55564edcceaSEric Biggers if (!ws) 55664edcceaSEric Biggers return -ENOMEM; 55764edcceaSEric Biggers h = (u8 *)&ws->z[l]; 55864edcceaSEric Biggers 55964edcceaSEric Biggers /* Decode the signature. Reference: FIPS 204 Algorithm 27, sigDecode */ 56064edcceaSEric Biggers ctilde = sig; 56164edcceaSEric Biggers sig += params->ctilde_len; 56264edcceaSEric Biggers if (!decode_z(ws->z, l, params->gamma1, params->beta, &sig)) 56364edcceaSEric Biggers return -EBADMSG; 56464edcceaSEric Biggers if (!decode_hint_vector(h, k, params->omega, sig)) 56564edcceaSEric Biggers return -EBADMSG; 56664edcceaSEric Biggers 56764edcceaSEric Biggers /* Recreate the challenge c from the signer's commitment hash. */ 56864edcceaSEric Biggers sample_in_ball(&ws->c, ctilde, params->ctilde_len, params->tau, 56964edcceaSEric Biggers &ws->shake); 57064edcceaSEric Biggers ntt(&ws->c); 57164edcceaSEric Biggers 57264edcceaSEric Biggers /* Compute the message representative mu. */ 57364edcceaSEric Biggers shake256(pk, pk_len, ws->tr, sizeof(ws->tr)); 57464edcceaSEric Biggers shake256_init(&ws->shake); 57564edcceaSEric Biggers shake_update(&ws->shake, ws->tr, sizeof(ws->tr)); 57664edcceaSEric Biggers shake_update(&ws->shake, msg_prefix, sizeof(msg_prefix)); 57764edcceaSEric Biggers shake_update(&ws->shake, msg, msg_len); 57864edcceaSEric Biggers shake_squeeze(&ws->shake, ws->mu, sizeof(ws->mu)); 57964edcceaSEric Biggers 58064edcceaSEric Biggers /* Start computing ctildeprime = H(mu || w1Encode(w'_1)). */ 58164edcceaSEric Biggers shake256_init(&ws->shake); 58264edcceaSEric Biggers shake_update(&ws->shake, ws->mu, sizeof(ws->mu)); 58364edcceaSEric Biggers 58464edcceaSEric Biggers /* 58564edcceaSEric Biggers * Compute the commitment w'_1 from A, z, c, t_1, and h. 58664edcceaSEric Biggers * 58764edcceaSEric Biggers * The computation is the same for each of the k rows. Just do each row 58864edcceaSEric Biggers * before moving on to the next, resulting in only one loop over k. 58964edcceaSEric Biggers */ 59064edcceaSEric Biggers for (int i = 0; i < k; i++) { 59164edcceaSEric Biggers /* 59264edcceaSEric Biggers * tmp = NTT(A) * NTT(z) * 2^-32 59364edcceaSEric Biggers * To reduce memory use, generate each element of NTT(A) 59464edcceaSEric Biggers * on-demand. Note that each element is used only once. 59564edcceaSEric Biggers */ 59664edcceaSEric Biggers ws->tmp = (struct mldsa_ring_elem){}; 59764edcceaSEric Biggers for (int j = 0; j < l; j++) { 59864edcceaSEric Biggers rej_ntt_poly(&ws->a, pk /* rho is first field of pk */, 59964edcceaSEric Biggers cpu_to_le16((i << 8) | j), &ws->a_shake, 60064edcceaSEric Biggers ws->block); 60164edcceaSEric Biggers for (int n = 0; n < N; n++) 60264edcceaSEric Biggers ws->tmp.x[n] += 60364edcceaSEric Biggers Zq_mult(ws->a.x[n], ws->z[j].x[n]); 60464edcceaSEric Biggers } 60564edcceaSEric Biggers /* All components of tmp now have abs value < l*q. */ 60664edcceaSEric Biggers 60764edcceaSEric Biggers /* Decode the next element of t_1. */ 60864edcceaSEric Biggers t1_encoded = decode_t1_elem(&ws->t1_scaled, t1_encoded); 60964edcceaSEric Biggers 61064edcceaSEric Biggers /* 61164edcceaSEric Biggers * tmp -= NTT(c) * NTT(t_1 * 2^d) * 2^-32 61264edcceaSEric Biggers * 61364edcceaSEric Biggers * Taking a conservative bound for the output of ntt(), the 61464edcceaSEric Biggers * multiplicands can have absolute value up to 9*q. That 61564edcceaSEric Biggers * corresponds to a product with absolute value 81*q^2. That is 61664edcceaSEric Biggers * within the limits of Zq_mult() which needs < ~256*q^2. 61764edcceaSEric Biggers */ 61864edcceaSEric Biggers for (int j = 0; j < N; j++) 61964edcceaSEric Biggers ws->tmp.x[j] -= Zq_mult(ws->c.x[j], ws->t1_scaled.x[j]); 62064edcceaSEric Biggers /* All components of tmp now have abs value < (l+1)*q. */ 62164edcceaSEric Biggers 62264edcceaSEric Biggers /* tmp = w'_Approx = NTT^-1(tmp) * 2^32 */ 62364edcceaSEric Biggers invntt_and_mul_2_32(&ws->tmp); 62464edcceaSEric Biggers /* All coefficients of tmp are now in [0, q - 1]. */ 62564edcceaSEric Biggers 62664edcceaSEric Biggers /* 62764edcceaSEric Biggers * tmp = w'_1 = UseHint(h, w'_Approx) 62864edcceaSEric Biggers * For efficiency, set gamma2 to a compile-time constant. 62964edcceaSEric Biggers */ 63064edcceaSEric Biggers if (k == 4) 63164edcceaSEric Biggers use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 88); 63264edcceaSEric Biggers else 63364edcceaSEric Biggers use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 32); 63464edcceaSEric Biggers 63564edcceaSEric Biggers /* Encode and hash the next element of w'_1. */ 63664edcceaSEric Biggers w1_enc_len = encode_w1(ws->w1_encoded, &ws->tmp, k); 63764edcceaSEric Biggers shake_update(&ws->shake, ws->w1_encoded, w1_enc_len); 63864edcceaSEric Biggers } 63964edcceaSEric Biggers 64064edcceaSEric Biggers /* Finish computing ctildeprime. */ 64164edcceaSEric Biggers shake_squeeze(&ws->shake, ws->ctildeprime, params->ctilde_len); 64264edcceaSEric Biggers 64364edcceaSEric Biggers /* Verify that ctilde == ctildeprime. */ 64464edcceaSEric Biggers if (memcmp(ws->ctildeprime, ctilde, params->ctilde_len) != 0) 64564edcceaSEric Biggers return -EKEYREJECTED; 64664edcceaSEric Biggers /* ||z||_infinity < gamma1 - beta was already checked in decode_z(). */ 64764edcceaSEric Biggers return 0; 64864edcceaSEric Biggers } 64964edcceaSEric Biggers EXPORT_SYMBOL_GPL(mldsa_verify); 65064edcceaSEric Biggers 651*959a634eSEric Biggers #ifdef CONFIG_CRYPTO_FIPS 652*959a634eSEric Biggers static int __init mldsa_mod_init(void) 653*959a634eSEric Biggers { 654*959a634eSEric Biggers if (fips_enabled) { 655*959a634eSEric Biggers /* 656*959a634eSEric Biggers * FIPS cryptographic algorithm self-test. As per the FIPS 657*959a634eSEric Biggers * Implementation Guidance, testing any ML-DSA parameter set 658*959a634eSEric Biggers * satisfies the test requirement for all of them, and only a 659*959a634eSEric Biggers * positive test is required. 660*959a634eSEric Biggers */ 661*959a634eSEric Biggers int err = mldsa_verify(MLDSA65, fips_test_mldsa65_signature, 662*959a634eSEric Biggers sizeof(fips_test_mldsa65_signature), 663*959a634eSEric Biggers fips_test_mldsa65_message, 664*959a634eSEric Biggers sizeof(fips_test_mldsa65_message), 665*959a634eSEric Biggers fips_test_mldsa65_public_key, 666*959a634eSEric Biggers sizeof(fips_test_mldsa65_public_key)); 667*959a634eSEric Biggers if (err) 668*959a634eSEric Biggers panic("mldsa: FIPS self-test failed; err=%pe\n", 669*959a634eSEric Biggers ERR_PTR(err)); 670*959a634eSEric Biggers } 671*959a634eSEric Biggers return 0; 672*959a634eSEric Biggers } 673*959a634eSEric Biggers subsys_initcall(mldsa_mod_init); 674*959a634eSEric Biggers 675*959a634eSEric Biggers static void __exit mldsa_mod_exit(void) 676*959a634eSEric Biggers { 677*959a634eSEric Biggers } 678*959a634eSEric Biggers module_exit(mldsa_mod_exit); 679*959a634eSEric Biggers #endif /* CONFIG_CRYPTO_FIPS */ 680*959a634eSEric Biggers 68164edcceaSEric Biggers MODULE_DESCRIPTION("ML-DSA signature verification"); 68264edcceaSEric Biggers MODULE_LICENSE("GPL"); 683