xref: /linux/drivers/iommu/iommufd/pages.c (revision f22cc6f766f84496b260347d4f0d92cf95f30699)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/file.h>
49 #include <linux/highmem.h>
50 #include <linux/iommu.h>
51 #include <linux/iommufd.h>
52 #include <linux/kthread.h>
53 #include <linux/overflow.h>
54 #include <linux/slab.h>
55 #include <linux/sched/mm.h>
56 
57 #include "double_span.h"
58 #include "io_pagetable.h"
59 
60 #ifndef CONFIG_IOMMUFD_TEST
61 #define TEMP_MEMORY_LIMIT 65536
62 #else
63 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
64 #endif
65 #define BATCH_BACKUP_SIZE 32
66 
67 /*
68  * More memory makes pin_user_pages() and the batching more efficient, but as
69  * this is only a performance optimization don't try too hard to get it. A 64k
70  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
71  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
72  * stack memory as a backup contingency. If backup_len is given this cannot
73  * fail.
74  */
75 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
76 {
77 	void *res;
78 
79 	if (WARN_ON(*size == 0))
80 		return NULL;
81 
82 	if (*size < backup_len)
83 		return backup;
84 
85 	if (!backup && iommufd_should_fail())
86 		return NULL;
87 
88 	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
89 	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
90 	if (res)
91 		return res;
92 	*size = PAGE_SIZE;
93 	if (backup_len) {
94 		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
95 		if (res)
96 			return res;
97 		*size = backup_len;
98 		return backup;
99 	}
100 	return kmalloc(*size, GFP_KERNEL);
101 }
102 
103 void interval_tree_double_span_iter_update(
104 	struct interval_tree_double_span_iter *iter)
105 {
106 	unsigned long last_hole = ULONG_MAX;
107 	unsigned int i;
108 
109 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
110 		if (interval_tree_span_iter_done(&iter->spans[i])) {
111 			iter->is_used = -1;
112 			return;
113 		}
114 
115 		if (iter->spans[i].is_hole) {
116 			last_hole = min(last_hole, iter->spans[i].last_hole);
117 			continue;
118 		}
119 
120 		iter->is_used = i + 1;
121 		iter->start_used = iter->spans[i].start_used;
122 		iter->last_used = min(iter->spans[i].last_used, last_hole);
123 		return;
124 	}
125 
126 	iter->is_used = 0;
127 	iter->start_hole = iter->spans[0].start_hole;
128 	iter->last_hole =
129 		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
130 }
131 
132 void interval_tree_double_span_iter_first(
133 	struct interval_tree_double_span_iter *iter,
134 	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
135 	unsigned long first_index, unsigned long last_index)
136 {
137 	unsigned int i;
138 
139 	iter->itrees[0] = itree1;
140 	iter->itrees[1] = itree2;
141 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
142 		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
143 					      first_index, last_index);
144 	interval_tree_double_span_iter_update(iter);
145 }
146 
147 void interval_tree_double_span_iter_next(
148 	struct interval_tree_double_span_iter *iter)
149 {
150 	unsigned int i;
151 
152 	if (iter->is_used == -1 ||
153 	    iter->last_hole == iter->spans[0].last_index) {
154 		iter->is_used = -1;
155 		return;
156 	}
157 
158 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
159 		interval_tree_span_iter_advance(
160 			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
161 	interval_tree_double_span_iter_update(iter);
162 }
163 
164 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
165 {
166 	int rc;
167 
168 	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
169 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
170 		WARN_ON(rc || pages->npinned > pages->npages);
171 }
172 
173 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
174 {
175 	int rc;
176 
177 	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
178 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
179 		WARN_ON(rc || pages->npinned > pages->npages);
180 }
181 
182 static void iopt_pages_err_unpin(struct iopt_pages *pages,
183 				 unsigned long start_index,
184 				 unsigned long last_index,
185 				 struct page **page_list)
186 {
187 	unsigned long npages = last_index - start_index + 1;
188 
189 	unpin_user_pages(page_list, npages);
190 	iopt_pages_sub_npinned(pages, npages);
191 }
192 
193 /*
194  * index is the number of PAGE_SIZE units from the start of the area's
195  * iopt_pages. If the iova is sub page-size then the area has an iova that
196  * covers a portion of the first and last pages in the range.
197  */
198 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
199 					     unsigned long index)
200 {
201 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
202 		WARN_ON(index < iopt_area_index(area) ||
203 			index > iopt_area_last_index(area));
204 	index -= iopt_area_index(area);
205 	if (index == 0)
206 		return iopt_area_iova(area);
207 	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
208 }
209 
210 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
211 						  unsigned long index)
212 {
213 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
214 		WARN_ON(index < iopt_area_index(area) ||
215 			index > iopt_area_last_index(area));
216 	if (index == iopt_area_last_index(area))
217 		return iopt_area_last_iova(area);
218 	return iopt_area_iova(area) - area->page_offset +
219 	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
220 }
221 
222 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
223 			       size_t size)
224 {
225 	size_t ret;
226 
227 	ret = iommu_unmap(domain, iova, size);
228 	/*
229 	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
230 	 * something other than exactly as requested. This implies that the
231 	 * iommu driver may not fail unmap for reasons beyond bad agruments.
232 	 * Particularly, the iommu driver may not do a memory allocation on the
233 	 * unmap path.
234 	 */
235 	WARN_ON(ret != size);
236 }
237 
238 static void iopt_area_unmap_domain_range(struct iopt_area *area,
239 					 struct iommu_domain *domain,
240 					 unsigned long start_index,
241 					 unsigned long last_index)
242 {
243 	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
244 
245 	iommu_unmap_nofail(domain, start_iova,
246 			   iopt_area_index_to_iova_last(area, last_index) -
247 				   start_iova + 1);
248 }
249 
250 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
251 						     unsigned long index)
252 {
253 	struct interval_tree_node *node;
254 
255 	node = interval_tree_iter_first(&pages->domains_itree, index, index);
256 	if (!node)
257 		return NULL;
258 	return container_of(node, struct iopt_area, pages_node);
259 }
260 
261 /*
262  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
263  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
264  * place to another. Generally everything is made more efficient if operations
265  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
266  * better cache locality, etc
267  */
268 struct pfn_batch {
269 	unsigned long *pfns;
270 	u32 *npfns;
271 	unsigned int array_size;
272 	unsigned int end;
273 	unsigned int total_pfns;
274 };
275 
276 static void batch_clear(struct pfn_batch *batch)
277 {
278 	batch->total_pfns = 0;
279 	batch->end = 0;
280 	batch->pfns[0] = 0;
281 	batch->npfns[0] = 0;
282 }
283 
284 /*
285  * Carry means we carry a portion of the final hugepage over to the front of the
286  * batch
287  */
288 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
289 {
290 	if (!keep_pfns)
291 		return batch_clear(batch);
292 
293 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
294 		WARN_ON(!batch->end ||
295 			batch->npfns[batch->end - 1] < keep_pfns);
296 
297 	batch->total_pfns = keep_pfns;
298 	batch->pfns[0] = batch->pfns[batch->end - 1] +
299 			 (batch->npfns[batch->end - 1] - keep_pfns);
300 	batch->npfns[0] = keep_pfns;
301 	batch->end = 1;
302 }
303 
304 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
305 {
306 	if (!batch->total_pfns)
307 		return;
308 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
309 		WARN_ON(batch->total_pfns != batch->npfns[0]);
310 	skip_pfns = min(batch->total_pfns, skip_pfns);
311 	batch->pfns[0] += skip_pfns;
312 	batch->npfns[0] -= skip_pfns;
313 	batch->total_pfns -= skip_pfns;
314 }
315 
316 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
317 			size_t backup_len)
318 {
319 	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
320 	size_t size = max_pages * elmsz;
321 
322 	batch->pfns = temp_kmalloc(&size, backup, backup_len);
323 	if (!batch->pfns)
324 		return -ENOMEM;
325 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
326 		return -EINVAL;
327 	batch->array_size = size / elmsz;
328 	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
329 	batch_clear(batch);
330 	return 0;
331 }
332 
333 static int batch_init(struct pfn_batch *batch, size_t max_pages)
334 {
335 	return __batch_init(batch, max_pages, NULL, 0);
336 }
337 
338 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
339 			      void *backup, size_t backup_len)
340 {
341 	__batch_init(batch, max_pages, backup, backup_len);
342 }
343 
344 static void batch_destroy(struct pfn_batch *batch, void *backup)
345 {
346 	if (batch->pfns != backup)
347 		kfree(batch->pfns);
348 }
349 
350 static bool batch_add_pfn_num(struct pfn_batch *batch, unsigned long pfn,
351 			      u32 nr)
352 {
353 	const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
354 	unsigned int end = batch->end;
355 
356 	if (end && pfn == batch->pfns[end - 1] + batch->npfns[end - 1] &&
357 	    nr <= MAX_NPFNS - batch->npfns[end - 1]) {
358 		batch->npfns[end - 1] += nr;
359 	} else if (end < batch->array_size) {
360 		batch->pfns[end] = pfn;
361 		batch->npfns[end] = nr;
362 		batch->end++;
363 	} else {
364 		return false;
365 	}
366 
367 	batch->total_pfns += nr;
368 	return true;
369 }
370 
371 static void batch_remove_pfn_num(struct pfn_batch *batch, unsigned long nr)
372 {
373 	batch->npfns[batch->end - 1] -= nr;
374 	if (batch->npfns[batch->end - 1] == 0)
375 		batch->end--;
376 	batch->total_pfns -= nr;
377 }
378 
379 /* true if the pfn was added, false otherwise */
380 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
381 {
382 	return batch_add_pfn_num(batch, pfn, 1);
383 }
384 
385 /*
386  * Fill the batch with pfns from the domain. When the batch is full, or it
387  * reaches last_index, the function will return. The caller should use
388  * batch->total_pfns to determine the starting point for the next iteration.
389  */
390 static void batch_from_domain(struct pfn_batch *batch,
391 			      struct iommu_domain *domain,
392 			      struct iopt_area *area, unsigned long start_index,
393 			      unsigned long last_index)
394 {
395 	unsigned int page_offset = 0;
396 	unsigned long iova;
397 	phys_addr_t phys;
398 
399 	iova = iopt_area_index_to_iova(area, start_index);
400 	if (start_index == iopt_area_index(area))
401 		page_offset = area->page_offset;
402 	while (start_index <= last_index) {
403 		/*
404 		 * This is pretty slow, it would be nice to get the page size
405 		 * back from the driver, or have the driver directly fill the
406 		 * batch.
407 		 */
408 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
409 		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
410 			return;
411 		iova += PAGE_SIZE - page_offset;
412 		page_offset = 0;
413 		start_index++;
414 	}
415 }
416 
417 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
418 					   struct iopt_area *area,
419 					   unsigned long start_index,
420 					   unsigned long last_index,
421 					   struct page **out_pages)
422 {
423 	unsigned int page_offset = 0;
424 	unsigned long iova;
425 	phys_addr_t phys;
426 
427 	iova = iopt_area_index_to_iova(area, start_index);
428 	if (start_index == iopt_area_index(area))
429 		page_offset = area->page_offset;
430 	while (start_index <= last_index) {
431 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
432 		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
433 		iova += PAGE_SIZE - page_offset;
434 		page_offset = 0;
435 		start_index++;
436 	}
437 	return out_pages;
438 }
439 
440 /* Continues reading a domain until we reach a discontinuity in the pfns. */
441 static void batch_from_domain_continue(struct pfn_batch *batch,
442 				       struct iommu_domain *domain,
443 				       struct iopt_area *area,
444 				       unsigned long start_index,
445 				       unsigned long last_index)
446 {
447 	unsigned int array_size = batch->array_size;
448 
449 	batch->array_size = batch->end;
450 	batch_from_domain(batch, domain, area, start_index, last_index);
451 	batch->array_size = array_size;
452 }
453 
454 /*
455  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
456  * mode permits splitting a mapped area up, and then one of the splits is
457  * unmapped. Doing this normally would cause us to violate our invariant of
458  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
459  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
460  * PAGE_SIZE units, not larger or smaller.
461  */
462 static int batch_iommu_map_small(struct iommu_domain *domain,
463 				 unsigned long iova, phys_addr_t paddr,
464 				 size_t size, int prot)
465 {
466 	unsigned long start_iova = iova;
467 	int rc;
468 
469 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
470 		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
471 			size % PAGE_SIZE);
472 
473 	while (size) {
474 		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
475 			       GFP_KERNEL_ACCOUNT);
476 		if (rc)
477 			goto err_unmap;
478 		iova += PAGE_SIZE;
479 		paddr += PAGE_SIZE;
480 		size -= PAGE_SIZE;
481 	}
482 	return 0;
483 
484 err_unmap:
485 	if (start_iova != iova)
486 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
487 	return rc;
488 }
489 
490 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
491 			   struct iopt_area *area, unsigned long start_index)
492 {
493 	bool disable_large_pages = area->iopt->disable_large_pages;
494 	unsigned long last_iova = iopt_area_last_iova(area);
495 	unsigned int page_offset = 0;
496 	unsigned long start_iova;
497 	unsigned long next_iova;
498 	unsigned int cur = 0;
499 	unsigned long iova;
500 	int rc;
501 
502 	/* The first index might be a partial page */
503 	if (start_index == iopt_area_index(area))
504 		page_offset = area->page_offset;
505 	next_iova = iova = start_iova =
506 		iopt_area_index_to_iova(area, start_index);
507 	while (cur < batch->end) {
508 		next_iova = min(last_iova + 1,
509 				next_iova + batch->npfns[cur] * PAGE_SIZE -
510 					page_offset);
511 		if (disable_large_pages)
512 			rc = batch_iommu_map_small(
513 				domain, iova,
514 				PFN_PHYS(batch->pfns[cur]) + page_offset,
515 				next_iova - iova, area->iommu_prot);
516 		else
517 			rc = iommu_map(domain, iova,
518 				       PFN_PHYS(batch->pfns[cur]) + page_offset,
519 				       next_iova - iova, area->iommu_prot,
520 				       GFP_KERNEL_ACCOUNT);
521 		if (rc)
522 			goto err_unmap;
523 		iova = next_iova;
524 		page_offset = 0;
525 		cur++;
526 	}
527 	return 0;
528 err_unmap:
529 	if (start_iova != iova)
530 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
531 	return rc;
532 }
533 
534 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
535 			      unsigned long start_index,
536 			      unsigned long last_index)
537 {
538 	XA_STATE(xas, xa, start_index);
539 	void *entry;
540 
541 	rcu_read_lock();
542 	while (true) {
543 		entry = xas_next(&xas);
544 		if (xas_retry(&xas, entry))
545 			continue;
546 		WARN_ON(!xa_is_value(entry));
547 		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
548 		    start_index == last_index)
549 			break;
550 		start_index++;
551 	}
552 	rcu_read_unlock();
553 }
554 
555 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
556 				    unsigned long start_index,
557 				    unsigned long last_index)
558 {
559 	XA_STATE(xas, xa, start_index);
560 	void *entry;
561 
562 	xas_lock(&xas);
563 	while (true) {
564 		entry = xas_next(&xas);
565 		if (xas_retry(&xas, entry))
566 			continue;
567 		WARN_ON(!xa_is_value(entry));
568 		if (!batch_add_pfn(batch, xa_to_value(entry)))
569 			break;
570 		xas_store(&xas, NULL);
571 		if (start_index == last_index)
572 			break;
573 		start_index++;
574 	}
575 	xas_unlock(&xas);
576 }
577 
578 static void clear_xarray(struct xarray *xa, unsigned long start_index,
579 			 unsigned long last_index)
580 {
581 	XA_STATE(xas, xa, start_index);
582 	void *entry;
583 
584 	xas_lock(&xas);
585 	xas_for_each(&xas, entry, last_index)
586 		xas_store(&xas, NULL);
587 	xas_unlock(&xas);
588 }
589 
590 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
591 			   unsigned long last_index, struct page **pages)
592 {
593 	struct page **end_pages = pages + (last_index - start_index) + 1;
594 	struct page **half_pages = pages + (end_pages - pages) / 2;
595 	XA_STATE(xas, xa, start_index);
596 
597 	do {
598 		void *old;
599 
600 		xas_lock(&xas);
601 		while (pages != end_pages) {
602 			/* xarray does not participate in fault injection */
603 			if (pages == half_pages && iommufd_should_fail()) {
604 				xas_set_err(&xas, -EINVAL);
605 				xas_unlock(&xas);
606 				/* aka xas_destroy() */
607 				xas_nomem(&xas, GFP_KERNEL);
608 				goto err_clear;
609 			}
610 
611 			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
612 			if (xas_error(&xas))
613 				break;
614 			WARN_ON(old);
615 			pages++;
616 			xas_next(&xas);
617 		}
618 		xas_unlock(&xas);
619 	} while (xas_nomem(&xas, GFP_KERNEL));
620 
621 err_clear:
622 	if (xas_error(&xas)) {
623 		if (xas.xa_index != start_index)
624 			clear_xarray(xa, start_index, xas.xa_index - 1);
625 		return xas_error(&xas);
626 	}
627 	return 0;
628 }
629 
630 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
631 			     size_t npages)
632 {
633 	struct page **end = pages + npages;
634 
635 	for (; pages != end; pages++)
636 		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
637 			break;
638 }
639 
640 static int batch_from_folios(struct pfn_batch *batch, struct folio ***folios_p,
641 			     unsigned long *offset_p, unsigned long npages)
642 {
643 	int rc = 0;
644 	struct folio **folios = *folios_p;
645 	unsigned long offset = *offset_p;
646 
647 	while (npages) {
648 		struct folio *folio = *folios;
649 		unsigned long nr = folio_nr_pages(folio) - offset;
650 		unsigned long pfn = page_to_pfn(folio_page(folio, offset));
651 
652 		nr = min(nr, npages);
653 		npages -= nr;
654 
655 		if (!batch_add_pfn_num(batch, pfn, nr))
656 			break;
657 		if (nr > 1) {
658 			rc = folio_add_pins(folio, nr - 1);
659 			if (rc) {
660 				batch_remove_pfn_num(batch, nr);
661 				goto out;
662 			}
663 		}
664 
665 		folios++;
666 		offset = 0;
667 	}
668 
669 out:
670 	*folios_p = folios;
671 	*offset_p = offset;
672 	return rc;
673 }
674 
675 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
676 			unsigned int first_page_off, size_t npages)
677 {
678 	unsigned int cur = 0;
679 
680 	while (first_page_off) {
681 		if (batch->npfns[cur] > first_page_off)
682 			break;
683 		first_page_off -= batch->npfns[cur];
684 		cur++;
685 	}
686 
687 	while (npages) {
688 		size_t to_unpin = min_t(size_t, npages,
689 					batch->npfns[cur] - first_page_off);
690 
691 		unpin_user_page_range_dirty_lock(
692 			pfn_to_page(batch->pfns[cur] + first_page_off),
693 			to_unpin, pages->writable);
694 		iopt_pages_sub_npinned(pages, to_unpin);
695 		cur++;
696 		first_page_off = 0;
697 		npages -= to_unpin;
698 	}
699 }
700 
701 static void copy_data_page(struct page *page, void *data, unsigned long offset,
702 			   size_t length, unsigned int flags)
703 {
704 	void *mem;
705 
706 	mem = kmap_local_page(page);
707 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
708 		memcpy(mem + offset, data, length);
709 		set_page_dirty_lock(page);
710 	} else {
711 		memcpy(data, mem + offset, length);
712 	}
713 	kunmap_local(mem);
714 }
715 
716 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
717 			      unsigned long offset, unsigned long length,
718 			      unsigned int flags)
719 {
720 	unsigned long copied = 0;
721 	unsigned int npage = 0;
722 	unsigned int cur = 0;
723 
724 	while (cur < batch->end) {
725 		unsigned long bytes = min(length, PAGE_SIZE - offset);
726 
727 		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
728 			       offset, bytes, flags);
729 		offset = 0;
730 		length -= bytes;
731 		data += bytes;
732 		copied += bytes;
733 		npage++;
734 		if (npage == batch->npfns[cur]) {
735 			npage = 0;
736 			cur++;
737 		}
738 		if (!length)
739 			break;
740 	}
741 	return copied;
742 }
743 
744 /* pfn_reader_user is just the pin_user_pages() path */
745 struct pfn_reader_user {
746 	struct page **upages;
747 	size_t upages_len;
748 	unsigned long upages_start;
749 	unsigned long upages_end;
750 	unsigned int gup_flags;
751 	/*
752 	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
753 	 * neither
754 	 */
755 	int locked;
756 
757 	/* The following are only valid if file != NULL. */
758 	struct file *file;
759 	struct folio **ufolios;
760 	size_t ufolios_len;
761 	unsigned long ufolios_offset;
762 	struct folio **ufolios_next;
763 };
764 
765 static void pfn_reader_user_init(struct pfn_reader_user *user,
766 				 struct iopt_pages *pages)
767 {
768 	user->upages = NULL;
769 	user->upages_len = 0;
770 	user->upages_start = 0;
771 	user->upages_end = 0;
772 	user->locked = -1;
773 	user->gup_flags = FOLL_LONGTERM;
774 	if (pages->writable)
775 		user->gup_flags |= FOLL_WRITE;
776 
777 	user->file = (pages->type == IOPT_ADDRESS_FILE) ? pages->file : NULL;
778 	user->ufolios = NULL;
779 	user->ufolios_len = 0;
780 	user->ufolios_next = NULL;
781 	user->ufolios_offset = 0;
782 }
783 
784 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
785 				    struct iopt_pages *pages)
786 {
787 	if (user->locked != -1) {
788 		if (user->locked)
789 			mmap_read_unlock(pages->source_mm);
790 		if (!user->file && pages->source_mm != current->mm)
791 			mmput(pages->source_mm);
792 		user->locked = -1;
793 	}
794 
795 	kfree(user->upages);
796 	user->upages = NULL;
797 	kfree(user->ufolios);
798 	user->ufolios = NULL;
799 }
800 
801 static long pin_memfd_pages(struct pfn_reader_user *user, unsigned long start,
802 			    unsigned long npages)
803 {
804 	unsigned long i;
805 	unsigned long offset;
806 	unsigned long npages_out = 0;
807 	struct page **upages = user->upages;
808 	unsigned long end = start + (npages << PAGE_SHIFT) - 1;
809 	long nfolios = user->ufolios_len / sizeof(*user->ufolios);
810 
811 	/*
812 	 * todo: memfd_pin_folios should return the last pinned offset so
813 	 * we can compute npages pinned, and avoid looping over folios here
814 	 * if upages == NULL.
815 	 */
816 	nfolios = memfd_pin_folios(user->file, start, end, user->ufolios,
817 				   nfolios, &offset);
818 	if (nfolios <= 0)
819 		return nfolios;
820 
821 	offset >>= PAGE_SHIFT;
822 	user->ufolios_next = user->ufolios;
823 	user->ufolios_offset = offset;
824 
825 	for (i = 0; i < nfolios; i++) {
826 		struct folio *folio = user->ufolios[i];
827 		unsigned long nr = folio_nr_pages(folio);
828 		unsigned long npin = min(nr - offset, npages);
829 
830 		npages -= npin;
831 		npages_out += npin;
832 
833 		if (upages) {
834 			if (npin == 1) {
835 				*upages++ = folio_page(folio, offset);
836 			} else {
837 				int rc = folio_add_pins(folio, npin - 1);
838 
839 				if (rc)
840 					return rc;
841 
842 				while (npin--)
843 					*upages++ = folio_page(folio, offset++);
844 			}
845 		}
846 
847 		offset = 0;
848 	}
849 
850 	return npages_out;
851 }
852 
853 static int pfn_reader_user_pin(struct pfn_reader_user *user,
854 			       struct iopt_pages *pages,
855 			       unsigned long start_index,
856 			       unsigned long last_index)
857 {
858 	bool remote_mm = pages->source_mm != current->mm;
859 	unsigned long npages = last_index - start_index + 1;
860 	unsigned long start;
861 	unsigned long unum;
862 	uintptr_t uptr;
863 	long rc;
864 
865 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
866 	    WARN_ON(last_index < start_index))
867 		return -EINVAL;
868 
869 	if (!user->file && !user->upages) {
870 		/* All undone in pfn_reader_destroy() */
871 		user->upages_len = npages * sizeof(*user->upages);
872 		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
873 		if (!user->upages)
874 			return -ENOMEM;
875 	}
876 
877 	if (user->file && !user->ufolios) {
878 		user->ufolios_len = npages * sizeof(*user->ufolios);
879 		user->ufolios = temp_kmalloc(&user->ufolios_len, NULL, 0);
880 		if (!user->ufolios)
881 			return -ENOMEM;
882 	}
883 
884 	if (user->locked == -1) {
885 		/*
886 		 * The majority of usages will run the map task within the mm
887 		 * providing the pages, so we can optimize into
888 		 * get_user_pages_fast()
889 		 */
890 		if (!user->file && remote_mm) {
891 			if (!mmget_not_zero(pages->source_mm))
892 				return -EFAULT;
893 		}
894 		user->locked = 0;
895 	}
896 
897 	unum = user->file ? user->ufolios_len / sizeof(*user->ufolios) :
898 			    user->upages_len / sizeof(*user->upages);
899 	npages = min_t(unsigned long, npages, unum);
900 
901 	if (iommufd_should_fail())
902 		return -EFAULT;
903 
904 	if (user->file) {
905 		start = pages->start + (start_index * PAGE_SIZE);
906 		rc = pin_memfd_pages(user, start, npages);
907 	} else if (!remote_mm) {
908 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
909 		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
910 					 user->upages);
911 	} else {
912 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
913 		if (!user->locked) {
914 			mmap_read_lock(pages->source_mm);
915 			user->locked = 1;
916 		}
917 		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
918 					   user->gup_flags, user->upages,
919 					   &user->locked);
920 	}
921 	if (rc <= 0) {
922 		if (WARN_ON(!rc))
923 			return -EFAULT;
924 		return rc;
925 	}
926 	iopt_pages_add_npinned(pages, rc);
927 	user->upages_start = start_index;
928 	user->upages_end = start_index + rc;
929 	return 0;
930 }
931 
932 /* This is the "modern" and faster accounting method used by io_uring */
933 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
934 {
935 	unsigned long lock_limit;
936 	unsigned long cur_pages;
937 	unsigned long new_pages;
938 
939 	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
940 		     PAGE_SHIFT;
941 
942 	cur_pages = atomic_long_read(&pages->source_user->locked_vm);
943 	do {
944 		new_pages = cur_pages + npages;
945 		if (new_pages > lock_limit)
946 			return -ENOMEM;
947 	} while (!atomic_long_try_cmpxchg(&pages->source_user->locked_vm,
948 					  &cur_pages, new_pages));
949 	return 0;
950 }
951 
952 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
953 {
954 	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
955 		return;
956 	atomic_long_sub(npages, &pages->source_user->locked_vm);
957 }
958 
959 /* This is the accounting method used for compatibility with VFIO */
960 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
961 			       bool inc, struct pfn_reader_user *user)
962 {
963 	bool do_put = false;
964 	int rc;
965 
966 	if (user && user->locked) {
967 		mmap_read_unlock(pages->source_mm);
968 		user->locked = 0;
969 		/* If we had the lock then we also have a get */
970 
971 	} else if ((!user || (!user->upages && !user->ufolios)) &&
972 		   pages->source_mm != current->mm) {
973 		if (!mmget_not_zero(pages->source_mm))
974 			return -EINVAL;
975 		do_put = true;
976 	}
977 
978 	mmap_write_lock(pages->source_mm);
979 	rc = __account_locked_vm(pages->source_mm, npages, inc,
980 				 pages->source_task, false);
981 	mmap_write_unlock(pages->source_mm);
982 
983 	if (do_put)
984 		mmput(pages->source_mm);
985 	return rc;
986 }
987 
988 int iopt_pages_update_pinned(struct iopt_pages *pages, unsigned long npages,
989 			     bool inc, struct pfn_reader_user *user)
990 {
991 	int rc = 0;
992 
993 	switch (pages->account_mode) {
994 	case IOPT_PAGES_ACCOUNT_NONE:
995 		break;
996 	case IOPT_PAGES_ACCOUNT_USER:
997 		if (inc)
998 			rc = incr_user_locked_vm(pages, npages);
999 		else
1000 			decr_user_locked_vm(pages, npages);
1001 		break;
1002 	case IOPT_PAGES_ACCOUNT_MM:
1003 		rc = update_mm_locked_vm(pages, npages, inc, user);
1004 		break;
1005 	}
1006 	if (rc)
1007 		return rc;
1008 
1009 	pages->last_npinned = pages->npinned;
1010 	if (inc)
1011 		atomic64_add(npages, &pages->source_mm->pinned_vm);
1012 	else
1013 		atomic64_sub(npages, &pages->source_mm->pinned_vm);
1014 	return 0;
1015 }
1016 
1017 static void update_unpinned(struct iopt_pages *pages)
1018 {
1019 	if (WARN_ON(pages->npinned > pages->last_npinned))
1020 		return;
1021 	if (pages->npinned == pages->last_npinned)
1022 		return;
1023 	iopt_pages_update_pinned(pages, pages->last_npinned - pages->npinned,
1024 				 false, NULL);
1025 }
1026 
1027 /*
1028  * Changes in the number of pages pinned is done after the pages have been read
1029  * and processed. If the user lacked the limit then the error unwind will unpin
1030  * everything that was just pinned. This is because it is expensive to calculate
1031  * how many pages we have already pinned within a range to generate an accurate
1032  * prediction in advance of doing the work to actually pin them.
1033  */
1034 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
1035 					 struct iopt_pages *pages)
1036 {
1037 	unsigned long npages;
1038 	bool inc;
1039 
1040 	lockdep_assert_held(&pages->mutex);
1041 
1042 	if (pages->npinned == pages->last_npinned)
1043 		return 0;
1044 
1045 	if (pages->npinned < pages->last_npinned) {
1046 		npages = pages->last_npinned - pages->npinned;
1047 		inc = false;
1048 	} else {
1049 		if (iommufd_should_fail())
1050 			return -ENOMEM;
1051 		npages = pages->npinned - pages->last_npinned;
1052 		inc = true;
1053 	}
1054 	return iopt_pages_update_pinned(pages, npages, inc, user);
1055 }
1056 
1057 /*
1058  * PFNs are stored in three places, in order of preference:
1059  * - The iopt_pages xarray. This is only populated if there is a
1060  *   iopt_pages_access
1061  * - The iommu_domain under an area
1062  * - The original PFN source, ie pages->source_mm
1063  *
1064  * This iterator reads the pfns optimizing to load according to the
1065  * above order.
1066  */
1067 struct pfn_reader {
1068 	struct iopt_pages *pages;
1069 	struct interval_tree_double_span_iter span;
1070 	struct pfn_batch batch;
1071 	unsigned long batch_start_index;
1072 	unsigned long batch_end_index;
1073 	unsigned long last_index;
1074 
1075 	struct pfn_reader_user user;
1076 };
1077 
1078 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
1079 {
1080 	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
1081 }
1082 
1083 /*
1084  * The batch can contain a mixture of pages that are still in use and pages that
1085  * need to be unpinned. Unpin only pages that are not held anywhere else.
1086  */
1087 static void pfn_reader_unpin(struct pfn_reader *pfns)
1088 {
1089 	unsigned long last = pfns->batch_end_index - 1;
1090 	unsigned long start = pfns->batch_start_index;
1091 	struct interval_tree_double_span_iter span;
1092 	struct iopt_pages *pages = pfns->pages;
1093 
1094 	lockdep_assert_held(&pages->mutex);
1095 
1096 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1097 					   &pages->domains_itree, start, last) {
1098 		if (span.is_used)
1099 			continue;
1100 
1101 		batch_unpin(&pfns->batch, pages, span.start_hole - start,
1102 			    span.last_hole - span.start_hole + 1);
1103 	}
1104 }
1105 
1106 /* Process a single span to load it from the proper storage */
1107 static int pfn_reader_fill_span(struct pfn_reader *pfns)
1108 {
1109 	struct interval_tree_double_span_iter *span = &pfns->span;
1110 	unsigned long start_index = pfns->batch_end_index;
1111 	struct pfn_reader_user *user = &pfns->user;
1112 	unsigned long npages;
1113 	struct iopt_area *area;
1114 	int rc;
1115 
1116 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1117 	    WARN_ON(span->last_used < start_index))
1118 		return -EINVAL;
1119 
1120 	if (span->is_used == 1) {
1121 		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
1122 				  start_index, span->last_used);
1123 		return 0;
1124 	}
1125 
1126 	if (span->is_used == 2) {
1127 		/*
1128 		 * Pull as many pages from the first domain we find in the
1129 		 * target span. If it is too small then we will be called again
1130 		 * and we'll find another area.
1131 		 */
1132 		area = iopt_pages_find_domain_area(pfns->pages, start_index);
1133 		if (WARN_ON(!area))
1134 			return -EINVAL;
1135 
1136 		/* The storage_domain cannot change without the pages mutex */
1137 		batch_from_domain(
1138 			&pfns->batch, area->storage_domain, area, start_index,
1139 			min(iopt_area_last_index(area), span->last_used));
1140 		return 0;
1141 	}
1142 
1143 	if (start_index >= pfns->user.upages_end) {
1144 		rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1145 					 span->last_hole);
1146 		if (rc)
1147 			return rc;
1148 	}
1149 
1150 	npages = user->upages_end - start_index;
1151 	start_index -= user->upages_start;
1152 	rc = 0;
1153 
1154 	if (!user->file)
1155 		batch_from_pages(&pfns->batch, user->upages + start_index,
1156 				 npages);
1157 	else
1158 		rc = batch_from_folios(&pfns->batch, &user->ufolios_next,
1159 				       &user->ufolios_offset, npages);
1160 	return rc;
1161 }
1162 
1163 static bool pfn_reader_done(struct pfn_reader *pfns)
1164 {
1165 	return pfns->batch_start_index == pfns->last_index + 1;
1166 }
1167 
1168 static int pfn_reader_next(struct pfn_reader *pfns)
1169 {
1170 	int rc;
1171 
1172 	batch_clear(&pfns->batch);
1173 	pfns->batch_start_index = pfns->batch_end_index;
1174 
1175 	while (pfns->batch_end_index != pfns->last_index + 1) {
1176 		unsigned int npfns = pfns->batch.total_pfns;
1177 
1178 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1179 		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1180 			return -EINVAL;
1181 
1182 		rc = pfn_reader_fill_span(pfns);
1183 		if (rc)
1184 			return rc;
1185 
1186 		if (WARN_ON(!pfns->batch.total_pfns))
1187 			return -EINVAL;
1188 
1189 		pfns->batch_end_index =
1190 			pfns->batch_start_index + pfns->batch.total_pfns;
1191 		if (pfns->batch_end_index == pfns->span.last_used + 1)
1192 			interval_tree_double_span_iter_next(&pfns->span);
1193 
1194 		/* Batch is full */
1195 		if (npfns == pfns->batch.total_pfns)
1196 			return 0;
1197 	}
1198 	return 0;
1199 }
1200 
1201 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1202 			   unsigned long start_index, unsigned long last_index)
1203 {
1204 	int rc;
1205 
1206 	lockdep_assert_held(&pages->mutex);
1207 
1208 	pfns->pages = pages;
1209 	pfns->batch_start_index = start_index;
1210 	pfns->batch_end_index = start_index;
1211 	pfns->last_index = last_index;
1212 	pfn_reader_user_init(&pfns->user, pages);
1213 	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1214 	if (rc)
1215 		return rc;
1216 	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1217 					     &pages->domains_itree, start_index,
1218 					     last_index);
1219 	return 0;
1220 }
1221 
1222 /*
1223  * There are many assertions regarding the state of pages->npinned vs
1224  * pages->last_pinned, for instance something like unmapping a domain must only
1225  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1226  * the pins are updated. This is fine for success flows, but error flows
1227  * sometimes need to release the pins held inside the pfn_reader before going on
1228  * to complete unmapping and releasing pins held in domains.
1229  */
1230 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1231 {
1232 	struct iopt_pages *pages = pfns->pages;
1233 	struct pfn_reader_user *user = &pfns->user;
1234 
1235 	if (user->upages_end > pfns->batch_end_index) {
1236 		/* Any pages not transferred to the batch are just unpinned */
1237 
1238 		unsigned long npages = user->upages_end - pfns->batch_end_index;
1239 		unsigned long start_index = pfns->batch_end_index -
1240 					    user->upages_start;
1241 
1242 		if (!user->file) {
1243 			unpin_user_pages(user->upages + start_index, npages);
1244 		} else {
1245 			long n = user->ufolios_len / sizeof(*user->ufolios);
1246 
1247 			unpin_folios(user->ufolios_next,
1248 				     user->ufolios + n - user->ufolios_next);
1249 		}
1250 		iopt_pages_sub_npinned(pages, npages);
1251 		user->upages_end = pfns->batch_end_index;
1252 	}
1253 	if (pfns->batch_start_index != pfns->batch_end_index) {
1254 		pfn_reader_unpin(pfns);
1255 		pfns->batch_start_index = pfns->batch_end_index;
1256 	}
1257 }
1258 
1259 static void pfn_reader_destroy(struct pfn_reader *pfns)
1260 {
1261 	struct iopt_pages *pages = pfns->pages;
1262 
1263 	pfn_reader_release_pins(pfns);
1264 	pfn_reader_user_destroy(&pfns->user, pfns->pages);
1265 	batch_destroy(&pfns->batch, NULL);
1266 	WARN_ON(pages->last_npinned != pages->npinned);
1267 }
1268 
1269 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1270 			    unsigned long start_index, unsigned long last_index)
1271 {
1272 	int rc;
1273 
1274 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1275 	    WARN_ON(last_index < start_index))
1276 		return -EINVAL;
1277 
1278 	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1279 	if (rc)
1280 		return rc;
1281 	rc = pfn_reader_next(pfns);
1282 	if (rc) {
1283 		pfn_reader_destroy(pfns);
1284 		return rc;
1285 	}
1286 	return 0;
1287 }
1288 
1289 static struct iopt_pages *iopt_alloc_pages(unsigned long start_byte,
1290 					   unsigned long length, bool writable)
1291 {
1292 	struct iopt_pages *pages;
1293 
1294 	/*
1295 	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1296 	 * below from overflow
1297 	 */
1298 	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1299 		return ERR_PTR(-EINVAL);
1300 
1301 	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1302 	if (!pages)
1303 		return ERR_PTR(-ENOMEM);
1304 
1305 	kref_init(&pages->kref);
1306 	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1307 	mutex_init(&pages->mutex);
1308 	pages->source_mm = current->mm;
1309 	mmgrab(pages->source_mm);
1310 	pages->npages = DIV_ROUND_UP(length + start_byte, PAGE_SIZE);
1311 	pages->access_itree = RB_ROOT_CACHED;
1312 	pages->domains_itree = RB_ROOT_CACHED;
1313 	pages->writable = writable;
1314 	if (capable(CAP_IPC_LOCK))
1315 		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1316 	else
1317 		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1318 	pages->source_task = current->group_leader;
1319 	get_task_struct(current->group_leader);
1320 	pages->source_user = get_uid(current_user());
1321 	return pages;
1322 }
1323 
1324 struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
1325 					 unsigned long length, bool writable)
1326 {
1327 	struct iopt_pages *pages;
1328 	unsigned long end;
1329 	void __user *uptr_down =
1330 		(void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1331 
1332 	if (check_add_overflow((unsigned long)uptr, length, &end))
1333 		return ERR_PTR(-EOVERFLOW);
1334 
1335 	pages = iopt_alloc_pages(uptr - uptr_down, length, writable);
1336 	if (IS_ERR(pages))
1337 		return pages;
1338 	pages->uptr = uptr_down;
1339 	pages->type = IOPT_ADDRESS_USER;
1340 	return pages;
1341 }
1342 
1343 struct iopt_pages *iopt_alloc_file_pages(struct file *file, unsigned long start,
1344 					 unsigned long length, bool writable)
1345 
1346 {
1347 	struct iopt_pages *pages;
1348 	unsigned long start_down = ALIGN_DOWN(start, PAGE_SIZE);
1349 	unsigned long end;
1350 
1351 	if (length && check_add_overflow(start, length - 1, &end))
1352 		return ERR_PTR(-EOVERFLOW);
1353 
1354 	pages = iopt_alloc_pages(start - start_down, length, writable);
1355 	if (IS_ERR(pages))
1356 		return pages;
1357 	pages->file = get_file(file);
1358 	pages->start = start_down;
1359 	pages->type = IOPT_ADDRESS_FILE;
1360 	return pages;
1361 }
1362 
1363 void iopt_release_pages(struct kref *kref)
1364 {
1365 	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1366 
1367 	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1368 	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1369 	WARN_ON(pages->npinned);
1370 	WARN_ON(!xa_empty(&pages->pinned_pfns));
1371 	mmdrop(pages->source_mm);
1372 	mutex_destroy(&pages->mutex);
1373 	put_task_struct(pages->source_task);
1374 	free_uid(pages->source_user);
1375 	if (pages->type == IOPT_ADDRESS_FILE)
1376 		fput(pages->file);
1377 	kfree(pages);
1378 }
1379 
1380 static void
1381 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1382 		       struct iopt_pages *pages, struct iommu_domain *domain,
1383 		       unsigned long start_index, unsigned long last_index,
1384 		       unsigned long *unmapped_end_index,
1385 		       unsigned long real_last_index)
1386 {
1387 	while (start_index <= last_index) {
1388 		unsigned long batch_last_index;
1389 
1390 		if (*unmapped_end_index <= last_index) {
1391 			unsigned long start =
1392 				max(start_index, *unmapped_end_index);
1393 
1394 			if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1395 			    batch->total_pfns)
1396 				WARN_ON(*unmapped_end_index -
1397 						batch->total_pfns !=
1398 					start_index);
1399 			batch_from_domain(batch, domain, area, start,
1400 					  last_index);
1401 			batch_last_index = start_index + batch->total_pfns - 1;
1402 		} else {
1403 			batch_last_index = last_index;
1404 		}
1405 
1406 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1407 			WARN_ON(batch_last_index > real_last_index);
1408 
1409 		/*
1410 		 * unmaps must always 'cut' at a place where the pfns are not
1411 		 * contiguous to pair with the maps that always install
1412 		 * contiguous pages. Thus, if we have to stop unpinning in the
1413 		 * middle of the domains we need to keep reading pfns until we
1414 		 * find a cut point to do the unmap. The pfns we read are
1415 		 * carried over and either skipped or integrated into the next
1416 		 * batch.
1417 		 */
1418 		if (batch_last_index == last_index &&
1419 		    last_index != real_last_index)
1420 			batch_from_domain_continue(batch, domain, area,
1421 						   last_index + 1,
1422 						   real_last_index);
1423 
1424 		if (*unmapped_end_index <= batch_last_index) {
1425 			iopt_area_unmap_domain_range(
1426 				area, domain, *unmapped_end_index,
1427 				start_index + batch->total_pfns - 1);
1428 			*unmapped_end_index = start_index + batch->total_pfns;
1429 		}
1430 
1431 		/* unpin must follow unmap */
1432 		batch_unpin(batch, pages, 0,
1433 			    batch_last_index - start_index + 1);
1434 		start_index = batch_last_index + 1;
1435 
1436 		batch_clear_carry(batch,
1437 				  *unmapped_end_index - batch_last_index - 1);
1438 	}
1439 }
1440 
1441 static void __iopt_area_unfill_domain(struct iopt_area *area,
1442 				      struct iopt_pages *pages,
1443 				      struct iommu_domain *domain,
1444 				      unsigned long last_index)
1445 {
1446 	struct interval_tree_double_span_iter span;
1447 	unsigned long start_index = iopt_area_index(area);
1448 	unsigned long unmapped_end_index = start_index;
1449 	u64 backup[BATCH_BACKUP_SIZE];
1450 	struct pfn_batch batch;
1451 
1452 	lockdep_assert_held(&pages->mutex);
1453 
1454 	/*
1455 	 * For security we must not unpin something that is still DMA mapped,
1456 	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1457 	 * This creates a complexity where we need to skip over unpinning pages
1458 	 * held in the xarray, but continue to unmap from the domain.
1459 	 *
1460 	 * The domain unmap cannot stop in the middle of a contiguous range of
1461 	 * PFNs. To solve this problem the unpinning step will read ahead to the
1462 	 * end of any contiguous span, unmap that whole span, and then only
1463 	 * unpin the leading part that does not have any accesses. The residual
1464 	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1465 	 * batch as they are moved to the front of the PFN list and continue on
1466 	 * to the next iteration(s).
1467 	 */
1468 	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1469 	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1470 					   &pages->access_itree, start_index,
1471 					   last_index) {
1472 		if (span.is_used) {
1473 			batch_skip_carry(&batch,
1474 					 span.last_used - span.start_used + 1);
1475 			continue;
1476 		}
1477 		iopt_area_unpin_domain(&batch, area, pages, domain,
1478 				       span.start_hole, span.last_hole,
1479 				       &unmapped_end_index, last_index);
1480 	}
1481 	/*
1482 	 * If the range ends in a access then we do the residual unmap without
1483 	 * any unpins.
1484 	 */
1485 	if (unmapped_end_index != last_index + 1)
1486 		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1487 					     last_index);
1488 	WARN_ON(batch.total_pfns);
1489 	batch_destroy(&batch, backup);
1490 	update_unpinned(pages);
1491 }
1492 
1493 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1494 					    struct iopt_pages *pages,
1495 					    struct iommu_domain *domain,
1496 					    unsigned long end_index)
1497 {
1498 	if (end_index != iopt_area_index(area))
1499 		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1500 }
1501 
1502 /**
1503  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1504  * @area: The IOVA range to unmap
1505  * @domain: The domain to unmap
1506  *
1507  * The caller must know that unpinning is not required, usually because there
1508  * are other domains in the iopt.
1509  */
1510 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1511 {
1512 	iommu_unmap_nofail(domain, iopt_area_iova(area),
1513 			   iopt_area_length(area));
1514 }
1515 
1516 /**
1517  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1518  * @area: IOVA area to use
1519  * @pages: page supplier for the area (area->pages is NULL)
1520  * @domain: Domain to unmap from
1521  *
1522  * The domain should be removed from the domains_itree before calling. The
1523  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1524  * still accesses.
1525  */
1526 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1527 			     struct iommu_domain *domain)
1528 {
1529 	__iopt_area_unfill_domain(area, pages, domain,
1530 				  iopt_area_last_index(area));
1531 }
1532 
1533 /**
1534  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1535  * @area: IOVA area to use
1536  * @domain: Domain to load PFNs into
1537  *
1538  * Read the pfns from the area's underlying iopt_pages and map them into the
1539  * given domain. Called when attaching a new domain to an io_pagetable.
1540  */
1541 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1542 {
1543 	unsigned long done_end_index;
1544 	struct pfn_reader pfns;
1545 	int rc;
1546 
1547 	lockdep_assert_held(&area->pages->mutex);
1548 
1549 	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1550 			      iopt_area_last_index(area));
1551 	if (rc)
1552 		return rc;
1553 
1554 	while (!pfn_reader_done(&pfns)) {
1555 		done_end_index = pfns.batch_start_index;
1556 		rc = batch_to_domain(&pfns.batch, domain, area,
1557 				     pfns.batch_start_index);
1558 		if (rc)
1559 			goto out_unmap;
1560 		done_end_index = pfns.batch_end_index;
1561 
1562 		rc = pfn_reader_next(&pfns);
1563 		if (rc)
1564 			goto out_unmap;
1565 	}
1566 
1567 	rc = pfn_reader_update_pinned(&pfns);
1568 	if (rc)
1569 		goto out_unmap;
1570 	goto out_destroy;
1571 
1572 out_unmap:
1573 	pfn_reader_release_pins(&pfns);
1574 	iopt_area_unfill_partial_domain(area, area->pages, domain,
1575 					done_end_index);
1576 out_destroy:
1577 	pfn_reader_destroy(&pfns);
1578 	return rc;
1579 }
1580 
1581 /**
1582  * iopt_area_fill_domains() - Install PFNs into the area's domains
1583  * @area: The area to act on
1584  * @pages: The pages associated with the area (area->pages is NULL)
1585  *
1586  * Called during area creation. The area is freshly created and not inserted in
1587  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1588  * area's io_pagetable and the area is installed in the domains_itree.
1589  *
1590  * On failure all domains are left unchanged.
1591  */
1592 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1593 {
1594 	unsigned long done_first_end_index;
1595 	unsigned long done_all_end_index;
1596 	struct iommu_domain *domain;
1597 	unsigned long unmap_index;
1598 	struct pfn_reader pfns;
1599 	unsigned long index;
1600 	int rc;
1601 
1602 	lockdep_assert_held(&area->iopt->domains_rwsem);
1603 
1604 	if (xa_empty(&area->iopt->domains))
1605 		return 0;
1606 
1607 	mutex_lock(&pages->mutex);
1608 	rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1609 			      iopt_area_last_index(area));
1610 	if (rc)
1611 		goto out_unlock;
1612 
1613 	while (!pfn_reader_done(&pfns)) {
1614 		done_first_end_index = pfns.batch_end_index;
1615 		done_all_end_index = pfns.batch_start_index;
1616 		xa_for_each(&area->iopt->domains, index, domain) {
1617 			rc = batch_to_domain(&pfns.batch, domain, area,
1618 					     pfns.batch_start_index);
1619 			if (rc)
1620 				goto out_unmap;
1621 		}
1622 		done_all_end_index = done_first_end_index;
1623 
1624 		rc = pfn_reader_next(&pfns);
1625 		if (rc)
1626 			goto out_unmap;
1627 	}
1628 	rc = pfn_reader_update_pinned(&pfns);
1629 	if (rc)
1630 		goto out_unmap;
1631 
1632 	area->storage_domain = xa_load(&area->iopt->domains, 0);
1633 	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1634 	goto out_destroy;
1635 
1636 out_unmap:
1637 	pfn_reader_release_pins(&pfns);
1638 	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1639 		unsigned long end_index;
1640 
1641 		if (unmap_index < index)
1642 			end_index = done_first_end_index;
1643 		else
1644 			end_index = done_all_end_index;
1645 
1646 		/*
1647 		 * The area is not yet part of the domains_itree so we have to
1648 		 * manage the unpinning specially. The last domain does the
1649 		 * unpin, every other domain is just unmapped.
1650 		 */
1651 		if (unmap_index != area->iopt->next_domain_id - 1) {
1652 			if (end_index != iopt_area_index(area))
1653 				iopt_area_unmap_domain_range(
1654 					area, domain, iopt_area_index(area),
1655 					end_index - 1);
1656 		} else {
1657 			iopt_area_unfill_partial_domain(area, pages, domain,
1658 							end_index);
1659 		}
1660 	}
1661 out_destroy:
1662 	pfn_reader_destroy(&pfns);
1663 out_unlock:
1664 	mutex_unlock(&pages->mutex);
1665 	return rc;
1666 }
1667 
1668 /**
1669  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1670  * @area: The area to act on
1671  * @pages: The pages associated with the area (area->pages is NULL)
1672  *
1673  * Called during area destruction. This unmaps the iova's covered by all the
1674  * area's domains and releases the PFNs.
1675  */
1676 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1677 {
1678 	struct io_pagetable *iopt = area->iopt;
1679 	struct iommu_domain *domain;
1680 	unsigned long index;
1681 
1682 	lockdep_assert_held(&iopt->domains_rwsem);
1683 
1684 	mutex_lock(&pages->mutex);
1685 	if (!area->storage_domain)
1686 		goto out_unlock;
1687 
1688 	xa_for_each(&iopt->domains, index, domain)
1689 		if (domain != area->storage_domain)
1690 			iopt_area_unmap_domain_range(
1691 				area, domain, iopt_area_index(area),
1692 				iopt_area_last_index(area));
1693 
1694 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1695 		WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
1696 	interval_tree_remove(&area->pages_node, &pages->domains_itree);
1697 	iopt_area_unfill_domain(area, pages, area->storage_domain);
1698 	area->storage_domain = NULL;
1699 out_unlock:
1700 	mutex_unlock(&pages->mutex);
1701 }
1702 
1703 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1704 				    struct iopt_pages *pages,
1705 				    unsigned long start_index,
1706 				    unsigned long end_index)
1707 {
1708 	while (start_index <= end_index) {
1709 		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1710 					end_index);
1711 		batch_unpin(batch, pages, 0, batch->total_pfns);
1712 		start_index += batch->total_pfns;
1713 		batch_clear(batch);
1714 	}
1715 }
1716 
1717 /**
1718  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1719  * @pages: The pages to act on
1720  * @start_index: Starting PFN index
1721  * @last_index: Last PFN index
1722  *
1723  * Called when an iopt_pages_access is removed, removes pages from the itree.
1724  * The access should already be removed from the access_itree.
1725  */
1726 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1727 			      unsigned long start_index,
1728 			      unsigned long last_index)
1729 {
1730 	struct interval_tree_double_span_iter span;
1731 	u64 backup[BATCH_BACKUP_SIZE];
1732 	struct pfn_batch batch;
1733 	bool batch_inited = false;
1734 
1735 	lockdep_assert_held(&pages->mutex);
1736 
1737 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1738 					   &pages->domains_itree, start_index,
1739 					   last_index) {
1740 		if (!span.is_used) {
1741 			if (!batch_inited) {
1742 				batch_init_backup(&batch,
1743 						  last_index - start_index + 1,
1744 						  backup, sizeof(backup));
1745 				batch_inited = true;
1746 			}
1747 			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1748 						span.last_hole);
1749 		} else if (span.is_used == 2) {
1750 			/* Covered by a domain */
1751 			clear_xarray(&pages->pinned_pfns, span.start_used,
1752 				     span.last_used);
1753 		}
1754 		/* Otherwise covered by an existing access */
1755 	}
1756 	if (batch_inited)
1757 		batch_destroy(&batch, backup);
1758 	update_unpinned(pages);
1759 }
1760 
1761 /**
1762  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1763  * @pages: The pages to act on
1764  * @start_index: The first page index in the range
1765  * @last_index: The last page index in the range
1766  * @out_pages: The output array to return the pages
1767  *
1768  * This can be called if the caller is holding a refcount on an
1769  * iopt_pages_access that is known to have already been filled. It quickly reads
1770  * the pages directly from the xarray.
1771  *
1772  * This is part of the SW iommu interface to read pages for in-kernel use.
1773  */
1774 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1775 				 unsigned long start_index,
1776 				 unsigned long last_index,
1777 				 struct page **out_pages)
1778 {
1779 	XA_STATE(xas, &pages->pinned_pfns, start_index);
1780 	void *entry;
1781 
1782 	rcu_read_lock();
1783 	while (start_index <= last_index) {
1784 		entry = xas_next(&xas);
1785 		if (xas_retry(&xas, entry))
1786 			continue;
1787 		WARN_ON(!xa_is_value(entry));
1788 		*(out_pages++) = pfn_to_page(xa_to_value(entry));
1789 		start_index++;
1790 	}
1791 	rcu_read_unlock();
1792 }
1793 
1794 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1795 				       unsigned long start_index,
1796 				       unsigned long last_index,
1797 				       struct page **out_pages)
1798 {
1799 	while (start_index != last_index + 1) {
1800 		unsigned long domain_last;
1801 		struct iopt_area *area;
1802 
1803 		area = iopt_pages_find_domain_area(pages, start_index);
1804 		if (WARN_ON(!area))
1805 			return -EINVAL;
1806 
1807 		domain_last = min(iopt_area_last_index(area), last_index);
1808 		out_pages = raw_pages_from_domain(area->storage_domain, area,
1809 						  start_index, domain_last,
1810 						  out_pages);
1811 		start_index = domain_last + 1;
1812 	}
1813 	return 0;
1814 }
1815 
1816 static int iopt_pages_fill(struct iopt_pages *pages,
1817 			   struct pfn_reader_user *user,
1818 			   unsigned long start_index,
1819 			   unsigned long last_index,
1820 			   struct page **out_pages)
1821 {
1822 	unsigned long cur_index = start_index;
1823 	int rc;
1824 
1825 	while (cur_index != last_index + 1) {
1826 		user->upages = out_pages + (cur_index - start_index);
1827 		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1828 		if (rc)
1829 			goto out_unpin;
1830 		cur_index = user->upages_end;
1831 	}
1832 	return 0;
1833 
1834 out_unpin:
1835 	if (start_index != cur_index)
1836 		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1837 				     out_pages);
1838 	return rc;
1839 }
1840 
1841 /**
1842  * iopt_pages_fill_xarray() - Read PFNs
1843  * @pages: The pages to act on
1844  * @start_index: The first page index in the range
1845  * @last_index: The last page index in the range
1846  * @out_pages: The output array to return the pages, may be NULL
1847  *
1848  * This populates the xarray and returns the pages in out_pages. As the slow
1849  * path this is able to copy pages from other storage tiers into the xarray.
1850  *
1851  * On failure the xarray is left unchanged.
1852  *
1853  * This is part of the SW iommu interface to read pages for in-kernel use.
1854  */
1855 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1856 			   unsigned long last_index, struct page **out_pages)
1857 {
1858 	struct interval_tree_double_span_iter span;
1859 	unsigned long xa_end = start_index;
1860 	struct pfn_reader_user user;
1861 	int rc;
1862 
1863 	lockdep_assert_held(&pages->mutex);
1864 
1865 	pfn_reader_user_init(&user, pages);
1866 	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1867 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1868 					   &pages->domains_itree, start_index,
1869 					   last_index) {
1870 		struct page **cur_pages;
1871 
1872 		if (span.is_used == 1) {
1873 			cur_pages = out_pages + (span.start_used - start_index);
1874 			iopt_pages_fill_from_xarray(pages, span.start_used,
1875 						    span.last_used, cur_pages);
1876 			continue;
1877 		}
1878 
1879 		if (span.is_used == 2) {
1880 			cur_pages = out_pages + (span.start_used - start_index);
1881 			iopt_pages_fill_from_domain(pages, span.start_used,
1882 						    span.last_used, cur_pages);
1883 			rc = pages_to_xarray(&pages->pinned_pfns,
1884 					     span.start_used, span.last_used,
1885 					     cur_pages);
1886 			if (rc)
1887 				goto out_clean_xa;
1888 			xa_end = span.last_used + 1;
1889 			continue;
1890 		}
1891 
1892 		/* hole */
1893 		cur_pages = out_pages + (span.start_hole - start_index);
1894 		rc = iopt_pages_fill(pages, &user, span.start_hole,
1895 				     span.last_hole, cur_pages);
1896 		if (rc)
1897 			goto out_clean_xa;
1898 		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1899 				     span.last_hole, cur_pages);
1900 		if (rc) {
1901 			iopt_pages_err_unpin(pages, span.start_hole,
1902 					     span.last_hole, cur_pages);
1903 			goto out_clean_xa;
1904 		}
1905 		xa_end = span.last_hole + 1;
1906 	}
1907 	rc = pfn_reader_user_update_pinned(&user, pages);
1908 	if (rc)
1909 		goto out_clean_xa;
1910 	user.upages = NULL;
1911 	pfn_reader_user_destroy(&user, pages);
1912 	return 0;
1913 
1914 out_clean_xa:
1915 	if (start_index != xa_end)
1916 		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1917 	user.upages = NULL;
1918 	pfn_reader_user_destroy(&user, pages);
1919 	return rc;
1920 }
1921 
1922 /*
1923  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1924  * do every scenario and is fully consistent with what an iommu_domain would
1925  * see.
1926  */
1927 static int iopt_pages_rw_slow(struct iopt_pages *pages,
1928 			      unsigned long start_index,
1929 			      unsigned long last_index, unsigned long offset,
1930 			      void *data, unsigned long length,
1931 			      unsigned int flags)
1932 {
1933 	struct pfn_reader pfns;
1934 	int rc;
1935 
1936 	mutex_lock(&pages->mutex);
1937 
1938 	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1939 	if (rc)
1940 		goto out_unlock;
1941 
1942 	while (!pfn_reader_done(&pfns)) {
1943 		unsigned long done;
1944 
1945 		done = batch_rw(&pfns.batch, data, offset, length, flags);
1946 		data += done;
1947 		length -= done;
1948 		offset = 0;
1949 		pfn_reader_unpin(&pfns);
1950 
1951 		rc = pfn_reader_next(&pfns);
1952 		if (rc)
1953 			goto out_destroy;
1954 	}
1955 	if (WARN_ON(length != 0))
1956 		rc = -EINVAL;
1957 out_destroy:
1958 	pfn_reader_destroy(&pfns);
1959 out_unlock:
1960 	mutex_unlock(&pages->mutex);
1961 	return rc;
1962 }
1963 
1964 /*
1965  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1966  * memory allocations or interval tree searches.
1967  */
1968 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1969 			      unsigned long offset, void *data,
1970 			      unsigned long length, unsigned int flags)
1971 {
1972 	struct page *page = NULL;
1973 	int rc;
1974 
1975 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1976 	    WARN_ON(pages->type != IOPT_ADDRESS_USER))
1977 		return -EINVAL;
1978 
1979 	if (!mmget_not_zero(pages->source_mm))
1980 		return iopt_pages_rw_slow(pages, index, index, offset, data,
1981 					  length, flags);
1982 
1983 	if (iommufd_should_fail()) {
1984 		rc = -EINVAL;
1985 		goto out_mmput;
1986 	}
1987 
1988 	mmap_read_lock(pages->source_mm);
1989 	rc = pin_user_pages_remote(
1990 		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1991 		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1992 		NULL);
1993 	mmap_read_unlock(pages->source_mm);
1994 	if (rc != 1) {
1995 		if (WARN_ON(rc >= 0))
1996 			rc = -EINVAL;
1997 		goto out_mmput;
1998 	}
1999 	copy_data_page(page, data, offset, length, flags);
2000 	unpin_user_page(page);
2001 	rc = 0;
2002 
2003 out_mmput:
2004 	mmput(pages->source_mm);
2005 	return rc;
2006 }
2007 
2008 /**
2009  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
2010  * @pages: pages to act on
2011  * @start_byte: First byte of pages to copy to/from
2012  * @data: Kernel buffer to get/put the data
2013  * @length: Number of bytes to copy
2014  * @flags: IOMMUFD_ACCESS_RW_* flags
2015  *
2016  * This will find each page in the range, kmap it and then memcpy to/from
2017  * the given kernel buffer.
2018  */
2019 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
2020 			 void *data, unsigned long length, unsigned int flags)
2021 {
2022 	unsigned long start_index = start_byte / PAGE_SIZE;
2023 	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
2024 	bool change_mm = current->mm != pages->source_mm;
2025 	int rc = 0;
2026 
2027 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2028 	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
2029 		change_mm = true;
2030 
2031 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2032 		return -EPERM;
2033 
2034 	if (pages->type == IOPT_ADDRESS_FILE)
2035 		return iopt_pages_rw_slow(pages, start_index, last_index,
2036 					  start_byte % PAGE_SIZE, data, length,
2037 					  flags);
2038 
2039 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2040 	    WARN_ON(pages->type != IOPT_ADDRESS_USER))
2041 		return -EINVAL;
2042 
2043 	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
2044 		if (start_index == last_index)
2045 			return iopt_pages_rw_page(pages, start_index,
2046 						  start_byte % PAGE_SIZE, data,
2047 						  length, flags);
2048 		return iopt_pages_rw_slow(pages, start_index, last_index,
2049 					  start_byte % PAGE_SIZE, data, length,
2050 					  flags);
2051 	}
2052 
2053 	/*
2054 	 * Try to copy using copy_to_user(). We do this as a fast path and
2055 	 * ignore any pinning inconsistencies, unlike a real DMA path.
2056 	 */
2057 	if (change_mm) {
2058 		if (!mmget_not_zero(pages->source_mm))
2059 			return iopt_pages_rw_slow(pages, start_index,
2060 						  last_index,
2061 						  start_byte % PAGE_SIZE, data,
2062 						  length, flags);
2063 		kthread_use_mm(pages->source_mm);
2064 	}
2065 
2066 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
2067 		if (copy_to_user(pages->uptr + start_byte, data, length))
2068 			rc = -EFAULT;
2069 	} else {
2070 		if (copy_from_user(data, pages->uptr + start_byte, length))
2071 			rc = -EFAULT;
2072 	}
2073 
2074 	if (change_mm) {
2075 		kthread_unuse_mm(pages->source_mm);
2076 		mmput(pages->source_mm);
2077 	}
2078 
2079 	return rc;
2080 }
2081 
2082 static struct iopt_pages_access *
2083 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
2084 			    unsigned long last)
2085 {
2086 	struct interval_tree_node *node;
2087 
2088 	lockdep_assert_held(&pages->mutex);
2089 
2090 	/* There can be overlapping ranges in this interval tree */
2091 	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
2092 	     node; node = interval_tree_iter_next(node, index, last))
2093 		if (node->start == index && node->last == last)
2094 			return container_of(node, struct iopt_pages_access,
2095 					    node);
2096 	return NULL;
2097 }
2098 
2099 /**
2100  * iopt_area_add_access() - Record an in-knerel access for PFNs
2101  * @area: The source of PFNs
2102  * @start_index: First page index
2103  * @last_index: Inclusive last page index
2104  * @out_pages: Output list of struct page's representing the PFNs
2105  * @flags: IOMMUFD_ACCESS_RW_* flags
2106  * @lock_area: Fail userspace munmap on this area
2107  *
2108  * Record that an in-kernel access will be accessing the pages, ensure they are
2109  * pinned, and return the PFNs as a simple list of 'struct page *'.
2110  *
2111  * This should be undone through a matching call to iopt_area_remove_access()
2112  */
2113 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
2114 			 unsigned long last_index, struct page **out_pages,
2115 			 unsigned int flags, bool lock_area)
2116 {
2117 	struct iopt_pages *pages = area->pages;
2118 	struct iopt_pages_access *access;
2119 	int rc;
2120 
2121 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2122 		return -EPERM;
2123 
2124 	mutex_lock(&pages->mutex);
2125 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2126 	if (access) {
2127 		area->num_accesses++;
2128 		if (lock_area)
2129 			area->num_locks++;
2130 		access->users++;
2131 		iopt_pages_fill_from_xarray(pages, start_index, last_index,
2132 					    out_pages);
2133 		mutex_unlock(&pages->mutex);
2134 		return 0;
2135 	}
2136 
2137 	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
2138 	if (!access) {
2139 		rc = -ENOMEM;
2140 		goto err_unlock;
2141 	}
2142 
2143 	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
2144 	if (rc)
2145 		goto err_free;
2146 
2147 	access->node.start = start_index;
2148 	access->node.last = last_index;
2149 	access->users = 1;
2150 	area->num_accesses++;
2151 	if (lock_area)
2152 		area->num_locks++;
2153 	interval_tree_insert(&access->node, &pages->access_itree);
2154 	mutex_unlock(&pages->mutex);
2155 	return 0;
2156 
2157 err_free:
2158 	kfree(access);
2159 err_unlock:
2160 	mutex_unlock(&pages->mutex);
2161 	return rc;
2162 }
2163 
2164 /**
2165  * iopt_area_remove_access() - Release an in-kernel access for PFNs
2166  * @area: The source of PFNs
2167  * @start_index: First page index
2168  * @last_index: Inclusive last page index
2169  * @unlock_area: Must match the matching iopt_area_add_access()'s lock_area
2170  *
2171  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
2172  * must stop using the PFNs before calling this.
2173  */
2174 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
2175 			     unsigned long last_index, bool unlock_area)
2176 {
2177 	struct iopt_pages *pages = area->pages;
2178 	struct iopt_pages_access *access;
2179 
2180 	mutex_lock(&pages->mutex);
2181 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2182 	if (WARN_ON(!access))
2183 		goto out_unlock;
2184 
2185 	WARN_ON(area->num_accesses == 0 || access->users == 0);
2186 	if (unlock_area) {
2187 		WARN_ON(area->num_locks == 0);
2188 		area->num_locks--;
2189 	}
2190 	area->num_accesses--;
2191 	access->users--;
2192 	if (access->users)
2193 		goto out_unlock;
2194 
2195 	interval_tree_remove(&access->node, &pages->access_itree);
2196 	iopt_pages_unfill_xarray(pages, start_index, last_index);
2197 	kfree(access);
2198 out_unlock:
2199 	mutex_unlock(&pages->mutex);
2200 }
2201