xref: /linux/lib/crypto/tests/mldsa_kunit.c (revision 13d83ea9d81ddcb08b46377dcc9de6e5df1248d1)
1*ed894facSEric Biggers // SPDX-License-Identifier: GPL-2.0-or-later
2*ed894facSEric Biggers /*
3*ed894facSEric Biggers  * KUnit tests and benchmark for ML-DSA
4*ed894facSEric Biggers  *
5*ed894facSEric Biggers  * Copyright 2025 Google LLC
6*ed894facSEric Biggers  */
7*ed894facSEric Biggers #include <crypto/mldsa.h>
8*ed894facSEric Biggers #include <kunit/test.h>
9*ed894facSEric Biggers #include <linux/random.h>
10*ed894facSEric Biggers #include <linux/unaligned.h>
11*ed894facSEric Biggers 
12*ed894facSEric Biggers #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
13*ed894facSEric Biggers 
14*ed894facSEric Biggers /* ML-DSA parameters that the tests use */
15*ed894facSEric Biggers static const struct {
16*ed894facSEric Biggers 	int sig_len;
17*ed894facSEric Biggers 	int pk_len;
18*ed894facSEric Biggers 	int k;
19*ed894facSEric Biggers 	int lambda;
20*ed894facSEric Biggers 	int gamma1;
21*ed894facSEric Biggers 	int beta;
22*ed894facSEric Biggers 	int omega;
23*ed894facSEric Biggers } params[] = {
24*ed894facSEric Biggers 	[MLDSA44] = {
25*ed894facSEric Biggers 		.sig_len = MLDSA44_SIGNATURE_SIZE,
26*ed894facSEric Biggers 		.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
27*ed894facSEric Biggers 		.k = 4,
28*ed894facSEric Biggers 		.lambda = 128,
29*ed894facSEric Biggers 		.gamma1 = 1 << 17,
30*ed894facSEric Biggers 		.beta = 78,
31*ed894facSEric Biggers 		.omega = 80,
32*ed894facSEric Biggers 	},
33*ed894facSEric Biggers 	[MLDSA65] = {
34*ed894facSEric Biggers 		.sig_len = MLDSA65_SIGNATURE_SIZE,
35*ed894facSEric Biggers 		.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
36*ed894facSEric Biggers 		.k = 6,
37*ed894facSEric Biggers 		.lambda = 192,
38*ed894facSEric Biggers 		.gamma1 = 1 << 19,
39*ed894facSEric Biggers 		.beta = 196,
40*ed894facSEric Biggers 		.omega = 55,
41*ed894facSEric Biggers 	},
42*ed894facSEric Biggers 	[MLDSA87] = {
43*ed894facSEric Biggers 		.sig_len = MLDSA87_SIGNATURE_SIZE,
44*ed894facSEric Biggers 		.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
45*ed894facSEric Biggers 		.k = 8,
46*ed894facSEric Biggers 		.lambda = 256,
47*ed894facSEric Biggers 		.gamma1 = 1 << 19,
48*ed894facSEric Biggers 		.beta = 120,
49*ed894facSEric Biggers 		.omega = 75,
50*ed894facSEric Biggers 	},
51*ed894facSEric Biggers };
52*ed894facSEric Biggers 
53*ed894facSEric Biggers #include "mldsa-testvecs.h"
54*ed894facSEric Biggers 
55*ed894facSEric Biggers static void do_mldsa_and_assert_success(struct kunit *test,
56*ed894facSEric Biggers 					const struct mldsa_testvector *tv)
57*ed894facSEric Biggers {
58*ed894facSEric Biggers 	int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
59*ed894facSEric Biggers 			       tv->msg_len, tv->pk, tv->pk_len);
60*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, err, 0);
61*ed894facSEric Biggers }
62*ed894facSEric Biggers 
63*ed894facSEric Biggers static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
64*ed894facSEric Biggers {
65*ed894facSEric Biggers 	u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);
66*ed894facSEric Biggers 
67*ed894facSEric Biggers 	KUNIT_ASSERT_NOT_NULL(test, dst);
68*ed894facSEric Biggers 	return memcpy(dst, src, len);
69*ed894facSEric Biggers }
70*ed894facSEric Biggers 
71*ed894facSEric Biggers /*
72*ed894facSEric Biggers  * Test that changing coefficients in a valid signature's z vector results in
73*ed894facSEric Biggers  * the following behavior from mldsa_verify():
74*ed894facSEric Biggers  *
75*ed894facSEric Biggers  *  * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e.
76*ed894facSEric Biggers  *    absolute value >= gamma1 - beta, corresponding to the verifier detecting
77*ed894facSEric Biggers  *    the out-of-range coefficient and rejecting the signature as malformed
78*ed894facSEric Biggers  *
79*ed894facSEric Biggers  *  * -EKEYREJECTED if a coefficient is changed to a different in-range value,
80*ed894facSEric Biggers  *    i.e. absolute value < gamma1 - beta, corresponding to the verifier
81*ed894facSEric Biggers  *    continuing to the "real" signature check and that check failing
82*ed894facSEric Biggers  */
83*ed894facSEric Biggers static void test_mldsa_z_range(struct kunit *test,
84*ed894facSEric Biggers 			       const struct mldsa_testvector *tv)
85*ed894facSEric Biggers {
86*ed894facSEric Biggers 	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
87*ed894facSEric Biggers 	const int lambda = params[tv->alg].lambda;
88*ed894facSEric Biggers 	const s32 gamma1 = params[tv->alg].gamma1;
89*ed894facSEric Biggers 	const int beta = params[tv->alg].beta;
90*ed894facSEric Biggers 	/*
91*ed894facSEric Biggers 	 * We just modify the first coefficient.  The coefficient is gamma1
92*ed894facSEric Biggers 	 * minus either the first 18 or 20 bits of the u32, depending on gamma1.
93*ed894facSEric Biggers 	 *
94*ed894facSEric Biggers 	 * The layout of ML-DSA signatures is ctilde || z || h.  ctilde is
95*ed894facSEric Biggers 	 * lambda / 4 bytes, so z starts at &sig[lambda / 4].
96*ed894facSEric Biggers 	 */
97*ed894facSEric Biggers 	u8 *z_ptr = &sig[lambda / 4];
98*ed894facSEric Biggers 	const u32 z_data = get_unaligned_le32(z_ptr);
99*ed894facSEric Biggers 	const u32 mask = (gamma1 << 1) - 1;
100*ed894facSEric Biggers 	/* These are the four boundaries of the out-of-range values. */
101*ed894facSEric Biggers 	const s32 out_of_range_coeffs[] = {
102*ed894facSEric Biggers 		-gamma1 + 1,
103*ed894facSEric Biggers 		-(gamma1 - beta),
104*ed894facSEric Biggers 		gamma1,
105*ed894facSEric Biggers 		gamma1 - beta,
106*ed894facSEric Biggers 	};
107*ed894facSEric Biggers 	/*
108*ed894facSEric Biggers 	 * These are the two boundaries of the valid range, along with 0.  We
109*ed894facSEric Biggers 	 * assume that none of these matches the original coefficient.
110*ed894facSEric Biggers 	 */
111*ed894facSEric Biggers 	const s32 in_range_coeffs[] = {
112*ed894facSEric Biggers 		-(gamma1 - beta - 1),
113*ed894facSEric Biggers 		0,
114*ed894facSEric Biggers 		gamma1 - beta - 1,
115*ed894facSEric Biggers 	};
116*ed894facSEric Biggers 
117*ed894facSEric Biggers 	/* Initially the signature is valid. */
118*ed894facSEric Biggers 	do_mldsa_and_assert_success(test, tv);
119*ed894facSEric Biggers 
120*ed894facSEric Biggers 	/* Test some out-of-range coefficients. */
121*ed894facSEric Biggers 	for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
122*ed894facSEric Biggers 		const s32 c = out_of_range_coeffs[i];
123*ed894facSEric Biggers 
124*ed894facSEric Biggers 		put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
125*ed894facSEric Biggers 				   z_ptr);
126*ed894facSEric Biggers 		KUNIT_ASSERT_EQ(test, -EBADMSG,
127*ed894facSEric Biggers 				mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
128*ed894facSEric Biggers 					     tv->msg_len, tv->pk, tv->pk_len));
129*ed894facSEric Biggers 	}
130*ed894facSEric Biggers 
131*ed894facSEric Biggers 	/* Test some in-range coefficients. */
132*ed894facSEric Biggers 	for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
133*ed894facSEric Biggers 		const s32 c = in_range_coeffs[i];
134*ed894facSEric Biggers 
135*ed894facSEric Biggers 		put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
136*ed894facSEric Biggers 				   z_ptr);
137*ed894facSEric Biggers 		KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
138*ed894facSEric Biggers 				mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
139*ed894facSEric Biggers 					     tv->msg_len, tv->pk, tv->pk_len));
140*ed894facSEric Biggers 	}
141*ed894facSEric Biggers }
142*ed894facSEric Biggers 
143*ed894facSEric Biggers /* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */
144*ed894facSEric Biggers static void test_mldsa_bad_hints(struct kunit *test,
145*ed894facSEric Biggers 				 const struct mldsa_testvector *tv)
146*ed894facSEric Biggers {
147*ed894facSEric Biggers 	const int omega = params[tv->alg].omega;
148*ed894facSEric Biggers 	const int k = params[tv->alg].k;
149*ed894facSEric Biggers 	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
150*ed894facSEric Biggers 	/* Pointer to the encoded hint vector in the signature */
151*ed894facSEric Biggers 	u8 *hintvec = &sig[tv->sig_len - omega - k];
152*ed894facSEric Biggers 	u8 h;
153*ed894facSEric Biggers 
154*ed894facSEric Biggers 	/* Initially the signature is valid. */
155*ed894facSEric Biggers 	do_mldsa_and_assert_success(test, tv);
156*ed894facSEric Biggers 
157*ed894facSEric Biggers 	/* Cumulative hint count exceeds omega */
158*ed894facSEric Biggers 	memcpy(sig, tv->sig, tv->sig_len);
159*ed894facSEric Biggers 	hintvec[omega + k - 1] = omega + 1;
160*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
161*ed894facSEric Biggers 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
162*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len));
163*ed894facSEric Biggers 
164*ed894facSEric Biggers 	/* Cumulative hint count decreases */
165*ed894facSEric Biggers 	memcpy(sig, tv->sig, tv->sig_len);
166*ed894facSEric Biggers 	KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
167*ed894facSEric Biggers 	hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
168*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
169*ed894facSEric Biggers 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
170*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len));
171*ed894facSEric Biggers 
172*ed894facSEric Biggers 	/*
173*ed894facSEric Biggers 	 * Hint indices out of order.  To test this, swap hintvec[0] and
174*ed894facSEric Biggers 	 * hintvec[1].  This assumes that the original valid signature had at
175*ed894facSEric Biggers 	 * least two nonzero hints in the first element (asserted below).
176*ed894facSEric Biggers 	 */
177*ed894facSEric Biggers 	memcpy(sig, tv->sig, tv->sig_len);
178*ed894facSEric Biggers 	KUNIT_ASSERT_GE(test, hintvec[omega], 2);
179*ed894facSEric Biggers 	h = hintvec[0];
180*ed894facSEric Biggers 	hintvec[0] = hintvec[1];
181*ed894facSEric Biggers 	hintvec[1] = h;
182*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
183*ed894facSEric Biggers 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
184*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len));
185*ed894facSEric Biggers 
186*ed894facSEric Biggers 	/*
187*ed894facSEric Biggers 	 * Extra hint indices given.  For this test to work, the original valid
188*ed894facSEric Biggers 	 * signature must have fewer than omega nonzero hints (asserted below).
189*ed894facSEric Biggers 	 */
190*ed894facSEric Biggers 	memcpy(sig, tv->sig, tv->sig_len);
191*ed894facSEric Biggers 	KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
192*ed894facSEric Biggers 	hintvec[omega - 1] = 0xff;
193*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
194*ed894facSEric Biggers 			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
195*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len));
196*ed894facSEric Biggers }
197*ed894facSEric Biggers 
198*ed894facSEric Biggers static void test_mldsa_mutation(struct kunit *test,
199*ed894facSEric Biggers 				const struct mldsa_testvector *tv)
200*ed894facSEric Biggers {
201*ed894facSEric Biggers 	const int sig_len = tv->sig_len;
202*ed894facSEric Biggers 	const int msg_len = tv->msg_len;
203*ed894facSEric Biggers 	const int pk_len = tv->pk_len;
204*ed894facSEric Biggers 	const int num_iter = 200;
205*ed894facSEric Biggers 	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
206*ed894facSEric Biggers 	u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
207*ed894facSEric Biggers 	u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);
208*ed894facSEric Biggers 
209*ed894facSEric Biggers 	/* Initially the signature is valid. */
210*ed894facSEric Biggers 	do_mldsa_and_assert_success(test, tv);
211*ed894facSEric Biggers 
212*ed894facSEric Biggers 	/* Changing any bit in the signature should invalidate the signature */
213*ed894facSEric Biggers 	for (int i = 0; i < num_iter; i++) {
214*ed894facSEric Biggers 		size_t pos = get_random_u32_below(sig_len);
215*ed894facSEric Biggers 		u8 b = 1 << get_random_u32_below(8);
216*ed894facSEric Biggers 
217*ed894facSEric Biggers 		sig[pos] ^= b;
218*ed894facSEric Biggers 		KUNIT_ASSERT_NE(test, 0,
219*ed894facSEric Biggers 				mldsa_verify(tv->alg, sig, sig_len, msg,
220*ed894facSEric Biggers 					     msg_len, pk, pk_len));
221*ed894facSEric Biggers 		sig[pos] ^= b;
222*ed894facSEric Biggers 	}
223*ed894facSEric Biggers 
224*ed894facSEric Biggers 	/* Changing any bit in the message should invalidate the signature */
225*ed894facSEric Biggers 	for (int i = 0; i < num_iter; i++) {
226*ed894facSEric Biggers 		size_t pos = get_random_u32_below(msg_len);
227*ed894facSEric Biggers 		u8 b = 1 << get_random_u32_below(8);
228*ed894facSEric Biggers 
229*ed894facSEric Biggers 		msg[pos] ^= b;
230*ed894facSEric Biggers 		KUNIT_ASSERT_NE(test, 0,
231*ed894facSEric Biggers 				mldsa_verify(tv->alg, sig, sig_len, msg,
232*ed894facSEric Biggers 					     msg_len, pk, pk_len));
233*ed894facSEric Biggers 		msg[pos] ^= b;
234*ed894facSEric Biggers 	}
235*ed894facSEric Biggers 
236*ed894facSEric Biggers 	/* Changing any bit in the public key should invalidate the signature */
237*ed894facSEric Biggers 	for (int i = 0; i < num_iter; i++) {
238*ed894facSEric Biggers 		size_t pos = get_random_u32_below(pk_len);
239*ed894facSEric Biggers 		u8 b = 1 << get_random_u32_below(8);
240*ed894facSEric Biggers 
241*ed894facSEric Biggers 		pk[pos] ^= b;
242*ed894facSEric Biggers 		KUNIT_ASSERT_NE(test, 0,
243*ed894facSEric Biggers 				mldsa_verify(tv->alg, sig, sig_len, msg,
244*ed894facSEric Biggers 					     msg_len, pk, pk_len));
245*ed894facSEric Biggers 		pk[pos] ^= b;
246*ed894facSEric Biggers 	}
247*ed894facSEric Biggers 
248*ed894facSEric Biggers 	/* All changes should have been undone. */
249*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, 0,
250*ed894facSEric Biggers 			mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
251*ed894facSEric Biggers 				     pk_len));
252*ed894facSEric Biggers }
253*ed894facSEric Biggers 
254*ed894facSEric Biggers static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
255*ed894facSEric Biggers {
256*ed894facSEric Biggers 	/* Valid signature */
257*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
258*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
259*ed894facSEric Biggers 	do_mldsa_and_assert_success(test, tv);
260*ed894facSEric Biggers 
261*ed894facSEric Biggers 	/* Signature too short */
262*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
263*ed894facSEric Biggers 			mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
264*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len));
265*ed894facSEric Biggers 
266*ed894facSEric Biggers 	/* Signature too long */
267*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
268*ed894facSEric Biggers 			mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
269*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len));
270*ed894facSEric Biggers 
271*ed894facSEric Biggers 	/* Public key too short */
272*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
273*ed894facSEric Biggers 			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
274*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len - 1));
275*ed894facSEric Biggers 
276*ed894facSEric Biggers 	/* Public key too long */
277*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EBADMSG,
278*ed894facSEric Biggers 			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
279*ed894facSEric Biggers 				     tv->msg_len, tv->pk, tv->pk_len + 1));
280*ed894facSEric Biggers 
281*ed894facSEric Biggers 	/*
282*ed894facSEric Biggers 	 * Message too short.  Error is EKEYREJECTED because it gets rejected by
283*ed894facSEric Biggers 	 * the "real" signature check rather than the well-formedness checks.
284*ed894facSEric Biggers 	 */
285*ed894facSEric Biggers 	KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
286*ed894facSEric Biggers 			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
287*ed894facSEric Biggers 				     tv->msg_len - 1, tv->pk, tv->pk_len));
288*ed894facSEric Biggers 	/*
289*ed894facSEric Biggers 	 * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be
290*ed894facSEric Biggers 	 * accessed out of bounds.  However, ML-DSA just hashes the message and
291*ed894facSEric Biggers 	 * doesn't handle different message lengths differently anyway.
292*ed894facSEric Biggers 	 */
293*ed894facSEric Biggers 
294*ed894facSEric Biggers 	/* Test the validity checks on the z vector. */
295*ed894facSEric Biggers 	test_mldsa_z_range(test, tv);
296*ed894facSEric Biggers 
297*ed894facSEric Biggers 	/* Test the validity checks on the hint vector. */
298*ed894facSEric Biggers 	test_mldsa_bad_hints(test, tv);
299*ed894facSEric Biggers 
300*ed894facSEric Biggers 	/* Test randomly mutating the inputs. */
301*ed894facSEric Biggers 	test_mldsa_mutation(test, tv);
302*ed894facSEric Biggers }
303*ed894facSEric Biggers 
304*ed894facSEric Biggers static void test_mldsa44(struct kunit *test)
305*ed894facSEric Biggers {
306*ed894facSEric Biggers 	test_mldsa(test, &mldsa44_testvector);
307*ed894facSEric Biggers }
308*ed894facSEric Biggers 
309*ed894facSEric Biggers static void test_mldsa65(struct kunit *test)
310*ed894facSEric Biggers {
311*ed894facSEric Biggers 	test_mldsa(test, &mldsa65_testvector);
312*ed894facSEric Biggers }
313*ed894facSEric Biggers 
314*ed894facSEric Biggers static void test_mldsa87(struct kunit *test)
315*ed894facSEric Biggers {
316*ed894facSEric Biggers 	test_mldsa(test, &mldsa87_testvector);
317*ed894facSEric Biggers }
318*ed894facSEric Biggers 
319*ed894facSEric Biggers static s32 mod(s32 a, s32 m)
320*ed894facSEric Biggers {
321*ed894facSEric Biggers 	a %= m;
322*ed894facSEric Biggers 	if (a < 0)
323*ed894facSEric Biggers 		a += m;
324*ed894facSEric Biggers 	return a;
325*ed894facSEric Biggers }
326*ed894facSEric Biggers 
327*ed894facSEric Biggers static s32 symmetric_mod(s32 a, s32 m)
328*ed894facSEric Biggers {
329*ed894facSEric Biggers 	a = mod(a, m);
330*ed894facSEric Biggers 	if (a > m / 2)
331*ed894facSEric Biggers 		a -= m;
332*ed894facSEric Biggers 	return a;
333*ed894facSEric Biggers }
334*ed894facSEric Biggers 
335*ed894facSEric Biggers /* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */
336*ed894facSEric Biggers static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
337*ed894facSEric Biggers {
338*ed894facSEric Biggers 	s32 rplus = mod(r, Q);
339*ed894facSEric Biggers 
340*ed894facSEric Biggers 	*r0 = symmetric_mod(rplus, 2 * gamma2);
341*ed894facSEric Biggers 	if (rplus - *r0 == Q - 1) {
342*ed894facSEric Biggers 		*r1 = 0;
343*ed894facSEric Biggers 		*r0 = *r0 - 1;
344*ed894facSEric Biggers 	} else {
345*ed894facSEric Biggers 		*r1 = (rplus - *r0) / (2 * gamma2);
346*ed894facSEric Biggers 	}
347*ed894facSEric Biggers }
348*ed894facSEric Biggers 
349*ed894facSEric Biggers /* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */
350*ed894facSEric Biggers static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
351*ed894facSEric Biggers {
352*ed894facSEric Biggers 	s32 m = (Q - 1) / (2 * gamma2);
353*ed894facSEric Biggers 	s32 r0, r1;
354*ed894facSEric Biggers 
355*ed894facSEric Biggers 	decompose_ref(r, gamma2, &r0, &r1);
356*ed894facSEric Biggers 	if (h == 1 && r0 > 0)
357*ed894facSEric Biggers 		return mod(r1 + 1, m);
358*ed894facSEric Biggers 	if (h == 1 && r0 <= 0)
359*ed894facSEric Biggers 		return mod(r1 - 1, m);
360*ed894facSEric Biggers 	return r1;
361*ed894facSEric Biggers }
362*ed894facSEric Biggers 
363*ed894facSEric Biggers /*
364*ed894facSEric Biggers  * Test that for all possible inputs, mldsa_use_hint() gives the same output as
365*ed894facSEric Biggers  * a mechanical translation of the pseudocode from FIPS 204.
366*ed894facSEric Biggers  */
367*ed894facSEric Biggers static void test_mldsa_use_hint(struct kunit *test)
368*ed894facSEric Biggers {
369*ed894facSEric Biggers 	for (int i = 0; i < 2; i++) {
370*ed894facSEric Biggers 		const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);
371*ed894facSEric Biggers 
372*ed894facSEric Biggers 		for (u8 h = 0; h < 2; h++) {
373*ed894facSEric Biggers 			for (s32 r = 0; r < Q; r++) {
374*ed894facSEric Biggers 				KUNIT_ASSERT_EQ(test,
375*ed894facSEric Biggers 						mldsa_use_hint(h, r, gamma2),
376*ed894facSEric Biggers 						use_hint_ref(h, r, gamma2));
377*ed894facSEric Biggers 			}
378*ed894facSEric Biggers 		}
379*ed894facSEric Biggers 	}
380*ed894facSEric Biggers }
381*ed894facSEric Biggers 
382*ed894facSEric Biggers static void benchmark_mldsa(struct kunit *test,
383*ed894facSEric Biggers 			    const struct mldsa_testvector *tv)
384*ed894facSEric Biggers {
385*ed894facSEric Biggers 	const int warmup_niter = 200;
386*ed894facSEric Biggers 	const int benchmark_niter = 200;
387*ed894facSEric Biggers 	u64 t0, t1;
388*ed894facSEric Biggers 
389*ed894facSEric Biggers 	if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
390*ed894facSEric Biggers 		kunit_skip(test, "not enabled");
391*ed894facSEric Biggers 
392*ed894facSEric Biggers 	for (int i = 0; i < warmup_niter; i++)
393*ed894facSEric Biggers 		do_mldsa_and_assert_success(test, tv);
394*ed894facSEric Biggers 
395*ed894facSEric Biggers 	t0 = ktime_get_ns();
396*ed894facSEric Biggers 	for (int i = 0; i < benchmark_niter; i++)
397*ed894facSEric Biggers 		do_mldsa_and_assert_success(test, tv);
398*ed894facSEric Biggers 	t1 = ktime_get_ns();
399*ed894facSEric Biggers 	kunit_info(test, "%llu ops/s",
400*ed894facSEric Biggers 		   div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
401*ed894facSEric Biggers 			     t1 - t0 ?: 1));
402*ed894facSEric Biggers }
403*ed894facSEric Biggers 
404*ed894facSEric Biggers static void benchmark_mldsa44(struct kunit *test)
405*ed894facSEric Biggers {
406*ed894facSEric Biggers 	benchmark_mldsa(test, &mldsa44_testvector);
407*ed894facSEric Biggers }
408*ed894facSEric Biggers 
409*ed894facSEric Biggers static void benchmark_mldsa65(struct kunit *test)
410*ed894facSEric Biggers {
411*ed894facSEric Biggers 	benchmark_mldsa(test, &mldsa65_testvector);
412*ed894facSEric Biggers }
413*ed894facSEric Biggers 
414*ed894facSEric Biggers static void benchmark_mldsa87(struct kunit *test)
415*ed894facSEric Biggers {
416*ed894facSEric Biggers 	benchmark_mldsa(test, &mldsa87_testvector);
417*ed894facSEric Biggers }
418*ed894facSEric Biggers 
419*ed894facSEric Biggers static struct kunit_case mldsa_kunit_cases[] = {
420*ed894facSEric Biggers 	KUNIT_CASE(test_mldsa44),
421*ed894facSEric Biggers 	KUNIT_CASE(test_mldsa65),
422*ed894facSEric Biggers 	KUNIT_CASE(test_mldsa87),
423*ed894facSEric Biggers 	KUNIT_CASE(test_mldsa_use_hint),
424*ed894facSEric Biggers 	KUNIT_CASE(benchmark_mldsa44),
425*ed894facSEric Biggers 	KUNIT_CASE(benchmark_mldsa65),
426*ed894facSEric Biggers 	KUNIT_CASE(benchmark_mldsa87),
427*ed894facSEric Biggers 	{},
428*ed894facSEric Biggers };
429*ed894facSEric Biggers 
430*ed894facSEric Biggers static struct kunit_suite mldsa_kunit_suite = {
431*ed894facSEric Biggers 	.name = "mldsa",
432*ed894facSEric Biggers 	.test_cases = mldsa_kunit_cases,
433*ed894facSEric Biggers };
434*ed894facSEric Biggers kunit_test_suite(mldsa_kunit_suite);
435*ed894facSEric Biggers 
436*ed894facSEric Biggers MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
437*ed894facSEric Biggers MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
438*ed894facSEric Biggers MODULE_LICENSE("GPL");
439