xref: /linux/drivers/iommu/iommufd/selftest.c (revision 4b660dbd9ee2059850fd30e0df420ca7a38a1856)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * Kernel side components to support tools/testing/selftests/iommu
5  */
6 #include <linux/slab.h>
7 #include <linux/iommu.h>
8 #include <linux/xarray.h>
9 #include <linux/file.h>
10 #include <linux/anon_inodes.h>
11 #include <linux/fault-inject.h>
12 #include <linux/platform_device.h>
13 #include <uapi/linux/iommufd.h>
14 
15 #include "../iommu-priv.h"
16 #include "io_pagetable.h"
17 #include "iommufd_private.h"
18 #include "iommufd_test.h"
19 
20 static DECLARE_FAULT_ATTR(fail_iommufd);
21 static struct dentry *dbgfs_root;
22 static struct platform_device *selftest_iommu_dev;
23 static const struct iommu_ops mock_ops;
24 static struct iommu_domain_ops domain_nested_ops;
25 
26 size_t iommufd_test_memory_limit = 65536;
27 
28 struct mock_bus_type {
29 	struct bus_type bus;
30 	struct notifier_block nb;
31 };
32 
33 static struct mock_bus_type iommufd_mock_bus_type = {
34 	.bus = {
35 		.name = "iommufd_mock",
36 	},
37 };
38 
39 static DEFINE_IDA(mock_dev_ida);
40 
41 enum {
42 	MOCK_DIRTY_TRACK = 1,
43 	MOCK_IO_PAGE_SIZE = PAGE_SIZE / 2,
44 	MOCK_HUGE_PAGE_SIZE = 512 * MOCK_IO_PAGE_SIZE,
45 
46 	/*
47 	 * Like a real page table alignment requires the low bits of the address
48 	 * to be zero. xarray also requires the high bit to be zero, so we store
49 	 * the pfns shifted. The upper bits are used for metadata.
50 	 */
51 	MOCK_PFN_MASK = ULONG_MAX / MOCK_IO_PAGE_SIZE,
52 
53 	_MOCK_PFN_START = MOCK_PFN_MASK + 1,
54 	MOCK_PFN_START_IOVA = _MOCK_PFN_START,
55 	MOCK_PFN_LAST_IOVA = _MOCK_PFN_START,
56 	MOCK_PFN_DIRTY_IOVA = _MOCK_PFN_START << 1,
57 	MOCK_PFN_HUGE_IOVA = _MOCK_PFN_START << 2,
58 };
59 
60 /*
61  * Syzkaller has trouble randomizing the correct iova to use since it is linked
62  * to the map ioctl's output, and it has no ide about that. So, simplify things.
63  * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
64  * value. This has a much smaller randomization space and syzkaller can hit it.
65  */
66 static unsigned long __iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
67 						  u64 *iova)
68 {
69 	struct syz_layout {
70 		__u32 nth_area;
71 		__u32 offset;
72 	};
73 	struct syz_layout *syz = (void *)iova;
74 	unsigned int nth = syz->nth_area;
75 	struct iopt_area *area;
76 
77 	down_read(&iopt->iova_rwsem);
78 	for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
79 	     area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
80 		if (nth == 0) {
81 			up_read(&iopt->iova_rwsem);
82 			return iopt_area_iova(area) + syz->offset;
83 		}
84 		nth--;
85 	}
86 	up_read(&iopt->iova_rwsem);
87 
88 	return 0;
89 }
90 
91 static unsigned long iommufd_test_syz_conv_iova(struct iommufd_access *access,
92 						u64 *iova)
93 {
94 	unsigned long ret;
95 
96 	mutex_lock(&access->ioas_lock);
97 	if (!access->ioas) {
98 		mutex_unlock(&access->ioas_lock);
99 		return 0;
100 	}
101 	ret = __iommufd_test_syz_conv_iova(&access->ioas->iopt, iova);
102 	mutex_unlock(&access->ioas_lock);
103 	return ret;
104 }
105 
106 void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
107 				   unsigned int ioas_id, u64 *iova, u32 *flags)
108 {
109 	struct iommufd_ioas *ioas;
110 
111 	if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
112 		return;
113 	*flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;
114 
115 	ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
116 	if (IS_ERR(ioas))
117 		return;
118 	*iova = __iommufd_test_syz_conv_iova(&ioas->iopt, iova);
119 	iommufd_put_object(ucmd->ictx, &ioas->obj);
120 }
121 
122 struct mock_iommu_domain {
123 	unsigned long flags;
124 	struct iommu_domain domain;
125 	struct xarray pfns;
126 };
127 
128 struct mock_iommu_domain_nested {
129 	struct iommu_domain domain;
130 	struct mock_iommu_domain *parent;
131 	u32 iotlb[MOCK_NESTED_DOMAIN_IOTLB_NUM];
132 };
133 
134 enum selftest_obj_type {
135 	TYPE_IDEV,
136 };
137 
138 struct mock_dev {
139 	struct device dev;
140 	unsigned long flags;
141 	int id;
142 };
143 
144 struct selftest_obj {
145 	struct iommufd_object obj;
146 	enum selftest_obj_type type;
147 
148 	union {
149 		struct {
150 			struct iommufd_device *idev;
151 			struct iommufd_ctx *ictx;
152 			struct mock_dev *mock_dev;
153 		} idev;
154 	};
155 };
156 
157 static int mock_domain_nop_attach(struct iommu_domain *domain,
158 				  struct device *dev)
159 {
160 	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
161 
162 	if (domain->dirty_ops && (mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY))
163 		return -EINVAL;
164 
165 	return 0;
166 }
167 
168 static const struct iommu_domain_ops mock_blocking_ops = {
169 	.attach_dev = mock_domain_nop_attach,
170 };
171 
172 static struct iommu_domain mock_blocking_domain = {
173 	.type = IOMMU_DOMAIN_BLOCKED,
174 	.ops = &mock_blocking_ops,
175 };
176 
177 static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
178 {
179 	struct iommu_test_hw_info *info;
180 
181 	info = kzalloc(sizeof(*info), GFP_KERNEL);
182 	if (!info)
183 		return ERR_PTR(-ENOMEM);
184 
185 	info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
186 	*length = sizeof(*info);
187 	*type = IOMMU_HW_INFO_TYPE_SELFTEST;
188 
189 	return info;
190 }
191 
192 static int mock_domain_set_dirty_tracking(struct iommu_domain *domain,
193 					  bool enable)
194 {
195 	struct mock_iommu_domain *mock =
196 		container_of(domain, struct mock_iommu_domain, domain);
197 	unsigned long flags = mock->flags;
198 
199 	if (enable && !domain->dirty_ops)
200 		return -EINVAL;
201 
202 	/* No change? */
203 	if (!(enable ^ !!(flags & MOCK_DIRTY_TRACK)))
204 		return 0;
205 
206 	flags = (enable ? flags | MOCK_DIRTY_TRACK : flags & ~MOCK_DIRTY_TRACK);
207 
208 	mock->flags = flags;
209 	return 0;
210 }
211 
212 static bool mock_test_and_clear_dirty(struct mock_iommu_domain *mock,
213 				      unsigned long iova, size_t page_size,
214 				      unsigned long flags)
215 {
216 	unsigned long cur, end = iova + page_size - 1;
217 	bool dirty = false;
218 	void *ent, *old;
219 
220 	for (cur = iova; cur < end; cur += MOCK_IO_PAGE_SIZE) {
221 		ent = xa_load(&mock->pfns, cur / MOCK_IO_PAGE_SIZE);
222 		if (!ent || !(xa_to_value(ent) & MOCK_PFN_DIRTY_IOVA))
223 			continue;
224 
225 		dirty = true;
226 		/* Clear dirty */
227 		if (!(flags & IOMMU_DIRTY_NO_CLEAR)) {
228 			unsigned long val;
229 
230 			val = xa_to_value(ent) & ~MOCK_PFN_DIRTY_IOVA;
231 			old = xa_store(&mock->pfns, cur / MOCK_IO_PAGE_SIZE,
232 				       xa_mk_value(val), GFP_KERNEL);
233 			WARN_ON_ONCE(ent != old);
234 		}
235 	}
236 
237 	return dirty;
238 }
239 
240 static int mock_domain_read_and_clear_dirty(struct iommu_domain *domain,
241 					    unsigned long iova, size_t size,
242 					    unsigned long flags,
243 					    struct iommu_dirty_bitmap *dirty)
244 {
245 	struct mock_iommu_domain *mock =
246 		container_of(domain, struct mock_iommu_domain, domain);
247 	unsigned long end = iova + size;
248 	void *ent;
249 
250 	if (!(mock->flags & MOCK_DIRTY_TRACK) && dirty->bitmap)
251 		return -EINVAL;
252 
253 	do {
254 		unsigned long pgsize = MOCK_IO_PAGE_SIZE;
255 		unsigned long head;
256 
257 		ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
258 		if (!ent) {
259 			iova += pgsize;
260 			continue;
261 		}
262 
263 		if (xa_to_value(ent) & MOCK_PFN_HUGE_IOVA)
264 			pgsize = MOCK_HUGE_PAGE_SIZE;
265 		head = iova & ~(pgsize - 1);
266 
267 		/* Clear dirty */
268 		if (mock_test_and_clear_dirty(mock, head, pgsize, flags))
269 			iommu_dirty_bitmap_record(dirty, head, pgsize);
270 		iova = head + pgsize;
271 	} while (iova < end);
272 
273 	return 0;
274 }
275 
276 const struct iommu_dirty_ops dirty_ops = {
277 	.set_dirty_tracking = mock_domain_set_dirty_tracking,
278 	.read_and_clear_dirty = mock_domain_read_and_clear_dirty,
279 };
280 
281 static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
282 {
283 	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
284 	struct mock_iommu_domain *mock;
285 
286 	mock = kzalloc(sizeof(*mock), GFP_KERNEL);
287 	if (!mock)
288 		return NULL;
289 	mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
290 	mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
291 	mock->domain.pgsize_bitmap = MOCK_IO_PAGE_SIZE;
292 	if (dev && mdev->flags & MOCK_FLAGS_DEVICE_HUGE_IOVA)
293 		mock->domain.pgsize_bitmap |= MOCK_HUGE_PAGE_SIZE;
294 	mock->domain.ops = mock_ops.default_domain_ops;
295 	mock->domain.type = IOMMU_DOMAIN_UNMANAGED;
296 	xa_init(&mock->pfns);
297 	return &mock->domain;
298 }
299 
300 static struct iommu_domain *
301 __mock_domain_alloc_nested(struct mock_iommu_domain *mock_parent,
302 			   const struct iommu_hwpt_selftest *user_cfg)
303 {
304 	struct mock_iommu_domain_nested *mock_nested;
305 	int i;
306 
307 	mock_nested = kzalloc(sizeof(*mock_nested), GFP_KERNEL);
308 	if (!mock_nested)
309 		return ERR_PTR(-ENOMEM);
310 	mock_nested->parent = mock_parent;
311 	mock_nested->domain.ops = &domain_nested_ops;
312 	mock_nested->domain.type = IOMMU_DOMAIN_NESTED;
313 	for (i = 0; i < MOCK_NESTED_DOMAIN_IOTLB_NUM; i++)
314 		mock_nested->iotlb[i] = user_cfg->iotlb;
315 	return &mock_nested->domain;
316 }
317 
318 static struct iommu_domain *
319 mock_domain_alloc_user(struct device *dev, u32 flags,
320 		       struct iommu_domain *parent,
321 		       const struct iommu_user_data *user_data)
322 {
323 	struct mock_iommu_domain *mock_parent;
324 	struct iommu_hwpt_selftest user_cfg;
325 	int rc;
326 
327 	/* must be mock_domain */
328 	if (!parent) {
329 		struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
330 		bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
331 		bool no_dirty_ops = mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY;
332 		struct iommu_domain *domain;
333 
334 		if (flags & (~(IOMMU_HWPT_ALLOC_NEST_PARENT |
335 			       IOMMU_HWPT_ALLOC_DIRTY_TRACKING)))
336 			return ERR_PTR(-EOPNOTSUPP);
337 		if (user_data || (has_dirty_flag && no_dirty_ops))
338 			return ERR_PTR(-EOPNOTSUPP);
339 		domain = mock_domain_alloc_paging(dev);
340 		if (!domain)
341 			return ERR_PTR(-ENOMEM);
342 		if (has_dirty_flag)
343 			container_of(domain, struct mock_iommu_domain, domain)
344 				->domain.dirty_ops = &dirty_ops;
345 		return domain;
346 	}
347 
348 	/* must be mock_domain_nested */
349 	if (user_data->type != IOMMU_HWPT_DATA_SELFTEST || flags)
350 		return ERR_PTR(-EOPNOTSUPP);
351 	if (!parent || parent->ops != mock_ops.default_domain_ops)
352 		return ERR_PTR(-EINVAL);
353 
354 	mock_parent = container_of(parent, struct mock_iommu_domain, domain);
355 	if (!mock_parent)
356 		return ERR_PTR(-EINVAL);
357 
358 	rc = iommu_copy_struct_from_user(&user_cfg, user_data,
359 					 IOMMU_HWPT_DATA_SELFTEST, iotlb);
360 	if (rc)
361 		return ERR_PTR(rc);
362 
363 	return __mock_domain_alloc_nested(mock_parent, &user_cfg);
364 }
365 
366 static void mock_domain_free(struct iommu_domain *domain)
367 {
368 	struct mock_iommu_domain *mock =
369 		container_of(domain, struct mock_iommu_domain, domain);
370 
371 	WARN_ON(!xa_empty(&mock->pfns));
372 	kfree(mock);
373 }
374 
375 static int mock_domain_map_pages(struct iommu_domain *domain,
376 				 unsigned long iova, phys_addr_t paddr,
377 				 size_t pgsize, size_t pgcount, int prot,
378 				 gfp_t gfp, size_t *mapped)
379 {
380 	struct mock_iommu_domain *mock =
381 		container_of(domain, struct mock_iommu_domain, domain);
382 	unsigned long flags = MOCK_PFN_START_IOVA;
383 	unsigned long start_iova = iova;
384 
385 	/*
386 	 * xarray does not reliably work with fault injection because it does a
387 	 * retry allocation, so put our own failure point.
388 	 */
389 	if (iommufd_should_fail())
390 		return -ENOENT;
391 
392 	WARN_ON(iova % MOCK_IO_PAGE_SIZE);
393 	WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
394 	for (; pgcount; pgcount--) {
395 		size_t cur;
396 
397 		for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
398 			void *old;
399 
400 			if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
401 				flags = MOCK_PFN_LAST_IOVA;
402 			if (pgsize != MOCK_IO_PAGE_SIZE) {
403 				flags |= MOCK_PFN_HUGE_IOVA;
404 			}
405 			old = xa_store(&mock->pfns, iova / MOCK_IO_PAGE_SIZE,
406 				       xa_mk_value((paddr / MOCK_IO_PAGE_SIZE) |
407 						   flags),
408 				       gfp);
409 			if (xa_is_err(old)) {
410 				for (; start_iova != iova;
411 				     start_iova += MOCK_IO_PAGE_SIZE)
412 					xa_erase(&mock->pfns,
413 						 start_iova /
414 							 MOCK_IO_PAGE_SIZE);
415 				return xa_err(old);
416 			}
417 			WARN_ON(old);
418 			iova += MOCK_IO_PAGE_SIZE;
419 			paddr += MOCK_IO_PAGE_SIZE;
420 			*mapped += MOCK_IO_PAGE_SIZE;
421 			flags = 0;
422 		}
423 	}
424 	return 0;
425 }
426 
427 static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
428 				      unsigned long iova, size_t pgsize,
429 				      size_t pgcount,
430 				      struct iommu_iotlb_gather *iotlb_gather)
431 {
432 	struct mock_iommu_domain *mock =
433 		container_of(domain, struct mock_iommu_domain, domain);
434 	bool first = true;
435 	size_t ret = 0;
436 	void *ent;
437 
438 	WARN_ON(iova % MOCK_IO_PAGE_SIZE);
439 	WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
440 
441 	for (; pgcount; pgcount--) {
442 		size_t cur;
443 
444 		for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
445 			ent = xa_erase(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
446 
447 			/*
448 			 * iommufd generates unmaps that must be a strict
449 			 * superset of the map's performend So every
450 			 * starting/ending IOVA should have been an iova passed
451 			 * to map.
452 			 *
453 			 * This simple logic doesn't work when the HUGE_PAGE is
454 			 * turned on since the core code will automatically
455 			 * switch between the two page sizes creating a break in
456 			 * the unmap calls. The break can land in the middle of
457 			 * contiguous IOVA.
458 			 */
459 			if (!(domain->pgsize_bitmap & MOCK_HUGE_PAGE_SIZE)) {
460 				if (first) {
461 					WARN_ON(ent && !(xa_to_value(ent) &
462 							 MOCK_PFN_START_IOVA));
463 					first = false;
464 				}
465 				if (pgcount == 1 &&
466 				    cur + MOCK_IO_PAGE_SIZE == pgsize)
467 					WARN_ON(ent && !(xa_to_value(ent) &
468 							 MOCK_PFN_LAST_IOVA));
469 			}
470 
471 			iova += MOCK_IO_PAGE_SIZE;
472 			ret += MOCK_IO_PAGE_SIZE;
473 		}
474 	}
475 	return ret;
476 }
477 
478 static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
479 					    dma_addr_t iova)
480 {
481 	struct mock_iommu_domain *mock =
482 		container_of(domain, struct mock_iommu_domain, domain);
483 	void *ent;
484 
485 	WARN_ON(iova % MOCK_IO_PAGE_SIZE);
486 	ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
487 	WARN_ON(!ent);
488 	return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE;
489 }
490 
491 static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
492 {
493 	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
494 
495 	switch (cap) {
496 	case IOMMU_CAP_CACHE_COHERENCY:
497 		return true;
498 	case IOMMU_CAP_DIRTY_TRACKING:
499 		return !(mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY);
500 	default:
501 		break;
502 	}
503 
504 	return false;
505 }
506 
507 static struct iommu_device mock_iommu_device = {
508 };
509 
510 static struct iommu_device *mock_probe_device(struct device *dev)
511 {
512 	if (dev->bus != &iommufd_mock_bus_type.bus)
513 		return ERR_PTR(-ENODEV);
514 	return &mock_iommu_device;
515 }
516 
517 static const struct iommu_ops mock_ops = {
518 	/*
519 	 * IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
520 	 * because it is zero.
521 	 */
522 	.default_domain = &mock_blocking_domain,
523 	.blocked_domain = &mock_blocking_domain,
524 	.owner = THIS_MODULE,
525 	.pgsize_bitmap = MOCK_IO_PAGE_SIZE,
526 	.hw_info = mock_domain_hw_info,
527 	.domain_alloc_paging = mock_domain_alloc_paging,
528 	.domain_alloc_user = mock_domain_alloc_user,
529 	.capable = mock_domain_capable,
530 	.device_group = generic_device_group,
531 	.probe_device = mock_probe_device,
532 	.default_domain_ops =
533 		&(struct iommu_domain_ops){
534 			.free = mock_domain_free,
535 			.attach_dev = mock_domain_nop_attach,
536 			.map_pages = mock_domain_map_pages,
537 			.unmap_pages = mock_domain_unmap_pages,
538 			.iova_to_phys = mock_domain_iova_to_phys,
539 		},
540 };
541 
542 static void mock_domain_free_nested(struct iommu_domain *domain)
543 {
544 	struct mock_iommu_domain_nested *mock_nested =
545 		container_of(domain, struct mock_iommu_domain_nested, domain);
546 
547 	kfree(mock_nested);
548 }
549 
550 static int
551 mock_domain_cache_invalidate_user(struct iommu_domain *domain,
552 				  struct iommu_user_data_array *array)
553 {
554 	struct mock_iommu_domain_nested *mock_nested =
555 		container_of(domain, struct mock_iommu_domain_nested, domain);
556 	struct iommu_hwpt_invalidate_selftest inv;
557 	u32 processed = 0;
558 	int i = 0, j;
559 	int rc = 0;
560 
561 	if (array->type != IOMMU_HWPT_INVALIDATE_DATA_SELFTEST) {
562 		rc = -EINVAL;
563 		goto out;
564 	}
565 
566 	for ( ; i < array->entry_num; i++) {
567 		rc = iommu_copy_struct_from_user_array(&inv, array,
568 						       IOMMU_HWPT_INVALIDATE_DATA_SELFTEST,
569 						       i, iotlb_id);
570 		if (rc)
571 			break;
572 
573 		if (inv.flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
574 			rc = -EOPNOTSUPP;
575 			break;
576 		}
577 
578 		if (inv.iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX) {
579 			rc = -EINVAL;
580 			break;
581 		}
582 
583 		if (inv.flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
584 			/* Invalidate all mock iotlb entries and ignore iotlb_id */
585 			for (j = 0; j < MOCK_NESTED_DOMAIN_IOTLB_NUM; j++)
586 				mock_nested->iotlb[j] = 0;
587 		} else {
588 			mock_nested->iotlb[inv.iotlb_id] = 0;
589 		}
590 
591 		processed++;
592 	}
593 
594 out:
595 	array->entry_num = processed;
596 	return rc;
597 }
598 
599 static struct iommu_domain_ops domain_nested_ops = {
600 	.free = mock_domain_free_nested,
601 	.attach_dev = mock_domain_nop_attach,
602 	.cache_invalidate_user = mock_domain_cache_invalidate_user,
603 };
604 
605 static inline struct iommufd_hw_pagetable *
606 __get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, u32 hwpt_type)
607 {
608 	struct iommufd_object *obj;
609 
610 	obj = iommufd_get_object(ucmd->ictx, mockpt_id, hwpt_type);
611 	if (IS_ERR(obj))
612 		return ERR_CAST(obj);
613 	return container_of(obj, struct iommufd_hw_pagetable, obj);
614 }
615 
616 static inline struct iommufd_hw_pagetable *
617 get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
618 		 struct mock_iommu_domain **mock)
619 {
620 	struct iommufd_hw_pagetable *hwpt;
621 
622 	hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_PAGING);
623 	if (IS_ERR(hwpt))
624 		return hwpt;
625 	if (hwpt->domain->type != IOMMU_DOMAIN_UNMANAGED ||
626 	    hwpt->domain->ops != mock_ops.default_domain_ops) {
627 		iommufd_put_object(ucmd->ictx, &hwpt->obj);
628 		return ERR_PTR(-EINVAL);
629 	}
630 	*mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
631 	return hwpt;
632 }
633 
634 static inline struct iommufd_hw_pagetable *
635 get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
636 			struct mock_iommu_domain_nested **mock_nested)
637 {
638 	struct iommufd_hw_pagetable *hwpt;
639 
640 	hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_NESTED);
641 	if (IS_ERR(hwpt))
642 		return hwpt;
643 	if (hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
644 	    hwpt->domain->ops != &domain_nested_ops) {
645 		iommufd_put_object(ucmd->ictx, &hwpt->obj);
646 		return ERR_PTR(-EINVAL);
647 	}
648 	*mock_nested = container_of(hwpt->domain,
649 				    struct mock_iommu_domain_nested, domain);
650 	return hwpt;
651 }
652 
653 static void mock_dev_release(struct device *dev)
654 {
655 	struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
656 
657 	ida_free(&mock_dev_ida, mdev->id);
658 	kfree(mdev);
659 }
660 
661 static struct mock_dev *mock_dev_create(unsigned long dev_flags)
662 {
663 	struct mock_dev *mdev;
664 	int rc;
665 
666 	if (dev_flags &
667 	    ~(MOCK_FLAGS_DEVICE_NO_DIRTY | MOCK_FLAGS_DEVICE_HUGE_IOVA))
668 		return ERR_PTR(-EINVAL);
669 
670 	mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
671 	if (!mdev)
672 		return ERR_PTR(-ENOMEM);
673 
674 	device_initialize(&mdev->dev);
675 	mdev->flags = dev_flags;
676 	mdev->dev.release = mock_dev_release;
677 	mdev->dev.bus = &iommufd_mock_bus_type.bus;
678 
679 	rc = ida_alloc(&mock_dev_ida, GFP_KERNEL);
680 	if (rc < 0)
681 		goto err_put;
682 	mdev->id = rc;
683 
684 	rc = dev_set_name(&mdev->dev, "iommufd_mock%u", mdev->id);
685 	if (rc)
686 		goto err_put;
687 
688 	rc = device_add(&mdev->dev);
689 	if (rc)
690 		goto err_put;
691 	return mdev;
692 
693 err_put:
694 	put_device(&mdev->dev);
695 	return ERR_PTR(rc);
696 }
697 
698 static void mock_dev_destroy(struct mock_dev *mdev)
699 {
700 	device_unregister(&mdev->dev);
701 }
702 
703 bool iommufd_selftest_is_mock_dev(struct device *dev)
704 {
705 	return dev->release == mock_dev_release;
706 }
707 
708 /* Create an hw_pagetable with the mock domain so we can test the domain ops */
709 static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
710 				    struct iommu_test_cmd *cmd)
711 {
712 	struct iommufd_device *idev;
713 	struct selftest_obj *sobj;
714 	u32 pt_id = cmd->id;
715 	u32 dev_flags = 0;
716 	u32 idev_id;
717 	int rc;
718 
719 	sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
720 	if (IS_ERR(sobj))
721 		return PTR_ERR(sobj);
722 
723 	sobj->idev.ictx = ucmd->ictx;
724 	sobj->type = TYPE_IDEV;
725 
726 	if (cmd->op == IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS)
727 		dev_flags = cmd->mock_domain_flags.dev_flags;
728 
729 	sobj->idev.mock_dev = mock_dev_create(dev_flags);
730 	if (IS_ERR(sobj->idev.mock_dev)) {
731 		rc = PTR_ERR(sobj->idev.mock_dev);
732 		goto out_sobj;
733 	}
734 
735 	idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
736 				   &idev_id);
737 	if (IS_ERR(idev)) {
738 		rc = PTR_ERR(idev);
739 		goto out_mdev;
740 	}
741 	sobj->idev.idev = idev;
742 
743 	rc = iommufd_device_attach(idev, &pt_id);
744 	if (rc)
745 		goto out_unbind;
746 
747 	/* Userspace must destroy the device_id to destroy the object */
748 	cmd->mock_domain.out_hwpt_id = pt_id;
749 	cmd->mock_domain.out_stdev_id = sobj->obj.id;
750 	cmd->mock_domain.out_idev_id = idev_id;
751 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
752 	if (rc)
753 		goto out_detach;
754 	iommufd_object_finalize(ucmd->ictx, &sobj->obj);
755 	return 0;
756 
757 out_detach:
758 	iommufd_device_detach(idev);
759 out_unbind:
760 	iommufd_device_unbind(idev);
761 out_mdev:
762 	mock_dev_destroy(sobj->idev.mock_dev);
763 out_sobj:
764 	iommufd_object_abort(ucmd->ictx, &sobj->obj);
765 	return rc;
766 }
767 
768 /* Replace the mock domain with a manually allocated hw_pagetable */
769 static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
770 					    unsigned int device_id, u32 pt_id,
771 					    struct iommu_test_cmd *cmd)
772 {
773 	struct iommufd_object *dev_obj;
774 	struct selftest_obj *sobj;
775 	int rc;
776 
777 	/*
778 	 * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
779 	 * it doesn't race with detach, which is not allowed.
780 	 */
781 	dev_obj =
782 		iommufd_get_object(ucmd->ictx, device_id, IOMMUFD_OBJ_SELFTEST);
783 	if (IS_ERR(dev_obj))
784 		return PTR_ERR(dev_obj);
785 
786 	sobj = container_of(dev_obj, struct selftest_obj, obj);
787 	if (sobj->type != TYPE_IDEV) {
788 		rc = -EINVAL;
789 		goto out_dev_obj;
790 	}
791 
792 	rc = iommufd_device_replace(sobj->idev.idev, &pt_id);
793 	if (rc)
794 		goto out_dev_obj;
795 
796 	cmd->mock_domain_replace.pt_id = pt_id;
797 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
798 
799 out_dev_obj:
800 	iommufd_put_object(ucmd->ictx, dev_obj);
801 	return rc;
802 }
803 
804 /* Add an additional reserved IOVA to the IOAS */
805 static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
806 				     unsigned int mockpt_id,
807 				     unsigned long start, size_t length)
808 {
809 	struct iommufd_ioas *ioas;
810 	int rc;
811 
812 	ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
813 	if (IS_ERR(ioas))
814 		return PTR_ERR(ioas);
815 	down_write(&ioas->iopt.iova_rwsem);
816 	rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
817 	up_write(&ioas->iopt.iova_rwsem);
818 	iommufd_put_object(ucmd->ictx, &ioas->obj);
819 	return rc;
820 }
821 
822 /* Check that every pfn under each iova matches the pfn under a user VA */
823 static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
824 				    unsigned int mockpt_id, unsigned long iova,
825 				    size_t length, void __user *uptr)
826 {
827 	struct iommufd_hw_pagetable *hwpt;
828 	struct mock_iommu_domain *mock;
829 	uintptr_t end;
830 	int rc;
831 
832 	if (iova % MOCK_IO_PAGE_SIZE || length % MOCK_IO_PAGE_SIZE ||
833 	    (uintptr_t)uptr % MOCK_IO_PAGE_SIZE ||
834 	    check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
835 		return -EINVAL;
836 
837 	hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
838 	if (IS_ERR(hwpt))
839 		return PTR_ERR(hwpt);
840 
841 	for (; length; length -= MOCK_IO_PAGE_SIZE) {
842 		struct page *pages[1];
843 		unsigned long pfn;
844 		long npages;
845 		void *ent;
846 
847 		npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
848 					     pages);
849 		if (npages < 0) {
850 			rc = npages;
851 			goto out_put;
852 		}
853 		if (WARN_ON(npages != 1)) {
854 			rc = -EFAULT;
855 			goto out_put;
856 		}
857 		pfn = page_to_pfn(pages[0]);
858 		put_page(pages[0]);
859 
860 		ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
861 		if (!ent ||
862 		    (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE !=
863 			    pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
864 			rc = -EINVAL;
865 			goto out_put;
866 		}
867 		iova += MOCK_IO_PAGE_SIZE;
868 		uptr += MOCK_IO_PAGE_SIZE;
869 	}
870 	rc = 0;
871 
872 out_put:
873 	iommufd_put_object(ucmd->ictx, &hwpt->obj);
874 	return rc;
875 }
876 
877 /* Check that the page ref count matches, to look for missing pin/unpins */
878 static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
879 				      void __user *uptr, size_t length,
880 				      unsigned int refs)
881 {
882 	uintptr_t end;
883 
884 	if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
885 	    check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
886 		return -EINVAL;
887 
888 	for (; length; length -= PAGE_SIZE) {
889 		struct page *pages[1];
890 		long npages;
891 
892 		npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
893 		if (npages < 0)
894 			return npages;
895 		if (WARN_ON(npages != 1))
896 			return -EFAULT;
897 		if (!PageCompound(pages[0])) {
898 			unsigned int count;
899 
900 			count = page_ref_count(pages[0]);
901 			if (count / GUP_PIN_COUNTING_BIAS != refs) {
902 				put_page(pages[0]);
903 				return -EIO;
904 			}
905 		}
906 		put_page(pages[0]);
907 		uptr += PAGE_SIZE;
908 	}
909 	return 0;
910 }
911 
912 static int iommufd_test_md_check_iotlb(struct iommufd_ucmd *ucmd,
913 				       u32 mockpt_id, unsigned int iotlb_id,
914 				       u32 iotlb)
915 {
916 	struct mock_iommu_domain_nested *mock_nested;
917 	struct iommufd_hw_pagetable *hwpt;
918 	int rc = 0;
919 
920 	hwpt = get_md_pagetable_nested(ucmd, mockpt_id, &mock_nested);
921 	if (IS_ERR(hwpt))
922 		return PTR_ERR(hwpt);
923 
924 	mock_nested = container_of(hwpt->domain,
925 				   struct mock_iommu_domain_nested, domain);
926 
927 	if (iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX ||
928 	    mock_nested->iotlb[iotlb_id] != iotlb)
929 		rc = -EINVAL;
930 	iommufd_put_object(ucmd->ictx, &hwpt->obj);
931 	return rc;
932 }
933 
934 struct selftest_access {
935 	struct iommufd_access *access;
936 	struct file *file;
937 	struct mutex lock;
938 	struct list_head items;
939 	unsigned int next_id;
940 	bool destroying;
941 };
942 
943 struct selftest_access_item {
944 	struct list_head items_elm;
945 	unsigned long iova;
946 	size_t length;
947 	unsigned int id;
948 };
949 
950 static const struct file_operations iommfd_test_staccess_fops;
951 
952 static struct selftest_access *iommufd_access_get(int fd)
953 {
954 	struct file *file;
955 
956 	file = fget(fd);
957 	if (!file)
958 		return ERR_PTR(-EBADFD);
959 
960 	if (file->f_op != &iommfd_test_staccess_fops) {
961 		fput(file);
962 		return ERR_PTR(-EBADFD);
963 	}
964 	return file->private_data;
965 }
966 
967 static void iommufd_test_access_unmap(void *data, unsigned long iova,
968 				      unsigned long length)
969 {
970 	unsigned long iova_last = iova + length - 1;
971 	struct selftest_access *staccess = data;
972 	struct selftest_access_item *item;
973 	struct selftest_access_item *tmp;
974 
975 	mutex_lock(&staccess->lock);
976 	list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
977 		if (iova > item->iova + item->length - 1 ||
978 		    iova_last < item->iova)
979 			continue;
980 		list_del(&item->items_elm);
981 		iommufd_access_unpin_pages(staccess->access, item->iova,
982 					   item->length);
983 		kfree(item);
984 	}
985 	mutex_unlock(&staccess->lock);
986 }
987 
988 static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
989 					    unsigned int access_id,
990 					    unsigned int item_id)
991 {
992 	struct selftest_access_item *item;
993 	struct selftest_access *staccess;
994 
995 	staccess = iommufd_access_get(access_id);
996 	if (IS_ERR(staccess))
997 		return PTR_ERR(staccess);
998 
999 	mutex_lock(&staccess->lock);
1000 	list_for_each_entry(item, &staccess->items, items_elm) {
1001 		if (item->id == item_id) {
1002 			list_del(&item->items_elm);
1003 			iommufd_access_unpin_pages(staccess->access, item->iova,
1004 						   item->length);
1005 			mutex_unlock(&staccess->lock);
1006 			kfree(item);
1007 			fput(staccess->file);
1008 			return 0;
1009 		}
1010 	}
1011 	mutex_unlock(&staccess->lock);
1012 	fput(staccess->file);
1013 	return -ENOENT;
1014 }
1015 
1016 static int iommufd_test_staccess_release(struct inode *inode,
1017 					 struct file *filep)
1018 {
1019 	struct selftest_access *staccess = filep->private_data;
1020 
1021 	if (staccess->access) {
1022 		iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
1023 		iommufd_access_destroy(staccess->access);
1024 	}
1025 	mutex_destroy(&staccess->lock);
1026 	kfree(staccess);
1027 	return 0;
1028 }
1029 
1030 static const struct iommufd_access_ops selftest_access_ops_pin = {
1031 	.needs_pin_pages = 1,
1032 	.unmap = iommufd_test_access_unmap,
1033 };
1034 
1035 static const struct iommufd_access_ops selftest_access_ops = {
1036 	.unmap = iommufd_test_access_unmap,
1037 };
1038 
1039 static const struct file_operations iommfd_test_staccess_fops = {
1040 	.release = iommufd_test_staccess_release,
1041 };
1042 
1043 static struct selftest_access *iommufd_test_alloc_access(void)
1044 {
1045 	struct selftest_access *staccess;
1046 	struct file *filep;
1047 
1048 	staccess = kzalloc(sizeof(*staccess), GFP_KERNEL_ACCOUNT);
1049 	if (!staccess)
1050 		return ERR_PTR(-ENOMEM);
1051 	INIT_LIST_HEAD(&staccess->items);
1052 	mutex_init(&staccess->lock);
1053 
1054 	filep = anon_inode_getfile("[iommufd_test_staccess]",
1055 				   &iommfd_test_staccess_fops, staccess,
1056 				   O_RDWR);
1057 	if (IS_ERR(filep)) {
1058 		kfree(staccess);
1059 		return ERR_CAST(filep);
1060 	}
1061 	staccess->file = filep;
1062 	return staccess;
1063 }
1064 
1065 static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
1066 				      unsigned int ioas_id, unsigned int flags)
1067 {
1068 	struct iommu_test_cmd *cmd = ucmd->cmd;
1069 	struct selftest_access *staccess;
1070 	struct iommufd_access *access;
1071 	u32 id;
1072 	int fdno;
1073 	int rc;
1074 
1075 	if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
1076 		return -EOPNOTSUPP;
1077 
1078 	staccess = iommufd_test_alloc_access();
1079 	if (IS_ERR(staccess))
1080 		return PTR_ERR(staccess);
1081 
1082 	fdno = get_unused_fd_flags(O_CLOEXEC);
1083 	if (fdno < 0) {
1084 		rc = -ENOMEM;
1085 		goto out_free_staccess;
1086 	}
1087 
1088 	access = iommufd_access_create(
1089 		ucmd->ictx,
1090 		(flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
1091 			&selftest_access_ops_pin :
1092 			&selftest_access_ops,
1093 		staccess, &id);
1094 	if (IS_ERR(access)) {
1095 		rc = PTR_ERR(access);
1096 		goto out_put_fdno;
1097 	}
1098 	rc = iommufd_access_attach(access, ioas_id);
1099 	if (rc)
1100 		goto out_destroy;
1101 	cmd->create_access.out_access_fd = fdno;
1102 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1103 	if (rc)
1104 		goto out_destroy;
1105 
1106 	staccess->access = access;
1107 	fd_install(fdno, staccess->file);
1108 	return 0;
1109 
1110 out_destroy:
1111 	iommufd_access_destroy(access);
1112 out_put_fdno:
1113 	put_unused_fd(fdno);
1114 out_free_staccess:
1115 	fput(staccess->file);
1116 	return rc;
1117 }
1118 
1119 static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
1120 					    unsigned int access_id,
1121 					    unsigned int ioas_id)
1122 {
1123 	struct selftest_access *staccess;
1124 	int rc;
1125 
1126 	staccess = iommufd_access_get(access_id);
1127 	if (IS_ERR(staccess))
1128 		return PTR_ERR(staccess);
1129 
1130 	rc = iommufd_access_replace(staccess->access, ioas_id);
1131 	fput(staccess->file);
1132 	return rc;
1133 }
1134 
1135 /* Check that the pages in a page array match the pages in the user VA */
1136 static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
1137 				    size_t npages)
1138 {
1139 	for (; npages; npages--) {
1140 		struct page *tmp_pages[1];
1141 		long rc;
1142 
1143 		rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
1144 		if (rc < 0)
1145 			return rc;
1146 		if (WARN_ON(rc != 1))
1147 			return -EFAULT;
1148 		put_page(tmp_pages[0]);
1149 		if (tmp_pages[0] != *pages)
1150 			return -EBADE;
1151 		pages++;
1152 		uptr += PAGE_SIZE;
1153 	}
1154 	return 0;
1155 }
1156 
1157 static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
1158 				     unsigned int access_id, unsigned long iova,
1159 				     size_t length, void __user *uptr,
1160 				     u32 flags)
1161 {
1162 	struct iommu_test_cmd *cmd = ucmd->cmd;
1163 	struct selftest_access_item *item;
1164 	struct selftest_access *staccess;
1165 	struct page **pages;
1166 	size_t npages;
1167 	int rc;
1168 
1169 	/* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1170 	if (length > 16*1024*1024)
1171 		return -ENOMEM;
1172 
1173 	if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
1174 		return -EOPNOTSUPP;
1175 
1176 	staccess = iommufd_access_get(access_id);
1177 	if (IS_ERR(staccess))
1178 		return PTR_ERR(staccess);
1179 
1180 	if (staccess->access->ops != &selftest_access_ops_pin) {
1181 		rc = -EOPNOTSUPP;
1182 		goto out_put;
1183 	}
1184 
1185 	if (flags & MOCK_FLAGS_ACCESS_SYZ)
1186 		iova = iommufd_test_syz_conv_iova(staccess->access,
1187 					&cmd->access_pages.iova);
1188 
1189 	npages = (ALIGN(iova + length, PAGE_SIZE) -
1190 		  ALIGN_DOWN(iova, PAGE_SIZE)) /
1191 		 PAGE_SIZE;
1192 	pages = kvcalloc(npages, sizeof(*pages), GFP_KERNEL_ACCOUNT);
1193 	if (!pages) {
1194 		rc = -ENOMEM;
1195 		goto out_put;
1196 	}
1197 
1198 	/*
1199 	 * Drivers will need to think very carefully about this locking. The
1200 	 * core code can do multiple unmaps instantaneously after
1201 	 * iommufd_access_pin_pages() and *all* the unmaps must not return until
1202 	 * the range is unpinned. This simple implementation puts a global lock
1203 	 * around the pin, which may not suit drivers that want this to be a
1204 	 * performance path. drivers that get this wrong will trigger WARN_ON
1205 	 * races and cause EDEADLOCK failures to userspace.
1206 	 */
1207 	mutex_lock(&staccess->lock);
1208 	rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
1209 				      flags & MOCK_FLAGS_ACCESS_WRITE);
1210 	if (rc)
1211 		goto out_unlock;
1212 
1213 	/* For syzkaller allow uptr to be NULL to skip this check */
1214 	if (uptr) {
1215 		rc = iommufd_test_check_pages(
1216 			uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
1217 			npages);
1218 		if (rc)
1219 			goto out_unaccess;
1220 	}
1221 
1222 	item = kzalloc(sizeof(*item), GFP_KERNEL_ACCOUNT);
1223 	if (!item) {
1224 		rc = -ENOMEM;
1225 		goto out_unaccess;
1226 	}
1227 
1228 	item->iova = iova;
1229 	item->length = length;
1230 	item->id = staccess->next_id++;
1231 	list_add_tail(&item->items_elm, &staccess->items);
1232 
1233 	cmd->access_pages.out_access_pages_id = item->id;
1234 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1235 	if (rc)
1236 		goto out_free_item;
1237 	goto out_unlock;
1238 
1239 out_free_item:
1240 	list_del(&item->items_elm);
1241 	kfree(item);
1242 out_unaccess:
1243 	iommufd_access_unpin_pages(staccess->access, iova, length);
1244 out_unlock:
1245 	mutex_unlock(&staccess->lock);
1246 	kvfree(pages);
1247 out_put:
1248 	fput(staccess->file);
1249 	return rc;
1250 }
1251 
1252 static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
1253 				  unsigned int access_id, unsigned long iova,
1254 				  size_t length, void __user *ubuf,
1255 				  unsigned int flags)
1256 {
1257 	struct iommu_test_cmd *cmd = ucmd->cmd;
1258 	struct selftest_access *staccess;
1259 	void *tmp;
1260 	int rc;
1261 
1262 	/* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1263 	if (length > 16*1024*1024)
1264 		return -ENOMEM;
1265 
1266 	if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
1267 		      MOCK_FLAGS_ACCESS_SYZ))
1268 		return -EOPNOTSUPP;
1269 
1270 	staccess = iommufd_access_get(access_id);
1271 	if (IS_ERR(staccess))
1272 		return PTR_ERR(staccess);
1273 
1274 	tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
1275 	if (!tmp) {
1276 		rc = -ENOMEM;
1277 		goto out_put;
1278 	}
1279 
1280 	if (flags & MOCK_ACCESS_RW_WRITE) {
1281 		if (copy_from_user(tmp, ubuf, length)) {
1282 			rc = -EFAULT;
1283 			goto out_free;
1284 		}
1285 	}
1286 
1287 	if (flags & MOCK_FLAGS_ACCESS_SYZ)
1288 		iova = iommufd_test_syz_conv_iova(staccess->access,
1289 				&cmd->access_rw.iova);
1290 
1291 	rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
1292 	if (rc)
1293 		goto out_free;
1294 	if (!(flags & MOCK_ACCESS_RW_WRITE)) {
1295 		if (copy_to_user(ubuf, tmp, length)) {
1296 			rc = -EFAULT;
1297 			goto out_free;
1298 		}
1299 	}
1300 
1301 out_free:
1302 	kvfree(tmp);
1303 out_put:
1304 	fput(staccess->file);
1305 	return rc;
1306 }
1307 static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
1308 static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
1309 	      __IOMMUFD_ACCESS_RW_SLOW_PATH);
1310 
1311 static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
1312 			      unsigned long iova, size_t length,
1313 			      unsigned long page_size, void __user *uptr,
1314 			      u32 flags)
1315 {
1316 	unsigned long bitmap_size, i, max;
1317 	struct iommu_test_cmd *cmd = ucmd->cmd;
1318 	struct iommufd_hw_pagetable *hwpt;
1319 	struct mock_iommu_domain *mock;
1320 	int rc, count = 0;
1321 	void *tmp;
1322 
1323 	if (!page_size || !length || iova % page_size || length % page_size ||
1324 	    !uptr)
1325 		return -EINVAL;
1326 
1327 	hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
1328 	if (IS_ERR(hwpt))
1329 		return PTR_ERR(hwpt);
1330 
1331 	if (!(mock->flags & MOCK_DIRTY_TRACK)) {
1332 		rc = -EINVAL;
1333 		goto out_put;
1334 	}
1335 
1336 	max = length / page_size;
1337 	bitmap_size = max / BITS_PER_BYTE;
1338 
1339 	tmp = kvzalloc(bitmap_size, GFP_KERNEL_ACCOUNT);
1340 	if (!tmp) {
1341 		rc = -ENOMEM;
1342 		goto out_put;
1343 	}
1344 
1345 	if (copy_from_user(tmp, uptr, bitmap_size)) {
1346 		rc = -EFAULT;
1347 		goto out_free;
1348 	}
1349 
1350 	for (i = 0; i < max; i++) {
1351 		unsigned long cur = iova + i * page_size;
1352 		void *ent, *old;
1353 
1354 		if (!test_bit(i, (unsigned long *)tmp))
1355 			continue;
1356 
1357 		ent = xa_load(&mock->pfns, cur / page_size);
1358 		if (ent) {
1359 			unsigned long val;
1360 
1361 			val = xa_to_value(ent) | MOCK_PFN_DIRTY_IOVA;
1362 			old = xa_store(&mock->pfns, cur / page_size,
1363 				       xa_mk_value(val), GFP_KERNEL);
1364 			WARN_ON_ONCE(ent != old);
1365 			count++;
1366 		}
1367 	}
1368 
1369 	cmd->dirty.out_nr_dirty = count;
1370 	rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1371 out_free:
1372 	kvfree(tmp);
1373 out_put:
1374 	iommufd_put_object(ucmd->ictx, &hwpt->obj);
1375 	return rc;
1376 }
1377 
1378 void iommufd_selftest_destroy(struct iommufd_object *obj)
1379 {
1380 	struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
1381 
1382 	switch (sobj->type) {
1383 	case TYPE_IDEV:
1384 		iommufd_device_detach(sobj->idev.idev);
1385 		iommufd_device_unbind(sobj->idev.idev);
1386 		mock_dev_destroy(sobj->idev.mock_dev);
1387 		break;
1388 	}
1389 }
1390 
1391 int iommufd_test(struct iommufd_ucmd *ucmd)
1392 {
1393 	struct iommu_test_cmd *cmd = ucmd->cmd;
1394 
1395 	switch (cmd->op) {
1396 	case IOMMU_TEST_OP_ADD_RESERVED:
1397 		return iommufd_test_add_reserved(ucmd, cmd->id,
1398 						 cmd->add_reserved.start,
1399 						 cmd->add_reserved.length);
1400 	case IOMMU_TEST_OP_MOCK_DOMAIN:
1401 	case IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS:
1402 		return iommufd_test_mock_domain(ucmd, cmd);
1403 	case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
1404 		return iommufd_test_mock_domain_replace(
1405 			ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
1406 	case IOMMU_TEST_OP_MD_CHECK_MAP:
1407 		return iommufd_test_md_check_pa(
1408 			ucmd, cmd->id, cmd->check_map.iova,
1409 			cmd->check_map.length,
1410 			u64_to_user_ptr(cmd->check_map.uptr));
1411 	case IOMMU_TEST_OP_MD_CHECK_REFS:
1412 		return iommufd_test_md_check_refs(
1413 			ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
1414 			cmd->check_refs.length, cmd->check_refs.refs);
1415 	case IOMMU_TEST_OP_MD_CHECK_IOTLB:
1416 		return iommufd_test_md_check_iotlb(ucmd, cmd->id,
1417 						   cmd->check_iotlb.id,
1418 						   cmd->check_iotlb.iotlb);
1419 	case IOMMU_TEST_OP_CREATE_ACCESS:
1420 		return iommufd_test_create_access(ucmd, cmd->id,
1421 						  cmd->create_access.flags);
1422 	case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
1423 		return iommufd_test_access_replace_ioas(
1424 			ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
1425 	case IOMMU_TEST_OP_ACCESS_PAGES:
1426 		return iommufd_test_access_pages(
1427 			ucmd, cmd->id, cmd->access_pages.iova,
1428 			cmd->access_pages.length,
1429 			u64_to_user_ptr(cmd->access_pages.uptr),
1430 			cmd->access_pages.flags);
1431 	case IOMMU_TEST_OP_ACCESS_RW:
1432 		return iommufd_test_access_rw(
1433 			ucmd, cmd->id, cmd->access_rw.iova,
1434 			cmd->access_rw.length,
1435 			u64_to_user_ptr(cmd->access_rw.uptr),
1436 			cmd->access_rw.flags);
1437 	case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
1438 		return iommufd_test_access_item_destroy(
1439 			ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
1440 	case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
1441 		/* Protect _batch_init(), can not be less than elmsz */
1442 		if (cmd->memory_limit.limit <
1443 		    sizeof(unsigned long) + sizeof(u32))
1444 			return -EINVAL;
1445 		iommufd_test_memory_limit = cmd->memory_limit.limit;
1446 		return 0;
1447 	case IOMMU_TEST_OP_DIRTY:
1448 		return iommufd_test_dirty(ucmd, cmd->id, cmd->dirty.iova,
1449 					  cmd->dirty.length,
1450 					  cmd->dirty.page_size,
1451 					  u64_to_user_ptr(cmd->dirty.uptr),
1452 					  cmd->dirty.flags);
1453 	default:
1454 		return -EOPNOTSUPP;
1455 	}
1456 }
1457 
1458 bool iommufd_should_fail(void)
1459 {
1460 	return should_fail(&fail_iommufd, 1);
1461 }
1462 
1463 int __init iommufd_test_init(void)
1464 {
1465 	struct platform_device_info pdevinfo = {
1466 		.name = "iommufd_selftest_iommu",
1467 	};
1468 	int rc;
1469 
1470 	dbgfs_root =
1471 		fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);
1472 
1473 	selftest_iommu_dev = platform_device_register_full(&pdevinfo);
1474 	if (IS_ERR(selftest_iommu_dev)) {
1475 		rc = PTR_ERR(selftest_iommu_dev);
1476 		goto err_dbgfs;
1477 	}
1478 
1479 	rc = bus_register(&iommufd_mock_bus_type.bus);
1480 	if (rc)
1481 		goto err_platform;
1482 
1483 	rc = iommu_device_sysfs_add(&mock_iommu_device,
1484 				    &selftest_iommu_dev->dev, NULL, "%s",
1485 				    dev_name(&selftest_iommu_dev->dev));
1486 	if (rc)
1487 		goto err_bus;
1488 
1489 	rc = iommu_device_register_bus(&mock_iommu_device, &mock_ops,
1490 				  &iommufd_mock_bus_type.bus,
1491 				  &iommufd_mock_bus_type.nb);
1492 	if (rc)
1493 		goto err_sysfs;
1494 	return 0;
1495 
1496 err_sysfs:
1497 	iommu_device_sysfs_remove(&mock_iommu_device);
1498 err_bus:
1499 	bus_unregister(&iommufd_mock_bus_type.bus);
1500 err_platform:
1501 	platform_device_unregister(selftest_iommu_dev);
1502 err_dbgfs:
1503 	debugfs_remove_recursive(dbgfs_root);
1504 	return rc;
1505 }
1506 
1507 void iommufd_test_exit(void)
1508 {
1509 	iommu_device_sysfs_remove(&mock_iommu_device);
1510 	iommu_device_unregister_bus(&mock_iommu_device,
1511 				    &iommufd_mock_bus_type.bus,
1512 				    &iommufd_mock_bus_type.nb);
1513 	bus_unregister(&iommufd_mock_bus_type.bus);
1514 	platform_device_unregister(selftest_iommu_dev);
1515 	debugfs_remove_recursive(dbgfs_root);
1516 }
1517