xref: /linux/lib/tests/base64_kunit.c (revision 7b8e9264f55a9c320f398e337d215e68cca50131)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * base64_kunit_test.c - KUnit tests for base64 encoding and decoding functions
4  *
5  * Copyright (c) 2025, Guan-Chun Wu <409411716@gms.tku.edu.tw>
6  */
7 
8 #include <kunit/test.h>
9 #include <linux/base64.h>
10 
11 /* ---------- Benchmark helpers ---------- */
12 static u64 bench_encode_ns(const u8 *data, int len, char *dst, int reps,
13 			   enum base64_variant variant)
14 {
15 	u64 t0, t1;
16 
17 	t0 = ktime_get_ns();
18 	for (int i = 0; i < reps; i++)
19 		base64_encode(data, len, dst, true, variant);
20 	t1 = ktime_get_ns();
21 
22 	return div64_u64(t1 - t0, (u64)reps);
23 }
24 
25 static u64 bench_decode_ns(const char *data, int len, u8 *dst, int reps,
26 			   enum base64_variant variant)
27 {
28 	u64 t0, t1;
29 
30 	t0 = ktime_get_ns();
31 	for (int i = 0; i < reps; i++)
32 		base64_decode(data, len, dst, true, variant);
33 	t1 = ktime_get_ns();
34 
35 	return div64_u64(t1 - t0, (u64)reps);
36 }
37 
38 static void run_perf_and_check(struct kunit *test, const char *label, int size,
39 			       enum base64_variant variant)
40 {
41 	const int reps = 1000;
42 	size_t outlen = DIV_ROUND_UP(size, 3) * 4;
43 	u8 *in = kmalloc(size, GFP_KERNEL);
44 	char *enc = kmalloc(outlen, GFP_KERNEL);
45 	u8 *decoded = kmalloc(size, GFP_KERNEL);
46 
47 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, in);
48 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, enc);
49 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, decoded);
50 
51 	get_random_bytes(in, size);
52 	int enc_len = base64_encode(in, size, enc, true, variant);
53 	int dec_len = base64_decode(enc, enc_len, decoded, true, variant);
54 
55 	/* correctness sanity check */
56 	KUNIT_EXPECT_EQ(test, dec_len, size);
57 	KUNIT_EXPECT_MEMEQ(test, decoded, in, size);
58 
59 	/* benchmark encode */
60 
61 	u64 t1 = bench_encode_ns(in, size, enc, reps, variant);
62 
63 	kunit_info(test, "[%s] encode run : %lluns", label, t1);
64 
65 	u64 t2 = bench_decode_ns(enc, enc_len, decoded, reps, variant);
66 
67 	kunit_info(test, "[%s] decode run : %lluns", label, t2);
68 
69 	kfree(in);
70 	kfree(enc);
71 	kfree(decoded);
72 }
73 
74 static void base64_performance_tests(struct kunit *test)
75 {
76 	/* run on STD variant only */
77 	run_perf_and_check(test, "64B", 64, BASE64_STD);
78 	run_perf_and_check(test, "1KB", 1024, BASE64_STD);
79 }
80 
81 /* ---------- Helpers for encode ---------- */
82 static void expect_encode_ok(struct kunit *test, const u8 *src, int srclen,
83 			     const char *expected, bool padding,
84 			     enum base64_variant variant)
85 {
86 	char buf[128];
87 	int encoded_len = base64_encode(src, srclen, buf, padding, variant);
88 
89 	buf[encoded_len] = '\0';
90 
91 	KUNIT_EXPECT_EQ(test, encoded_len, strlen(expected));
92 	KUNIT_EXPECT_STREQ(test, buf, expected);
93 }
94 
95 /* ---------- Helpers for decode ---------- */
96 static void expect_decode_ok(struct kunit *test, const char *src,
97 			     const u8 *expected, int expected_len, bool padding,
98 			     enum base64_variant variant)
99 {
100 	u8 buf[128];
101 	int decoded_len = base64_decode(src, strlen(src), buf, padding, variant);
102 
103 	KUNIT_EXPECT_EQ(test, decoded_len, expected_len);
104 	KUNIT_EXPECT_MEMEQ(test, buf, expected, expected_len);
105 }
106 
107 static void expect_decode_err(struct kunit *test, const char *src,
108 			      int srclen, bool padding,
109 			      enum base64_variant variant)
110 {
111 	u8 buf[64];
112 	int decoded_len = base64_decode(src, srclen, buf, padding, variant);
113 
114 	KUNIT_EXPECT_EQ(test, decoded_len, -1);
115 }
116 
117 /* ---------- Encode Tests ---------- */
118 static void base64_std_encode_tests(struct kunit *test)
119 {
120 	/* With padding */
121 	expect_encode_ok(test, (const u8 *)"", 0, "", true, BASE64_STD);
122 	expect_encode_ok(test, (const u8 *)"f", 1, "Zg==", true, BASE64_STD);
123 	expect_encode_ok(test, (const u8 *)"fo", 2, "Zm8=", true, BASE64_STD);
124 	expect_encode_ok(test, (const u8 *)"foo", 3, "Zm9v", true, BASE64_STD);
125 	expect_encode_ok(test, (const u8 *)"foob", 4, "Zm9vYg==", true, BASE64_STD);
126 	expect_encode_ok(test, (const u8 *)"fooba", 5, "Zm9vYmE=", true, BASE64_STD);
127 	expect_encode_ok(test, (const u8 *)"foobar", 6, "Zm9vYmFy", true, BASE64_STD);
128 
129 	/* Extra cases with padding */
130 	expect_encode_ok(test, (const u8 *)"Hello, world!", 13, "SGVsbG8sIHdvcmxkIQ==",
131 			 true, BASE64_STD);
132 	expect_encode_ok(test, (const u8 *)"ABCDEFGHIJKLMNOPQRSTUVWXYZ", 26,
133 			 "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVo=", true, BASE64_STD);
134 	expect_encode_ok(test, (const u8 *)"abcdefghijklmnopqrstuvwxyz", 26,
135 			 "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo=", true, BASE64_STD);
136 	expect_encode_ok(test, (const u8 *)"0123456789+/", 12, "MDEyMzQ1Njc4OSsv",
137 			 true, BASE64_STD);
138 
139 	/* Without padding */
140 	expect_encode_ok(test, (const u8 *)"", 0, "", false, BASE64_STD);
141 	expect_encode_ok(test, (const u8 *)"f", 1, "Zg", false, BASE64_STD);
142 	expect_encode_ok(test, (const u8 *)"fo", 2, "Zm8", false, BASE64_STD);
143 	expect_encode_ok(test, (const u8 *)"foo", 3, "Zm9v", false, BASE64_STD);
144 	expect_encode_ok(test, (const u8 *)"foob", 4, "Zm9vYg", false, BASE64_STD);
145 	expect_encode_ok(test, (const u8 *)"fooba", 5, "Zm9vYmE", false, BASE64_STD);
146 	expect_encode_ok(test, (const u8 *)"foobar", 6, "Zm9vYmFy", false, BASE64_STD);
147 
148 	/* Extra cases without padding */
149 	expect_encode_ok(test, (const u8 *)"Hello, world!", 13, "SGVsbG8sIHdvcmxkIQ",
150 			 false, BASE64_STD);
151 	expect_encode_ok(test, (const u8 *)"ABCDEFGHIJKLMNOPQRSTUVWXYZ", 26,
152 			 "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVo", false, BASE64_STD);
153 	expect_encode_ok(test, (const u8 *)"abcdefghijklmnopqrstuvwxyz", 26,
154 			 "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo", false, BASE64_STD);
155 	expect_encode_ok(test, (const u8 *)"0123456789+/", 12, "MDEyMzQ1Njc4OSsv",
156 			 false, BASE64_STD);
157 }
158 
159 /* ---------- Decode Tests ---------- */
160 static void base64_std_decode_tests(struct kunit *test)
161 {
162 	/* -------- With padding --------*/
163 	expect_decode_ok(test, "", (const u8 *)"", 0, true, BASE64_STD);
164 	expect_decode_ok(test, "Zg==", (const u8 *)"f", 1, true, BASE64_STD);
165 	expect_decode_ok(test, "Zm8=", (const u8 *)"fo", 2, true, BASE64_STD);
166 	expect_decode_ok(test, "Zm9v", (const u8 *)"foo", 3, true, BASE64_STD);
167 	expect_decode_ok(test, "Zm9vYg==", (const u8 *)"foob", 4, true, BASE64_STD);
168 	expect_decode_ok(test, "Zm9vYmE=", (const u8 *)"fooba", 5, true, BASE64_STD);
169 	expect_decode_ok(test, "Zm9vYmFy", (const u8 *)"foobar", 6, true, BASE64_STD);
170 	expect_decode_ok(test, "SGVsbG8sIHdvcmxkIQ==", (const u8 *)"Hello, world!", 13,
171 			 true, BASE64_STD);
172 	expect_decode_ok(test, "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVo=",
173 			 (const u8 *)"ABCDEFGHIJKLMNOPQRSTUVWXYZ", 26, true, BASE64_STD);
174 	expect_decode_ok(test, "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo=",
175 			 (const u8 *)"abcdefghijklmnopqrstuvwxyz", 26, true, BASE64_STD);
176 
177 	/* Error cases */
178 	expect_decode_err(test, "Zg=!", 4, true, BASE64_STD);
179 	expect_decode_err(test, "Zm$=", 4, true, BASE64_STD);
180 	expect_decode_err(test, "Z===", 4, true, BASE64_STD);
181 	expect_decode_err(test, "Zg", 2, true, BASE64_STD);
182 	expect_decode_err(test, "Zm9v====", 8, true, BASE64_STD);
183 	expect_decode_err(test, "Zm==A", 5, true, BASE64_STD);
184 
185 	{
186 		char with_nul[4] = { 'Z', 'g', '\0', '=' };
187 
188 		expect_decode_err(test, with_nul, 4, true, BASE64_STD);
189 	}
190 
191 	/* -------- Without padding --------*/
192 	expect_decode_ok(test, "", (const u8 *)"", 0, false, BASE64_STD);
193 	expect_decode_ok(test, "Zg", (const u8 *)"f", 1, false, BASE64_STD);
194 	expect_decode_ok(test, "Zm8", (const u8 *)"fo", 2, false, BASE64_STD);
195 	expect_decode_ok(test, "Zm9v", (const u8 *)"foo", 3, false, BASE64_STD);
196 	expect_decode_ok(test, "Zm9vYg", (const u8 *)"foob", 4, false, BASE64_STD);
197 	expect_decode_ok(test, "Zm9vYmE", (const u8 *)"fooba", 5, false, BASE64_STD);
198 	expect_decode_ok(test, "Zm9vYmFy", (const u8 *)"foobar", 6, false, BASE64_STD);
199 	expect_decode_ok(test, "TWFu", (const u8 *)"Man", 3, false, BASE64_STD);
200 	expect_decode_ok(test, "SGVsbG8sIHdvcmxkIQ", (const u8 *)"Hello, world!", 13,
201 			 false, BASE64_STD);
202 	expect_decode_ok(test, "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVo",
203 			 (const u8 *)"ABCDEFGHIJKLMNOPQRSTUVWXYZ", 26, false, BASE64_STD);
204 	expect_decode_ok(test, "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo",
205 			 (const u8 *)"abcdefghijklmnopqrstuvwxyz", 26, false, BASE64_STD);
206 	expect_decode_ok(test, "MDEyMzQ1Njc4OSsv", (const u8 *)"0123456789+/", 12,
207 			 false, BASE64_STD);
208 
209 	/* Error cases */
210 	expect_decode_err(test, "Zg=!", 4, false, BASE64_STD);
211 	expect_decode_err(test, "Zm$=", 4, false, BASE64_STD);
212 	expect_decode_err(test, "Z===", 4, false, BASE64_STD);
213 	expect_decode_err(test, "Zg=", 3, false, BASE64_STD);
214 	expect_decode_err(test, "Zm9v====", 8, false, BASE64_STD);
215 	expect_decode_err(test, "Zm==v", 4, false, BASE64_STD);
216 
217 	{
218 		char with_nul[4] = { 'Z', 'g', '\0', '=' };
219 
220 		expect_decode_err(test, with_nul, 4, false, BASE64_STD);
221 	}
222 }
223 
224 /* ---------- Variant tests (URLSAFE / IMAP) ---------- */
225 static void base64_variant_tests(struct kunit *test)
226 {
227 	const u8 sample1[] = { 0x00, 0xfb, 0xff, 0x7f, 0x80 };
228 	char std_buf[128], url_buf[128], imap_buf[128];
229 	u8 back[128];
230 	int n_std, n_url, n_imap, m;
231 	int i;
232 
233 	n_std = base64_encode(sample1, sizeof(sample1), std_buf, false, BASE64_STD);
234 	n_url = base64_encode(sample1, sizeof(sample1), url_buf, false, BASE64_URLSAFE);
235 	std_buf[n_std] = '\0';
236 	url_buf[n_url] = '\0';
237 
238 	for (i = 0; i < n_std; i++) {
239 		if (std_buf[i] == '+')
240 			std_buf[i] = '-';
241 		else if (std_buf[i] == '/')
242 			std_buf[i] = '_';
243 	}
244 	KUNIT_EXPECT_STREQ(test, std_buf, url_buf);
245 
246 	m = base64_decode(url_buf, n_url, back, false, BASE64_URLSAFE);
247 	KUNIT_EXPECT_EQ(test, m, (int)sizeof(sample1));
248 	KUNIT_EXPECT_MEMEQ(test, back, sample1, sizeof(sample1));
249 
250 	n_std  = base64_encode(sample1, sizeof(sample1), std_buf, false, BASE64_STD);
251 	n_imap = base64_encode(sample1, sizeof(sample1), imap_buf, false, BASE64_IMAP);
252 	std_buf[n_std]   = '\0';
253 	imap_buf[n_imap] = '\0';
254 
255 	for (i = 0; i < n_std; i++)
256 		if (std_buf[i] == '/')
257 			std_buf[i] = ',';
258 	KUNIT_EXPECT_STREQ(test, std_buf, imap_buf);
259 
260 	m = base64_decode(imap_buf, n_imap, back, false, BASE64_IMAP);
261 	KUNIT_EXPECT_EQ(test, m, (int)sizeof(sample1));
262 	KUNIT_EXPECT_MEMEQ(test, back, sample1, sizeof(sample1));
263 
264 	{
265 		const char *bad = "Zg==";
266 		u8 tmp[8];
267 
268 		m = base64_decode(bad, strlen(bad), tmp, false, BASE64_URLSAFE);
269 		KUNIT_EXPECT_EQ(test, m, -1);
270 
271 		m = base64_decode(bad, strlen(bad), tmp, false, BASE64_IMAP);
272 		KUNIT_EXPECT_EQ(test, m, -1);
273 	}
274 }
275 
276 /* ---------- Test registration ---------- */
277 static struct kunit_case base64_test_cases[] = {
278 	KUNIT_CASE(base64_performance_tests),
279 	KUNIT_CASE(base64_std_encode_tests),
280 	KUNIT_CASE(base64_std_decode_tests),
281 	KUNIT_CASE(base64_variant_tests),
282 	{}
283 };
284 
285 static struct kunit_suite base64_test_suite = {
286 	.name = "base64",
287 	.test_cases = base64_test_cases,
288 };
289 
290 kunit_test_suite(base64_test_suite);
291 
292 MODULE_AUTHOR("Guan-Chun Wu <409411716@gms.tku.edu.tw>");
293 MODULE_DESCRIPTION("KUnit tests for Base64 encoding/decoding, including performance checks");
294 MODULE_LICENSE("GPL");
295