1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * KUnit tests and benchmark for ML-DSA 4 * 5 * Copyright 2025 Google LLC 6 */ 7 #include <crypto/mldsa.h> 8 #include <kunit/test.h> 9 #include <linux/random.h> 10 #include <linux/unaligned.h> 11 12 #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */ 13 14 /* ML-DSA parameters that the tests use */ 15 static const struct { 16 int sig_len; 17 int pk_len; 18 int k; 19 int lambda; 20 int gamma1; 21 int beta; 22 int omega; 23 } params[] = { 24 [MLDSA44] = { 25 .sig_len = MLDSA44_SIGNATURE_SIZE, 26 .pk_len = MLDSA44_PUBLIC_KEY_SIZE, 27 .k = 4, 28 .lambda = 128, 29 .gamma1 = 1 << 17, 30 .beta = 78, 31 .omega = 80, 32 }, 33 [MLDSA65] = { 34 .sig_len = MLDSA65_SIGNATURE_SIZE, 35 .pk_len = MLDSA65_PUBLIC_KEY_SIZE, 36 .k = 6, 37 .lambda = 192, 38 .gamma1 = 1 << 19, 39 .beta = 196, 40 .omega = 55, 41 }, 42 [MLDSA87] = { 43 .sig_len = MLDSA87_SIGNATURE_SIZE, 44 .pk_len = MLDSA87_PUBLIC_KEY_SIZE, 45 .k = 8, 46 .lambda = 256, 47 .gamma1 = 1 << 19, 48 .beta = 120, 49 .omega = 75, 50 }, 51 }; 52 53 #include "mldsa-testvecs.h" 54 55 static void do_mldsa_and_assert_success(struct kunit *test, 56 const struct mldsa_testvector *tv) 57 { 58 int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 59 tv->msg_len, tv->pk, tv->pk_len); 60 KUNIT_ASSERT_EQ(test, err, 0); 61 } 62 63 static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len) 64 { 65 u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL); 66 67 KUNIT_ASSERT_NOT_NULL(test, dst); 68 return memcpy(dst, src, len); 69 } 70 71 /* 72 * Test that changing coefficients in a valid signature's z vector results in 73 * the following behavior from mldsa_verify(): 74 * 75 * * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e. 76 * absolute value >= gamma1 - beta, corresponding to the verifier detecting 77 * the out-of-range coefficient and rejecting the signature as malformed 78 * 79 * * -EKEYREJECTED if a coefficient is changed to a different in-range value, 80 * i.e. absolute value < gamma1 - beta, corresponding to the verifier 81 * continuing to the "real" signature check and that check failing 82 */ 83 static void test_mldsa_z_range(struct kunit *test, 84 const struct mldsa_testvector *tv) 85 { 86 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len); 87 const int lambda = params[tv->alg].lambda; 88 const s32 gamma1 = params[tv->alg].gamma1; 89 const int beta = params[tv->alg].beta; 90 /* 91 * We just modify the first coefficient. The coefficient is gamma1 92 * minus either the first 18 or 20 bits of the u32, depending on gamma1. 93 * 94 * The layout of ML-DSA signatures is ctilde || z || h. ctilde is 95 * lambda / 4 bytes, so z starts at &sig[lambda / 4]. 96 */ 97 u8 *z_ptr = &sig[lambda / 4]; 98 const u32 z_data = get_unaligned_le32(z_ptr); 99 const u32 mask = (gamma1 << 1) - 1; 100 /* These are the four boundaries of the out-of-range values. */ 101 const s32 out_of_range_coeffs[] = { 102 -gamma1 + 1, 103 -(gamma1 - beta), 104 gamma1, 105 gamma1 - beta, 106 }; 107 /* 108 * These are the two boundaries of the valid range, along with 0. We 109 * assume that none of these matches the original coefficient. 110 */ 111 const s32 in_range_coeffs[] = { 112 -(gamma1 - beta - 1), 113 0, 114 gamma1 - beta - 1, 115 }; 116 117 /* Initially the signature is valid. */ 118 do_mldsa_and_assert_success(test, tv); 119 120 /* Test some out-of-range coefficients. */ 121 for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) { 122 const s32 c = out_of_range_coeffs[i]; 123 124 put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)), 125 z_ptr); 126 KUNIT_ASSERT_EQ(test, -EBADMSG, 127 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 128 tv->msg_len, tv->pk, tv->pk_len)); 129 } 130 131 /* Test some in-range coefficients. */ 132 for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) { 133 const s32 c = in_range_coeffs[i]; 134 135 put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)), 136 z_ptr); 137 KUNIT_ASSERT_EQ(test, -EKEYREJECTED, 138 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 139 tv->msg_len, tv->pk, tv->pk_len)); 140 } 141 } 142 143 /* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */ 144 static void test_mldsa_bad_hints(struct kunit *test, 145 const struct mldsa_testvector *tv) 146 { 147 const int omega = params[tv->alg].omega; 148 const int k = params[tv->alg].k; 149 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len); 150 /* Pointer to the encoded hint vector in the signature */ 151 u8 *hintvec = &sig[tv->sig_len - omega - k]; 152 u8 h; 153 154 /* Initially the signature is valid. */ 155 do_mldsa_and_assert_success(test, tv); 156 157 /* Cumulative hint count exceeds omega */ 158 memcpy(sig, tv->sig, tv->sig_len); 159 hintvec[omega + k - 1] = omega + 1; 160 KUNIT_ASSERT_EQ(test, -EBADMSG, 161 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 162 tv->msg_len, tv->pk, tv->pk_len)); 163 164 /* Cumulative hint count decreases */ 165 memcpy(sig, tv->sig, tv->sig_len); 166 KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1); 167 hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1; 168 KUNIT_ASSERT_EQ(test, -EBADMSG, 169 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 170 tv->msg_len, tv->pk, tv->pk_len)); 171 172 /* 173 * Hint indices out of order. To test this, swap hintvec[0] and 174 * hintvec[1]. This assumes that the original valid signature had at 175 * least two nonzero hints in the first element (asserted below). 176 */ 177 memcpy(sig, tv->sig, tv->sig_len); 178 KUNIT_ASSERT_GE(test, hintvec[omega], 2); 179 h = hintvec[0]; 180 hintvec[0] = hintvec[1]; 181 hintvec[1] = h; 182 KUNIT_ASSERT_EQ(test, -EBADMSG, 183 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 184 tv->msg_len, tv->pk, tv->pk_len)); 185 186 /* 187 * Extra hint indices given. For this test to work, the original valid 188 * signature must have fewer than omega nonzero hints (asserted below). 189 */ 190 memcpy(sig, tv->sig, tv->sig_len); 191 KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega); 192 hintvec[omega - 1] = 0xff; 193 KUNIT_ASSERT_EQ(test, -EBADMSG, 194 mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg, 195 tv->msg_len, tv->pk, tv->pk_len)); 196 } 197 198 static void test_mldsa_mutation(struct kunit *test, 199 const struct mldsa_testvector *tv) 200 { 201 const int sig_len = tv->sig_len; 202 const int msg_len = tv->msg_len; 203 const int pk_len = tv->pk_len; 204 const int num_iter = 200; 205 u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len); 206 u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len); 207 u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len); 208 209 /* Initially the signature is valid. */ 210 do_mldsa_and_assert_success(test, tv); 211 212 /* Changing any bit in the signature should invalidate the signature */ 213 for (int i = 0; i < num_iter; i++) { 214 size_t pos = get_random_u32_below(sig_len); 215 u8 b = 1 << get_random_u32_below(8); 216 217 sig[pos] ^= b; 218 KUNIT_ASSERT_NE(test, 0, 219 mldsa_verify(tv->alg, sig, sig_len, msg, 220 msg_len, pk, pk_len)); 221 sig[pos] ^= b; 222 } 223 224 /* Changing any bit in the message should invalidate the signature */ 225 for (int i = 0; i < num_iter; i++) { 226 size_t pos = get_random_u32_below(msg_len); 227 u8 b = 1 << get_random_u32_below(8); 228 229 msg[pos] ^= b; 230 KUNIT_ASSERT_NE(test, 0, 231 mldsa_verify(tv->alg, sig, sig_len, msg, 232 msg_len, pk, pk_len)); 233 msg[pos] ^= b; 234 } 235 236 /* Changing any bit in the public key should invalidate the signature */ 237 for (int i = 0; i < num_iter; i++) { 238 size_t pos = get_random_u32_below(pk_len); 239 u8 b = 1 << get_random_u32_below(8); 240 241 pk[pos] ^= b; 242 KUNIT_ASSERT_NE(test, 0, 243 mldsa_verify(tv->alg, sig, sig_len, msg, 244 msg_len, pk, pk_len)); 245 pk[pos] ^= b; 246 } 247 248 /* All changes should have been undone. */ 249 KUNIT_ASSERT_EQ(test, 0, 250 mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk, 251 pk_len)); 252 } 253 254 static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv) 255 { 256 /* Valid signature */ 257 KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len); 258 KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len); 259 do_mldsa_and_assert_success(test, tv); 260 261 /* Signature too short */ 262 KUNIT_ASSERT_EQ(test, -EBADMSG, 263 mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg, 264 tv->msg_len, tv->pk, tv->pk_len)); 265 266 /* Signature too long */ 267 KUNIT_ASSERT_EQ(test, -EBADMSG, 268 mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg, 269 tv->msg_len, tv->pk, tv->pk_len)); 270 271 /* Public key too short */ 272 KUNIT_ASSERT_EQ(test, -EBADMSG, 273 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 274 tv->msg_len, tv->pk, tv->pk_len - 1)); 275 276 /* Public key too long */ 277 KUNIT_ASSERT_EQ(test, -EBADMSG, 278 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 279 tv->msg_len, tv->pk, tv->pk_len + 1)); 280 281 /* 282 * Message too short. Error is EKEYREJECTED because it gets rejected by 283 * the "real" signature check rather than the well-formedness checks. 284 */ 285 KUNIT_ASSERT_EQ(test, -EKEYREJECTED, 286 mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg, 287 tv->msg_len - 1, tv->pk, tv->pk_len)); 288 /* 289 * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be 290 * accessed out of bounds. However, ML-DSA just hashes the message and 291 * doesn't handle different message lengths differently anyway. 292 */ 293 294 /* Test the validity checks on the z vector. */ 295 test_mldsa_z_range(test, tv); 296 297 /* Test the validity checks on the hint vector. */ 298 test_mldsa_bad_hints(test, tv); 299 300 /* Test randomly mutating the inputs. */ 301 test_mldsa_mutation(test, tv); 302 } 303 304 static void test_mldsa44(struct kunit *test) 305 { 306 test_mldsa(test, &mldsa44_testvector); 307 } 308 309 static void test_mldsa65(struct kunit *test) 310 { 311 test_mldsa(test, &mldsa65_testvector); 312 } 313 314 static void test_mldsa87(struct kunit *test) 315 { 316 test_mldsa(test, &mldsa87_testvector); 317 } 318 319 static s32 mod(s32 a, s32 m) 320 { 321 a %= m; 322 if (a < 0) 323 a += m; 324 return a; 325 } 326 327 static s32 symmetric_mod(s32 a, s32 m) 328 { 329 a = mod(a, m); 330 if (a > m / 2) 331 a -= m; 332 return a; 333 } 334 335 /* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */ 336 static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1) 337 { 338 s32 rplus = mod(r, Q); 339 340 *r0 = symmetric_mod(rplus, 2 * gamma2); 341 if (rplus - *r0 == Q - 1) { 342 *r1 = 0; 343 *r0 = *r0 - 1; 344 } else { 345 *r1 = (rplus - *r0) / (2 * gamma2); 346 } 347 } 348 349 /* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */ 350 static s32 use_hint_ref(u8 h, s32 r, s32 gamma2) 351 { 352 s32 m = (Q - 1) / (2 * gamma2); 353 s32 r0, r1; 354 355 decompose_ref(r, gamma2, &r0, &r1); 356 if (h == 1 && r0 > 0) 357 return mod(r1 + 1, m); 358 if (h == 1 && r0 <= 0) 359 return mod(r1 - 1, m); 360 return r1; 361 } 362 363 /* 364 * Test that for all possible inputs, mldsa_use_hint() gives the same output as 365 * a mechanical translation of the pseudocode from FIPS 204. 366 */ 367 static void test_mldsa_use_hint(struct kunit *test) 368 { 369 for (int i = 0; i < 2; i++) { 370 const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32); 371 372 for (u8 h = 0; h < 2; h++) { 373 for (s32 r = 0; r < Q; r++) { 374 KUNIT_ASSERT_EQ(test, 375 mldsa_use_hint(h, r, gamma2), 376 use_hint_ref(h, r, gamma2)); 377 } 378 } 379 } 380 } 381 382 static void benchmark_mldsa(struct kunit *test, 383 const struct mldsa_testvector *tv) 384 { 385 const int warmup_niter = 200; 386 const int benchmark_niter = 200; 387 u64 t0, t1; 388 389 if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK)) 390 kunit_skip(test, "not enabled"); 391 392 for (int i = 0; i < warmup_niter; i++) 393 do_mldsa_and_assert_success(test, tv); 394 395 t0 = ktime_get_ns(); 396 for (int i = 0; i < benchmark_niter; i++) 397 do_mldsa_and_assert_success(test, tv); 398 t1 = ktime_get_ns(); 399 kunit_info(test, "%llu ops/s", 400 div64_u64((u64)benchmark_niter * NSEC_PER_SEC, 401 t1 - t0 ?: 1)); 402 } 403 404 static void benchmark_mldsa44(struct kunit *test) 405 { 406 benchmark_mldsa(test, &mldsa44_testvector); 407 } 408 409 static void benchmark_mldsa65(struct kunit *test) 410 { 411 benchmark_mldsa(test, &mldsa65_testvector); 412 } 413 414 static void benchmark_mldsa87(struct kunit *test) 415 { 416 benchmark_mldsa(test, &mldsa87_testvector); 417 } 418 419 static struct kunit_case mldsa_kunit_cases[] = { 420 KUNIT_CASE(test_mldsa44), 421 KUNIT_CASE(test_mldsa65), 422 KUNIT_CASE(test_mldsa87), 423 KUNIT_CASE(test_mldsa_use_hint), 424 KUNIT_CASE(benchmark_mldsa44), 425 KUNIT_CASE(benchmark_mldsa65), 426 KUNIT_CASE(benchmark_mldsa87), 427 {}, 428 }; 429 430 static struct kunit_suite mldsa_kunit_suite = { 431 .name = "mldsa", 432 .test_cases = mldsa_kunit_cases, 433 }; 434 kunit_test_suite(mldsa_kunit_suite); 435 436 MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA"); 437 MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING"); 438 MODULE_LICENSE("GPL"); 439