1*ed894facSEric Biggers // SPDX-License-Identifier: GPL-2.0-or-later 2*ed894facSEric Biggers /* 3*ed894facSEric Biggers * KUnit tests and benchmark for ML-DSA 4*ed894facSEric Biggers * 5*ed894facSEric Biggers * Copyright 2025 Google LLC 6*ed894facSEric Biggers */ 7*ed894facSEric Biggers #include <crypto/mldsa.h> 8*ed894facSEric Biggers #include <kunit/test.h> 9*ed894facSEric Biggers #include <linux/random.h> 10*ed894facSEric Biggers #include <linux/unaligned.h> 11*ed894facSEric Biggers 12*ed894facSEric Biggers #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */ 13*ed894facSEric Biggers 14*ed894facSEric Biggers /* ML-DSA parameters that the tests use */ 15*ed894facSEric Biggers static const struct { 16*ed894facSEric Biggers int sig_len; 17*ed894facSEric Biggers int pk_len; 18*ed894facSEric Biggers int k; 19*ed894facSEric Biggers int lambda; 20*ed894facSEric Biggers int gamma1; 21*ed894facSEric Biggers int beta; 22*ed894facSEric Biggers int omega; 23*ed894facSEric Biggers } params[] = { 24*ed894facSEric Biggers [MLDSA44] = { 25*ed894facSEric Biggers .sig_len = MLDSA44_SIGNATURE_SIZE, 26*ed894facSEric Biggers .pk_len = MLDSA44_PUBLIC_KEY_SIZE, 27*ed894facSEric Biggers .k = 4, 28*ed894facSEric Biggers .lambda = 128, 29*ed894facSEric Biggers .gamma1 = 1 << 17, 30*ed894facSEric Biggers .beta = 78, 31*ed894facSEric Biggers .omega = 80, 32*ed894facSEric Biggers }, 33*ed894facSEric Biggers [MLDSA65] = { 34*ed894facSEric Biggers .sig_len = MLDSA65_SIGNATURE_SIZE, 35*ed894facSEric Biggers .pk_len = MLDSA65_PUBLIC_KEY_SIZE, 36*ed894facSEric Biggers .k = 6, 37*ed894facSEric Biggers .lambda = 192, 38*ed894facSEric Biggers .gamma1 = 1 << 19, 39*ed894facSEric Biggers .beta = 196, 40*ed894facSEric Biggers .omega = 55, 41*ed894facSEric Biggers }, 42*ed894facSEric Biggers [MLDSA87] = { 43*ed894facSEric Biggers .sig_len = MLDSA87_SIGNATURE_SIZE, 44*ed894facSEric Biggers .pk_len = MLDSA87_PUBLIC_KEY_SIZE, 45*ed894facSEric Biggers .k = 8, 46*ed894facSEric Biggers .lambda = 256, 47*ed894facSEric Biggers .gamma1 = 1 << 19, 48*ed894facSEric Biggers .beta = 120, 49*ed894facSEric Biggers .omega = 75, 50*ed894facSEric Biggers }, 51*ed894facSEric Biggers }; 52*ed894facSEric Biggers 53*ed894facSEric Biggers #include "mldsa-testvecs.h" 54*ed894facSEric Biggers 55*ed894facSEric Biggers static void do_mldsa_and_assert_success(struct kunit *test, 56*ed894facSEric Biggers const struct mldsa_testvector *tv) 57*ed894facSEric Biggers { 58*ed894facSEric Biggers int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 59*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len); 60*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, err, 0); 61*ed894facSEric Biggers } 62*ed894facSEric Biggers 63*ed894facSEric Biggers static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len) 64*ed894facSEric Biggers { 65*ed894facSEric Biggers u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL); 66*ed894facSEric Biggers 67*ed894facSEric Biggers KUNIT_ASSERT_NOT_NULL(test, dst); 68*ed894facSEric Biggers return memcpy(dst, src, len); 69*ed894facSEric Biggers } 70*ed894facSEric Biggers 71*ed894facSEric Biggers /* 72*ed894facSEric Biggers * Test that changing coefficients in a valid signature's z vector results in 73*ed894facSEric Biggers * the following behavior from mldsa_verify(): 74*ed894facSEric Biggers * 75*ed894facSEric Biggers * * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e. 76*ed894facSEric Biggers * absolute value >= gamma1 - beta, corresponding to the verifier detecting 77*ed894facSEric Biggers * the out-of-range coefficient and rejecting the signature as malformed 78*ed894facSEric Biggers * 79*ed894facSEric Biggers * * -EKEYREJECTED if a coefficient is changed to a different in-range value, 80*ed894facSEric Biggers * i.e. absolute value < gamma1 - beta, corresponding to the verifier 81*ed894facSEric Biggers * continuing to the "real" signature check and that check failing 82*ed894facSEric Biggers */ 83*ed894facSEric Biggers static void test_mldsa_z_range(struct kunit *test, 84*ed894facSEric Biggers const struct mldsa_testvector *tv) 85*ed894facSEric Biggers { 86*ed894facSEric Biggers u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len); 87*ed894facSEric Biggers const int lambda = params[tv->alg].lambda; 88*ed894facSEric Biggers const s32 gamma1 = params[tv->alg].gamma1; 89*ed894facSEric Biggers const int beta = params[tv->alg].beta; 90*ed894facSEric Biggers /* 91*ed894facSEric Biggers * We just modify the first coefficient. The coefficient is gamma1 92*ed894facSEric Biggers * minus either the first 18 or 20 bits of the u32, depending on gamma1. 93*ed894facSEric Biggers * 94*ed894facSEric Biggers * The layout of ML-DSA signatures is ctilde || z || h. ctilde is 95*ed894facSEric Biggers * lambda / 4 bytes, so z starts at &sig[lambda / 4]. 96*ed894facSEric Biggers */ 97*ed894facSEric Biggers u8 *z_ptr = &sig[lambda / 4]; 98*ed894facSEric Biggers const u32 z_data = get_unaligned_le32(z_ptr); 99*ed894facSEric Biggers const u32 mask = (gamma1 << 1) - 1; 100*ed894facSEric Biggers /* These are the four boundaries of the out-of-range values. */ 101*ed894facSEric Biggers const s32 out_of_range_coeffs[] = { 102*ed894facSEric Biggers -gamma1 + 1, 103*ed894facSEric Biggers -(gamma1 - beta), 104*ed894facSEric Biggers gamma1, 105*ed894facSEric Biggers gamma1 - beta, 106*ed894facSEric Biggers }; 107*ed894facSEric Biggers /* 108*ed894facSEric Biggers * These are the two boundaries of the valid range, along with 0. We 109*ed894facSEric Biggers * assume that none of these matches the original coefficient. 110*ed894facSEric Biggers */ 111*ed894facSEric Biggers const s32 in_range_coeffs[] = { 112*ed894facSEric Biggers -(gamma1 - beta - 1), 113*ed894facSEric Biggers 0, 114*ed894facSEric Biggers gamma1 - beta - 1, 115*ed894facSEric Biggers }; 116*ed894facSEric Biggers 117*ed894facSEric Biggers /* Initially the signature is valid. */ 118*ed894facSEric Biggers do_mldsa_and_assert_success(test, tv); 119*ed894facSEric Biggers 120*ed894facSEric Biggers /* Test some out-of-range coefficients. */ 121*ed894facSEric Biggers for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) { 122*ed894facSEric Biggers const s32 c = out_of_range_coeffs[i]; 123*ed894facSEric Biggers 124*ed894facSEric Biggers put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)), 125*ed894facSEric Biggers z_ptr); 126*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 127*ed894facSEric Biggers mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 128*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 129*ed894facSEric Biggers } 130*ed894facSEric Biggers 131*ed894facSEric Biggers /* Test some in-range coefficients. */ 132*ed894facSEric Biggers for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) { 133*ed894facSEric Biggers const s32 c = in_range_coeffs[i]; 134*ed894facSEric Biggers 135*ed894facSEric Biggers put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)), 136*ed894facSEric Biggers z_ptr); 137*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EKEYREJECTED, 138*ed894facSEric Biggers mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 139*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 140*ed894facSEric Biggers } 141*ed894facSEric Biggers } 142*ed894facSEric Biggers 143*ed894facSEric Biggers /* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */ 144*ed894facSEric Biggers static void test_mldsa_bad_hints(struct kunit *test, 145*ed894facSEric Biggers const struct mldsa_testvector *tv) 146*ed894facSEric Biggers { 147*ed894facSEric Biggers const int omega = params[tv->alg].omega; 148*ed894facSEric Biggers const int k = params[tv->alg].k; 149*ed894facSEric Biggers u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len); 150*ed894facSEric Biggers /* Pointer to the encoded hint vector in the signature */ 151*ed894facSEric Biggers u8 *hintvec = &sig[tv->sig_len - omega - k]; 152*ed894facSEric Biggers u8 h; 153*ed894facSEric Biggers 154*ed894facSEric Biggers /* Initially the signature is valid. */ 155*ed894facSEric Biggers do_mldsa_and_assert_success(test, tv); 156*ed894facSEric Biggers 157*ed894facSEric Biggers /* Cumulative hint count exceeds omega */ 158*ed894facSEric Biggers memcpy(sig, tv->sig, tv->sig_len); 159*ed894facSEric Biggers hintvec[omega + k - 1] = omega + 1; 160*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 161*ed894facSEric Biggers mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 162*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 163*ed894facSEric Biggers 164*ed894facSEric Biggers /* Cumulative hint count decreases */ 165*ed894facSEric Biggers memcpy(sig, tv->sig, tv->sig_len); 166*ed894facSEric Biggers KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1); 167*ed894facSEric Biggers hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1; 168*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 169*ed894facSEric Biggers mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 170*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 171*ed894facSEric Biggers 172*ed894facSEric Biggers /* 173*ed894facSEric Biggers * Hint indices out of order. To test this, swap hintvec[0] and 174*ed894facSEric Biggers * hintvec[1]. This assumes that the original valid signature had at 175*ed894facSEric Biggers * least two nonzero hints in the first element (asserted below). 176*ed894facSEric Biggers */ 177*ed894facSEric Biggers memcpy(sig, tv->sig, tv->sig_len); 178*ed894facSEric Biggers KUNIT_ASSERT_GE(test, hintvec[omega], 2); 179*ed894facSEric Biggers h = hintvec[0]; 180*ed894facSEric Biggers hintvec[0] = hintvec[1]; 181*ed894facSEric Biggers hintvec[1] = h; 182*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 183*ed894facSEric Biggers mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 184*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 185*ed894facSEric Biggers 186*ed894facSEric Biggers /* 187*ed894facSEric Biggers * Extra hint indices given. For this test to work, the original valid 188*ed894facSEric Biggers * signature must have fewer than omega nonzero hints (asserted below). 189*ed894facSEric Biggers */ 190*ed894facSEric Biggers memcpy(sig, tv->sig, tv->sig_len); 191*ed894facSEric Biggers KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega); 192*ed894facSEric Biggers hintvec[omega - 1] = 0xff; 193*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 194*ed894facSEric Biggers mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 195*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 196*ed894facSEric Biggers } 197*ed894facSEric Biggers 198*ed894facSEric Biggers static void test_mldsa_mutation(struct kunit *test, 199*ed894facSEric Biggers const struct mldsa_testvector *tv) 200*ed894facSEric Biggers { 201*ed894facSEric Biggers const int sig_len = tv->sig_len; 202*ed894facSEric Biggers const int msg_len = tv->msg_len; 203*ed894facSEric Biggers const int pk_len = tv->pk_len; 204*ed894facSEric Biggers const int num_iter = 200; 205*ed894facSEric Biggers u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len); 206*ed894facSEric Biggers u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len); 207*ed894facSEric Biggers u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len); 208*ed894facSEric Biggers 209*ed894facSEric Biggers /* Initially the signature is valid. */ 210*ed894facSEric Biggers do_mldsa_and_assert_success(test, tv); 211*ed894facSEric Biggers 212*ed894facSEric Biggers /* Changing any bit in the signature should invalidate the signature */ 213*ed894facSEric Biggers for (int i = 0; i < num_iter; i++) { 214*ed894facSEric Biggers size_t pos = get_random_u32_below(sig_len); 215*ed894facSEric Biggers u8 b = 1 << get_random_u32_below(8); 216*ed894facSEric Biggers 217*ed894facSEric Biggers sig[pos] ^= b; 218*ed894facSEric Biggers KUNIT_ASSERT_NE(test, 0, 219*ed894facSEric Biggers mldsa_verify(tv->alg, sig, sig_len, msg, 220*ed894facSEric Biggers msg_len, pk, pk_len)); 221*ed894facSEric Biggers sig[pos] ^= b; 222*ed894facSEric Biggers } 223*ed894facSEric Biggers 224*ed894facSEric Biggers /* Changing any bit in the message should invalidate the signature */ 225*ed894facSEric Biggers for (int i = 0; i < num_iter; i++) { 226*ed894facSEric Biggers size_t pos = get_random_u32_below(msg_len); 227*ed894facSEric Biggers u8 b = 1 << get_random_u32_below(8); 228*ed894facSEric Biggers 229*ed894facSEric Biggers msg[pos] ^= b; 230*ed894facSEric Biggers KUNIT_ASSERT_NE(test, 0, 231*ed894facSEric Biggers mldsa_verify(tv->alg, sig, sig_len, msg, 232*ed894facSEric Biggers msg_len, pk, pk_len)); 233*ed894facSEric Biggers msg[pos] ^= b; 234*ed894facSEric Biggers } 235*ed894facSEric Biggers 236*ed894facSEric Biggers /* Changing any bit in the public key should invalidate the signature */ 237*ed894facSEric Biggers for (int i = 0; i < num_iter; i++) { 238*ed894facSEric Biggers size_t pos = get_random_u32_below(pk_len); 239*ed894facSEric Biggers u8 b = 1 << get_random_u32_below(8); 240*ed894facSEric Biggers 241*ed894facSEric Biggers pk[pos] ^= b; 242*ed894facSEric Biggers KUNIT_ASSERT_NE(test, 0, 243*ed894facSEric Biggers mldsa_verify(tv->alg, sig, sig_len, msg, 244*ed894facSEric Biggers msg_len, pk, pk_len)); 245*ed894facSEric Biggers pk[pos] ^= b; 246*ed894facSEric Biggers } 247*ed894facSEric Biggers 248*ed894facSEric Biggers /* All changes should have been undone. */ 249*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, 0, 250*ed894facSEric Biggers mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk, 251*ed894facSEric Biggers pk_len)); 252*ed894facSEric Biggers } 253*ed894facSEric Biggers 254*ed894facSEric Biggers static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv) 255*ed894facSEric Biggers { 256*ed894facSEric Biggers /* Valid signature */ 257*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len); 258*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len); 259*ed894facSEric Biggers do_mldsa_and_assert_success(test, tv); 260*ed894facSEric Biggers 261*ed894facSEric Biggers /* Signature too short */ 262*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 263*ed894facSEric Biggers mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg, 264*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 265*ed894facSEric Biggers 266*ed894facSEric Biggers /* Signature too long */ 267*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 268*ed894facSEric Biggers mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg, 269*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len)); 270*ed894facSEric Biggers 271*ed894facSEric Biggers /* Public key too short */ 272*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 273*ed894facSEric Biggers mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 274*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len - 1)); 275*ed894facSEric Biggers 276*ed894facSEric Biggers /* Public key too long */ 277*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EBADMSG, 278*ed894facSEric Biggers mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 279*ed894facSEric Biggers tv->msg_len, tv->pk, tv->pk_len + 1)); 280*ed894facSEric Biggers 281*ed894facSEric Biggers /* 282*ed894facSEric Biggers * Message too short. Error is EKEYREJECTED because it gets rejected by 283*ed894facSEric Biggers * the "real" signature check rather than the well-formedness checks. 284*ed894facSEric Biggers */ 285*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, -EKEYREJECTED, 286*ed894facSEric Biggers mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 287*ed894facSEric Biggers tv->msg_len - 1, tv->pk, tv->pk_len)); 288*ed894facSEric Biggers /* 289*ed894facSEric Biggers * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be 290*ed894facSEric Biggers * accessed out of bounds. However, ML-DSA just hashes the message and 291*ed894facSEric Biggers * doesn't handle different message lengths differently anyway. 292*ed894facSEric Biggers */ 293*ed894facSEric Biggers 294*ed894facSEric Biggers /* Test the validity checks on the z vector. */ 295*ed894facSEric Biggers test_mldsa_z_range(test, tv); 296*ed894facSEric Biggers 297*ed894facSEric Biggers /* Test the validity checks on the hint vector. */ 298*ed894facSEric Biggers test_mldsa_bad_hints(test, tv); 299*ed894facSEric Biggers 300*ed894facSEric Biggers /* Test randomly mutating the inputs. */ 301*ed894facSEric Biggers test_mldsa_mutation(test, tv); 302*ed894facSEric Biggers } 303*ed894facSEric Biggers 304*ed894facSEric Biggers static void test_mldsa44(struct kunit *test) 305*ed894facSEric Biggers { 306*ed894facSEric Biggers test_mldsa(test, &mldsa44_testvector); 307*ed894facSEric Biggers } 308*ed894facSEric Biggers 309*ed894facSEric Biggers static void test_mldsa65(struct kunit *test) 310*ed894facSEric Biggers { 311*ed894facSEric Biggers test_mldsa(test, &mldsa65_testvector); 312*ed894facSEric Biggers } 313*ed894facSEric Biggers 314*ed894facSEric Biggers static void test_mldsa87(struct kunit *test) 315*ed894facSEric Biggers { 316*ed894facSEric Biggers test_mldsa(test, &mldsa87_testvector); 317*ed894facSEric Biggers } 318*ed894facSEric Biggers 319*ed894facSEric Biggers static s32 mod(s32 a, s32 m) 320*ed894facSEric Biggers { 321*ed894facSEric Biggers a %= m; 322*ed894facSEric Biggers if (a < 0) 323*ed894facSEric Biggers a += m; 324*ed894facSEric Biggers return a; 325*ed894facSEric Biggers } 326*ed894facSEric Biggers 327*ed894facSEric Biggers static s32 symmetric_mod(s32 a, s32 m) 328*ed894facSEric Biggers { 329*ed894facSEric Biggers a = mod(a, m); 330*ed894facSEric Biggers if (a > m / 2) 331*ed894facSEric Biggers a -= m; 332*ed894facSEric Biggers return a; 333*ed894facSEric Biggers } 334*ed894facSEric Biggers 335*ed894facSEric Biggers /* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */ 336*ed894facSEric Biggers static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1) 337*ed894facSEric Biggers { 338*ed894facSEric Biggers s32 rplus = mod(r, Q); 339*ed894facSEric Biggers 340*ed894facSEric Biggers *r0 = symmetric_mod(rplus, 2 * gamma2); 341*ed894facSEric Biggers if (rplus - *r0 == Q - 1) { 342*ed894facSEric Biggers *r1 = 0; 343*ed894facSEric Biggers *r0 = *r0 - 1; 344*ed894facSEric Biggers } else { 345*ed894facSEric Biggers *r1 = (rplus - *r0) / (2 * gamma2); 346*ed894facSEric Biggers } 347*ed894facSEric Biggers } 348*ed894facSEric Biggers 349*ed894facSEric Biggers /* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */ 350*ed894facSEric Biggers static s32 use_hint_ref(u8 h, s32 r, s32 gamma2) 351*ed894facSEric Biggers { 352*ed894facSEric Biggers s32 m = (Q - 1) / (2 * gamma2); 353*ed894facSEric Biggers s32 r0, r1; 354*ed894facSEric Biggers 355*ed894facSEric Biggers decompose_ref(r, gamma2, &r0, &r1); 356*ed894facSEric Biggers if (h == 1 && r0 > 0) 357*ed894facSEric Biggers return mod(r1 + 1, m); 358*ed894facSEric Biggers if (h == 1 && r0 <= 0) 359*ed894facSEric Biggers return mod(r1 - 1, m); 360*ed894facSEric Biggers return r1; 361*ed894facSEric Biggers } 362*ed894facSEric Biggers 363*ed894facSEric Biggers /* 364*ed894facSEric Biggers * Test that for all possible inputs, mldsa_use_hint() gives the same output as 365*ed894facSEric Biggers * a mechanical translation of the pseudocode from FIPS 204. 366*ed894facSEric Biggers */ 367*ed894facSEric Biggers static void test_mldsa_use_hint(struct kunit *test) 368*ed894facSEric Biggers { 369*ed894facSEric Biggers for (int i = 0; i < 2; i++) { 370*ed894facSEric Biggers const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32); 371*ed894facSEric Biggers 372*ed894facSEric Biggers for (u8 h = 0; h < 2; h++) { 373*ed894facSEric Biggers for (s32 r = 0; r < Q; r++) { 374*ed894facSEric Biggers KUNIT_ASSERT_EQ(test, 375*ed894facSEric Biggers mldsa_use_hint(h, r, gamma2), 376*ed894facSEric Biggers use_hint_ref(h, r, gamma2)); 377*ed894facSEric Biggers } 378*ed894facSEric Biggers } 379*ed894facSEric Biggers } 380*ed894facSEric Biggers } 381*ed894facSEric Biggers 382*ed894facSEric Biggers static void benchmark_mldsa(struct kunit *test, 383*ed894facSEric Biggers const struct mldsa_testvector *tv) 384*ed894facSEric Biggers { 385*ed894facSEric Biggers const int warmup_niter = 200; 386*ed894facSEric Biggers const int benchmark_niter = 200; 387*ed894facSEric Biggers u64 t0, t1; 388*ed894facSEric Biggers 389*ed894facSEric Biggers if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK)) 390*ed894facSEric Biggers kunit_skip(test, "not enabled"); 391*ed894facSEric Biggers 392*ed894facSEric Biggers for (int i = 0; i < warmup_niter; i++) 393*ed894facSEric Biggers do_mldsa_and_assert_success(test, tv); 394*ed894facSEric Biggers 395*ed894facSEric Biggers t0 = ktime_get_ns(); 396*ed894facSEric Biggers for (int i = 0; i < benchmark_niter; i++) 397*ed894facSEric Biggers do_mldsa_and_assert_success(test, tv); 398*ed894facSEric Biggers t1 = ktime_get_ns(); 399*ed894facSEric Biggers kunit_info(test, "%llu ops/s", 400*ed894facSEric Biggers div64_u64((u64)benchmark_niter * NSEC_PER_SEC, 401*ed894facSEric Biggers t1 - t0 ?: 1)); 402*ed894facSEric Biggers } 403*ed894facSEric Biggers 404*ed894facSEric Biggers static void benchmark_mldsa44(struct kunit *test) 405*ed894facSEric Biggers { 406*ed894facSEric Biggers benchmark_mldsa(test, &mldsa44_testvector); 407*ed894facSEric Biggers } 408*ed894facSEric Biggers 409*ed894facSEric Biggers static void benchmark_mldsa65(struct kunit *test) 410*ed894facSEric Biggers { 411*ed894facSEric Biggers benchmark_mldsa(test, &mldsa65_testvector); 412*ed894facSEric Biggers } 413*ed894facSEric Biggers 414*ed894facSEric Biggers static void benchmark_mldsa87(struct kunit *test) 415*ed894facSEric Biggers { 416*ed894facSEric Biggers benchmark_mldsa(test, &mldsa87_testvector); 417*ed894facSEric Biggers } 418*ed894facSEric Biggers 419*ed894facSEric Biggers static struct kunit_case mldsa_kunit_cases[] = { 420*ed894facSEric Biggers KUNIT_CASE(test_mldsa44), 421*ed894facSEric Biggers KUNIT_CASE(test_mldsa65), 422*ed894facSEric Biggers KUNIT_CASE(test_mldsa87), 423*ed894facSEric Biggers KUNIT_CASE(test_mldsa_use_hint), 424*ed894facSEric Biggers KUNIT_CASE(benchmark_mldsa44), 425*ed894facSEric Biggers KUNIT_CASE(benchmark_mldsa65), 426*ed894facSEric Biggers KUNIT_CASE(benchmark_mldsa87), 427*ed894facSEric Biggers {}, 428*ed894facSEric Biggers }; 429*ed894facSEric Biggers 430*ed894facSEric Biggers static struct kunit_suite mldsa_kunit_suite = { 431*ed894facSEric Biggers .name = "mldsa", 432*ed894facSEric Biggers .test_cases = mldsa_kunit_cases, 433*ed894facSEric Biggers }; 434*ed894facSEric Biggers kunit_test_suite(mldsa_kunit_suite); 435*ed894facSEric Biggers 436*ed894facSEric Biggers MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA"); 437*ed894facSEric Biggers MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING"); 438*ed894facSEric Biggers MODULE_LICENSE("GPL"); 439