xref: /linux/lib/crypto/tests/mldsa_kunit.c (revision 13d83ea9d81ddcb08b46377dcc9de6e5df1248d1)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * KUnit tests and benchmark for ML-DSA
4  *
5  * Copyright 2025 Google LLC
6  */
7 #include <crypto/mldsa.h>
8 #include <kunit/test.h>
9 #include <linux/random.h>
10 #include <linux/unaligned.h>
11 
12 #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
13 
14 /* ML-DSA parameters that the tests use */
15 static const struct {
16 	int sig_len;
17 	int pk_len;
18 	int k;
19 	int lambda;
20 	int gamma1;
21 	int beta;
22 	int omega;
23 } params[] = {
24 	[MLDSA44] = {
25 		.sig_len = MLDSA44_SIGNATURE_SIZE,
26 		.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
27 		.k = 4,
28 		.lambda = 128,
29 		.gamma1 = 1 << 17,
30 		.beta = 78,
31 		.omega = 80,
32 	},
33 	[MLDSA65] = {
34 		.sig_len = MLDSA65_SIGNATURE_SIZE,
35 		.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
36 		.k = 6,
37 		.lambda = 192,
38 		.gamma1 = 1 << 19,
39 		.beta = 196,
40 		.omega = 55,
41 	},
42 	[MLDSA87] = {
43 		.sig_len = MLDSA87_SIGNATURE_SIZE,
44 		.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
45 		.k = 8,
46 		.lambda = 256,
47 		.gamma1 = 1 << 19,
48 		.beta = 120,
49 		.omega = 75,
50 	},
51 };
52 
53 #include "mldsa-testvecs.h"
54 
55 static void do_mldsa_and_assert_success(struct kunit *test,
56 					const struct mldsa_testvector *tv)
57 {
58 	int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
59 			       tv->msg_len, tv->pk, tv->pk_len);
60 	KUNIT_ASSERT_EQ(test, err, 0);
61 }
62 
63 static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
64 {
65 	u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);
66 
67 	KUNIT_ASSERT_NOT_NULL(test, dst);
68 	return memcpy(dst, src, len);
69 }
70 
71 /*
72  * Test that changing coefficients in a valid signature's z vector results in
73  * the following behavior from mldsa_verify():
74  *
75  *  * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e.
76  *    absolute value >= gamma1 - beta, corresponding to the verifier detecting
77  *    the out-of-range coefficient and rejecting the signature as malformed
78  *
79  *  * -EKEYREJECTED if a coefficient is changed to a different in-range value,
80  *    i.e. absolute value < gamma1 - beta, corresponding to the verifier
81  *    continuing to the "real" signature check and that check failing
82  */
83 static void test_mldsa_z_range(struct kunit *test,
84 			       const struct mldsa_testvector *tv)
85 {
86 	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
87 	const int lambda = params[tv->alg].lambda;
88 	const s32 gamma1 = params[tv->alg].gamma1;
89 	const int beta = params[tv->alg].beta;
90 	/*
91 	 * We just modify the first coefficient.  The coefficient is gamma1
92 	 * minus either the first 18 or 20 bits of the u32, depending on gamma1.
93 	 *
94 	 * The layout of ML-DSA signatures is ctilde || z || h.  ctilde is
95 	 * lambda / 4 bytes, so z starts at &sig[lambda / 4].
96 	 */
97 	u8 *z_ptr = &sig[lambda / 4];
98 	const u32 z_data = get_unaligned_le32(z_ptr);
99 	const u32 mask = (gamma1 << 1) - 1;
100 	/* These are the four boundaries of the out-of-range values. */
101 	const s32 out_of_range_coeffs[] = {
102 		-gamma1 + 1,
103 		-(gamma1 - beta),
104 		gamma1,
105 		gamma1 - beta,
106 	};
107 	/*
108 	 * These are the two boundaries of the valid range, along with 0.  We
109 	 * assume that none of these matches the original coefficient.
110 	 */
111 	const s32 in_range_coeffs[] = {
112 		-(gamma1 - beta - 1),
113 		0,
114 		gamma1 - beta - 1,
115 	};
116 
117 	/* Initially the signature is valid. */
118 	do_mldsa_and_assert_success(test, tv);
119 
120 	/* Test some out-of-range coefficients. */
121 	for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
122 		const s32 c = out_of_range_coeffs[i];
123 
124 		put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
125 				   z_ptr);
126 		KUNIT_ASSERT_EQ(test, -EBADMSG,
127 				mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
128 					     tv->msg_len, tv->pk, tv->pk_len));
129 	}
130 
131 	/* Test some in-range coefficients. */
132 	for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
133 		const s32 c = in_range_coeffs[i];
134 
135 		put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
136 				   z_ptr);
137 		KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
138 				mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
139 					     tv->msg_len, tv->pk, tv->pk_len));
140 	}
141 }
142 
143 /* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */
144 static void test_mldsa_bad_hints(struct kunit *test,
145 				 const struct mldsa_testvector *tv)
146 {
147 	const int omega = params[tv->alg].omega;
148 	const int k = params[tv->alg].k;
149 	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
150 	/* Pointer to the encoded hint vector in the signature */
151 	u8 *hintvec = &sig[tv->sig_len - omega - k];
152 	u8 h;
153 
154 	/* Initially the signature is valid. */
155 	do_mldsa_and_assert_success(test, tv);
156 
157 	/* Cumulative hint count exceeds omega */
158 	memcpy(sig, tv->sig, tv->sig_len);
159 	hintvec[omega + k - 1] = omega + 1;
160 	KUNIT_ASSERT_EQ(test, -EBADMSG,
161 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
162 				     tv->msg_len, tv->pk, tv->pk_len));
163 
164 	/* Cumulative hint count decreases */
165 	memcpy(sig, tv->sig, tv->sig_len);
166 	KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
167 	hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
168 	KUNIT_ASSERT_EQ(test, -EBADMSG,
169 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
170 				     tv->msg_len, tv->pk, tv->pk_len));
171 
172 	/*
173 	 * Hint indices out of order.  To test this, swap hintvec[0] and
174 	 * hintvec[1].  This assumes that the original valid signature had at
175 	 * least two nonzero hints in the first element (asserted below).
176 	 */
177 	memcpy(sig, tv->sig, tv->sig_len);
178 	KUNIT_ASSERT_GE(test, hintvec[omega], 2);
179 	h = hintvec[0];
180 	hintvec[0] = hintvec[1];
181 	hintvec[1] = h;
182 	KUNIT_ASSERT_EQ(test, -EBADMSG,
183 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
184 				     tv->msg_len, tv->pk, tv->pk_len));
185 
186 	/*
187 	 * Extra hint indices given.  For this test to work, the original valid
188 	 * signature must have fewer than omega nonzero hints (asserted below).
189 	 */
190 	memcpy(sig, tv->sig, tv->sig_len);
191 	KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
192 	hintvec[omega - 1] = 0xff;
193 	KUNIT_ASSERT_EQ(test, -EBADMSG,
194 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
195 				     tv->msg_len, tv->pk, tv->pk_len));
196 }
197 
198 static void test_mldsa_mutation(struct kunit *test,
199 				const struct mldsa_testvector *tv)
200 {
201 	const int sig_len = tv->sig_len;
202 	const int msg_len = tv->msg_len;
203 	const int pk_len = tv->pk_len;
204 	const int num_iter = 200;
205 	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
206 	u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
207 	u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);
208 
209 	/* Initially the signature is valid. */
210 	do_mldsa_and_assert_success(test, tv);
211 
212 	/* Changing any bit in the signature should invalidate the signature */
213 	for (int i = 0; i < num_iter; i++) {
214 		size_t pos = get_random_u32_below(sig_len);
215 		u8 b = 1 << get_random_u32_below(8);
216 
217 		sig[pos] ^= b;
218 		KUNIT_ASSERT_NE(test, 0,
219 				mldsa_verify(tv->alg, sig, sig_len, msg,
220 					     msg_len, pk, pk_len));
221 		sig[pos] ^= b;
222 	}
223 
224 	/* Changing any bit in the message should invalidate the signature */
225 	for (int i = 0; i < num_iter; i++) {
226 		size_t pos = get_random_u32_below(msg_len);
227 		u8 b = 1 << get_random_u32_below(8);
228 
229 		msg[pos] ^= b;
230 		KUNIT_ASSERT_NE(test, 0,
231 				mldsa_verify(tv->alg, sig, sig_len, msg,
232 					     msg_len, pk, pk_len));
233 		msg[pos] ^= b;
234 	}
235 
236 	/* Changing any bit in the public key should invalidate the signature */
237 	for (int i = 0; i < num_iter; i++) {
238 		size_t pos = get_random_u32_below(pk_len);
239 		u8 b = 1 << get_random_u32_below(8);
240 
241 		pk[pos] ^= b;
242 		KUNIT_ASSERT_NE(test, 0,
243 				mldsa_verify(tv->alg, sig, sig_len, msg,
244 					     msg_len, pk, pk_len));
245 		pk[pos] ^= b;
246 	}
247 
248 	/* All changes should have been undone. */
249 	KUNIT_ASSERT_EQ(test, 0,
250 			mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
251 				     pk_len));
252 }
253 
254 static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
255 {
256 	/* Valid signature */
257 	KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
258 	KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
259 	do_mldsa_and_assert_success(test, tv);
260 
261 	/* Signature too short */
262 	KUNIT_ASSERT_EQ(test, -EBADMSG,
263 			mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
264 				     tv->msg_len, tv->pk, tv->pk_len));
265 
266 	/* Signature too long */
267 	KUNIT_ASSERT_EQ(test, -EBADMSG,
268 			mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
269 				     tv->msg_len, tv->pk, tv->pk_len));
270 
271 	/* Public key too short */
272 	KUNIT_ASSERT_EQ(test, -EBADMSG,
273 			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
274 				     tv->msg_len, tv->pk, tv->pk_len - 1));
275 
276 	/* Public key too long */
277 	KUNIT_ASSERT_EQ(test, -EBADMSG,
278 			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
279 				     tv->msg_len, tv->pk, tv->pk_len + 1));
280 
281 	/*
282 	 * Message too short.  Error is EKEYREJECTED because it gets rejected by
283 	 * the "real" signature check rather than the well-formedness checks.
284 	 */
285 	KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
286 			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
287 				     tv->msg_len - 1, tv->pk, tv->pk_len));
288 	/*
289 	 * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be
290 	 * accessed out of bounds.  However, ML-DSA just hashes the message and
291 	 * doesn't handle different message lengths differently anyway.
292 	 */
293 
294 	/* Test the validity checks on the z vector. */
295 	test_mldsa_z_range(test, tv);
296 
297 	/* Test the validity checks on the hint vector. */
298 	test_mldsa_bad_hints(test, tv);
299 
300 	/* Test randomly mutating the inputs. */
301 	test_mldsa_mutation(test, tv);
302 }
303 
304 static void test_mldsa44(struct kunit *test)
305 {
306 	test_mldsa(test, &mldsa44_testvector);
307 }
308 
309 static void test_mldsa65(struct kunit *test)
310 {
311 	test_mldsa(test, &mldsa65_testvector);
312 }
313 
314 static void test_mldsa87(struct kunit *test)
315 {
316 	test_mldsa(test, &mldsa87_testvector);
317 }
318 
319 static s32 mod(s32 a, s32 m)
320 {
321 	a %= m;
322 	if (a < 0)
323 		a += m;
324 	return a;
325 }
326 
327 static s32 symmetric_mod(s32 a, s32 m)
328 {
329 	a = mod(a, m);
330 	if (a > m / 2)
331 		a -= m;
332 	return a;
333 }
334 
335 /* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */
336 static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
337 {
338 	s32 rplus = mod(r, Q);
339 
340 	*r0 = symmetric_mod(rplus, 2 * gamma2);
341 	if (rplus - *r0 == Q - 1) {
342 		*r1 = 0;
343 		*r0 = *r0 - 1;
344 	} else {
345 		*r1 = (rplus - *r0) / (2 * gamma2);
346 	}
347 }
348 
349 /* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */
350 static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
351 {
352 	s32 m = (Q - 1) / (2 * gamma2);
353 	s32 r0, r1;
354 
355 	decompose_ref(r, gamma2, &r0, &r1);
356 	if (h == 1 && r0 > 0)
357 		return mod(r1 + 1, m);
358 	if (h == 1 && r0 <= 0)
359 		return mod(r1 - 1, m);
360 	return r1;
361 }
362 
363 /*
364  * Test that for all possible inputs, mldsa_use_hint() gives the same output as
365  * a mechanical translation of the pseudocode from FIPS 204.
366  */
367 static void test_mldsa_use_hint(struct kunit *test)
368 {
369 	for (int i = 0; i < 2; i++) {
370 		const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);
371 
372 		for (u8 h = 0; h < 2; h++) {
373 			for (s32 r = 0; r < Q; r++) {
374 				KUNIT_ASSERT_EQ(test,
375 						mldsa_use_hint(h, r, gamma2),
376 						use_hint_ref(h, r, gamma2));
377 			}
378 		}
379 	}
380 }
381 
382 static void benchmark_mldsa(struct kunit *test,
383 			    const struct mldsa_testvector *tv)
384 {
385 	const int warmup_niter = 200;
386 	const int benchmark_niter = 200;
387 	u64 t0, t1;
388 
389 	if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
390 		kunit_skip(test, "not enabled");
391 
392 	for (int i = 0; i < warmup_niter; i++)
393 		do_mldsa_and_assert_success(test, tv);
394 
395 	t0 = ktime_get_ns();
396 	for (int i = 0; i < benchmark_niter; i++)
397 		do_mldsa_and_assert_success(test, tv);
398 	t1 = ktime_get_ns();
399 	kunit_info(test, "%llu ops/s",
400 		   div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
401 			     t1 - t0 ?: 1));
402 }
403 
404 static void benchmark_mldsa44(struct kunit *test)
405 {
406 	benchmark_mldsa(test, &mldsa44_testvector);
407 }
408 
409 static void benchmark_mldsa65(struct kunit *test)
410 {
411 	benchmark_mldsa(test, &mldsa65_testvector);
412 }
413 
414 static void benchmark_mldsa87(struct kunit *test)
415 {
416 	benchmark_mldsa(test, &mldsa87_testvector);
417 }
418 
419 static struct kunit_case mldsa_kunit_cases[] = {
420 	KUNIT_CASE(test_mldsa44),
421 	KUNIT_CASE(test_mldsa65),
422 	KUNIT_CASE(test_mldsa87),
423 	KUNIT_CASE(test_mldsa_use_hint),
424 	KUNIT_CASE(benchmark_mldsa44),
425 	KUNIT_CASE(benchmark_mldsa65),
426 	KUNIT_CASE(benchmark_mldsa87),
427 	{},
428 };
429 
430 static struct kunit_suite mldsa_kunit_suite = {
431 	.name = "mldsa",
432 	.test_cases = mldsa_kunit_cases,
433 };
434 kunit_test_suite(mldsa_kunit_suite);
435 
436 MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
437 MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
438 MODULE_LICENSE("GPL");
439