xref: /linux/drivers/iommu/generic_pt/kunit_generic_pt.h (revision fbf5df34a4dbcd09d433dd4f0916bf9b2ddb16de)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
4  *
5  * Test the format API directly.
6  *
7  */
8 #include "kunit_iommu.h"
9 #include "pt_iter.h"
10 
11 static void do_map(struct kunit *test, pt_vaddr_t va, pt_oaddr_t pa,
12 		   pt_vaddr_t len)
13 {
14 	struct kunit_iommu_priv *priv = test->priv;
15 	int ret;
16 
17 	KUNIT_ASSERT_EQ(test, len, (size_t)len);
18 
19 	ret = iommu_map(&priv->domain, va, pa, len, IOMMU_READ | IOMMU_WRITE,
20 			GFP_KERNEL);
21 	KUNIT_ASSERT_NO_ERRNO_FN(test, "map_pages", ret);
22 }
23 
24 #define KUNIT_ASSERT_PT_LOAD(test, pts, entry)             \
25 	({                                                 \
26 		pt_load_entry(pts);                        \
27 		KUNIT_ASSERT_EQ(test, (pts)->type, entry); \
28 	})
29 
30 struct check_levels_arg {
31 	struct kunit *test;
32 	void *fn_arg;
33 	void (*fn)(struct kunit *test, struct pt_state *pts, void *arg);
34 };
35 
36 static int __check_all_levels(struct pt_range *range, void *arg,
37 			      unsigned int level, struct pt_table_p *table)
38 {
39 	struct pt_state pts = pt_init(range, level, table);
40 	struct check_levels_arg *chk = arg;
41 	struct kunit *test = chk->test;
42 	int ret;
43 
44 	_pt_iter_first(&pts);
45 
46 
47 	/*
48 	 * If we were able to use the full VA space this should always be the
49 	 * last index in each table.
50 	 */
51 	if (!(IS_32BIT && range->max_vasz_lg2 > 32)) {
52 		if (pt_feature(range->common, PT_FEAT_SIGN_EXTEND) &&
53 		    pts.level == pts.range->top_level)
54 			KUNIT_ASSERT_EQ(test, pts.index,
55 					log2_to_int(range->max_vasz_lg2 - 1 -
56 						    pt_table_item_lg2sz(&pts)) -
57 						1);
58 		else
59 			KUNIT_ASSERT_EQ(test, pts.index,
60 					log2_to_int(pt_table_oa_lg2sz(&pts) -
61 						    pt_table_item_lg2sz(&pts)) -
62 						1);
63 	}
64 
65 	if (pt_can_have_table(&pts)) {
66 		pt_load_single_entry(&pts);
67 		KUNIT_ASSERT_EQ(test, pts.type, PT_ENTRY_TABLE);
68 		ret = pt_descend(&pts, arg, __check_all_levels);
69 		KUNIT_ASSERT_EQ(test, ret, 0);
70 
71 		/* Index 0 is used by the test */
72 		if (IS_32BIT && !pts.index)
73 			return 0;
74 		KUNIT_ASSERT_NE(chk->test, pts.index, 0);
75 	}
76 
77 	/*
78 	 * A format should not create a table with only one entry, at least this
79 	 * test approach won't work.
80 	 */
81 	KUNIT_ASSERT_GT(chk->test, pts.end_index, 1);
82 
83 	/*
84 	 * For increase top we end up using index 0 for the original top's tree,
85 	 * so use index 1 for testing instead.
86 	 */
87 	pts.index = 0;
88 	pt_index_to_va(&pts);
89 	pt_load_single_entry(&pts);
90 	if (pts.type == PT_ENTRY_TABLE && pts.end_index > 2) {
91 		pts.index = 1;
92 		pt_index_to_va(&pts);
93 	}
94 	(*chk->fn)(chk->test, &pts, chk->fn_arg);
95 	return 0;
96 }
97 
98 /*
99  * Call fn for each level in the table with a pts setup to index 0 in a table
100  * for that level. This allows writing tests that run on every level.
101  * The test can use every index in the table except the last one.
102  */
103 static void check_all_levels(struct kunit *test,
104 			     void (*fn)(struct kunit *test,
105 					struct pt_state *pts, void *arg),
106 			     void *fn_arg)
107 {
108 	struct kunit_iommu_priv *priv = test->priv;
109 	struct pt_range range = pt_top_range(priv->common);
110 	struct check_levels_arg chk = {
111 		.test = test,
112 		.fn = fn,
113 		.fn_arg = fn_arg,
114 	};
115 	int ret;
116 
117 	if (pt_feature(priv->common, PT_FEAT_DYNAMIC_TOP) &&
118 	    priv->common->max_vasz_lg2 > range.max_vasz_lg2)
119 		range.last_va = fvalog2_set_mod_max(range.va,
120 						    priv->common->max_vasz_lg2);
121 
122 	/*
123 	 * Map a page at the highest VA, this will populate all the levels so we
124 	 * can then iterate over them. Index 0 will be used for testing.
125 	 */
126 	if (IS_32BIT && range.max_vasz_lg2 > 32)
127 		range.last_va = (u32)range.last_va;
128 	range.va = range.last_va - (priv->smallest_pgsz - 1);
129 	do_map(test, range.va, 0, priv->smallest_pgsz);
130 
131 	range = pt_make_range(priv->common, range.va, range.last_va);
132 	ret = pt_walk_range(&range, __check_all_levels, &chk);
133 	KUNIT_ASSERT_EQ(test, ret, 0);
134 }
135 
136 static void test_init(struct kunit *test)
137 {
138 	struct kunit_iommu_priv *priv = test->priv;
139 
140 	/* Fixture does the setup */
141 	KUNIT_ASSERT_NE(test, priv->info.pgsize_bitmap, 0);
142 }
143 
144 /*
145  * Basic check that the log2_* functions are working, especially at the integer
146  * limits.
147  */
148 static void test_bitops(struct kunit *test)
149 {
150 	int i;
151 
152 	KUNIT_ASSERT_EQ(test, fls_t(u32, 0), 0);
153 	KUNIT_ASSERT_EQ(test, fls_t(u32, 1), 1);
154 	KUNIT_ASSERT_EQ(test, fls_t(u32, BIT(2)), 3);
155 	KUNIT_ASSERT_EQ(test, fls_t(u32, U32_MAX), 32);
156 
157 	KUNIT_ASSERT_EQ(test, fls_t(u64, 0), 0);
158 	KUNIT_ASSERT_EQ(test, fls_t(u64, 1), 1);
159 	KUNIT_ASSERT_EQ(test, fls_t(u64, BIT(2)), 3);
160 	KUNIT_ASSERT_EQ(test, fls_t(u64, U64_MAX), 64);
161 
162 	KUNIT_ASSERT_EQ(test, ffs_t(u32, 1), 0);
163 	KUNIT_ASSERT_EQ(test, ffs_t(u32, BIT(2)), 2);
164 	KUNIT_ASSERT_EQ(test, ffs_t(u32, BIT(31)), 31);
165 
166 	KUNIT_ASSERT_EQ(test, ffs_t(u64, 1), 0);
167 	KUNIT_ASSERT_EQ(test, ffs_t(u64, BIT(2)), 2);
168 	KUNIT_ASSERT_EQ(test, ffs_t(u64, BIT_ULL(63)), 63);
169 
170 	for (i = 0; i != 31; i++)
171 		KUNIT_ASSERT_EQ(test, ffz_t(u64, BIT_ULL(i) - 1), i);
172 
173 	for (i = 0; i != 63; i++)
174 		KUNIT_ASSERT_EQ(test, ffz_t(u64, BIT_ULL(i) - 1), i);
175 
176 	for (i = 0; i != 32; i++) {
177 		u64 val = get_random_u64();
178 
179 		KUNIT_ASSERT_EQ(test, log2_mod_t(u32, val, ffs_t(u32, val)), 0);
180 		KUNIT_ASSERT_EQ(test, log2_mod_t(u64, val, ffs_t(u64, val)), 0);
181 
182 		KUNIT_ASSERT_EQ(test, log2_mod_t(u32, val, ffz_t(u32, val)),
183 				log2_to_max_int_t(u32, ffz_t(u32, val)));
184 		KUNIT_ASSERT_EQ(test, log2_mod_t(u64, val, ffz_t(u64, val)),
185 				log2_to_max_int_t(u64, ffz_t(u64, val)));
186 	}
187 }
188 
189 static unsigned int ref_best_pgsize(pt_vaddr_t pgsz_bitmap, pt_vaddr_t va,
190 				    pt_vaddr_t last_va, pt_oaddr_t oa)
191 {
192 	pt_vaddr_t pgsz_lg2;
193 
194 	/* Brute force the constraints described in pt_compute_best_pgsize() */
195 	for (pgsz_lg2 = PT_VADDR_MAX_LG2 - 1; pgsz_lg2 != 0; pgsz_lg2--) {
196 		if ((pgsz_bitmap & log2_to_int(pgsz_lg2)) &&
197 		    log2_mod(va, pgsz_lg2) == 0 &&
198 		    oalog2_mod(oa, pgsz_lg2) == 0 &&
199 		    va + log2_to_int(pgsz_lg2) - 1 <= last_va &&
200 		    log2_div_eq(va, va + log2_to_int(pgsz_lg2) - 1, pgsz_lg2) &&
201 		    oalog2_div_eq(oa, oa + log2_to_int(pgsz_lg2) - 1, pgsz_lg2))
202 			return pgsz_lg2;
203 	}
204 	return 0;
205 }
206 
207 /* Check that the bit logic in pt_compute_best_pgsize() works. */
208 static void test_best_pgsize(struct kunit *test)
209 {
210 	unsigned int a_lg2;
211 	unsigned int b_lg2;
212 	unsigned int c_lg2;
213 
214 	/* Try random prefixes with every suffix combination */
215 	for (a_lg2 = 1; a_lg2 != 10; a_lg2++) {
216 		for (b_lg2 = 1; b_lg2 != 10; b_lg2++) {
217 			for (c_lg2 = 1; c_lg2 != 10; c_lg2++) {
218 				pt_vaddr_t pgsz_bitmap = get_random_u64();
219 				pt_vaddr_t va = get_random_u64() << a_lg2;
220 				pt_oaddr_t oa = get_random_u64() << b_lg2;
221 				pt_vaddr_t last_va = log2_set_mod_max(
222 					get_random_u64(), c_lg2);
223 
224 				if (va > last_va)
225 					swap(va, last_va);
226 				KUNIT_ASSERT_EQ(
227 					test,
228 					pt_compute_best_pgsize(pgsz_bitmap, va,
229 							       last_va, oa),
230 					ref_best_pgsize(pgsz_bitmap, va,
231 							last_va, oa));
232 			}
233 		}
234 	}
235 
236 	/* 0 prefix, every suffix */
237 	for (c_lg2 = 1; c_lg2 != PT_VADDR_MAX_LG2 - 1; c_lg2++) {
238 		pt_vaddr_t pgsz_bitmap = get_random_u64();
239 		pt_vaddr_t va = 0;
240 		pt_oaddr_t oa = 0;
241 		pt_vaddr_t last_va = log2_set_mod_max(0, c_lg2);
242 
243 		KUNIT_ASSERT_EQ(test,
244 				pt_compute_best_pgsize(pgsz_bitmap, va, last_va,
245 						       oa),
246 				ref_best_pgsize(pgsz_bitmap, va, last_va, oa));
247 	}
248 
249 	/* 1's prefix, every suffix */
250 	for (a_lg2 = 1; a_lg2 != 10; a_lg2++) {
251 		for (b_lg2 = 1; b_lg2 != 10; b_lg2++) {
252 			for (c_lg2 = 1; c_lg2 != 10; c_lg2++) {
253 				pt_vaddr_t pgsz_bitmap = get_random_u64();
254 				pt_vaddr_t va = PT_VADDR_MAX << a_lg2;
255 				pt_oaddr_t oa = PT_VADDR_MAX << b_lg2;
256 				pt_vaddr_t last_va = PT_VADDR_MAX;
257 
258 				KUNIT_ASSERT_EQ(
259 					test,
260 					pt_compute_best_pgsize(pgsz_bitmap, va,
261 							       last_va, oa),
262 					ref_best_pgsize(pgsz_bitmap, va,
263 							last_va, oa));
264 			}
265 		}
266 	}
267 
268 	/* pgsize_bitmap is always 0 */
269 	for (a_lg2 = 1; a_lg2 != 10; a_lg2++) {
270 		for (b_lg2 = 1; b_lg2 != 10; b_lg2++) {
271 			for (c_lg2 = 1; c_lg2 != 10; c_lg2++) {
272 				pt_vaddr_t pgsz_bitmap = 0;
273 				pt_vaddr_t va = get_random_u64() << a_lg2;
274 				pt_oaddr_t oa = get_random_u64() << b_lg2;
275 				pt_vaddr_t last_va = log2_set_mod_max(
276 					get_random_u64(), c_lg2);
277 
278 				if (va > last_va)
279 					swap(va, last_va);
280 				KUNIT_ASSERT_EQ(
281 					test,
282 					pt_compute_best_pgsize(pgsz_bitmap, va,
283 							       last_va, oa),
284 					0);
285 			}
286 		}
287 	}
288 
289 	if (sizeof(pt_vaddr_t) <= 4)
290 		return;
291 
292 	/* over 32 bit page sizes */
293 	for (a_lg2 = 32; a_lg2 != 42; a_lg2++) {
294 		for (b_lg2 = 32; b_lg2 != 42; b_lg2++) {
295 			for (c_lg2 = 32; c_lg2 != 42; c_lg2++) {
296 				pt_vaddr_t pgsz_bitmap = get_random_u64();
297 				pt_vaddr_t va = get_random_u64() << a_lg2;
298 				pt_oaddr_t oa = get_random_u64() << b_lg2;
299 				pt_vaddr_t last_va = log2_set_mod_max(
300 					get_random_u64(), c_lg2);
301 
302 				if (va > last_va)
303 					swap(va, last_va);
304 				KUNIT_ASSERT_EQ(
305 					test,
306 					pt_compute_best_pgsize(pgsz_bitmap, va,
307 							       last_va, oa),
308 					ref_best_pgsize(pgsz_bitmap, va,
309 							last_va, oa));
310 			}
311 		}
312 	}
313 }
314 
315 static void test_pgsz_count(struct kunit *test)
316 {
317 	KUNIT_EXPECT_EQ(test,
318 			pt_pgsz_count(SZ_4K, 0, SZ_1G - 1, 0, ilog2(SZ_4K)),
319 			SZ_1G / SZ_4K);
320 	KUNIT_EXPECT_EQ(test,
321 			pt_pgsz_count(SZ_2M | SZ_4K, SZ_4K, SZ_1G - 1, SZ_4K,
322 				      ilog2(SZ_4K)),
323 			(SZ_2M - SZ_4K) / SZ_4K);
324 }
325 
326 /*
327  * Check that pt_install_table() and pt_table_pa() match
328  */
329 static void test_lvl_table_ptr(struct kunit *test, struct pt_state *pts,
330 			       void *arg)
331 {
332 	struct kunit_iommu_priv *priv = test->priv;
333 	pt_oaddr_t paddr =
334 		log2_set_mod(priv->test_oa, 0, priv->smallest_pgsz_lg2);
335 	struct pt_write_attrs attrs = {};
336 
337 	if (!pt_can_have_table(pts))
338 		return;
339 
340 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
341 				 pt_iommu_set_prot(pts->range->common, &attrs,
342 						   IOMMU_READ));
343 
344 	pt_load_single_entry(pts);
345 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
346 
347 	KUNIT_ASSERT_TRUE(test, pt_install_table(pts, paddr, &attrs));
348 
349 	/* A second install should pass because install updates pts->entry. */
350 	KUNIT_ASSERT_EQ(test, pt_install_table(pts, paddr, &attrs), true);
351 
352 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_TABLE);
353 	KUNIT_ASSERT_EQ(test, pt_table_pa(pts), paddr);
354 
355 	pt_clear_entries(pts, ilog2(1));
356 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
357 }
358 
359 static void test_table_ptr(struct kunit *test)
360 {
361 	check_all_levels(test, test_lvl_table_ptr, NULL);
362 }
363 
364 struct lvl_radix_arg {
365 	pt_vaddr_t vbits;
366 };
367 
368 /*
369  * Check pt_table_oa_lg2sz() and pt_table_item_lg2sz() they need to decode a
370  * continuous list of VA across all the levels that covers the entire advertised
371  * VA space.
372  */
373 static void test_lvl_radix(struct kunit *test, struct pt_state *pts, void *arg)
374 {
375 	unsigned int table_lg2sz = pt_table_oa_lg2sz(pts);
376 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
377 	struct lvl_radix_arg *radix = arg;
378 
379 	/* Every bit below us is decoded */
380 	KUNIT_ASSERT_EQ(test, log2_set_mod_max(0, isz_lg2), radix->vbits);
381 
382 	/* We are not decoding bits someone else is */
383 	KUNIT_ASSERT_EQ(test, log2_div(radix->vbits, isz_lg2), 0);
384 
385 	/* Can't decode past the pt_vaddr_t size */
386 	KUNIT_ASSERT_LE(test, table_lg2sz, PT_VADDR_MAX_LG2);
387 	KUNIT_ASSERT_EQ(test, fvalog2_div(table_lg2sz, PT_MAX_VA_ADDRESS_LG2),
388 			0);
389 
390 	radix->vbits = fvalog2_set_mod_max(0, table_lg2sz);
391 }
392 
393 static void test_max_va(struct kunit *test)
394 {
395 	struct kunit_iommu_priv *priv = test->priv;
396 	struct pt_range range = pt_top_range(priv->common);
397 
398 	KUNIT_ASSERT_GE(test, priv->common->max_vasz_lg2, range.max_vasz_lg2);
399 }
400 
401 static void test_table_radix(struct kunit *test)
402 {
403 	struct kunit_iommu_priv *priv = test->priv;
404 	struct lvl_radix_arg radix = { .vbits = priv->smallest_pgsz - 1 };
405 	struct pt_range range;
406 
407 	check_all_levels(test, test_lvl_radix, &radix);
408 
409 	range = pt_top_range(priv->common);
410 	if (range.max_vasz_lg2 == PT_VADDR_MAX_LG2) {
411 		KUNIT_ASSERT_EQ(test, radix.vbits, PT_VADDR_MAX);
412 	} else {
413 		if (!IS_32BIT)
414 			KUNIT_ASSERT_EQ(test,
415 					log2_set_mod_max(0, range.max_vasz_lg2),
416 					radix.vbits);
417 		KUNIT_ASSERT_EQ(test, log2_div(radix.vbits, range.max_vasz_lg2),
418 				0);
419 	}
420 }
421 
422 static unsigned int safe_pt_num_items_lg2(const struct pt_state *pts)
423 {
424 	struct pt_range top_range = pt_top_range(pts->range->common);
425 	struct pt_state top_pts = pt_init_top(&top_range);
426 
427 	/*
428 	 * Avoid calling pt_num_items_lg2() on the top, instead we can derive
429 	 * the size of the top table from the top range.
430 	 */
431 	if (pts->level == top_range.top_level)
432 		return ilog2(pt_range_to_end_index(&top_pts));
433 	return pt_num_items_lg2(pts);
434 }
435 
436 static void test_lvl_possible_sizes(struct kunit *test, struct pt_state *pts,
437 				    void *arg)
438 {
439 	unsigned int num_items_lg2 = safe_pt_num_items_lg2(pts);
440 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
441 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
442 
443 	if (!pt_can_have_leaf(pts)) {
444 		KUNIT_ASSERT_EQ(test, pgsize_bitmap, 0);
445 		return;
446 	}
447 
448 	/* No bits for sizes that would be outside this table */
449 	KUNIT_ASSERT_EQ(test, log2_mod(pgsize_bitmap, isz_lg2), 0);
450 	KUNIT_ASSERT_EQ(
451 		test, fvalog2_div(pgsize_bitmap, num_items_lg2 + isz_lg2), 0);
452 
453 	/*
454 	 * Non contiguous must be supported. AMDv1 has a HW bug where it does
455 	 * not support it on one of the levels.
456 	 */
457 	if ((u64)pgsize_bitmap != 0xff0000000000ULL ||
458 	    strcmp(__stringify(PTPFX_RAW), "amdv1") != 0)
459 		KUNIT_ASSERT_TRUE(test, pgsize_bitmap & log2_to_int(isz_lg2));
460 	else
461 		KUNIT_ASSERT_NE(test, pgsize_bitmap, 0);
462 
463 	/* A contiguous entry should not span the whole table */
464 	if (num_items_lg2 + isz_lg2 != PT_VADDR_MAX_LG2)
465 		KUNIT_ASSERT_FALSE(
466 			test,
467 			pgsize_bitmap & log2_to_int(num_items_lg2 + isz_lg2));
468 }
469 
470 static void test_entry_possible_sizes(struct kunit *test)
471 {
472 	check_all_levels(test, test_lvl_possible_sizes, NULL);
473 }
474 
475 static void sweep_all_pgsizes(struct kunit *test, struct pt_state *pts,
476 			      struct pt_write_attrs *attrs,
477 			      pt_oaddr_t test_oaddr)
478 {
479 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
480 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
481 	unsigned int len_lg2;
482 
483 	if (pts->index != 0)
484 		return;
485 
486 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2 - 1; len_lg2++) {
487 		struct pt_state sub_pts = *pts;
488 		pt_oaddr_t oaddr;
489 
490 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
491 			continue;
492 
493 		oaddr = log2_set_mod(test_oaddr, 0, len_lg2);
494 		pt_install_leaf_entry(pts, oaddr, len_lg2, attrs);
495 		/* Verify that every contiguous item translates correctly */
496 		for (sub_pts.index = 0;
497 		     sub_pts.index != log2_to_int(len_lg2 - isz_lg2);
498 		     sub_pts.index++) {
499 			KUNIT_ASSERT_PT_LOAD(test, &sub_pts, PT_ENTRY_OA);
500 			KUNIT_ASSERT_EQ(test, pt_item_oa(&sub_pts),
501 					oaddr + sub_pts.index *
502 							oalog2_mul(1, isz_lg2));
503 			KUNIT_ASSERT_EQ(test, pt_entry_oa(&sub_pts), oaddr);
504 			KUNIT_ASSERT_EQ(test, pt_entry_num_contig_lg2(&sub_pts),
505 					len_lg2 - isz_lg2);
506 		}
507 
508 		pt_clear_entries(pts, len_lg2 - isz_lg2);
509 		KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
510 	}
511 }
512 
513 /*
514  * Check that pt_install_leaf_entry() and pt_entry_oa() match.
515  * Check that pt_clear_entries() works.
516  */
517 static void test_lvl_entry_oa(struct kunit *test, struct pt_state *pts,
518 			      void *arg)
519 {
520 	unsigned int max_oa_lg2 = pts->range->common->max_oasz_lg2;
521 	struct kunit_iommu_priv *priv = test->priv;
522 	struct pt_write_attrs attrs = {};
523 
524 	if (!pt_can_have_leaf(pts))
525 		return;
526 
527 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
528 				 pt_iommu_set_prot(pts->range->common, &attrs,
529 						   IOMMU_READ));
530 
531 	sweep_all_pgsizes(test, pts, &attrs, priv->test_oa);
532 
533 	/* Check that the table can store the boundary OAs */
534 	sweep_all_pgsizes(test, pts, &attrs, 0);
535 	if (max_oa_lg2 == PT_OADDR_MAX_LG2)
536 		sweep_all_pgsizes(test, pts, &attrs, PT_OADDR_MAX);
537 	else
538 		sweep_all_pgsizes(test, pts, &attrs,
539 				  oalog2_to_max_int(max_oa_lg2));
540 }
541 
542 static void test_entry_oa(struct kunit *test)
543 {
544 	check_all_levels(test, test_lvl_entry_oa, NULL);
545 }
546 
547 /* Test pt_attr_from_entry() */
548 static void test_lvl_attr_from_entry(struct kunit *test, struct pt_state *pts,
549 				     void *arg)
550 {
551 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
552 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
553 	struct kunit_iommu_priv *priv = test->priv;
554 	unsigned int len_lg2;
555 	unsigned int prot;
556 
557 	if (!pt_can_have_leaf(pts))
558 		return;
559 
560 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2; len_lg2++) {
561 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
562 			continue;
563 		for (prot = 0; prot <= (IOMMU_READ | IOMMU_WRITE | IOMMU_CACHE |
564 					IOMMU_NOEXEC | IOMMU_MMIO);
565 		     prot++) {
566 			pt_oaddr_t oaddr;
567 			struct pt_write_attrs attrs = {};
568 			u64 good_entry;
569 
570 			/*
571 			 * If the format doesn't support this combination of
572 			 * prot bits skip it
573 			 */
574 			if (pt_iommu_set_prot(pts->range->common, &attrs,
575 					      prot)) {
576 				/* But RW has to be supported */
577 				KUNIT_ASSERT_NE(test, prot,
578 						IOMMU_READ | IOMMU_WRITE);
579 				continue;
580 			}
581 
582 			oaddr = log2_set_mod(priv->test_oa, 0, len_lg2);
583 			pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs);
584 			KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA);
585 
586 			good_entry = pts->entry;
587 
588 			memset(&attrs, 0, sizeof(attrs));
589 			pt_attr_from_entry(pts, &attrs);
590 
591 			pt_clear_entries(pts, len_lg2 - isz_lg2);
592 			KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
593 
594 			pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs);
595 			KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA);
596 
597 			/*
598 			 * The descriptor produced by pt_attr_from_entry()
599 			 * produce an identical entry value when re-written
600 			 */
601 			KUNIT_ASSERT_EQ(test, good_entry, pts->entry);
602 
603 			pt_clear_entries(pts, len_lg2 - isz_lg2);
604 		}
605 	}
606 }
607 
608 static void test_attr_from_entry(struct kunit *test)
609 {
610 	check_all_levels(test, test_lvl_attr_from_entry, NULL);
611 }
612 
613 static void test_lvl_dirty(struct kunit *test, struct pt_state *pts, void *arg)
614 {
615 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
616 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
617 	struct kunit_iommu_priv *priv = test->priv;
618 	unsigned int start_idx = pts->index;
619 	struct pt_write_attrs attrs = {};
620 	unsigned int len_lg2;
621 
622 	if (!pt_can_have_leaf(pts))
623 		return;
624 
625 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
626 				 pt_iommu_set_prot(pts->range->common, &attrs,
627 						   IOMMU_READ | IOMMU_WRITE));
628 
629 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2; len_lg2++) {
630 		pt_oaddr_t oaddr;
631 		unsigned int i;
632 
633 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
634 			continue;
635 
636 		oaddr = log2_set_mod(priv->test_oa, 0, len_lg2);
637 		pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs);
638 		KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA);
639 
640 		pt_load_entry(pts);
641 		pt_entry_make_write_clean(pts);
642 		pt_load_entry(pts);
643 		KUNIT_ASSERT_FALSE(test, pt_entry_is_write_dirty(pts));
644 
645 		for (i = 0; i != log2_to_int(len_lg2 - isz_lg2); i++) {
646 			/* dirty every contiguous entry */
647 			pts->index = start_idx + i;
648 			pt_load_entry(pts);
649 			KUNIT_ASSERT_TRUE(test, pt_entry_make_write_dirty(pts));
650 			pts->index = start_idx;
651 			pt_load_entry(pts);
652 			KUNIT_ASSERT_TRUE(test, pt_entry_is_write_dirty(pts));
653 
654 			pt_entry_make_write_clean(pts);
655 			pt_load_entry(pts);
656 			KUNIT_ASSERT_FALSE(test, pt_entry_is_write_dirty(pts));
657 		}
658 
659 		pt_clear_entries(pts, len_lg2 - isz_lg2);
660 	}
661 }
662 
663 static __maybe_unused void test_dirty(struct kunit *test)
664 {
665 	struct kunit_iommu_priv *priv = test->priv;
666 
667 	if (!pt_dirty_supported(priv->common))
668 		kunit_skip(test,
669 			   "Page table features do not support dirty tracking");
670 
671 	check_all_levels(test, test_lvl_dirty, NULL);
672 }
673 
674 static void test_lvl_sw_bit_leaf(struct kunit *test, struct pt_state *pts,
675 				 void *arg)
676 {
677 	struct kunit_iommu_priv *priv = test->priv;
678 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
679 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
680 	struct pt_write_attrs attrs = {};
681 	unsigned int len_lg2;
682 
683 	if (!pt_can_have_leaf(pts))
684 		return;
685 	if (pts->index != 0)
686 		return;
687 
688 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
689 				 pt_iommu_set_prot(pts->range->common, &attrs,
690 						   IOMMU_READ));
691 
692 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2 - 1; len_lg2++) {
693 		pt_oaddr_t paddr = log2_set_mod(priv->test_oa, 0, len_lg2);
694 		struct pt_write_attrs new_attrs = {};
695 		unsigned int bitnr;
696 
697 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
698 			continue;
699 
700 		pt_install_leaf_entry(pts, paddr, len_lg2, &attrs);
701 
702 		for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common);
703 		     bitnr++)
704 			KUNIT_ASSERT_FALSE(test,
705 					   pt_test_sw_bit_acquire(pts, bitnr));
706 
707 		for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common);
708 		     bitnr++) {
709 			KUNIT_ASSERT_FALSE(test,
710 					   pt_test_sw_bit_acquire(pts, bitnr));
711 			pt_set_sw_bit_release(pts, bitnr);
712 			KUNIT_ASSERT_TRUE(test,
713 					  pt_test_sw_bit_acquire(pts, bitnr));
714 		}
715 
716 		for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common);
717 		     bitnr++)
718 			KUNIT_ASSERT_TRUE(test,
719 					  pt_test_sw_bit_acquire(pts, bitnr));
720 
721 		KUNIT_ASSERT_EQ(test, pt_item_oa(pts), paddr);
722 
723 		/* SW bits didn't leak into the attrs */
724 		pt_attr_from_entry(pts, &new_attrs);
725 		KUNIT_ASSERT_MEMEQ(test, &new_attrs, &attrs, sizeof(attrs));
726 
727 		pt_clear_entries(pts, len_lg2 - isz_lg2);
728 		KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
729 	}
730 }
731 
732 static __maybe_unused void test_sw_bit_leaf(struct kunit *test)
733 {
734 	check_all_levels(test, test_lvl_sw_bit_leaf, NULL);
735 }
736 
737 static void test_lvl_sw_bit_table(struct kunit *test, struct pt_state *pts,
738 				  void *arg)
739 {
740 	struct kunit_iommu_priv *priv = test->priv;
741 	struct pt_write_attrs attrs = {};
742 	pt_oaddr_t paddr =
743 		log2_set_mod(priv->test_oa, 0, priv->smallest_pgsz_lg2);
744 	unsigned int bitnr;
745 
746 	if (!pt_can_have_leaf(pts))
747 		return;
748 	if (pts->index != 0)
749 		return;
750 
751 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
752 				 pt_iommu_set_prot(pts->range->common, &attrs,
753 						   IOMMU_READ));
754 
755 	KUNIT_ASSERT_TRUE(test, pt_install_table(pts, paddr, &attrs));
756 
757 	for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++)
758 		KUNIT_ASSERT_FALSE(test, pt_test_sw_bit_acquire(pts, bitnr));
759 
760 	for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++) {
761 		KUNIT_ASSERT_FALSE(test, pt_test_sw_bit_acquire(pts, bitnr));
762 		pt_set_sw_bit_release(pts, bitnr);
763 		KUNIT_ASSERT_TRUE(test, pt_test_sw_bit_acquire(pts, bitnr));
764 	}
765 
766 	for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++)
767 		KUNIT_ASSERT_TRUE(test, pt_test_sw_bit_acquire(pts, bitnr));
768 
769 	KUNIT_ASSERT_EQ(test, pt_table_pa(pts), paddr);
770 
771 	pt_clear_entries(pts, ilog2(1));
772 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
773 }
774 
775 static __maybe_unused void test_sw_bit_table(struct kunit *test)
776 {
777 	check_all_levels(test, test_lvl_sw_bit_table, NULL);
778 }
779 
780 static struct kunit_case generic_pt_test_cases[] = {
781 	KUNIT_CASE_FMT(test_init),
782 	KUNIT_CASE_FMT(test_bitops),
783 	KUNIT_CASE_FMT(test_best_pgsize),
784 	KUNIT_CASE_FMT(test_pgsz_count),
785 	KUNIT_CASE_FMT(test_table_ptr),
786 	KUNIT_CASE_FMT(test_max_va),
787 	KUNIT_CASE_FMT(test_table_radix),
788 	KUNIT_CASE_FMT(test_entry_possible_sizes),
789 	KUNIT_CASE_FMT(test_entry_oa),
790 	KUNIT_CASE_FMT(test_attr_from_entry),
791 #ifdef pt_entry_is_write_dirty
792 	KUNIT_CASE_FMT(test_dirty),
793 #endif
794 #ifdef pt_sw_bit
795 	KUNIT_CASE_FMT(test_sw_bit_leaf),
796 	KUNIT_CASE_FMT(test_sw_bit_table),
797 #endif
798 	{},
799 };
800 
801 static int pt_kunit_generic_pt_init(struct kunit *test)
802 {
803 	struct kunit_iommu_priv *priv;
804 	int ret;
805 
806 	priv = kunit_kzalloc(test, sizeof(*priv), GFP_KERNEL);
807 	if (!priv)
808 		return -ENOMEM;
809 	ret = pt_kunit_priv_init(test, priv);
810 	if (ret) {
811 		kunit_kfree(test, priv);
812 		return ret;
813 	}
814 	test->priv = priv;
815 	return 0;
816 }
817 
818 static void pt_kunit_generic_pt_exit(struct kunit *test)
819 {
820 	struct kunit_iommu_priv *priv = test->priv;
821 
822 	if (!test->priv)
823 		return;
824 
825 	pt_iommu_deinit(priv->iommu);
826 	kunit_kfree(test, test->priv);
827 }
828 
829 static struct kunit_suite NS(generic_pt_suite) = {
830 	.name = __stringify(NS(fmt_test)),
831 	.init = pt_kunit_generic_pt_init,
832 	.exit = pt_kunit_generic_pt_exit,
833 	.test_cases = generic_pt_test_cases,
834 };
835 kunit_test_suites(&NS(generic_pt_suite));
836