xref: /linux/drivers/iommu/generic_pt/kunit_iommu_pt.h (revision ce5cfb0fa20dc6454da039612e34325b7b4a8243)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
4  */
5 #include "kunit_iommu.h"
6 #include "pt_iter.h"
7 #include <linux/generic_pt/iommu.h>
8 #include <linux/iommu.h>
9 
10 static void do_map(struct kunit *test, pt_vaddr_t va, pt_oaddr_t pa,
11 		   pt_vaddr_t len);
12 
13 struct count_valids {
14 	u64 per_size[PT_VADDR_MAX_LG2];
15 };
16 
__count_valids(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)17 static int __count_valids(struct pt_range *range, void *arg, unsigned int level,
18 			  struct pt_table_p *table)
19 {
20 	struct pt_state pts = pt_init(range, level, table);
21 	struct count_valids *valids = arg;
22 
23 	for_each_pt_level_entry(&pts) {
24 		if (pts.type == PT_ENTRY_TABLE) {
25 			pt_descend(&pts, arg, __count_valids);
26 			continue;
27 		}
28 		if (pts.type == PT_ENTRY_OA) {
29 			valids->per_size[pt_entry_oa_lg2sz(&pts)]++;
30 			continue;
31 		}
32 	}
33 	return 0;
34 }
35 
36 /*
37  * Number of valid table entries. This counts contiguous entries as a single
38  * valid.
39  */
count_valids(struct kunit * test)40 static unsigned int count_valids(struct kunit *test)
41 {
42 	struct kunit_iommu_priv *priv = test->priv;
43 	struct pt_range range = pt_top_range(priv->common);
44 	struct count_valids valids = {};
45 	u64 total = 0;
46 	unsigned int i;
47 
48 	KUNIT_ASSERT_NO_ERRNO(test,
49 			      pt_walk_range(&range, __count_valids, &valids));
50 
51 	for (i = 0; i != ARRAY_SIZE(valids.per_size); i++)
52 		total += valids.per_size[i];
53 	return total;
54 }
55 
56 /* Only a single page size is present, count the number of valid entries */
count_valids_single(struct kunit * test,pt_vaddr_t pgsz)57 static unsigned int count_valids_single(struct kunit *test, pt_vaddr_t pgsz)
58 {
59 	struct kunit_iommu_priv *priv = test->priv;
60 	struct pt_range range = pt_top_range(priv->common);
61 	struct count_valids valids = {};
62 	u64 total = 0;
63 	unsigned int i;
64 
65 	KUNIT_ASSERT_NO_ERRNO(test,
66 			      pt_walk_range(&range, __count_valids, &valids));
67 
68 	for (i = 0; i != ARRAY_SIZE(valids.per_size); i++) {
69 		if ((1ULL << i) == pgsz)
70 			total = valids.per_size[i];
71 		else
72 			KUNIT_ASSERT_EQ(test, valids.per_size[i], 0);
73 	}
74 	return total;
75 }
76 
do_unmap(struct kunit * test,pt_vaddr_t va,pt_vaddr_t len)77 static void do_unmap(struct kunit *test, pt_vaddr_t va, pt_vaddr_t len)
78 {
79 	struct kunit_iommu_priv *priv = test->priv;
80 	size_t ret;
81 
82 	ret = iommu_unmap(&priv->domain, va, len);
83 	KUNIT_ASSERT_EQ(test, ret, len);
84 }
85 
check_iova(struct kunit * test,pt_vaddr_t va,pt_oaddr_t pa,pt_vaddr_t len)86 static void check_iova(struct kunit *test, pt_vaddr_t va, pt_oaddr_t pa,
87 		       pt_vaddr_t len)
88 {
89 	struct kunit_iommu_priv *priv = test->priv;
90 	pt_vaddr_t pfn = log2_div(va, priv->smallest_pgsz_lg2);
91 	pt_vaddr_t end_pfn = pfn + log2_div(len, priv->smallest_pgsz_lg2);
92 
93 	for (; pfn != end_pfn; pfn++) {
94 		phys_addr_t res = iommu_iova_to_phys(&priv->domain,
95 						     pfn * priv->smallest_pgsz);
96 
97 		KUNIT_ASSERT_EQ(test, res, (phys_addr_t)pa);
98 		if (res != pa)
99 			break;
100 		pa += priv->smallest_pgsz;
101 	}
102 }
103 
test_increase_level(struct kunit * test)104 static void test_increase_level(struct kunit *test)
105 {
106 	struct kunit_iommu_priv *priv = test->priv;
107 	struct pt_common *common = priv->common;
108 
109 	if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP))
110 		kunit_skip(test, "PT_FEAT_DYNAMIC_TOP not set for this format");
111 
112 	if (IS_32BIT)
113 		kunit_skip(test, "Unable to test on 32bit");
114 
115 	KUNIT_ASSERT_GT(test, common->max_vasz_lg2,
116 			pt_top_range(common).max_vasz_lg2);
117 
118 	/* Add every possible level to the max */
119 	while (common->max_vasz_lg2 != pt_top_range(common).max_vasz_lg2) {
120 		struct pt_range top_range = pt_top_range(common);
121 
122 		if (top_range.va == 0)
123 			do_map(test, top_range.last_va + 1, 0,
124 			       priv->smallest_pgsz);
125 		else
126 			do_map(test, top_range.va - priv->smallest_pgsz, 0,
127 			       priv->smallest_pgsz);
128 
129 		KUNIT_ASSERT_EQ(test, pt_top_range(common).top_level,
130 				top_range.top_level + 1);
131 		KUNIT_ASSERT_GE(test, common->max_vasz_lg2,
132 				pt_top_range(common).max_vasz_lg2);
133 	}
134 }
135 
test_map_simple(struct kunit * test)136 static void test_map_simple(struct kunit *test)
137 {
138 	struct kunit_iommu_priv *priv = test->priv;
139 	struct pt_range range = pt_top_range(priv->common);
140 	struct count_valids valids = {};
141 	pt_vaddr_t pgsize_bitmap = priv->safe_pgsize_bitmap;
142 	unsigned int pgsz_lg2;
143 	pt_vaddr_t cur_va;
144 
145 	/* Map every reported page size */
146 	cur_va = range.va + priv->smallest_pgsz * 256;
147 	for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) {
148 		pt_oaddr_t paddr = log2_set_mod(priv->test_oa, 0, pgsz_lg2);
149 		u64 len = log2_to_int(pgsz_lg2);
150 
151 		if (!(pgsize_bitmap & len))
152 			continue;
153 
154 		cur_va = ALIGN(cur_va, len);
155 		do_map(test, cur_va, paddr, len);
156 		if (len <= SZ_2G)
157 			check_iova(test, cur_va, paddr, len);
158 		cur_va += len;
159 	}
160 
161 	/* The read interface reports that every page size was created */
162 	range = pt_top_range(priv->common);
163 	KUNIT_ASSERT_NO_ERRNO(test,
164 			      pt_walk_range(&range, __count_valids, &valids));
165 	for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) {
166 		if (pgsize_bitmap & (1ULL << pgsz_lg2))
167 			KUNIT_ASSERT_EQ(test, valids.per_size[pgsz_lg2], 1);
168 		else
169 			KUNIT_ASSERT_EQ(test, valids.per_size[pgsz_lg2], 0);
170 	}
171 
172 	/* Unmap works */
173 	range = pt_top_range(priv->common);
174 	cur_va = range.va + priv->smallest_pgsz * 256;
175 	for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) {
176 		u64 len = log2_to_int(pgsz_lg2);
177 
178 		if (!(pgsize_bitmap & len))
179 			continue;
180 		cur_va = ALIGN(cur_va, len);
181 		do_unmap(test, cur_va, len);
182 		cur_va += len;
183 	}
184 	KUNIT_ASSERT_EQ(test, count_valids(test), 0);
185 }
186 
187 /*
188  * Test to convert a table pointer into an OA by mapping something small,
189  * unmapping it so as to leave behind a table pointer, then mapping something
190  * larger that will convert the table into an OA.
191  */
test_map_table_to_oa(struct kunit * test)192 static void test_map_table_to_oa(struct kunit *test)
193 {
194 	struct kunit_iommu_priv *priv = test->priv;
195 	pt_vaddr_t limited_pgbitmap =
196 		priv->info.pgsize_bitmap % (IS_32BIT ? SZ_2G : SZ_16G);
197 	struct pt_range range = pt_top_range(priv->common);
198 	unsigned int pgsz_lg2;
199 	pt_vaddr_t max_pgsize;
200 	pt_vaddr_t cur_va;
201 
202 	max_pgsize = 1ULL << (vafls(limited_pgbitmap) - 1);
203 	KUNIT_ASSERT_TRUE(test, priv->info.pgsize_bitmap & max_pgsize);
204 
205 	for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) {
206 		pt_oaddr_t paddr = log2_set_mod(priv->test_oa, 0, pgsz_lg2);
207 		u64 len = log2_to_int(pgsz_lg2);
208 		pt_vaddr_t offset;
209 
210 		if (!(priv->info.pgsize_bitmap & len))
211 			continue;
212 		if (len > max_pgsize)
213 			break;
214 
215 		cur_va = ALIGN(range.va + priv->smallest_pgsz * 256,
216 			       max_pgsize);
217 		for (offset = 0; offset != max_pgsize; offset += len)
218 			do_map(test, cur_va + offset, paddr + offset, len);
219 		check_iova(test, cur_va, paddr, max_pgsize);
220 		KUNIT_ASSERT_EQ(test, count_valids_single(test, len),
221 				log2_div(max_pgsize, pgsz_lg2));
222 
223 		if (len == max_pgsize) {
224 			do_unmap(test, cur_va, max_pgsize);
225 		} else {
226 			do_unmap(test, cur_va, max_pgsize / 2);
227 			for (offset = max_pgsize / 2; offset != max_pgsize;
228 			     offset += len)
229 				do_unmap(test, cur_va + offset, len);
230 		}
231 
232 		KUNIT_ASSERT_EQ(test, count_valids(test), 0);
233 	}
234 }
235 
236 /*
237  * Test unmapping a small page at the start of a large page. This always unmaps
238  * the large page.
239  */
test_unmap_split(struct kunit * test)240 static void test_unmap_split(struct kunit *test)
241 {
242 	struct kunit_iommu_priv *priv = test->priv;
243 	struct pt_range top_range = pt_top_range(priv->common);
244 	pt_vaddr_t pgsize_bitmap = priv->safe_pgsize_bitmap;
245 	unsigned int pgsz_lg2;
246 	unsigned int count = 0;
247 
248 	for (pgsz_lg2 = 0; pgsz_lg2 != PT_VADDR_MAX_LG2; pgsz_lg2++) {
249 		pt_vaddr_t base_len = log2_to_int(pgsz_lg2);
250 		unsigned int next_pgsz_lg2;
251 
252 		if (!(pgsize_bitmap & base_len))
253 			continue;
254 
255 		for (next_pgsz_lg2 = pgsz_lg2 + 1;
256 		     next_pgsz_lg2 != PT_VADDR_MAX_LG2; next_pgsz_lg2++) {
257 			pt_vaddr_t next_len = log2_to_int(next_pgsz_lg2);
258 			pt_vaddr_t vaddr = top_range.va;
259 			pt_oaddr_t paddr = 0;
260 			size_t gnmapped;
261 
262 			if (!(pgsize_bitmap & next_len))
263 				continue;
264 
265 			do_map(test, vaddr, paddr, next_len);
266 			gnmapped = iommu_unmap(&priv->domain, vaddr, base_len);
267 			KUNIT_ASSERT_EQ(test, gnmapped, next_len);
268 
269 			/* Make sure unmap doesn't keep going */
270 			do_map(test, vaddr, paddr, next_len);
271 			do_map(test, vaddr + next_len, paddr, next_len);
272 			gnmapped = iommu_unmap(&priv->domain, vaddr, base_len);
273 			KUNIT_ASSERT_EQ(test, gnmapped, next_len);
274 			gnmapped = iommu_unmap(&priv->domain, vaddr + next_len,
275 					       next_len);
276 			KUNIT_ASSERT_EQ(test, gnmapped, next_len);
277 
278 			count++;
279 		}
280 	}
281 
282 	if (count == 0)
283 		kunit_skip(test, "Test needs two page sizes");
284 }
285 
unmap_collisions(struct kunit * test,struct maple_tree * mt,pt_vaddr_t start,pt_vaddr_t last)286 static void unmap_collisions(struct kunit *test, struct maple_tree *mt,
287 			     pt_vaddr_t start, pt_vaddr_t last)
288 {
289 	struct kunit_iommu_priv *priv = test->priv;
290 	MA_STATE(mas, mt, start, last);
291 	void *entry;
292 
293 	mtree_lock(mt);
294 	mas_for_each(&mas, entry, last) {
295 		pt_vaddr_t mas_start = mas.index;
296 		pt_vaddr_t len = (mas.last - mas_start) + 1;
297 		pt_oaddr_t paddr;
298 
299 		mas_erase(&mas);
300 		mas_pause(&mas);
301 		mtree_unlock(mt);
302 
303 		paddr = oalog2_mod(mas_start, priv->common->max_oasz_lg2);
304 		check_iova(test, mas_start, paddr, len);
305 		do_unmap(test, mas_start, len);
306 		mtree_lock(mt);
307 	}
308 	mtree_unlock(mt);
309 }
310 
clamp_range(struct kunit * test,struct pt_range * range)311 static void clamp_range(struct kunit *test, struct pt_range *range)
312 {
313 	struct kunit_iommu_priv *priv = test->priv;
314 
315 	if (range->last_va - range->va > SZ_1G)
316 		range->last_va = range->va + SZ_1G;
317 	KUNIT_ASSERT_NE(test, range->last_va, PT_VADDR_MAX);
318 	if (range->va <= MAPLE_RESERVED_RANGE)
319 		range->va =
320 			ALIGN(MAPLE_RESERVED_RANGE, priv->smallest_pgsz);
321 }
322 
323 /*
324  * Randomly map and unmap ranges that can large physical pages. If a random
325  * range overlaps with existing ranges then unmap them. This hits all the
326  * special cases.
327  */
test_random_map(struct kunit * test)328 static void test_random_map(struct kunit *test)
329 {
330 	struct kunit_iommu_priv *priv = test->priv;
331 	struct pt_range upper_range = pt_upper_range(priv->common);
332 	struct pt_range top_range = pt_top_range(priv->common);
333 	struct maple_tree mt;
334 	unsigned int iter;
335 
336 	mt_init(&mt);
337 
338 	/*
339 	 * Shrink the range so randomization is more likely to have
340 	 * intersections
341 	 */
342 	clamp_range(test, &top_range);
343 	clamp_range(test, &upper_range);
344 
345 	for (iter = 0; iter != 1000; iter++) {
346 		struct pt_range *range = &top_range;
347 		pt_oaddr_t paddr;
348 		pt_vaddr_t start;
349 		pt_vaddr_t end;
350 		int ret;
351 
352 		if (pt_feature(priv->common, PT_FEAT_SIGN_EXTEND) &&
353 		    ULONG_MAX >= PT_VADDR_MAX && get_random_u32_inclusive(0, 1))
354 			range = &upper_range;
355 
356 		start = get_random_u32_below(
357 			min(U32_MAX, range->last_va - range->va));
358 		end = get_random_u32_below(
359 			min(U32_MAX, range->last_va - start));
360 
361 		start = ALIGN_DOWN(start, priv->smallest_pgsz);
362 		end = ALIGN(end, priv->smallest_pgsz);
363 		start += range->va;
364 		end += start;
365 		if (start < range->va || end > range->last_va + 1 ||
366 		    start >= end)
367 			continue;
368 
369 		/* Try overmapping to test the failure handling */
370 		paddr = oalog2_mod(start, priv->common->max_oasz_lg2);
371 		ret = iommu_map(&priv->domain, start, paddr, end - start,
372 				IOMMU_READ | IOMMU_WRITE, GFP_KERNEL);
373 		if (ret) {
374 			KUNIT_ASSERT_EQ(test, ret, -EADDRINUSE);
375 			unmap_collisions(test, &mt, start, end - 1);
376 			do_map(test, start, paddr, end - start);
377 		}
378 
379 		KUNIT_ASSERT_NO_ERRNO_FN(test, "mtree_insert_range",
380 					 mtree_insert_range(&mt, start, end - 1,
381 							    XA_ZERO_ENTRY,
382 							    GFP_KERNEL));
383 
384 		check_iova(test, start, paddr, end - start);
385 		if (iter % 100)
386 			cond_resched();
387 	}
388 
389 	unmap_collisions(test, &mt, 0, PT_VADDR_MAX);
390 	KUNIT_ASSERT_EQ(test, count_valids(test), 0);
391 
392 	mtree_destroy(&mt);
393 }
394 
395 /* See https://lore.kernel.org/r/b9b18a03-63a2-4065-a27e-d92dd5c860bc@amd.com */
test_pgsize_boundary(struct kunit * test)396 static void test_pgsize_boundary(struct kunit *test)
397 {
398 	struct kunit_iommu_priv *priv = test->priv;
399 	struct pt_range top_range = pt_top_range(priv->common);
400 
401 	if (top_range.va != 0 || top_range.last_va < 0xfef9ffff ||
402 	    priv->smallest_pgsz != SZ_4K)
403 		kunit_skip(test, "Format does not have the required range");
404 
405 	do_map(test, 0xfef80000, 0x208b95d000, 0xfef9ffff - 0xfef80000 + 1);
406 }
407 
408 /* See https://lore.kernel.org/r/20250826143816.38686-1-eugkoira@amazon.com */
test_mixed(struct kunit * test)409 static void test_mixed(struct kunit *test)
410 {
411 	struct kunit_iommu_priv *priv = test->priv;
412 	struct pt_range top_range = pt_top_range(priv->common);
413 	u64 start = 0x3fe400ULL << 12;
414 	u64 end = 0x4c0600ULL << 12;
415 	pt_vaddr_t len = end - start;
416 	pt_oaddr_t oa = start;
417 
418 	if (top_range.last_va <= start || sizeof(unsigned long) == 4)
419 		kunit_skip(test, "range is too small");
420 	if ((priv->safe_pgsize_bitmap & GENMASK(30, 21)) != (BIT(30) | BIT(21)))
421 		kunit_skip(test, "incompatible psize");
422 
423 	do_map(test, start, oa, len);
424 	/* 14 2M, 3 1G, 3 2M */
425 	KUNIT_ASSERT_EQ(test, count_valids(test), 20);
426 	check_iova(test, start, oa, len);
427 }
428 
429 static struct kunit_case iommu_test_cases[] = {
430 	KUNIT_CASE_FMT(test_increase_level),
431 	KUNIT_CASE_FMT(test_map_simple),
432 	KUNIT_CASE_FMT(test_map_table_to_oa),
433 	KUNIT_CASE_FMT(test_unmap_split),
434 	KUNIT_CASE_FMT(test_random_map),
435 	KUNIT_CASE_FMT(test_pgsize_boundary),
436 	KUNIT_CASE_FMT(test_mixed),
437 	{},
438 };
439 
pt_kunit_iommu_init(struct kunit * test)440 static int pt_kunit_iommu_init(struct kunit *test)
441 {
442 	struct kunit_iommu_priv *priv;
443 	int ret;
444 
445 	priv = kunit_kzalloc(test, sizeof(*priv), GFP_KERNEL);
446 	if (!priv)
447 		return -ENOMEM;
448 
449 	priv->orig_nr_secondary_pagetable =
450 		global_node_page_state(NR_SECONDARY_PAGETABLE);
451 	ret = pt_kunit_priv_init(test, priv);
452 	if (ret) {
453 		kunit_kfree(test, priv);
454 		return ret;
455 	}
456 	test->priv = priv;
457 	return 0;
458 }
459 
pt_kunit_iommu_exit(struct kunit * test)460 static void pt_kunit_iommu_exit(struct kunit *test)
461 {
462 	struct kunit_iommu_priv *priv = test->priv;
463 
464 	if (!test->priv)
465 		return;
466 
467 	pt_iommu_deinit(priv->iommu);
468 	/*
469 	 * Look for memory leaks, assumes kunit is running isolated and nothing
470 	 * else is using secondary page tables.
471 	 */
472 	KUNIT_ASSERT_EQ(test, priv->orig_nr_secondary_pagetable,
473 			global_node_page_state(NR_SECONDARY_PAGETABLE));
474 	kunit_kfree(test, test->priv);
475 }
476 
477 static struct kunit_suite NS(iommu_suite) = {
478 	.name = __stringify(NS(iommu_test)),
479 	.init = pt_kunit_iommu_init,
480 	.exit = pt_kunit_iommu_exit,
481 	.test_cases = iommu_test_cases,
482 };
483 kunit_test_suites(&NS(iommu_suite));
484 
485 MODULE_LICENSE("GPL");
486 MODULE_DESCRIPTION("Kunit for generic page table");
487 MODULE_IMPORT_NS("GENERIC_PT_IOMMU");
488