xref: /linux/drivers/iommu/generic_pt/kunit_generic_pt.h (revision bba2c3615bd6cfee7456d1130f2e6b01b3f4e9ba)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
4  *
5  * Test the format API directly.
6  *
7  */
8 #include "kunit_iommu.h"
9 #include "pt_iter.h"
10 
11 static void do_map(struct kunit *test, pt_vaddr_t va, pt_oaddr_t pa,
12 		   pt_vaddr_t len)
13 {
14 	struct kunit_iommu_priv *priv = test->priv;
15 	int ret;
16 
17 	KUNIT_ASSERT_EQ(test, len, (size_t)len);
18 
19 	ret = iommu_map(&priv->domain, va, pa, len, IOMMU_READ | IOMMU_WRITE,
20 			GFP_KERNEL);
21 	KUNIT_ASSERT_NO_ERRNO_FN(test, "map_pages", ret);
22 }
23 
24 #define KUNIT_ASSERT_PT_LOAD(test, pts, entry)             \
25 	({                                                 \
26 		pt_load_entry(pts);                        \
27 		KUNIT_ASSERT_EQ(test, (pts)->type, entry); \
28 	})
29 
30 struct check_levels_arg {
31 	struct kunit *test;
32 	void *fn_arg;
33 	void (*fn)(struct kunit *test, struct pt_state *pts, void *arg);
34 };
35 
36 static int __check_all_levels(struct pt_range *range, void *arg,
37 			      unsigned int level, struct pt_table_p *table)
38 {
39 	struct pt_state pts = pt_init(range, level, table);
40 	struct check_levels_arg *chk = arg;
41 	struct kunit *test = chk->test;
42 	int ret;
43 
44 	_pt_iter_first(&pts);
45 
46 
47 	/*
48 	 * If we were able to use the full VA space this should always be the
49 	 * last index in each table.
50 	 */
51 	if (!(IS_32BIT && range->max_vasz_lg2 > 32)) {
52 		if (pt_feature(range->common, PT_FEAT_SIGN_EXTEND) &&
53 		    pts.level == pts.range->top_level)
54 			KUNIT_ASSERT_EQ(test, pts.index,
55 					log2_to_int(range->max_vasz_lg2 - 1 -
56 						    pt_table_item_lg2sz(&pts)) -
57 						1);
58 		else
59 			KUNIT_ASSERT_EQ(test, pts.index,
60 					log2_to_int(pt_table_oa_lg2sz(&pts) -
61 						    pt_table_item_lg2sz(&pts)) -
62 						1);
63 	}
64 
65 	if (pt_can_have_table(&pts)) {
66 		pt_load_single_entry(&pts);
67 		KUNIT_ASSERT_EQ(test, pts.type, PT_ENTRY_TABLE);
68 		ret = pt_descend(&pts, arg, __check_all_levels);
69 		KUNIT_ASSERT_EQ(test, ret, 0);
70 
71 		/* Index 0 is used by the test */
72 		if (IS_32BIT && !pts.index)
73 			return 0;
74 		KUNIT_ASSERT_NE(chk->test, pts.index, 0);
75 	}
76 
77 	/*
78 	 * A format should not create a table with only one entry, at least this
79 	 * test approach won't work.
80 	 */
81 	KUNIT_ASSERT_GT(chk->test, pts.end_index, 1);
82 
83 	/*
84 	 * For increase top we end up using index 0 for the original top's tree,
85 	 * so use index 1 for testing instead.
86 	 */
87 	pts.index = 0;
88 	pt_index_to_va(&pts);
89 	pt_load_single_entry(&pts);
90 	if (pts.type == PT_ENTRY_TABLE && pts.end_index > 2) {
91 		pts.index = 1;
92 		pt_index_to_va(&pts);
93 	}
94 	(*chk->fn)(chk->test, &pts, chk->fn_arg);
95 	return 0;
96 }
97 
98 /*
99  * Call fn for each level in the table with a pts setup to index 0 in a table
100  * for that level. This allows writing tests that run on every level.
101  * The test can use every index in the table except the last one.
102  */
103 static void check_all_levels(struct kunit *test,
104 			     void (*fn)(struct kunit *test,
105 					struct pt_state *pts, void *arg),
106 			     void *fn_arg)
107 {
108 	struct kunit_iommu_priv *priv = test->priv;
109 	struct pt_range range = pt_top_range(priv->common);
110 	struct check_levels_arg chk = {
111 		.test = test,
112 		.fn = fn,
113 		.fn_arg = fn_arg,
114 	};
115 	int ret;
116 
117 	if (pt_feature(priv->common, PT_FEAT_DYNAMIC_TOP) &&
118 	    priv->common->max_vasz_lg2 > range.max_vasz_lg2)
119 		range.last_va = fvalog2_set_mod_max(range.va,
120 						    priv->common->max_vasz_lg2);
121 
122 	/*
123 	 * Map a page at the highest VA, this will populate all the levels so we
124 	 * can then iterate over them. Index 0 will be used for testing.
125 	 */
126 	if (IS_32BIT && range.max_vasz_lg2 > 32)
127 		range.last_va = (u32)range.last_va;
128 	range.va = range.last_va - (priv->smallest_pgsz - 1);
129 	do_map(test, range.va, 0, priv->smallest_pgsz);
130 
131 	range = pt_make_range(priv->common, range.va, range.last_va);
132 	ret = pt_walk_range(&range, __check_all_levels, &chk);
133 	KUNIT_ASSERT_EQ(test, ret, 0);
134 }
135 
136 static void test_init(struct kunit *test)
137 {
138 	struct kunit_iommu_priv *priv = test->priv;
139 
140 	/* Fixture does the setup */
141 	KUNIT_ASSERT_NE(test, priv->info.pgsize_bitmap, 0);
142 }
143 
144 /*
145  * Basic check that the log2_* functions are working, especially at the integer
146  * limits.
147  */
148 static void test_bitops(struct kunit *test)
149 {
150 	int i;
151 
152 	KUNIT_ASSERT_EQ(test, fls_t(u32, 0), 0);
153 	KUNIT_ASSERT_EQ(test, fls_t(u32, 1), 1);
154 	KUNIT_ASSERT_EQ(test, fls_t(u32, BIT(2)), 3);
155 	KUNIT_ASSERT_EQ(test, fls_t(u32, U32_MAX), 32);
156 
157 	KUNIT_ASSERT_EQ(test, fls_t(u64, 0), 0);
158 	KUNIT_ASSERT_EQ(test, fls_t(u64, 1), 1);
159 	KUNIT_ASSERT_EQ(test, fls_t(u64, BIT(2)), 3);
160 	KUNIT_ASSERT_EQ(test, fls_t(u64, U64_MAX), 64);
161 
162 	KUNIT_ASSERT_EQ(test, ffs_t(u32, 1), 0);
163 	KUNIT_ASSERT_EQ(test, ffs_t(u32, BIT(2)), 2);
164 	KUNIT_ASSERT_EQ(test, ffs_t(u32, BIT(31)), 31);
165 
166 	KUNIT_ASSERT_EQ(test, ffs_t(u64, 1), 0);
167 	KUNIT_ASSERT_EQ(test, ffs_t(u64, BIT(2)), 2);
168 	KUNIT_ASSERT_EQ(test, ffs_t(u64, BIT_ULL(63)), 63);
169 
170 	for (i = 0; i != 31; i++)
171 		KUNIT_ASSERT_EQ(test, ffz_t(u64, BIT_ULL(i) - 1), i);
172 
173 	for (i = 0; i != 63; i++)
174 		KUNIT_ASSERT_EQ(test, ffz_t(u64, BIT_ULL(i) - 1), i);
175 
176 	for (i = 0; i != 32; i++) {
177 		u64 val = get_random_u64();
178 
179 		KUNIT_ASSERT_EQ(test, log2_mod_t(u32, val, ffs_t(u32, val)), 0);
180 		KUNIT_ASSERT_EQ(test, log2_mod_t(u64, val, ffs_t(u64, val)), 0);
181 
182 		KUNIT_ASSERT_EQ(test, log2_mod_t(u32, val, ffz_t(u32, val)),
183 				log2_to_max_int_t(u32, ffz_t(u32, val)));
184 		KUNIT_ASSERT_EQ(test, log2_mod_t(u64, val, ffz_t(u64, val)),
185 				log2_to_max_int_t(u64, ffz_t(u64, val)));
186 	}
187 }
188 
189 static unsigned int ref_best_pgsize(pt_vaddr_t pgsz_bitmap, pt_vaddr_t va,
190 				    pt_vaddr_t last_va, pt_oaddr_t oa)
191 {
192 	pt_vaddr_t pgsz_lg2;
193 
194 	/* Brute force the constraints described in pt_compute_best_pgsize() */
195 	for (pgsz_lg2 = PT_VADDR_MAX_LG2 - 1; pgsz_lg2 != 0; pgsz_lg2--) {
196 		if ((pgsz_bitmap & log2_to_int(pgsz_lg2)) &&
197 		    log2_mod(va, pgsz_lg2) == 0 &&
198 		    oalog2_mod(oa, pgsz_lg2) == 0 &&
199 		    va + log2_to_int(pgsz_lg2) - 1 <= last_va &&
200 		    log2_div_eq(va, va + log2_to_int(pgsz_lg2) - 1, pgsz_lg2) &&
201 		    oalog2_div_eq(oa, oa + log2_to_int(pgsz_lg2) - 1, pgsz_lg2))
202 			return pgsz_lg2;
203 	}
204 	return 0;
205 }
206 
207 /* Check that the bit logic in pt_compute_best_pgsize() works. */
208 static void test_best_pgsize(struct kunit *test)
209 {
210 	unsigned int a_lg2;
211 	unsigned int b_lg2;
212 	unsigned int c_lg2;
213 
214 	/* Try random prefixes with every suffix combination */
215 	for (a_lg2 = 1; a_lg2 != 10; a_lg2++) {
216 		for (b_lg2 = 1; b_lg2 != 10; b_lg2++) {
217 			for (c_lg2 = 1; c_lg2 != 10; c_lg2++) {
218 				pt_vaddr_t pgsz_bitmap = get_random_u64();
219 				pt_vaddr_t va = get_random_u64() << a_lg2;
220 				pt_oaddr_t oa = get_random_u64() << b_lg2;
221 				pt_vaddr_t last_va = log2_set_mod_max(
222 					get_random_u64(), c_lg2);
223 
224 				if (va > last_va)
225 					swap(va, last_va);
226 				KUNIT_ASSERT_EQ(
227 					test,
228 					pt_compute_best_pgsize(pgsz_bitmap, va,
229 							       last_va, oa),
230 					ref_best_pgsize(pgsz_bitmap, va,
231 							last_va, oa));
232 			}
233 		}
234 	}
235 
236 	/* 0 prefix, every suffix */
237 	for (c_lg2 = 1; c_lg2 != PT_VADDR_MAX_LG2 - 1; c_lg2++) {
238 		pt_vaddr_t pgsz_bitmap = get_random_u64();
239 		pt_vaddr_t va = 0;
240 		pt_oaddr_t oa = 0;
241 		pt_vaddr_t last_va = log2_set_mod_max(0, c_lg2);
242 
243 		KUNIT_ASSERT_EQ(test,
244 				pt_compute_best_pgsize(pgsz_bitmap, va, last_va,
245 						       oa),
246 				ref_best_pgsize(pgsz_bitmap, va, last_va, oa));
247 	}
248 
249 	/* 1's prefix, every suffix */
250 	for (a_lg2 = 1; a_lg2 != 10; a_lg2++) {
251 		for (b_lg2 = 1; b_lg2 != 10; b_lg2++) {
252 			for (c_lg2 = 1; c_lg2 != 10; c_lg2++) {
253 				pt_vaddr_t pgsz_bitmap = get_random_u64();
254 				pt_vaddr_t va = PT_VADDR_MAX << a_lg2;
255 				pt_oaddr_t oa = PT_VADDR_MAX << b_lg2;
256 				pt_vaddr_t last_va = PT_VADDR_MAX;
257 
258 				KUNIT_ASSERT_EQ(
259 					test,
260 					pt_compute_best_pgsize(pgsz_bitmap, va,
261 							       last_va, oa),
262 					ref_best_pgsize(pgsz_bitmap, va,
263 							last_va, oa));
264 			}
265 		}
266 	}
267 
268 	/* pgsize_bitmap is always 0 */
269 	for (a_lg2 = 1; a_lg2 != 10; a_lg2++) {
270 		for (b_lg2 = 1; b_lg2 != 10; b_lg2++) {
271 			for (c_lg2 = 1; c_lg2 != 10; c_lg2++) {
272 				pt_vaddr_t pgsz_bitmap = 0;
273 				pt_vaddr_t va = get_random_u64() << a_lg2;
274 				pt_oaddr_t oa = get_random_u64() << b_lg2;
275 				pt_vaddr_t last_va = log2_set_mod_max(
276 					get_random_u64(), c_lg2);
277 
278 				if (va > last_va)
279 					swap(va, last_va);
280 				KUNIT_ASSERT_EQ(
281 					test,
282 					pt_compute_best_pgsize(pgsz_bitmap, va,
283 							       last_va, oa),
284 					0);
285 			}
286 		}
287 	}
288 
289 	if (sizeof(pt_vaddr_t) <= 4)
290 		return;
291 
292 	/* over 32 bit page sizes */
293 	for (a_lg2 = 32; a_lg2 != 42; a_lg2++) {
294 		for (b_lg2 = 32; b_lg2 != 42; b_lg2++) {
295 			for (c_lg2 = 32; c_lg2 != 42; c_lg2++) {
296 				pt_vaddr_t pgsz_bitmap = get_random_u64();
297 				pt_vaddr_t va = get_random_u64() << a_lg2;
298 				pt_oaddr_t oa = get_random_u64() << b_lg2;
299 				pt_vaddr_t last_va = log2_set_mod_max(
300 					get_random_u64(), c_lg2);
301 
302 				if (va > last_va)
303 					swap(va, last_va);
304 				KUNIT_ASSERT_EQ(
305 					test,
306 					pt_compute_best_pgsize(pgsz_bitmap, va,
307 							       last_va, oa),
308 					ref_best_pgsize(pgsz_bitmap, va,
309 							last_va, oa));
310 			}
311 		}
312 	}
313 }
314 
315 static void test_pgsz_count(struct kunit *test)
316 {
317 	KUNIT_EXPECT_EQ(test,
318 			pt_pgsz_count(SZ_4K, 0, SZ_1G - 1, 0, ilog2(SZ_4K)),
319 			SZ_1G / SZ_4K);
320 	KUNIT_EXPECT_EQ(test,
321 			pt_pgsz_count(SZ_2M | SZ_4K, SZ_4K, SZ_1G - 1, SZ_4K,
322 				      ilog2(SZ_4K)),
323 			(SZ_2M - SZ_4K) / SZ_4K);
324 }
325 
326 /*
327  * Check that pt_install_table() and pt_table_pa() match
328  */
329 static void test_lvl_table_ptr(struct kunit *test, struct pt_state *pts,
330 			       void *arg)
331 {
332 	struct kunit_iommu_priv *priv = test->priv;
333 	pt_oaddr_t paddr =
334 		log2_set_mod(priv->test_oa, 0, priv->smallest_pgsz_lg2);
335 	struct pt_write_attrs attrs = {};
336 
337 	if (!pt_can_have_table(pts))
338 		return;
339 
340 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
341 				 pt_iommu_set_prot(pts->range->common, &attrs,
342 						   IOMMU_READ));
343 
344 	pt_load_single_entry(pts);
345 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
346 
347 	KUNIT_ASSERT_TRUE(test, pt_install_table(pts, paddr, &attrs));
348 
349 	/* A second install should pass because install updates pts->entry. */
350 	KUNIT_ASSERT_EQ(test, pt_install_table(pts, paddr, &attrs), true);
351 
352 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_TABLE);
353 	KUNIT_ASSERT_EQ(test, pt_table_pa(pts), paddr);
354 
355 	pt_clear_entries(pts, ilog2(1));
356 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
357 }
358 
359 static void test_table_ptr(struct kunit *test)
360 {
361 	check_all_levels(test, test_lvl_table_ptr, NULL);
362 }
363 
364 struct lvl_radix_arg {
365 	pt_vaddr_t vbits;
366 };
367 
368 /*
369  * Check pt_table_oa_lg2sz() and pt_table_item_lg2sz() they need to decode a
370  * continuous list of VA across all the levels that covers the entire advertised
371  * VA space.
372  */
373 static void test_lvl_radix(struct kunit *test, struct pt_state *pts, void *arg)
374 {
375 	unsigned int table_lg2sz = pt_table_oa_lg2sz(pts);
376 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
377 	struct lvl_radix_arg *radix = arg;
378 
379 	/* Every bit below us is decoded */
380 	KUNIT_ASSERT_EQ(test, log2_set_mod_max(0, isz_lg2), radix->vbits);
381 
382 	/* We are not decoding bits someone else is */
383 	KUNIT_ASSERT_EQ(test, log2_div(radix->vbits, isz_lg2), 0);
384 
385 	/* Can't decode past the pt_vaddr_t size */
386 	KUNIT_ASSERT_LE(test, table_lg2sz, PT_VADDR_MAX_LG2);
387 	KUNIT_ASSERT_EQ(test, fvalog2_div(table_lg2sz, PT_MAX_VA_ADDRESS_LG2),
388 			0);
389 
390 	radix->vbits = fvalog2_set_mod_max(0, table_lg2sz);
391 }
392 
393 static void test_max_va(struct kunit *test)
394 {
395 	struct kunit_iommu_priv *priv = test->priv;
396 	struct pt_range range = pt_top_range(priv->common);
397 
398 	KUNIT_ASSERT_GE(test, priv->common->max_vasz_lg2, range.max_vasz_lg2);
399 }
400 
401 static void test_table_radix(struct kunit *test)
402 {
403 	struct kunit_iommu_priv *priv = test->priv;
404 	struct lvl_radix_arg radix = { .vbits = priv->smallest_pgsz - 1 };
405 	struct pt_range range;
406 
407 	check_all_levels(test, test_lvl_radix, &radix);
408 
409 	range = pt_top_range(priv->common);
410 	if (range.max_vasz_lg2 == PT_VADDR_MAX_LG2) {
411 		KUNIT_ASSERT_EQ(test, radix.vbits, PT_VADDR_MAX);
412 	} else {
413 		if (!IS_32BIT)
414 			KUNIT_ASSERT_EQ(test,
415 					log2_set_mod_max(0, range.max_vasz_lg2),
416 					radix.vbits);
417 		KUNIT_ASSERT_EQ(test, log2_div(radix.vbits, range.max_vasz_lg2),
418 				0);
419 	}
420 }
421 
422 static unsigned int safe_pt_num_items_lg2(const struct pt_state *pts)
423 {
424 	struct pt_range top_range = pt_top_range(pts->range->common);
425 	struct pt_state top_pts = pt_init_top(&top_range);
426 
427 	/*
428 	 * Avoid calling pt_num_items_lg2() on the top, instead we can derive
429 	 * the size of the top table from the top range.
430 	 */
431 	if (pts->level == top_range.top_level)
432 		return ilog2(pt_range_to_end_index(&top_pts));
433 	return pt_num_items_lg2(pts);
434 }
435 
436 static void test_lvl_possible_sizes(struct kunit *test, struct pt_state *pts,
437 				    void *arg)
438 {
439 	unsigned int num_items_lg2 = safe_pt_num_items_lg2(pts);
440 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
441 	/* Matches get_info() */
442 	pt_vaddr_t limited_pgsize_bitmap =
443 		log2_mod(pgsize_bitmap, pts->range->common->max_vasz_lg2 - 1);
444 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
445 
446 	if (!pt_can_have_leaf(pts)) {
447 		KUNIT_ASSERT_EQ(test, pgsize_bitmap, 0);
448 		return;
449 	}
450 
451 	/* No bits for sizes that would be outside this table */
452 	KUNIT_ASSERT_EQ(test, log2_mod(pgsize_bitmap, isz_lg2), 0);
453 	KUNIT_ASSERT_EQ(
454 		test,
455 		fvalog2_div(limited_pgsize_bitmap, num_items_lg2 + isz_lg2), 0);
456 
457 	/*
458 	 * Non contiguous must be supported. AMDv1 has a HW bug where it does
459 	 * not support it on one of the levels.
460 	 */
461 	if ((u64)pgsize_bitmap != 0xff0000000000ULL ||
462 	    strcmp(__stringify(PTPFX_RAW), "amdv1") != 0)
463 		KUNIT_ASSERT_TRUE(test, pgsize_bitmap & log2_to_int(isz_lg2));
464 	else
465 		KUNIT_ASSERT_NE(test, pgsize_bitmap, 0);
466 
467 	/* A contiguous entry should not span the whole table */
468 	if (num_items_lg2 + isz_lg2 != PT_VADDR_MAX_LG2)
469 		KUNIT_ASSERT_FALSE(
470 			test, limited_pgsize_bitmap &
471 					log2_to_int(num_items_lg2 + isz_lg2));
472 }
473 
474 static void test_entry_possible_sizes(struct kunit *test)
475 {
476 	check_all_levels(test, test_lvl_possible_sizes, NULL);
477 }
478 
479 static void sweep_all_pgsizes(struct kunit *test, struct pt_state *pts,
480 			      struct pt_write_attrs *attrs,
481 			      pt_oaddr_t test_oaddr)
482 {
483 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
484 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
485 	unsigned int len_lg2;
486 
487 	if (pts->index != 0)
488 		return;
489 
490 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2 - 1; len_lg2++) {
491 		struct pt_state sub_pts = *pts;
492 		pt_oaddr_t oaddr;
493 
494 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
495 			continue;
496 
497 		oaddr = log2_set_mod(test_oaddr, 0, len_lg2);
498 		pt_install_leaf_entry(pts, oaddr, len_lg2, attrs);
499 		/* Verify that every contiguous item translates correctly */
500 		for (sub_pts.index = 0;
501 		     sub_pts.index != log2_to_int(len_lg2 - isz_lg2);
502 		     sub_pts.index++) {
503 			KUNIT_ASSERT_PT_LOAD(test, &sub_pts, PT_ENTRY_OA);
504 			KUNIT_ASSERT_EQ(test, pt_item_oa(&sub_pts),
505 					oaddr + sub_pts.index *
506 							oalog2_mul(1, isz_lg2));
507 			KUNIT_ASSERT_EQ(test, pt_entry_oa(&sub_pts), oaddr);
508 			KUNIT_ASSERT_EQ(test, pt_entry_num_contig_lg2(&sub_pts),
509 					len_lg2 - isz_lg2);
510 		}
511 
512 		pt_clear_entries(pts, len_lg2 - isz_lg2);
513 		KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
514 	}
515 }
516 
517 /*
518  * Check that pt_install_leaf_entry() and pt_entry_oa() match.
519  * Check that pt_clear_entries() works.
520  */
521 static void test_lvl_entry_oa(struct kunit *test, struct pt_state *pts,
522 			      void *arg)
523 {
524 	unsigned int max_oa_lg2 = pts->range->common->max_oasz_lg2;
525 	struct kunit_iommu_priv *priv = test->priv;
526 	struct pt_write_attrs attrs = {};
527 
528 	if (!pt_can_have_leaf(pts))
529 		return;
530 
531 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
532 				 pt_iommu_set_prot(pts->range->common, &attrs,
533 						   IOMMU_READ));
534 
535 	sweep_all_pgsizes(test, pts, &attrs, priv->test_oa);
536 
537 	/* Check that the table can store the boundary OAs */
538 	sweep_all_pgsizes(test, pts, &attrs, 0);
539 	if (max_oa_lg2 == PT_OADDR_MAX_LG2)
540 		sweep_all_pgsizes(test, pts, &attrs, PT_OADDR_MAX);
541 	else
542 		sweep_all_pgsizes(test, pts, &attrs,
543 				  oalog2_to_max_int(max_oa_lg2));
544 }
545 
546 static void test_entry_oa(struct kunit *test)
547 {
548 	check_all_levels(test, test_lvl_entry_oa, NULL);
549 }
550 
551 /* Test pt_attr_from_entry() */
552 static void test_lvl_attr_from_entry(struct kunit *test, struct pt_state *pts,
553 				     void *arg)
554 {
555 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
556 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
557 	struct kunit_iommu_priv *priv = test->priv;
558 	unsigned int len_lg2;
559 	unsigned int prot;
560 
561 	if (!pt_can_have_leaf(pts))
562 		return;
563 
564 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2; len_lg2++) {
565 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
566 			continue;
567 		for (prot = 0; prot <= (IOMMU_READ | IOMMU_WRITE | IOMMU_CACHE |
568 					IOMMU_NOEXEC | IOMMU_MMIO);
569 		     prot++) {
570 			pt_oaddr_t oaddr;
571 			struct pt_write_attrs attrs = {};
572 			u64 good_entry;
573 
574 			/*
575 			 * If the format doesn't support this combination of
576 			 * prot bits skip it
577 			 */
578 			if (pt_iommu_set_prot(pts->range->common, &attrs,
579 					      prot)) {
580 				/* But RW has to be supported */
581 				KUNIT_ASSERT_NE(test, prot,
582 						IOMMU_READ | IOMMU_WRITE);
583 				continue;
584 			}
585 
586 			oaddr = log2_set_mod(priv->test_oa, 0, len_lg2);
587 			pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs);
588 			KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA);
589 
590 			good_entry = pts->entry;
591 
592 			memset(&attrs, 0, sizeof(attrs));
593 			pt_attr_from_entry(pts, &attrs);
594 
595 			pt_clear_entries(pts, len_lg2 - isz_lg2);
596 			KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
597 
598 			pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs);
599 			KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA);
600 
601 			/*
602 			 * The descriptor produced by pt_attr_from_entry()
603 			 * produce an identical entry value when re-written
604 			 */
605 			KUNIT_ASSERT_EQ(test, good_entry, pts->entry);
606 
607 			pt_clear_entries(pts, len_lg2 - isz_lg2);
608 		}
609 	}
610 }
611 
612 static void test_attr_from_entry(struct kunit *test)
613 {
614 	check_all_levels(test, test_lvl_attr_from_entry, NULL);
615 }
616 
617 static void test_lvl_dirty(struct kunit *test, struct pt_state *pts, void *arg)
618 {
619 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
620 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
621 	struct kunit_iommu_priv *priv = test->priv;
622 	unsigned int start_idx = pts->index;
623 	struct pt_write_attrs attrs = {};
624 	unsigned int len_lg2;
625 
626 	if (!pt_can_have_leaf(pts))
627 		return;
628 
629 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
630 				 pt_iommu_set_prot(pts->range->common, &attrs,
631 						   IOMMU_READ | IOMMU_WRITE));
632 
633 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2; len_lg2++) {
634 		pt_oaddr_t oaddr;
635 		unsigned int i;
636 
637 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
638 			continue;
639 
640 		oaddr = log2_set_mod(priv->test_oa, 0, len_lg2);
641 		pt_install_leaf_entry(pts, oaddr, len_lg2, &attrs);
642 		KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_OA);
643 
644 		pt_load_entry(pts);
645 		pt_entry_make_write_clean(pts);
646 		pt_load_entry(pts);
647 		KUNIT_ASSERT_FALSE(test, pt_entry_is_write_dirty(pts));
648 
649 		for (i = 0; i != log2_to_int(len_lg2 - isz_lg2); i++) {
650 			/* dirty every contiguous entry */
651 			pts->index = start_idx + i;
652 			pt_load_entry(pts);
653 			KUNIT_ASSERT_TRUE(test, pt_entry_make_write_dirty(pts));
654 			pts->index = start_idx;
655 			pt_load_entry(pts);
656 			KUNIT_ASSERT_TRUE(test, pt_entry_is_write_dirty(pts));
657 
658 			pt_entry_make_write_clean(pts);
659 			pt_load_entry(pts);
660 			KUNIT_ASSERT_FALSE(test, pt_entry_is_write_dirty(pts));
661 		}
662 
663 		pt_clear_entries(pts, len_lg2 - isz_lg2);
664 	}
665 }
666 
667 static __maybe_unused void test_dirty(struct kunit *test)
668 {
669 	struct kunit_iommu_priv *priv = test->priv;
670 
671 	if (!pt_dirty_supported(priv->common))
672 		kunit_skip(test,
673 			   "Page table features do not support dirty tracking");
674 
675 	check_all_levels(test, test_lvl_dirty, NULL);
676 }
677 
678 static void test_lvl_sw_bit_leaf(struct kunit *test, struct pt_state *pts,
679 				 void *arg)
680 {
681 	struct kunit_iommu_priv *priv = test->priv;
682 	pt_vaddr_t pgsize_bitmap = pt_possible_sizes(pts);
683 	unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
684 	struct pt_write_attrs attrs = {};
685 	unsigned int len_lg2;
686 
687 	if (!pt_can_have_leaf(pts))
688 		return;
689 	if (pts->index != 0)
690 		return;
691 
692 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
693 				 pt_iommu_set_prot(pts->range->common, &attrs,
694 						   IOMMU_READ));
695 
696 	for (len_lg2 = 0; len_lg2 < PT_VADDR_MAX_LG2 - 1; len_lg2++) {
697 		pt_oaddr_t paddr = log2_set_mod(priv->test_oa, 0, len_lg2);
698 		struct pt_write_attrs new_attrs = {};
699 		unsigned int bitnr;
700 
701 		if (!(pgsize_bitmap & log2_to_int(len_lg2)))
702 			continue;
703 
704 		pt_install_leaf_entry(pts, paddr, len_lg2, &attrs);
705 
706 		for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common);
707 		     bitnr++)
708 			KUNIT_ASSERT_FALSE(test,
709 					   pt_test_sw_bit_acquire(pts, bitnr));
710 
711 		for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common);
712 		     bitnr++) {
713 			KUNIT_ASSERT_FALSE(test,
714 					   pt_test_sw_bit_acquire(pts, bitnr));
715 			pt_set_sw_bit_release(pts, bitnr);
716 			KUNIT_ASSERT_TRUE(test,
717 					  pt_test_sw_bit_acquire(pts, bitnr));
718 		}
719 
720 		for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common);
721 		     bitnr++)
722 			KUNIT_ASSERT_TRUE(test,
723 					  pt_test_sw_bit_acquire(pts, bitnr));
724 
725 		KUNIT_ASSERT_EQ(test, pt_item_oa(pts), paddr);
726 
727 		/* SW bits didn't leak into the attrs */
728 		pt_attr_from_entry(pts, &new_attrs);
729 		KUNIT_ASSERT_MEMEQ(test, &new_attrs, &attrs, sizeof(attrs));
730 
731 		pt_clear_entries(pts, len_lg2 - isz_lg2);
732 		KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
733 	}
734 }
735 
736 static __maybe_unused void test_sw_bit_leaf(struct kunit *test)
737 {
738 	check_all_levels(test, test_lvl_sw_bit_leaf, NULL);
739 }
740 
741 static void test_lvl_sw_bit_table(struct kunit *test, struct pt_state *pts,
742 				  void *arg)
743 {
744 	struct kunit_iommu_priv *priv = test->priv;
745 	struct pt_write_attrs attrs = {};
746 	pt_oaddr_t paddr =
747 		log2_set_mod(priv->test_oa, 0, priv->smallest_pgsz_lg2);
748 	unsigned int bitnr;
749 
750 	if (!pt_can_have_leaf(pts))
751 		return;
752 	if (pts->index != 0)
753 		return;
754 
755 	KUNIT_ASSERT_NO_ERRNO_FN(test, "pt_iommu_set_prot",
756 				 pt_iommu_set_prot(pts->range->common, &attrs,
757 						   IOMMU_READ));
758 
759 	KUNIT_ASSERT_TRUE(test, pt_install_table(pts, paddr, &attrs));
760 
761 	for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++)
762 		KUNIT_ASSERT_FALSE(test, pt_test_sw_bit_acquire(pts, bitnr));
763 
764 	for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++) {
765 		KUNIT_ASSERT_FALSE(test, pt_test_sw_bit_acquire(pts, bitnr));
766 		pt_set_sw_bit_release(pts, bitnr);
767 		KUNIT_ASSERT_TRUE(test, pt_test_sw_bit_acquire(pts, bitnr));
768 	}
769 
770 	for (bitnr = 0; bitnr <= pt_max_sw_bit(pts->range->common); bitnr++)
771 		KUNIT_ASSERT_TRUE(test, pt_test_sw_bit_acquire(pts, bitnr));
772 
773 	KUNIT_ASSERT_EQ(test, pt_table_pa(pts), paddr);
774 
775 	pt_clear_entries(pts, ilog2(1));
776 	KUNIT_ASSERT_PT_LOAD(test, pts, PT_ENTRY_EMPTY);
777 }
778 
779 static __maybe_unused void test_sw_bit_table(struct kunit *test)
780 {
781 	check_all_levels(test, test_lvl_sw_bit_table, NULL);
782 }
783 
784 static struct kunit_case generic_pt_test_cases[] = {
785 	KUNIT_CASE_FMT(test_init),
786 	KUNIT_CASE_FMT(test_bitops),
787 	KUNIT_CASE_FMT(test_best_pgsize),
788 	KUNIT_CASE_FMT(test_pgsz_count),
789 	KUNIT_CASE_FMT(test_table_ptr),
790 	KUNIT_CASE_FMT(test_max_va),
791 	KUNIT_CASE_FMT(test_table_radix),
792 	KUNIT_CASE_FMT(test_entry_possible_sizes),
793 	KUNIT_CASE_FMT(test_entry_oa),
794 	KUNIT_CASE_FMT(test_attr_from_entry),
795 #ifdef pt_entry_is_write_dirty
796 	KUNIT_CASE_FMT(test_dirty),
797 #endif
798 #ifdef pt_sw_bit
799 	KUNIT_CASE_FMT(test_sw_bit_leaf),
800 	KUNIT_CASE_FMT(test_sw_bit_table),
801 #endif
802 	{},
803 };
804 
805 static int pt_kunit_generic_pt_init(struct kunit *test)
806 {
807 	struct kunit_iommu_priv *priv;
808 	int ret;
809 
810 	priv = kunit_kzalloc(test, sizeof(*priv), GFP_KERNEL);
811 	if (!priv)
812 		return -ENOMEM;
813 	ret = pt_kunit_priv_init(test, priv);
814 	if (ret) {
815 		kunit_kfree(test, priv);
816 		return ret;
817 	}
818 	test->priv = priv;
819 	return 0;
820 }
821 
822 static void pt_kunit_generic_pt_exit(struct kunit *test)
823 {
824 	struct kunit_iommu_priv *priv = test->priv;
825 
826 	if (!test->priv)
827 		return;
828 
829 	pt_iommu_deinit(priv->iommu);
830 	kunit_kfree(test, test->priv);
831 }
832 
833 static struct kunit_suite NS(generic_pt_suite) = {
834 	.name = __stringify(NS(fmt_test)),
835 	.init = pt_kunit_generic_pt_init,
836 	.exit = pt_kunit_generic_pt_exit,
837 	.test_cases = generic_pt_test_cases,
838 };
839 kunit_test_suites(&NS(generic_pt_suite));
840