xref: /linux/drivers/iommu/iommufd/pages.c (revision fbf5df34a4dbcd09d433dd4f0916bf9b2ddb16de)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/dma-buf.h>
49 #include <linux/dma-resv.h>
50 #include <linux/file.h>
51 #include <linux/highmem.h>
52 #include <linux/iommu.h>
53 #include <linux/iommufd.h>
54 #include <linux/kthread.h>
55 #include <linux/overflow.h>
56 #include <linux/slab.h>
57 #include <linux/sched/mm.h>
58 #include <linux/vfio_pci_core.h>
59 
60 #include "double_span.h"
61 #include "io_pagetable.h"
62 
63 #ifndef CONFIG_IOMMUFD_TEST
64 #define TEMP_MEMORY_LIMIT 65536
65 #else
66 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
67 #endif
68 #define BATCH_BACKUP_SIZE 32
69 
70 /*
71  * More memory makes pin_user_pages() and the batching more efficient, but as
72  * this is only a performance optimization don't try too hard to get it. A 64k
73  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
74  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
75  * stack memory as a backup contingency. If backup_len is given this cannot
76  * fail.
77  */
78 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
79 {
80 	void *res;
81 
82 	if (WARN_ON(*size == 0))
83 		return NULL;
84 
85 	if (*size < backup_len)
86 		return backup;
87 
88 	if (!backup && iommufd_should_fail())
89 		return NULL;
90 
91 	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
92 	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
93 	if (res)
94 		return res;
95 	*size = PAGE_SIZE;
96 	if (backup_len) {
97 		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
98 		if (res)
99 			return res;
100 		*size = backup_len;
101 		return backup;
102 	}
103 	return kmalloc(*size, GFP_KERNEL);
104 }
105 
106 void interval_tree_double_span_iter_update(
107 	struct interval_tree_double_span_iter *iter)
108 {
109 	unsigned long last_hole = ULONG_MAX;
110 	unsigned int i;
111 
112 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
113 		if (interval_tree_span_iter_done(&iter->spans[i])) {
114 			iter->is_used = -1;
115 			return;
116 		}
117 
118 		if (iter->spans[i].is_hole) {
119 			last_hole = min(last_hole, iter->spans[i].last_hole);
120 			continue;
121 		}
122 
123 		iter->is_used = i + 1;
124 		iter->start_used = iter->spans[i].start_used;
125 		iter->last_used = min(iter->spans[i].last_used, last_hole);
126 		return;
127 	}
128 
129 	iter->is_used = 0;
130 	iter->start_hole = iter->spans[0].start_hole;
131 	iter->last_hole =
132 		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
133 }
134 
135 void interval_tree_double_span_iter_first(
136 	struct interval_tree_double_span_iter *iter,
137 	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
138 	unsigned long first_index, unsigned long last_index)
139 {
140 	unsigned int i;
141 
142 	iter->itrees[0] = itree1;
143 	iter->itrees[1] = itree2;
144 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
145 		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
146 					      first_index, last_index);
147 	interval_tree_double_span_iter_update(iter);
148 }
149 
150 void interval_tree_double_span_iter_next(
151 	struct interval_tree_double_span_iter *iter)
152 {
153 	unsigned int i;
154 
155 	if (iter->is_used == -1 ||
156 	    iter->last_hole == iter->spans[0].last_index) {
157 		iter->is_used = -1;
158 		return;
159 	}
160 
161 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
162 		interval_tree_span_iter_advance(
163 			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
164 	interval_tree_double_span_iter_update(iter);
165 }
166 
167 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
168 {
169 	int rc;
170 
171 	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
172 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
173 		WARN_ON(rc || pages->npinned > pages->npages);
174 }
175 
176 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
177 {
178 	int rc;
179 
180 	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
181 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
182 		WARN_ON(rc || pages->npinned > pages->npages);
183 }
184 
185 static void iopt_pages_err_unpin(struct iopt_pages *pages,
186 				 unsigned long start_index,
187 				 unsigned long last_index,
188 				 struct page **page_list)
189 {
190 	unsigned long npages = last_index - start_index + 1;
191 
192 	unpin_user_pages(page_list, npages);
193 	iopt_pages_sub_npinned(pages, npages);
194 }
195 
196 /*
197  * index is the number of PAGE_SIZE units from the start of the area's
198  * iopt_pages. If the iova is sub page-size then the area has an iova that
199  * covers a portion of the first and last pages in the range.
200  */
201 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
202 					     unsigned long index)
203 {
204 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
205 		WARN_ON(index < iopt_area_index(area) ||
206 			index > iopt_area_last_index(area));
207 	index -= iopt_area_index(area);
208 	if (index == 0)
209 		return iopt_area_iova(area);
210 	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
211 }
212 
213 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
214 						  unsigned long index)
215 {
216 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
217 		WARN_ON(index < iopt_area_index(area) ||
218 			index > iopt_area_last_index(area));
219 	if (index == iopt_area_last_index(area))
220 		return iopt_area_last_iova(area);
221 	return iopt_area_iova(area) - area->page_offset +
222 	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
223 }
224 
225 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
226 			       size_t size)
227 {
228 	size_t ret;
229 
230 	ret = iommu_unmap(domain, iova, size);
231 	/*
232 	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
233 	 * something other than exactly as requested. This implies that the
234 	 * iommu driver may not fail unmap for reasons beyond bad agruments.
235 	 * Particularly, the iommu driver may not do a memory allocation on the
236 	 * unmap path.
237 	 */
238 	WARN_ON(ret != size);
239 }
240 
241 static void iopt_area_unmap_domain_range(struct iopt_area *area,
242 					 struct iommu_domain *domain,
243 					 unsigned long start_index,
244 					 unsigned long last_index)
245 {
246 	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
247 
248 	iommu_unmap_nofail(domain, start_iova,
249 			   iopt_area_index_to_iova_last(area, last_index) -
250 				   start_iova + 1);
251 }
252 
253 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
254 						     unsigned long index)
255 {
256 	struct interval_tree_node *node;
257 
258 	node = interval_tree_iter_first(&pages->domains_itree, index, index);
259 	if (!node)
260 		return NULL;
261 	return container_of(node, struct iopt_area, pages_node);
262 }
263 
264 enum batch_kind {
265 	BATCH_CPU_MEMORY = 0,
266 	BATCH_MMIO,
267 };
268 
269 /*
270  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
271  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
272  * place to another. Generally everything is made more efficient if operations
273  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
274  * better cache locality, etc
275  */
276 struct pfn_batch {
277 	unsigned long *pfns;
278 	u32 *npfns;
279 	unsigned int array_size;
280 	unsigned int end;
281 	unsigned int total_pfns;
282 	enum batch_kind kind;
283 };
284 enum { MAX_NPFNS = type_max(typeof(((struct pfn_batch *)0)->npfns[0])) };
285 
286 static void batch_clear(struct pfn_batch *batch)
287 {
288 	batch->total_pfns = 0;
289 	batch->end = 0;
290 	batch->pfns[0] = 0;
291 	batch->npfns[0] = 0;
292 	batch->kind = 0;
293 }
294 
295 /*
296  * Carry means we carry a portion of the final hugepage over to the front of the
297  * batch
298  */
299 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
300 {
301 	if (!keep_pfns)
302 		return batch_clear(batch);
303 
304 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
305 		WARN_ON(!batch->end ||
306 			batch->npfns[batch->end - 1] < keep_pfns);
307 
308 	batch->total_pfns = keep_pfns;
309 	batch->pfns[0] = batch->pfns[batch->end - 1] +
310 			 (batch->npfns[batch->end - 1] - keep_pfns);
311 	batch->npfns[0] = keep_pfns;
312 	batch->end = 1;
313 }
314 
315 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
316 {
317 	if (!batch->total_pfns)
318 		return;
319 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
320 		WARN_ON(batch->total_pfns != batch->npfns[0]);
321 	skip_pfns = min(batch->total_pfns, skip_pfns);
322 	batch->pfns[0] += skip_pfns;
323 	batch->npfns[0] -= skip_pfns;
324 	batch->total_pfns -= skip_pfns;
325 }
326 
327 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
328 			size_t backup_len)
329 {
330 	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
331 	size_t size = max_pages * elmsz;
332 
333 	batch->pfns = temp_kmalloc(&size, backup, backup_len);
334 	if (!batch->pfns)
335 		return -ENOMEM;
336 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
337 		return -EINVAL;
338 	batch->array_size = size / elmsz;
339 	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
340 	batch_clear(batch);
341 	return 0;
342 }
343 
344 static int batch_init(struct pfn_batch *batch, size_t max_pages)
345 {
346 	return __batch_init(batch, max_pages, NULL, 0);
347 }
348 
349 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
350 			      void *backup, size_t backup_len)
351 {
352 	__batch_init(batch, max_pages, backup, backup_len);
353 }
354 
355 static void batch_destroy(struct pfn_batch *batch, void *backup)
356 {
357 	if (batch->pfns != backup)
358 		kfree(batch->pfns);
359 }
360 
361 static bool batch_add_pfn_num(struct pfn_batch *batch, unsigned long pfn,
362 			      u32 nr, enum batch_kind kind)
363 {
364 	unsigned int end = batch->end;
365 
366 	if (batch->kind != kind) {
367 		/* One kind per batch */
368 		if (batch->end != 0)
369 			return false;
370 		batch->kind = kind;
371 	}
372 
373 	if (end && pfn == batch->pfns[end - 1] + batch->npfns[end - 1] &&
374 	    nr <= MAX_NPFNS - batch->npfns[end - 1]) {
375 		batch->npfns[end - 1] += nr;
376 	} else if (end < batch->array_size) {
377 		batch->pfns[end] = pfn;
378 		batch->npfns[end] = nr;
379 		batch->end++;
380 	} else {
381 		return false;
382 	}
383 
384 	batch->total_pfns += nr;
385 	return true;
386 }
387 
388 static void batch_remove_pfn_num(struct pfn_batch *batch, unsigned long nr)
389 {
390 	batch->npfns[batch->end - 1] -= nr;
391 	if (batch->npfns[batch->end - 1] == 0)
392 		batch->end--;
393 	batch->total_pfns -= nr;
394 }
395 
396 /* true if the pfn was added, false otherwise */
397 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
398 {
399 	return batch_add_pfn_num(batch, pfn, 1, BATCH_CPU_MEMORY);
400 }
401 
402 /*
403  * Fill the batch with pfns from the domain. When the batch is full, or it
404  * reaches last_index, the function will return. The caller should use
405  * batch->total_pfns to determine the starting point for the next iteration.
406  */
407 static void batch_from_domain(struct pfn_batch *batch,
408 			      struct iommu_domain *domain,
409 			      struct iopt_area *area, unsigned long start_index,
410 			      unsigned long last_index)
411 {
412 	unsigned int page_offset = 0;
413 	unsigned long iova;
414 	phys_addr_t phys;
415 
416 	iova = iopt_area_index_to_iova(area, start_index);
417 	if (start_index == iopt_area_index(area))
418 		page_offset = area->page_offset;
419 	while (start_index <= last_index) {
420 		/*
421 		 * This is pretty slow, it would be nice to get the page size
422 		 * back from the driver, or have the driver directly fill the
423 		 * batch.
424 		 */
425 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
426 		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
427 			return;
428 		iova += PAGE_SIZE - page_offset;
429 		page_offset = 0;
430 		start_index++;
431 	}
432 }
433 
434 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
435 					   struct iopt_area *area,
436 					   unsigned long start_index,
437 					   unsigned long last_index,
438 					   struct page **out_pages)
439 {
440 	unsigned int page_offset = 0;
441 	unsigned long iova;
442 	phys_addr_t phys;
443 
444 	iova = iopt_area_index_to_iova(area, start_index);
445 	if (start_index == iopt_area_index(area))
446 		page_offset = area->page_offset;
447 	while (start_index <= last_index) {
448 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
449 		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
450 		iova += PAGE_SIZE - page_offset;
451 		page_offset = 0;
452 		start_index++;
453 	}
454 	return out_pages;
455 }
456 
457 /* Continues reading a domain until we reach a discontinuity in the pfns. */
458 static void batch_from_domain_continue(struct pfn_batch *batch,
459 				       struct iommu_domain *domain,
460 				       struct iopt_area *area,
461 				       unsigned long start_index,
462 				       unsigned long last_index)
463 {
464 	unsigned int array_size = batch->array_size;
465 
466 	batch->array_size = batch->end;
467 	batch_from_domain(batch, domain, area, start_index, last_index);
468 	batch->array_size = array_size;
469 }
470 
471 /*
472  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
473  * mode permits splitting a mapped area up, and then one of the splits is
474  * unmapped. Doing this normally would cause us to violate our invariant of
475  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
476  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
477  * PAGE_SIZE units, not larger or smaller.
478  */
479 static int batch_iommu_map_small(struct iommu_domain *domain,
480 				 unsigned long iova, phys_addr_t paddr,
481 				 size_t size, int prot)
482 {
483 	unsigned long start_iova = iova;
484 	int rc;
485 
486 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
487 		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
488 			size % PAGE_SIZE);
489 
490 	while (size) {
491 		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
492 			       GFP_KERNEL_ACCOUNT);
493 		if (rc)
494 			goto err_unmap;
495 		iova += PAGE_SIZE;
496 		paddr += PAGE_SIZE;
497 		size -= PAGE_SIZE;
498 	}
499 	return 0;
500 
501 err_unmap:
502 	if (start_iova != iova)
503 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
504 	return rc;
505 }
506 
507 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
508 			   struct iopt_area *area, unsigned long start_index)
509 {
510 	bool disable_large_pages = area->iopt->disable_large_pages;
511 	unsigned long last_iova = iopt_area_last_iova(area);
512 	int iommu_prot = area->iommu_prot;
513 	unsigned int page_offset = 0;
514 	unsigned long start_iova;
515 	unsigned long next_iova;
516 	unsigned int cur = 0;
517 	unsigned long iova;
518 	int rc;
519 
520 	if (batch->kind == BATCH_MMIO) {
521 		iommu_prot &= ~IOMMU_CACHE;
522 		iommu_prot |= IOMMU_MMIO;
523 	}
524 
525 	/* The first index might be a partial page */
526 	if (start_index == iopt_area_index(area))
527 		page_offset = area->page_offset;
528 	next_iova = iova = start_iova =
529 		iopt_area_index_to_iova(area, start_index);
530 	while (cur < batch->end) {
531 		next_iova = min(last_iova + 1,
532 				next_iova + batch->npfns[cur] * PAGE_SIZE -
533 					page_offset);
534 		if (disable_large_pages)
535 			rc = batch_iommu_map_small(
536 				domain, iova,
537 				PFN_PHYS(batch->pfns[cur]) + page_offset,
538 				next_iova - iova, iommu_prot);
539 		else
540 			rc = iommu_map(domain, iova,
541 				       PFN_PHYS(batch->pfns[cur]) + page_offset,
542 				       next_iova - iova, iommu_prot,
543 				       GFP_KERNEL_ACCOUNT);
544 		if (rc)
545 			goto err_unmap;
546 		iova = next_iova;
547 		page_offset = 0;
548 		cur++;
549 	}
550 	return 0;
551 err_unmap:
552 	if (start_iova != iova)
553 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
554 	return rc;
555 }
556 
557 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
558 			      unsigned long start_index,
559 			      unsigned long last_index)
560 {
561 	XA_STATE(xas, xa, start_index);
562 	void *entry;
563 
564 	rcu_read_lock();
565 	while (true) {
566 		entry = xas_next(&xas);
567 		if (xas_retry(&xas, entry))
568 			continue;
569 		WARN_ON(!xa_is_value(entry));
570 		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
571 		    start_index == last_index)
572 			break;
573 		start_index++;
574 	}
575 	rcu_read_unlock();
576 }
577 
578 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
579 				    unsigned long start_index,
580 				    unsigned long last_index)
581 {
582 	XA_STATE(xas, xa, start_index);
583 	void *entry;
584 
585 	xas_lock(&xas);
586 	while (true) {
587 		entry = xas_next(&xas);
588 		if (xas_retry(&xas, entry))
589 			continue;
590 		WARN_ON(!xa_is_value(entry));
591 		if (!batch_add_pfn(batch, xa_to_value(entry)))
592 			break;
593 		xas_store(&xas, NULL);
594 		if (start_index == last_index)
595 			break;
596 		start_index++;
597 	}
598 	xas_unlock(&xas);
599 }
600 
601 static void clear_xarray(struct xarray *xa, unsigned long start_index,
602 			 unsigned long last_index)
603 {
604 	XA_STATE(xas, xa, start_index);
605 	void *entry;
606 
607 	xas_lock(&xas);
608 	xas_for_each(&xas, entry, last_index)
609 		xas_store(&xas, NULL);
610 	xas_unlock(&xas);
611 }
612 
613 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
614 			   unsigned long last_index, struct page **pages)
615 {
616 	struct page **end_pages = pages + (last_index - start_index) + 1;
617 	struct page **half_pages = pages + (end_pages - pages) / 2;
618 	XA_STATE(xas, xa, start_index);
619 
620 	do {
621 		void *old;
622 
623 		xas_lock(&xas);
624 		while (pages != end_pages) {
625 			/* xarray does not participate in fault injection */
626 			if (pages == half_pages && iommufd_should_fail()) {
627 				xas_set_err(&xas, -EINVAL);
628 				xas_unlock(&xas);
629 				/* aka xas_destroy() */
630 				xas_nomem(&xas, GFP_KERNEL);
631 				goto err_clear;
632 			}
633 
634 			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
635 			if (xas_error(&xas))
636 				break;
637 			WARN_ON(old);
638 			pages++;
639 			xas_next(&xas);
640 		}
641 		xas_unlock(&xas);
642 	} while (xas_nomem(&xas, GFP_KERNEL));
643 
644 err_clear:
645 	if (xas_error(&xas)) {
646 		if (xas.xa_index != start_index)
647 			clear_xarray(xa, start_index, xas.xa_index - 1);
648 		return xas_error(&xas);
649 	}
650 	return 0;
651 }
652 
653 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
654 			     size_t npages)
655 {
656 	struct page **end = pages + npages;
657 
658 	for (; pages != end; pages++)
659 		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
660 			break;
661 }
662 
663 static int batch_from_folios(struct pfn_batch *batch, struct folio ***folios_p,
664 			     unsigned long *offset_p, unsigned long npages)
665 {
666 	int rc = 0;
667 	struct folio **folios = *folios_p;
668 	unsigned long offset = *offset_p;
669 
670 	while (npages) {
671 		struct folio *folio = *folios;
672 		unsigned long nr = folio_nr_pages(folio) - offset;
673 		unsigned long pfn = page_to_pfn(folio_page(folio, offset));
674 
675 		nr = min(nr, npages);
676 		npages -= nr;
677 
678 		if (!batch_add_pfn_num(batch, pfn, nr, BATCH_CPU_MEMORY))
679 			break;
680 		if (nr > 1) {
681 			rc = folio_add_pins(folio, nr - 1);
682 			if (rc) {
683 				batch_remove_pfn_num(batch, nr);
684 				goto out;
685 			}
686 		}
687 
688 		folios++;
689 		offset = 0;
690 	}
691 
692 out:
693 	*folios_p = folios;
694 	*offset_p = offset;
695 	return rc;
696 }
697 
698 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
699 			unsigned int first_page_off, size_t npages)
700 {
701 	unsigned int cur = 0;
702 
703 	while (first_page_off) {
704 		if (batch->npfns[cur] > first_page_off)
705 			break;
706 		first_page_off -= batch->npfns[cur];
707 		cur++;
708 	}
709 
710 	while (npages) {
711 		size_t to_unpin = min_t(size_t, npages,
712 					batch->npfns[cur] - first_page_off);
713 
714 		unpin_user_page_range_dirty_lock(
715 			pfn_to_page(batch->pfns[cur] + first_page_off),
716 			to_unpin, pages->writable);
717 		iopt_pages_sub_npinned(pages, to_unpin);
718 		cur++;
719 		first_page_off = 0;
720 		npages -= to_unpin;
721 	}
722 }
723 
724 static void copy_data_page(struct page *page, void *data, unsigned long offset,
725 			   size_t length, unsigned int flags)
726 {
727 	void *mem;
728 
729 	mem = kmap_local_page(page);
730 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
731 		memcpy(mem + offset, data, length);
732 		set_page_dirty_lock(page);
733 	} else {
734 		memcpy(data, mem + offset, length);
735 	}
736 	kunmap_local(mem);
737 }
738 
739 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
740 			      unsigned long offset, unsigned long length,
741 			      unsigned int flags)
742 {
743 	unsigned long copied = 0;
744 	unsigned int npage = 0;
745 	unsigned int cur = 0;
746 
747 	while (cur < batch->end) {
748 		unsigned long bytes = min(length, PAGE_SIZE - offset);
749 
750 		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
751 			       offset, bytes, flags);
752 		offset = 0;
753 		length -= bytes;
754 		data += bytes;
755 		copied += bytes;
756 		npage++;
757 		if (npage == batch->npfns[cur]) {
758 			npage = 0;
759 			cur++;
760 		}
761 		if (!length)
762 			break;
763 	}
764 	return copied;
765 }
766 
767 /* pfn_reader_user is just the pin_user_pages() path */
768 struct pfn_reader_user {
769 	struct page **upages;
770 	size_t upages_len;
771 	unsigned long upages_start;
772 	unsigned long upages_end;
773 	unsigned int gup_flags;
774 	/*
775 	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
776 	 * neither
777 	 */
778 	int locked;
779 
780 	/* The following are only valid if file != NULL. */
781 	struct file *file;
782 	struct folio **ufolios;
783 	size_t ufolios_len;
784 	unsigned long ufolios_offset;
785 	struct folio **ufolios_next;
786 };
787 
788 static void pfn_reader_user_init(struct pfn_reader_user *user,
789 				 struct iopt_pages *pages)
790 {
791 	user->upages = NULL;
792 	user->upages_len = 0;
793 	user->upages_start = 0;
794 	user->upages_end = 0;
795 	user->locked = -1;
796 	user->gup_flags = FOLL_LONGTERM;
797 	if (pages->writable)
798 		user->gup_flags |= FOLL_WRITE;
799 
800 	user->file = (pages->type == IOPT_ADDRESS_FILE) ? pages->file : NULL;
801 	user->ufolios = NULL;
802 	user->ufolios_len = 0;
803 	user->ufolios_next = NULL;
804 	user->ufolios_offset = 0;
805 }
806 
807 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
808 				    struct iopt_pages *pages)
809 {
810 	if (user->locked != -1) {
811 		if (user->locked)
812 			mmap_read_unlock(pages->source_mm);
813 		if (!user->file && pages->source_mm != current->mm)
814 			mmput(pages->source_mm);
815 		user->locked = -1;
816 	}
817 
818 	kfree(user->upages);
819 	user->upages = NULL;
820 	kfree(user->ufolios);
821 	user->ufolios = NULL;
822 }
823 
824 static long pin_memfd_pages(struct pfn_reader_user *user, unsigned long start,
825 			    unsigned long npages)
826 {
827 	unsigned long i;
828 	unsigned long offset;
829 	unsigned long npages_out = 0;
830 	struct page **upages = user->upages;
831 	unsigned long end = start + (npages << PAGE_SHIFT) - 1;
832 	long nfolios = user->ufolios_len / sizeof(*user->ufolios);
833 
834 	/*
835 	 * todo: memfd_pin_folios should return the last pinned offset so
836 	 * we can compute npages pinned, and avoid looping over folios here
837 	 * if upages == NULL.
838 	 */
839 	nfolios = memfd_pin_folios(user->file, start, end, user->ufolios,
840 				   nfolios, &offset);
841 	if (nfolios <= 0)
842 		return nfolios;
843 
844 	offset >>= PAGE_SHIFT;
845 	user->ufolios_next = user->ufolios;
846 	user->ufolios_offset = offset;
847 
848 	for (i = 0; i < nfolios; i++) {
849 		struct folio *folio = user->ufolios[i];
850 		unsigned long nr = folio_nr_pages(folio);
851 		unsigned long npin = min(nr - offset, npages);
852 
853 		npages -= npin;
854 		npages_out += npin;
855 
856 		if (upages) {
857 			if (npin == 1) {
858 				*upages++ = folio_page(folio, offset);
859 			} else {
860 				int rc = folio_add_pins(folio, npin - 1);
861 
862 				if (rc)
863 					return rc;
864 
865 				while (npin--)
866 					*upages++ = folio_page(folio, offset++);
867 			}
868 		}
869 
870 		offset = 0;
871 	}
872 
873 	return npages_out;
874 }
875 
876 static int pfn_reader_user_pin(struct pfn_reader_user *user,
877 			       struct iopt_pages *pages,
878 			       unsigned long start_index,
879 			       unsigned long last_index)
880 {
881 	bool remote_mm = pages->source_mm != current->mm;
882 	unsigned long npages = last_index - start_index + 1;
883 	unsigned long start;
884 	unsigned long unum;
885 	uintptr_t uptr;
886 	long rc;
887 
888 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
889 	    WARN_ON(last_index < start_index))
890 		return -EINVAL;
891 
892 	if (!user->file && !user->upages) {
893 		/* All undone in pfn_reader_destroy() */
894 		user->upages_len = npages * sizeof(*user->upages);
895 		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
896 		if (!user->upages)
897 			return -ENOMEM;
898 	}
899 
900 	if (user->file && !user->ufolios) {
901 		user->ufolios_len = npages * sizeof(*user->ufolios);
902 		user->ufolios = temp_kmalloc(&user->ufolios_len, NULL, 0);
903 		if (!user->ufolios)
904 			return -ENOMEM;
905 	}
906 
907 	if (user->locked == -1) {
908 		/*
909 		 * The majority of usages will run the map task within the mm
910 		 * providing the pages, so we can optimize into
911 		 * get_user_pages_fast()
912 		 */
913 		if (!user->file && remote_mm) {
914 			if (!mmget_not_zero(pages->source_mm))
915 				return -EFAULT;
916 		}
917 		user->locked = 0;
918 	}
919 
920 	unum = user->file ? user->ufolios_len / sizeof(*user->ufolios) :
921 			    user->upages_len / sizeof(*user->upages);
922 	npages = min_t(unsigned long, npages, unum);
923 
924 	if (iommufd_should_fail())
925 		return -EFAULT;
926 
927 	if (user->file) {
928 		start = pages->start + (start_index * PAGE_SIZE);
929 		rc = pin_memfd_pages(user, start, npages);
930 	} else if (!remote_mm) {
931 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
932 		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
933 					 user->upages);
934 	} else {
935 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
936 		if (!user->locked) {
937 			mmap_read_lock(pages->source_mm);
938 			user->locked = 1;
939 		}
940 		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
941 					   user->gup_flags, user->upages,
942 					   &user->locked);
943 	}
944 	if (rc <= 0) {
945 		if (WARN_ON(!rc))
946 			return -EFAULT;
947 		return rc;
948 	}
949 	iopt_pages_add_npinned(pages, rc);
950 	user->upages_start = start_index;
951 	user->upages_end = start_index + rc;
952 	return 0;
953 }
954 
955 /* This is the "modern" and faster accounting method used by io_uring */
956 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
957 {
958 	unsigned long lock_limit;
959 	unsigned long cur_pages;
960 	unsigned long new_pages;
961 
962 	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
963 		     PAGE_SHIFT;
964 
965 	cur_pages = atomic_long_read(&pages->source_user->locked_vm);
966 	do {
967 		new_pages = cur_pages + npages;
968 		if (new_pages > lock_limit)
969 			return -ENOMEM;
970 	} while (!atomic_long_try_cmpxchg(&pages->source_user->locked_vm,
971 					  &cur_pages, new_pages));
972 	return 0;
973 }
974 
975 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
976 {
977 	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
978 		return;
979 	atomic_long_sub(npages, &pages->source_user->locked_vm);
980 }
981 
982 /* This is the accounting method used for compatibility with VFIO */
983 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
984 			       bool inc, struct pfn_reader_user *user)
985 {
986 	bool do_put = false;
987 	int rc;
988 
989 	if (user && user->locked) {
990 		mmap_read_unlock(pages->source_mm);
991 		user->locked = 0;
992 		/* If we had the lock then we also have a get */
993 
994 	} else if ((!user || (!user->upages && !user->ufolios)) &&
995 		   pages->source_mm != current->mm) {
996 		if (!mmget_not_zero(pages->source_mm))
997 			return -EINVAL;
998 		do_put = true;
999 	}
1000 
1001 	mmap_write_lock(pages->source_mm);
1002 	rc = __account_locked_vm(pages->source_mm, npages, inc,
1003 				 pages->source_task, false);
1004 	mmap_write_unlock(pages->source_mm);
1005 
1006 	if (do_put)
1007 		mmput(pages->source_mm);
1008 	return rc;
1009 }
1010 
1011 int iopt_pages_update_pinned(struct iopt_pages *pages, unsigned long npages,
1012 			     bool inc, struct pfn_reader_user *user)
1013 {
1014 	int rc = 0;
1015 
1016 	switch (pages->account_mode) {
1017 	case IOPT_PAGES_ACCOUNT_NONE:
1018 		break;
1019 	case IOPT_PAGES_ACCOUNT_USER:
1020 		if (inc)
1021 			rc = incr_user_locked_vm(pages, npages);
1022 		else
1023 			decr_user_locked_vm(pages, npages);
1024 		break;
1025 	case IOPT_PAGES_ACCOUNT_MM:
1026 		rc = update_mm_locked_vm(pages, npages, inc, user);
1027 		break;
1028 	}
1029 	if (rc)
1030 		return rc;
1031 
1032 	pages->last_npinned = pages->npinned;
1033 	if (inc)
1034 		atomic64_add(npages, &pages->source_mm->pinned_vm);
1035 	else
1036 		atomic64_sub(npages, &pages->source_mm->pinned_vm);
1037 	return 0;
1038 }
1039 
1040 static void update_unpinned(struct iopt_pages *pages)
1041 {
1042 	if (WARN_ON(pages->npinned > pages->last_npinned))
1043 		return;
1044 	if (pages->npinned == pages->last_npinned)
1045 		return;
1046 	iopt_pages_update_pinned(pages, pages->last_npinned - pages->npinned,
1047 				 false, NULL);
1048 }
1049 
1050 /*
1051  * Changes in the number of pages pinned is done after the pages have been read
1052  * and processed. If the user lacked the limit then the error unwind will unpin
1053  * everything that was just pinned. This is because it is expensive to calculate
1054  * how many pages we have already pinned within a range to generate an accurate
1055  * prediction in advance of doing the work to actually pin them.
1056  */
1057 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
1058 					 struct iopt_pages *pages)
1059 {
1060 	unsigned long npages;
1061 	bool inc;
1062 
1063 	lockdep_assert_held(&pages->mutex);
1064 
1065 	if (pages->npinned == pages->last_npinned)
1066 		return 0;
1067 
1068 	if (pages->npinned < pages->last_npinned) {
1069 		npages = pages->last_npinned - pages->npinned;
1070 		inc = false;
1071 	} else {
1072 		if (iommufd_should_fail())
1073 			return -ENOMEM;
1074 		npages = pages->npinned - pages->last_npinned;
1075 		inc = true;
1076 	}
1077 	return iopt_pages_update_pinned(pages, npages, inc, user);
1078 }
1079 
1080 struct pfn_reader_dmabuf {
1081 	struct phys_vec phys;
1082 	unsigned long start_offset;
1083 };
1084 
1085 static int pfn_reader_dmabuf_init(struct pfn_reader_dmabuf *dmabuf,
1086 				  struct iopt_pages *pages)
1087 {
1088 	/* Callers must not get here if the dmabuf was already revoked */
1089 	if (WARN_ON(iopt_dmabuf_revoked(pages)))
1090 		return -EINVAL;
1091 
1092 	dmabuf->phys = pages->dmabuf.phys;
1093 	dmabuf->start_offset = pages->dmabuf.start;
1094 	return 0;
1095 }
1096 
1097 static int pfn_reader_fill_dmabuf(struct pfn_reader_dmabuf *dmabuf,
1098 				  struct pfn_batch *batch,
1099 				  unsigned long start_index,
1100 				  unsigned long last_index)
1101 {
1102 	unsigned long start = dmabuf->start_offset + start_index * PAGE_SIZE;
1103 
1104 	/*
1105 	 * start/last_index and start are all PAGE_SIZE aligned, the batch is
1106 	 * always filled using page size aligned PFNs just like the other types.
1107 	 * If the dmabuf has been sliced on a sub page offset then the common
1108 	 * batch to domain code will adjust it before mapping to the domain.
1109 	 */
1110 	batch_add_pfn_num(batch, PHYS_PFN(dmabuf->phys.paddr + start),
1111 			  last_index - start_index + 1, BATCH_MMIO);
1112 	return 0;
1113 }
1114 
1115 /*
1116  * PFNs are stored in three places, in order of preference:
1117  * - The iopt_pages xarray. This is only populated if there is a
1118  *   iopt_pages_access
1119  * - The iommu_domain under an area
1120  * - The original PFN source, ie pages->source_mm
1121  *
1122  * This iterator reads the pfns optimizing to load according to the
1123  * above order.
1124  */
1125 struct pfn_reader {
1126 	struct iopt_pages *pages;
1127 	struct interval_tree_double_span_iter span;
1128 	struct pfn_batch batch;
1129 	unsigned long batch_start_index;
1130 	unsigned long batch_end_index;
1131 	unsigned long last_index;
1132 
1133 	union {
1134 		struct pfn_reader_user user;
1135 		struct pfn_reader_dmabuf dmabuf;
1136 	};
1137 };
1138 
1139 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
1140 {
1141 	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
1142 }
1143 
1144 /*
1145  * The batch can contain a mixture of pages that are still in use and pages that
1146  * need to be unpinned. Unpin only pages that are not held anywhere else.
1147  */
1148 static void pfn_reader_unpin(struct pfn_reader *pfns)
1149 {
1150 	unsigned long last = pfns->batch_end_index - 1;
1151 	unsigned long start = pfns->batch_start_index;
1152 	struct interval_tree_double_span_iter span;
1153 	struct iopt_pages *pages = pfns->pages;
1154 
1155 	lockdep_assert_held(&pages->mutex);
1156 
1157 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1158 					   &pages->domains_itree, start, last) {
1159 		if (span.is_used)
1160 			continue;
1161 
1162 		batch_unpin(&pfns->batch, pages, span.start_hole - start,
1163 			    span.last_hole - span.start_hole + 1);
1164 	}
1165 }
1166 
1167 /* Process a single span to load it from the proper storage */
1168 static int pfn_reader_fill_span(struct pfn_reader *pfns)
1169 {
1170 	struct interval_tree_double_span_iter *span = &pfns->span;
1171 	unsigned long start_index = pfns->batch_end_index;
1172 	struct pfn_reader_user *user;
1173 	unsigned long npages;
1174 	struct iopt_area *area;
1175 	int rc;
1176 
1177 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1178 	    WARN_ON(span->last_used < start_index))
1179 		return -EINVAL;
1180 
1181 	if (span->is_used == 1) {
1182 		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
1183 				  start_index, span->last_used);
1184 		return 0;
1185 	}
1186 
1187 	if (span->is_used == 2) {
1188 		/*
1189 		 * Pull as many pages from the first domain we find in the
1190 		 * target span. If it is too small then we will be called again
1191 		 * and we'll find another area.
1192 		 */
1193 		area = iopt_pages_find_domain_area(pfns->pages, start_index);
1194 		if (WARN_ON(!area))
1195 			return -EINVAL;
1196 
1197 		/* The storage_domain cannot change without the pages mutex */
1198 		batch_from_domain(
1199 			&pfns->batch, area->storage_domain, area, start_index,
1200 			min(iopt_area_last_index(area), span->last_used));
1201 		return 0;
1202 	}
1203 
1204 	if (iopt_is_dmabuf(pfns->pages))
1205 		return pfn_reader_fill_dmabuf(&pfns->dmabuf, &pfns->batch,
1206 					      start_index, span->last_hole);
1207 
1208 	user = &pfns->user;
1209 	if (start_index >= user->upages_end) {
1210 		rc = pfn_reader_user_pin(user, pfns->pages, start_index,
1211 					 span->last_hole);
1212 		if (rc)
1213 			return rc;
1214 	}
1215 
1216 	npages = user->upages_end - start_index;
1217 	start_index -= user->upages_start;
1218 	rc = 0;
1219 
1220 	if (!user->file)
1221 		batch_from_pages(&pfns->batch, user->upages + start_index,
1222 				 npages);
1223 	else
1224 		rc = batch_from_folios(&pfns->batch, &user->ufolios_next,
1225 				       &user->ufolios_offset, npages);
1226 	return rc;
1227 }
1228 
1229 static bool pfn_reader_done(struct pfn_reader *pfns)
1230 {
1231 	return pfns->batch_start_index == pfns->last_index + 1;
1232 }
1233 
1234 static int pfn_reader_next(struct pfn_reader *pfns)
1235 {
1236 	int rc;
1237 
1238 	batch_clear(&pfns->batch);
1239 	pfns->batch_start_index = pfns->batch_end_index;
1240 
1241 	while (pfns->batch_end_index != pfns->last_index + 1) {
1242 		unsigned int npfns = pfns->batch.total_pfns;
1243 
1244 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1245 		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1246 			return -EINVAL;
1247 
1248 		rc = pfn_reader_fill_span(pfns);
1249 		if (rc)
1250 			return rc;
1251 
1252 		if (WARN_ON(!pfns->batch.total_pfns))
1253 			return -EINVAL;
1254 
1255 		pfns->batch_end_index =
1256 			pfns->batch_start_index + pfns->batch.total_pfns;
1257 		if (pfns->batch_end_index == pfns->span.last_used + 1)
1258 			interval_tree_double_span_iter_next(&pfns->span);
1259 
1260 		/* Batch is full */
1261 		if (npfns == pfns->batch.total_pfns)
1262 			return 0;
1263 	}
1264 	return 0;
1265 }
1266 
1267 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1268 			   unsigned long start_index, unsigned long last_index)
1269 {
1270 	int rc;
1271 
1272 	lockdep_assert_held(&pages->mutex);
1273 
1274 	pfns->pages = pages;
1275 	pfns->batch_start_index = start_index;
1276 	pfns->batch_end_index = start_index;
1277 	pfns->last_index = last_index;
1278 	if (iopt_is_dmabuf(pages))
1279 		pfn_reader_dmabuf_init(&pfns->dmabuf, pages);
1280 	else
1281 		pfn_reader_user_init(&pfns->user, pages);
1282 	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1283 	if (rc)
1284 		return rc;
1285 	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1286 					     &pages->domains_itree, start_index,
1287 					     last_index);
1288 	return 0;
1289 }
1290 
1291 /*
1292  * There are many assertions regarding the state of pages->npinned vs
1293  * pages->last_pinned, for instance something like unmapping a domain must only
1294  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1295  * the pins are updated. This is fine for success flows, but error flows
1296  * sometimes need to release the pins held inside the pfn_reader before going on
1297  * to complete unmapping and releasing pins held in domains.
1298  */
1299 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1300 {
1301 	struct iopt_pages *pages = pfns->pages;
1302 	struct pfn_reader_user *user;
1303 
1304 	if (iopt_is_dmabuf(pages))
1305 		return;
1306 
1307 	user = &pfns->user;
1308 	if (user->upages_end > pfns->batch_end_index) {
1309 		/* Any pages not transferred to the batch are just unpinned */
1310 
1311 		unsigned long npages = user->upages_end - pfns->batch_end_index;
1312 		unsigned long start_index = pfns->batch_end_index -
1313 					    user->upages_start;
1314 
1315 		if (!user->file) {
1316 			unpin_user_pages(user->upages + start_index, npages);
1317 		} else {
1318 			long n = user->ufolios_len / sizeof(*user->ufolios);
1319 
1320 			unpin_folios(user->ufolios_next,
1321 				     user->ufolios + n - user->ufolios_next);
1322 		}
1323 		iopt_pages_sub_npinned(pages, npages);
1324 		user->upages_end = pfns->batch_end_index;
1325 	}
1326 	if (pfns->batch_start_index != pfns->batch_end_index) {
1327 		pfn_reader_unpin(pfns);
1328 		pfns->batch_start_index = pfns->batch_end_index;
1329 	}
1330 }
1331 
1332 static void pfn_reader_destroy(struct pfn_reader *pfns)
1333 {
1334 	struct iopt_pages *pages = pfns->pages;
1335 
1336 	pfn_reader_release_pins(pfns);
1337 	if (!iopt_is_dmabuf(pfns->pages))
1338 		pfn_reader_user_destroy(&pfns->user, pfns->pages);
1339 	batch_destroy(&pfns->batch, NULL);
1340 	WARN_ON(pages->last_npinned != pages->npinned);
1341 }
1342 
1343 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1344 			    unsigned long start_index, unsigned long last_index)
1345 {
1346 	int rc;
1347 
1348 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1349 	    WARN_ON(last_index < start_index))
1350 		return -EINVAL;
1351 
1352 	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1353 	if (rc)
1354 		return rc;
1355 	rc = pfn_reader_next(pfns);
1356 	if (rc) {
1357 		pfn_reader_destroy(pfns);
1358 		return rc;
1359 	}
1360 	return 0;
1361 }
1362 
1363 static struct iopt_pages *iopt_alloc_pages(unsigned long start_byte,
1364 					   unsigned long length, bool writable)
1365 {
1366 	struct iopt_pages *pages;
1367 
1368 	/*
1369 	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1370 	 * below from overflow
1371 	 */
1372 	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1373 		return ERR_PTR(-EINVAL);
1374 
1375 	pages = kzalloc_obj(*pages, GFP_KERNEL_ACCOUNT);
1376 	if (!pages)
1377 		return ERR_PTR(-ENOMEM);
1378 
1379 	kref_init(&pages->kref);
1380 	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1381 	mutex_init(&pages->mutex);
1382 	pages->source_mm = current->mm;
1383 	mmgrab(pages->source_mm);
1384 	pages->npages = DIV_ROUND_UP(length + start_byte, PAGE_SIZE);
1385 	pages->access_itree = RB_ROOT_CACHED;
1386 	pages->domains_itree = RB_ROOT_CACHED;
1387 	pages->writable = writable;
1388 	if (capable(CAP_IPC_LOCK))
1389 		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1390 	else
1391 		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1392 	pages->source_task = current->group_leader;
1393 	get_task_struct(current->group_leader);
1394 	pages->source_user = get_uid(current_user());
1395 	return pages;
1396 }
1397 
1398 struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
1399 					 unsigned long length, bool writable)
1400 {
1401 	struct iopt_pages *pages;
1402 	unsigned long end;
1403 	void __user *uptr_down =
1404 		(void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1405 
1406 	if (check_add_overflow((unsigned long)uptr, length, &end))
1407 		return ERR_PTR(-EOVERFLOW);
1408 
1409 	pages = iopt_alloc_pages(uptr - uptr_down, length, writable);
1410 	if (IS_ERR(pages))
1411 		return pages;
1412 	pages->uptr = uptr_down;
1413 	pages->type = IOPT_ADDRESS_USER;
1414 	return pages;
1415 }
1416 
1417 struct iopt_pages *iopt_alloc_file_pages(struct file *file,
1418 					 unsigned long start_byte,
1419 					 unsigned long start,
1420 					 unsigned long length, bool writable)
1421 
1422 {
1423 	struct iopt_pages *pages;
1424 
1425 	pages = iopt_alloc_pages(start_byte, length, writable);
1426 	if (IS_ERR(pages))
1427 		return pages;
1428 	pages->file = get_file(file);
1429 	pages->start = start - start_byte;
1430 	pages->type = IOPT_ADDRESS_FILE;
1431 	return pages;
1432 }
1433 
1434 static void iopt_revoke_notify(struct dma_buf_attachment *attach)
1435 {
1436 	struct iopt_pages *pages = attach->importer_priv;
1437 	struct iopt_pages_dmabuf_track *track;
1438 
1439 	guard(mutex)(&pages->mutex);
1440 	if (iopt_dmabuf_revoked(pages))
1441 		return;
1442 
1443 	list_for_each_entry(track, &pages->dmabuf.tracker, elm) {
1444 		struct iopt_area *area = track->area;
1445 
1446 		iopt_area_unmap_domain_range(area, track->domain,
1447 					     iopt_area_index(area),
1448 					     iopt_area_last_index(area));
1449 	}
1450 	pages->dmabuf.phys.len = 0;
1451 }
1452 
1453 static const struct dma_buf_attach_ops iopt_dmabuf_attach_revoke_ops = {
1454 	.allow_peer2peer = true,
1455 	.invalidate_mappings = iopt_revoke_notify,
1456 };
1457 
1458 /*
1459  * iommufd and vfio have a circular dependency. Future work for a phys
1460  * based private interconnect will remove this.
1461  */
1462 static int
1463 sym_vfio_pci_dma_buf_iommufd_map(struct dma_buf_attachment *attachment,
1464 				 struct phys_vec *phys)
1465 {
1466 	typeof(&vfio_pci_dma_buf_iommufd_map) fn;
1467 	int rc;
1468 
1469 	rc = iommufd_test_dma_buf_iommufd_map(attachment, phys);
1470 	if (rc != -EOPNOTSUPP)
1471 		return rc;
1472 
1473 	if (!IS_ENABLED(CONFIG_VFIO_PCI_DMABUF))
1474 		return -EOPNOTSUPP;
1475 
1476 	fn = symbol_get(vfio_pci_dma_buf_iommufd_map);
1477 	if (!fn)
1478 		return -EOPNOTSUPP;
1479 	rc = fn(attachment, phys);
1480 	symbol_put(vfio_pci_dma_buf_iommufd_map);
1481 	return rc;
1482 }
1483 
1484 static int iopt_map_dmabuf(struct iommufd_ctx *ictx, struct iopt_pages *pages,
1485 			   struct dma_buf *dmabuf)
1486 {
1487 	struct dma_buf_attachment *attach;
1488 	int rc;
1489 
1490 	attach = dma_buf_dynamic_attach(dmabuf, iommufd_global_device(),
1491 					&iopt_dmabuf_attach_revoke_ops, pages);
1492 	if (IS_ERR(attach))
1493 		return PTR_ERR(attach);
1494 
1495 	dma_resv_lock(dmabuf->resv, NULL);
1496 	/*
1497 	 * Lock ordering requires the mutex to be taken inside the reservation,
1498 	 * make sure lockdep sees this.
1499 	 */
1500 	if (IS_ENABLED(CONFIG_LOCKDEP)) {
1501 		mutex_lock(&pages->mutex);
1502 		mutex_unlock(&pages->mutex);
1503 	}
1504 
1505 	rc = dma_buf_pin(attach);
1506 	if (rc)
1507 		goto err_detach;
1508 
1509 	rc = sym_vfio_pci_dma_buf_iommufd_map(attach, &pages->dmabuf.phys);
1510 	if (rc)
1511 		goto err_unpin;
1512 
1513 	dma_resv_unlock(dmabuf->resv);
1514 
1515 	/* On success iopt_release_pages() will detach and put the dmabuf. */
1516 	pages->dmabuf.attach = attach;
1517 	return 0;
1518 
1519 err_unpin:
1520 	dma_buf_unpin(attach);
1521 err_detach:
1522 	dma_resv_unlock(dmabuf->resv);
1523 	dma_buf_detach(dmabuf, attach);
1524 	return rc;
1525 }
1526 
1527 struct iopt_pages *iopt_alloc_dmabuf_pages(struct iommufd_ctx *ictx,
1528 					   struct dma_buf *dmabuf,
1529 					   unsigned long start_byte,
1530 					   unsigned long start,
1531 					   unsigned long length, bool writable)
1532 {
1533 	static struct lock_class_key pages_dmabuf_mutex_key;
1534 	struct iopt_pages *pages;
1535 	int rc;
1536 
1537 	if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
1538 		return ERR_PTR(-EOPNOTSUPP);
1539 
1540 	if (dmabuf->size <= (start + length - 1) ||
1541 	    length / PAGE_SIZE >= MAX_NPFNS)
1542 		return ERR_PTR(-EINVAL);
1543 
1544 	pages = iopt_alloc_pages(start_byte, length, writable);
1545 	if (IS_ERR(pages))
1546 		return pages;
1547 
1548 	/*
1549 	 * The mmap_lock can be held when obtaining the dmabuf reservation lock
1550 	 * which creates a locking cycle with the pages mutex which is held
1551 	 * while obtaining the mmap_lock. This locking path is not present for
1552 	 * IOPT_ADDRESS_DMABUF so split the lock class.
1553 	 */
1554 	lockdep_set_class(&pages->mutex, &pages_dmabuf_mutex_key);
1555 
1556 	/* dmabuf does not use pinned page accounting. */
1557 	pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1558 	pages->type = IOPT_ADDRESS_DMABUF;
1559 	pages->dmabuf.start = start - start_byte;
1560 	INIT_LIST_HEAD(&pages->dmabuf.tracker);
1561 
1562 	rc = iopt_map_dmabuf(ictx, pages, dmabuf);
1563 	if (rc) {
1564 		iopt_put_pages(pages);
1565 		return ERR_PTR(rc);
1566 	}
1567 
1568 	return pages;
1569 }
1570 
1571 int iopt_dmabuf_track_domain(struct iopt_pages *pages, struct iopt_area *area,
1572 			     struct iommu_domain *domain)
1573 {
1574 	struct iopt_pages_dmabuf_track *track;
1575 
1576 	lockdep_assert_held(&pages->mutex);
1577 	if (WARN_ON(!iopt_is_dmabuf(pages)))
1578 		return -EINVAL;
1579 
1580 	list_for_each_entry(track, &pages->dmabuf.tracker, elm)
1581 		if (WARN_ON(track->domain == domain && track->area == area))
1582 			return -EINVAL;
1583 
1584 	track = kzalloc_obj(*track);
1585 	if (!track)
1586 		return -ENOMEM;
1587 	track->domain = domain;
1588 	track->area = area;
1589 	list_add_tail(&track->elm, &pages->dmabuf.tracker);
1590 
1591 	return 0;
1592 }
1593 
1594 void iopt_dmabuf_untrack_domain(struct iopt_pages *pages,
1595 				struct iopt_area *area,
1596 				struct iommu_domain *domain)
1597 {
1598 	struct iopt_pages_dmabuf_track *track;
1599 
1600 	lockdep_assert_held(&pages->mutex);
1601 	WARN_ON(!iopt_is_dmabuf(pages));
1602 
1603 	list_for_each_entry(track, &pages->dmabuf.tracker, elm) {
1604 		if (track->domain == domain && track->area == area) {
1605 			list_del(&track->elm);
1606 			kfree(track);
1607 			return;
1608 		}
1609 	}
1610 	WARN_ON(true);
1611 }
1612 
1613 int iopt_dmabuf_track_all_domains(struct iopt_area *area,
1614 				  struct iopt_pages *pages)
1615 {
1616 	struct iopt_pages_dmabuf_track *track;
1617 	struct iommu_domain *domain;
1618 	unsigned long index;
1619 	int rc;
1620 
1621 	list_for_each_entry(track, &pages->dmabuf.tracker, elm)
1622 		if (WARN_ON(track->area == area))
1623 			return -EINVAL;
1624 
1625 	xa_for_each(&area->iopt->domains, index, domain) {
1626 		rc = iopt_dmabuf_track_domain(pages, area, domain);
1627 		if (rc)
1628 			goto err_untrack;
1629 	}
1630 	return 0;
1631 err_untrack:
1632 	iopt_dmabuf_untrack_all_domains(area, pages);
1633 	return rc;
1634 }
1635 
1636 void iopt_dmabuf_untrack_all_domains(struct iopt_area *area,
1637 				     struct iopt_pages *pages)
1638 {
1639 	struct iopt_pages_dmabuf_track *track;
1640 	struct iopt_pages_dmabuf_track *tmp;
1641 
1642 	list_for_each_entry_safe(track, tmp, &pages->dmabuf.tracker,
1643 				 elm) {
1644 		if (track->area == area) {
1645 			list_del(&track->elm);
1646 			kfree(track);
1647 		}
1648 	}
1649 }
1650 
1651 void iopt_release_pages(struct kref *kref)
1652 {
1653 	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1654 
1655 	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1656 	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1657 	WARN_ON(pages->npinned);
1658 	WARN_ON(!xa_empty(&pages->pinned_pfns));
1659 	mmdrop(pages->source_mm);
1660 	mutex_destroy(&pages->mutex);
1661 	put_task_struct(pages->source_task);
1662 	free_uid(pages->source_user);
1663 	if (iopt_is_dmabuf(pages) && pages->dmabuf.attach) {
1664 		struct dma_buf *dmabuf = pages->dmabuf.attach->dmabuf;
1665 
1666 		dma_buf_unpin(pages->dmabuf.attach);
1667 		dma_buf_detach(dmabuf, pages->dmabuf.attach);
1668 		dma_buf_put(dmabuf);
1669 		WARN_ON(!list_empty(&pages->dmabuf.tracker));
1670 	} else if (pages->type == IOPT_ADDRESS_FILE) {
1671 		fput(pages->file);
1672 	}
1673 	kfree(pages);
1674 }
1675 
1676 static void
1677 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1678 		       struct iopt_pages *pages, struct iommu_domain *domain,
1679 		       unsigned long start_index, unsigned long last_index,
1680 		       unsigned long *unmapped_end_index,
1681 		       unsigned long real_last_index)
1682 {
1683 	while (start_index <= last_index) {
1684 		unsigned long batch_last_index;
1685 
1686 		if (*unmapped_end_index <= last_index) {
1687 			unsigned long start =
1688 				max(start_index, *unmapped_end_index);
1689 
1690 			if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1691 			    batch->total_pfns)
1692 				WARN_ON(*unmapped_end_index -
1693 						batch->total_pfns !=
1694 					start_index);
1695 			batch_from_domain(batch, domain, area, start,
1696 					  last_index);
1697 			batch_last_index = start_index + batch->total_pfns - 1;
1698 		} else {
1699 			batch_last_index = last_index;
1700 		}
1701 
1702 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1703 			WARN_ON(batch_last_index > real_last_index);
1704 
1705 		/*
1706 		 * unmaps must always 'cut' at a place where the pfns are not
1707 		 * contiguous to pair with the maps that always install
1708 		 * contiguous pages. Thus, if we have to stop unpinning in the
1709 		 * middle of the domains we need to keep reading pfns until we
1710 		 * find a cut point to do the unmap. The pfns we read are
1711 		 * carried over and either skipped or integrated into the next
1712 		 * batch.
1713 		 */
1714 		if (batch_last_index == last_index &&
1715 		    last_index != real_last_index)
1716 			batch_from_domain_continue(batch, domain, area,
1717 						   last_index + 1,
1718 						   real_last_index);
1719 
1720 		if (*unmapped_end_index <= batch_last_index) {
1721 			iopt_area_unmap_domain_range(
1722 				area, domain, *unmapped_end_index,
1723 				start_index + batch->total_pfns - 1);
1724 			*unmapped_end_index = start_index + batch->total_pfns;
1725 		}
1726 
1727 		/* unpin must follow unmap */
1728 		batch_unpin(batch, pages, 0,
1729 			    batch_last_index - start_index + 1);
1730 		start_index = batch_last_index + 1;
1731 
1732 		batch_clear_carry(batch,
1733 				  *unmapped_end_index - batch_last_index - 1);
1734 	}
1735 }
1736 
1737 static void __iopt_area_unfill_domain(struct iopt_area *area,
1738 				      struct iopt_pages *pages,
1739 				      struct iommu_domain *domain,
1740 				      unsigned long last_index)
1741 {
1742 	struct interval_tree_double_span_iter span;
1743 	unsigned long start_index = iopt_area_index(area);
1744 	unsigned long unmapped_end_index = start_index;
1745 	u64 backup[BATCH_BACKUP_SIZE];
1746 	struct pfn_batch batch;
1747 
1748 	lockdep_assert_held(&pages->mutex);
1749 
1750 	if (iopt_is_dmabuf(pages)) {
1751 		if (WARN_ON(iopt_dmabuf_revoked(pages)))
1752 			return;
1753 		iopt_area_unmap_domain_range(area, domain, start_index,
1754 					     last_index);
1755 		return;
1756 	}
1757 
1758 	/*
1759 	 * For security we must not unpin something that is still DMA mapped,
1760 	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1761 	 * This creates a complexity where we need to skip over unpinning pages
1762 	 * held in the xarray, but continue to unmap from the domain.
1763 	 *
1764 	 * The domain unmap cannot stop in the middle of a contiguous range of
1765 	 * PFNs. To solve this problem the unpinning step will read ahead to the
1766 	 * end of any contiguous span, unmap that whole span, and then only
1767 	 * unpin the leading part that does not have any accesses. The residual
1768 	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1769 	 * batch as they are moved to the front of the PFN list and continue on
1770 	 * to the next iteration(s).
1771 	 */
1772 	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1773 	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1774 					   &pages->access_itree, start_index,
1775 					   last_index) {
1776 		if (span.is_used) {
1777 			batch_skip_carry(&batch,
1778 					 span.last_used - span.start_used + 1);
1779 			continue;
1780 		}
1781 		iopt_area_unpin_domain(&batch, area, pages, domain,
1782 				       span.start_hole, span.last_hole,
1783 				       &unmapped_end_index, last_index);
1784 	}
1785 	/*
1786 	 * If the range ends in a access then we do the residual unmap without
1787 	 * any unpins.
1788 	 */
1789 	if (unmapped_end_index != last_index + 1)
1790 		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1791 					     last_index);
1792 	WARN_ON(batch.total_pfns);
1793 	batch_destroy(&batch, backup);
1794 	update_unpinned(pages);
1795 }
1796 
1797 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1798 					    struct iopt_pages *pages,
1799 					    struct iommu_domain *domain,
1800 					    unsigned long end_index)
1801 {
1802 	if (end_index != iopt_area_index(area))
1803 		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1804 }
1805 
1806 /**
1807  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1808  * @area: The IOVA range to unmap
1809  * @domain: The domain to unmap
1810  *
1811  * The caller must know that unpinning is not required, usually because there
1812  * are other domains in the iopt.
1813  */
1814 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1815 {
1816 	iommu_unmap_nofail(domain, iopt_area_iova(area),
1817 			   iopt_area_length(area));
1818 }
1819 
1820 /**
1821  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1822  * @area: IOVA area to use
1823  * @pages: page supplier for the area (area->pages is NULL)
1824  * @domain: Domain to unmap from
1825  *
1826  * The domain should be removed from the domains_itree before calling. The
1827  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1828  * still accesses.
1829  */
1830 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1831 			     struct iommu_domain *domain)
1832 {
1833 	if (iopt_dmabuf_revoked(pages))
1834 		return;
1835 
1836 	__iopt_area_unfill_domain(area, pages, domain,
1837 				  iopt_area_last_index(area));
1838 }
1839 
1840 /**
1841  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1842  * @area: IOVA area to use
1843  * @domain: Domain to load PFNs into
1844  *
1845  * Read the pfns from the area's underlying iopt_pages and map them into the
1846  * given domain. Called when attaching a new domain to an io_pagetable.
1847  */
1848 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1849 {
1850 	unsigned long done_end_index;
1851 	struct pfn_reader pfns;
1852 	int rc;
1853 
1854 	lockdep_assert_held(&area->pages->mutex);
1855 
1856 	if (iopt_dmabuf_revoked(area->pages))
1857 		return 0;
1858 
1859 	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1860 			      iopt_area_last_index(area));
1861 	if (rc)
1862 		return rc;
1863 
1864 	while (!pfn_reader_done(&pfns)) {
1865 		done_end_index = pfns.batch_start_index;
1866 		rc = batch_to_domain(&pfns.batch, domain, area,
1867 				     pfns.batch_start_index);
1868 		if (rc)
1869 			goto out_unmap;
1870 		done_end_index = pfns.batch_end_index;
1871 
1872 		rc = pfn_reader_next(&pfns);
1873 		if (rc)
1874 			goto out_unmap;
1875 	}
1876 
1877 	rc = pfn_reader_update_pinned(&pfns);
1878 	if (rc)
1879 		goto out_unmap;
1880 	goto out_destroy;
1881 
1882 out_unmap:
1883 	pfn_reader_release_pins(&pfns);
1884 	iopt_area_unfill_partial_domain(area, area->pages, domain,
1885 					done_end_index);
1886 out_destroy:
1887 	pfn_reader_destroy(&pfns);
1888 	return rc;
1889 }
1890 
1891 /**
1892  * iopt_area_fill_domains() - Install PFNs into the area's domains
1893  * @area: The area to act on
1894  * @pages: The pages associated with the area (area->pages is NULL)
1895  *
1896  * Called during area creation. The area is freshly created and not inserted in
1897  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1898  * area's io_pagetable and the area is installed in the domains_itree.
1899  *
1900  * On failure all domains are left unchanged.
1901  */
1902 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1903 {
1904 	unsigned long done_first_end_index;
1905 	unsigned long done_all_end_index;
1906 	struct iommu_domain *domain;
1907 	unsigned long unmap_index;
1908 	struct pfn_reader pfns;
1909 	unsigned long index;
1910 	int rc;
1911 
1912 	lockdep_assert_held(&area->iopt->domains_rwsem);
1913 
1914 	if (xa_empty(&area->iopt->domains))
1915 		return 0;
1916 
1917 	mutex_lock(&pages->mutex);
1918 	if (iopt_is_dmabuf(pages)) {
1919 		rc = iopt_dmabuf_track_all_domains(area, pages);
1920 		if (rc)
1921 			goto out_unlock;
1922 	}
1923 
1924 	if (!iopt_dmabuf_revoked(pages)) {
1925 		rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1926 				      iopt_area_last_index(area));
1927 		if (rc)
1928 			goto out_untrack;
1929 
1930 		while (!pfn_reader_done(&pfns)) {
1931 			done_first_end_index = pfns.batch_end_index;
1932 			done_all_end_index = pfns.batch_start_index;
1933 			xa_for_each(&area->iopt->domains, index, domain) {
1934 				rc = batch_to_domain(&pfns.batch, domain, area,
1935 						     pfns.batch_start_index);
1936 				if (rc)
1937 					goto out_unmap;
1938 			}
1939 			done_all_end_index = done_first_end_index;
1940 
1941 			rc = pfn_reader_next(&pfns);
1942 			if (rc)
1943 				goto out_unmap;
1944 		}
1945 		rc = pfn_reader_update_pinned(&pfns);
1946 		if (rc)
1947 			goto out_unmap;
1948 
1949 		pfn_reader_destroy(&pfns);
1950 	}
1951 
1952 	area->storage_domain = xa_load(&area->iopt->domains, 0);
1953 	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1954 	mutex_unlock(&pages->mutex);
1955 	return 0;
1956 
1957 out_unmap:
1958 	pfn_reader_release_pins(&pfns);
1959 	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1960 		unsigned long end_index;
1961 
1962 		if (unmap_index < index)
1963 			end_index = done_first_end_index;
1964 		else
1965 			end_index = done_all_end_index;
1966 
1967 		/*
1968 		 * The area is not yet part of the domains_itree so we have to
1969 		 * manage the unpinning specially. The last domain does the
1970 		 * unpin, every other domain is just unmapped.
1971 		 */
1972 		if (unmap_index != area->iopt->next_domain_id - 1) {
1973 			if (end_index != iopt_area_index(area))
1974 				iopt_area_unmap_domain_range(
1975 					area, domain, iopt_area_index(area),
1976 					end_index - 1);
1977 		} else {
1978 			iopt_area_unfill_partial_domain(area, pages, domain,
1979 							end_index);
1980 		}
1981 	}
1982 	pfn_reader_destroy(&pfns);
1983 out_untrack:
1984 	if (iopt_is_dmabuf(pages))
1985 		iopt_dmabuf_untrack_all_domains(area, pages);
1986 out_unlock:
1987 	mutex_unlock(&pages->mutex);
1988 	return rc;
1989 }
1990 
1991 /**
1992  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1993  * @area: The area to act on
1994  * @pages: The pages associated with the area (area->pages is NULL)
1995  *
1996  * Called during area destruction. This unmaps the iova's covered by all the
1997  * area's domains and releases the PFNs.
1998  */
1999 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
2000 {
2001 	struct io_pagetable *iopt = area->iopt;
2002 	struct iommu_domain *domain;
2003 	unsigned long index;
2004 
2005 	lockdep_assert_held(&iopt->domains_rwsem);
2006 
2007 	mutex_lock(&pages->mutex);
2008 	if (!area->storage_domain)
2009 		goto out_unlock;
2010 
2011 	xa_for_each(&iopt->domains, index, domain) {
2012 		if (domain == area->storage_domain)
2013 			continue;
2014 
2015 		if (!iopt_dmabuf_revoked(pages))
2016 			iopt_area_unmap_domain_range(
2017 				area, domain, iopt_area_index(area),
2018 				iopt_area_last_index(area));
2019 	}
2020 
2021 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
2022 		WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
2023 	interval_tree_remove(&area->pages_node, &pages->domains_itree);
2024 	iopt_area_unfill_domain(area, pages, area->storage_domain);
2025 	if (iopt_is_dmabuf(pages))
2026 		iopt_dmabuf_untrack_all_domains(area, pages);
2027 	area->storage_domain = NULL;
2028 out_unlock:
2029 	mutex_unlock(&pages->mutex);
2030 }
2031 
2032 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
2033 				    struct iopt_pages *pages,
2034 				    unsigned long start_index,
2035 				    unsigned long end_index)
2036 {
2037 	while (start_index <= end_index) {
2038 		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
2039 					end_index);
2040 		batch_unpin(batch, pages, 0, batch->total_pfns);
2041 		start_index += batch->total_pfns;
2042 		batch_clear(batch);
2043 	}
2044 }
2045 
2046 /**
2047  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
2048  * @pages: The pages to act on
2049  * @start_index: Starting PFN index
2050  * @last_index: Last PFN index
2051  *
2052  * Called when an iopt_pages_access is removed, removes pages from the itree.
2053  * The access should already be removed from the access_itree.
2054  */
2055 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
2056 			      unsigned long start_index,
2057 			      unsigned long last_index)
2058 {
2059 	struct interval_tree_double_span_iter span;
2060 	u64 backup[BATCH_BACKUP_SIZE];
2061 	struct pfn_batch batch;
2062 	bool batch_inited = false;
2063 
2064 	lockdep_assert_held(&pages->mutex);
2065 
2066 	interval_tree_for_each_double_span(&span, &pages->access_itree,
2067 					   &pages->domains_itree, start_index,
2068 					   last_index) {
2069 		if (!span.is_used) {
2070 			if (!batch_inited) {
2071 				batch_init_backup(&batch,
2072 						  last_index - start_index + 1,
2073 						  backup, sizeof(backup));
2074 				batch_inited = true;
2075 			}
2076 			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
2077 						span.last_hole);
2078 		} else if (span.is_used == 2) {
2079 			/* Covered by a domain */
2080 			clear_xarray(&pages->pinned_pfns, span.start_used,
2081 				     span.last_used);
2082 		}
2083 		/* Otherwise covered by an existing access */
2084 	}
2085 	if (batch_inited)
2086 		batch_destroy(&batch, backup);
2087 	update_unpinned(pages);
2088 }
2089 
2090 /**
2091  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
2092  * @pages: The pages to act on
2093  * @start_index: The first page index in the range
2094  * @last_index: The last page index in the range
2095  * @out_pages: The output array to return the pages
2096  *
2097  * This can be called if the caller is holding a refcount on an
2098  * iopt_pages_access that is known to have already been filled. It quickly reads
2099  * the pages directly from the xarray.
2100  *
2101  * This is part of the SW iommu interface to read pages for in-kernel use.
2102  */
2103 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
2104 				 unsigned long start_index,
2105 				 unsigned long last_index,
2106 				 struct page **out_pages)
2107 {
2108 	XA_STATE(xas, &pages->pinned_pfns, start_index);
2109 	void *entry;
2110 
2111 	rcu_read_lock();
2112 	while (start_index <= last_index) {
2113 		entry = xas_next(&xas);
2114 		if (xas_retry(&xas, entry))
2115 			continue;
2116 		WARN_ON(!xa_is_value(entry));
2117 		*(out_pages++) = pfn_to_page(xa_to_value(entry));
2118 		start_index++;
2119 	}
2120 	rcu_read_unlock();
2121 }
2122 
2123 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
2124 				       unsigned long start_index,
2125 				       unsigned long last_index,
2126 				       struct page **out_pages)
2127 {
2128 	while (start_index != last_index + 1) {
2129 		unsigned long domain_last;
2130 		struct iopt_area *area;
2131 
2132 		area = iopt_pages_find_domain_area(pages, start_index);
2133 		if (WARN_ON(!area))
2134 			return -EINVAL;
2135 
2136 		domain_last = min(iopt_area_last_index(area), last_index);
2137 		out_pages = raw_pages_from_domain(area->storage_domain, area,
2138 						  start_index, domain_last,
2139 						  out_pages);
2140 		start_index = domain_last + 1;
2141 	}
2142 	return 0;
2143 }
2144 
2145 static int iopt_pages_fill(struct iopt_pages *pages,
2146 			   struct pfn_reader_user *user,
2147 			   unsigned long start_index,
2148 			   unsigned long last_index,
2149 			   struct page **out_pages)
2150 {
2151 	unsigned long cur_index = start_index;
2152 	int rc;
2153 
2154 	while (cur_index != last_index + 1) {
2155 		user->upages = out_pages + (cur_index - start_index);
2156 		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
2157 		if (rc)
2158 			goto out_unpin;
2159 		cur_index = user->upages_end;
2160 	}
2161 	return 0;
2162 
2163 out_unpin:
2164 	if (start_index != cur_index)
2165 		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
2166 				     out_pages);
2167 	return rc;
2168 }
2169 
2170 /**
2171  * iopt_pages_fill_xarray() - Read PFNs
2172  * @pages: The pages to act on
2173  * @start_index: The first page index in the range
2174  * @last_index: The last page index in the range
2175  * @out_pages: The output array to return the pages, may be NULL
2176  *
2177  * This populates the xarray and returns the pages in out_pages. As the slow
2178  * path this is able to copy pages from other storage tiers into the xarray.
2179  *
2180  * On failure the xarray is left unchanged.
2181  *
2182  * This is part of the SW iommu interface to read pages for in-kernel use.
2183  */
2184 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
2185 			   unsigned long last_index, struct page **out_pages)
2186 {
2187 	struct interval_tree_double_span_iter span;
2188 	unsigned long xa_end = start_index;
2189 	struct pfn_reader_user user;
2190 	int rc;
2191 
2192 	lockdep_assert_held(&pages->mutex);
2193 
2194 	pfn_reader_user_init(&user, pages);
2195 	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
2196 	interval_tree_for_each_double_span(&span, &pages->access_itree,
2197 					   &pages->domains_itree, start_index,
2198 					   last_index) {
2199 		struct page **cur_pages;
2200 
2201 		if (span.is_used == 1) {
2202 			cur_pages = out_pages + (span.start_used - start_index);
2203 			iopt_pages_fill_from_xarray(pages, span.start_used,
2204 						    span.last_used, cur_pages);
2205 			continue;
2206 		}
2207 
2208 		if (span.is_used == 2) {
2209 			cur_pages = out_pages + (span.start_used - start_index);
2210 			iopt_pages_fill_from_domain(pages, span.start_used,
2211 						    span.last_used, cur_pages);
2212 			rc = pages_to_xarray(&pages->pinned_pfns,
2213 					     span.start_used, span.last_used,
2214 					     cur_pages);
2215 			if (rc)
2216 				goto out_clean_xa;
2217 			xa_end = span.last_used + 1;
2218 			continue;
2219 		}
2220 
2221 		/* hole */
2222 		cur_pages = out_pages + (span.start_hole - start_index);
2223 		rc = iopt_pages_fill(pages, &user, span.start_hole,
2224 				     span.last_hole, cur_pages);
2225 		if (rc)
2226 			goto out_clean_xa;
2227 		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
2228 				     span.last_hole, cur_pages);
2229 		if (rc) {
2230 			iopt_pages_err_unpin(pages, span.start_hole,
2231 					     span.last_hole, cur_pages);
2232 			goto out_clean_xa;
2233 		}
2234 		xa_end = span.last_hole + 1;
2235 	}
2236 	rc = pfn_reader_user_update_pinned(&user, pages);
2237 	if (rc)
2238 		goto out_clean_xa;
2239 	user.upages = NULL;
2240 	pfn_reader_user_destroy(&user, pages);
2241 	return 0;
2242 
2243 out_clean_xa:
2244 	if (start_index != xa_end)
2245 		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
2246 	user.upages = NULL;
2247 	pfn_reader_user_destroy(&user, pages);
2248 	return rc;
2249 }
2250 
2251 /*
2252  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
2253  * do every scenario and is fully consistent with what an iommu_domain would
2254  * see.
2255  */
2256 static int iopt_pages_rw_slow(struct iopt_pages *pages,
2257 			      unsigned long start_index,
2258 			      unsigned long last_index, unsigned long offset,
2259 			      void *data, unsigned long length,
2260 			      unsigned int flags)
2261 {
2262 	struct pfn_reader pfns;
2263 	int rc;
2264 
2265 	mutex_lock(&pages->mutex);
2266 
2267 	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
2268 	if (rc)
2269 		goto out_unlock;
2270 
2271 	while (!pfn_reader_done(&pfns)) {
2272 		unsigned long done;
2273 
2274 		done = batch_rw(&pfns.batch, data, offset, length, flags);
2275 		data += done;
2276 		length -= done;
2277 		offset = 0;
2278 		pfn_reader_unpin(&pfns);
2279 
2280 		rc = pfn_reader_next(&pfns);
2281 		if (rc)
2282 			goto out_destroy;
2283 	}
2284 	if (WARN_ON(length != 0))
2285 		rc = -EINVAL;
2286 out_destroy:
2287 	pfn_reader_destroy(&pfns);
2288 out_unlock:
2289 	mutex_unlock(&pages->mutex);
2290 	return rc;
2291 }
2292 
2293 /*
2294  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
2295  * memory allocations or interval tree searches.
2296  */
2297 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
2298 			      unsigned long offset, void *data,
2299 			      unsigned long length, unsigned int flags)
2300 {
2301 	struct page *page = NULL;
2302 	int rc;
2303 
2304 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2305 	    WARN_ON(pages->type != IOPT_ADDRESS_USER))
2306 		return -EINVAL;
2307 
2308 	if (!mmget_not_zero(pages->source_mm))
2309 		return iopt_pages_rw_slow(pages, index, index, offset, data,
2310 					  length, flags);
2311 
2312 	if (iommufd_should_fail()) {
2313 		rc = -EINVAL;
2314 		goto out_mmput;
2315 	}
2316 
2317 	mmap_read_lock(pages->source_mm);
2318 	rc = pin_user_pages_remote(
2319 		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
2320 		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
2321 		NULL);
2322 	mmap_read_unlock(pages->source_mm);
2323 	if (rc != 1) {
2324 		if (WARN_ON(rc >= 0))
2325 			rc = -EINVAL;
2326 		goto out_mmput;
2327 	}
2328 	copy_data_page(page, data, offset, length, flags);
2329 	unpin_user_page(page);
2330 	rc = 0;
2331 
2332 out_mmput:
2333 	mmput(pages->source_mm);
2334 	return rc;
2335 }
2336 
2337 /**
2338  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
2339  * @pages: pages to act on
2340  * @start_byte: First byte of pages to copy to/from
2341  * @data: Kernel buffer to get/put the data
2342  * @length: Number of bytes to copy
2343  * @flags: IOMMUFD_ACCESS_RW_* flags
2344  *
2345  * This will find each page in the range, kmap it and then memcpy to/from
2346  * the given kernel buffer.
2347  */
2348 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
2349 			 void *data, unsigned long length, unsigned int flags)
2350 {
2351 	unsigned long start_index = start_byte / PAGE_SIZE;
2352 	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
2353 	bool change_mm = current->mm != pages->source_mm;
2354 	int rc = 0;
2355 
2356 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2357 	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
2358 		change_mm = true;
2359 
2360 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2361 		return -EPERM;
2362 
2363 	if (iopt_is_dmabuf(pages))
2364 		return -EINVAL;
2365 
2366 	if (pages->type != IOPT_ADDRESS_USER)
2367 		return iopt_pages_rw_slow(pages, start_index, last_index,
2368 					  start_byte % PAGE_SIZE, data, length,
2369 					  flags);
2370 
2371 	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
2372 		if (start_index == last_index)
2373 			return iopt_pages_rw_page(pages, start_index,
2374 						  start_byte % PAGE_SIZE, data,
2375 						  length, flags);
2376 		return iopt_pages_rw_slow(pages, start_index, last_index,
2377 					  start_byte % PAGE_SIZE, data, length,
2378 					  flags);
2379 	}
2380 
2381 	/*
2382 	 * Try to copy using copy_to_user(). We do this as a fast path and
2383 	 * ignore any pinning inconsistencies, unlike a real DMA path.
2384 	 */
2385 	if (change_mm) {
2386 		if (!mmget_not_zero(pages->source_mm))
2387 			return iopt_pages_rw_slow(pages, start_index,
2388 						  last_index,
2389 						  start_byte % PAGE_SIZE, data,
2390 						  length, flags);
2391 		kthread_use_mm(pages->source_mm);
2392 	}
2393 
2394 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
2395 		if (copy_to_user(pages->uptr + start_byte, data, length))
2396 			rc = -EFAULT;
2397 	} else {
2398 		if (copy_from_user(data, pages->uptr + start_byte, length))
2399 			rc = -EFAULT;
2400 	}
2401 
2402 	if (change_mm) {
2403 		kthread_unuse_mm(pages->source_mm);
2404 		mmput(pages->source_mm);
2405 	}
2406 
2407 	return rc;
2408 }
2409 
2410 static struct iopt_pages_access *
2411 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
2412 			    unsigned long last)
2413 {
2414 	struct interval_tree_node *node;
2415 
2416 	lockdep_assert_held(&pages->mutex);
2417 
2418 	/* There can be overlapping ranges in this interval tree */
2419 	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
2420 	     node; node = interval_tree_iter_next(node, index, last))
2421 		if (node->start == index && node->last == last)
2422 			return container_of(node, struct iopt_pages_access,
2423 					    node);
2424 	return NULL;
2425 }
2426 
2427 /**
2428  * iopt_area_add_access() - Record an in-knerel access for PFNs
2429  * @area: The source of PFNs
2430  * @start_index: First page index
2431  * @last_index: Inclusive last page index
2432  * @out_pages: Output list of struct page's representing the PFNs
2433  * @flags: IOMMUFD_ACCESS_RW_* flags
2434  * @lock_area: Fail userspace munmap on this area
2435  *
2436  * Record that an in-kernel access will be accessing the pages, ensure they are
2437  * pinned, and return the PFNs as a simple list of 'struct page *'.
2438  *
2439  * This should be undone through a matching call to iopt_area_remove_access()
2440  */
2441 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
2442 			 unsigned long last_index, struct page **out_pages,
2443 			 unsigned int flags, bool lock_area)
2444 {
2445 	struct iopt_pages *pages = area->pages;
2446 	struct iopt_pages_access *access;
2447 	int rc;
2448 
2449 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2450 		return -EPERM;
2451 
2452 	mutex_lock(&pages->mutex);
2453 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2454 	if (access) {
2455 		area->num_accesses++;
2456 		if (lock_area)
2457 			area->num_locks++;
2458 		access->users++;
2459 		iopt_pages_fill_from_xarray(pages, start_index, last_index,
2460 					    out_pages);
2461 		mutex_unlock(&pages->mutex);
2462 		return 0;
2463 	}
2464 
2465 	access = kzalloc_obj(*access, GFP_KERNEL_ACCOUNT);
2466 	if (!access) {
2467 		rc = -ENOMEM;
2468 		goto err_unlock;
2469 	}
2470 
2471 	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
2472 	if (rc)
2473 		goto err_free;
2474 
2475 	access->node.start = start_index;
2476 	access->node.last = last_index;
2477 	access->users = 1;
2478 	area->num_accesses++;
2479 	if (lock_area)
2480 		area->num_locks++;
2481 	interval_tree_insert(&access->node, &pages->access_itree);
2482 	mutex_unlock(&pages->mutex);
2483 	return 0;
2484 
2485 err_free:
2486 	kfree(access);
2487 err_unlock:
2488 	mutex_unlock(&pages->mutex);
2489 	return rc;
2490 }
2491 
2492 /**
2493  * iopt_area_remove_access() - Release an in-kernel access for PFNs
2494  * @area: The source of PFNs
2495  * @start_index: First page index
2496  * @last_index: Inclusive last page index
2497  * @unlock_area: Must match the matching iopt_area_add_access()'s lock_area
2498  *
2499  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
2500  * must stop using the PFNs before calling this.
2501  */
2502 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
2503 			     unsigned long last_index, bool unlock_area)
2504 {
2505 	struct iopt_pages *pages = area->pages;
2506 	struct iopt_pages_access *access;
2507 
2508 	mutex_lock(&pages->mutex);
2509 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2510 	if (WARN_ON(!access))
2511 		goto out_unlock;
2512 
2513 	WARN_ON(area->num_accesses == 0 || access->users == 0);
2514 	if (unlock_area) {
2515 		WARN_ON(area->num_locks == 0);
2516 		area->num_locks--;
2517 	}
2518 	area->num_accesses--;
2519 	access->users--;
2520 	if (access->users)
2521 		goto out_unlock;
2522 
2523 	interval_tree_remove(&access->node, &pages->access_itree);
2524 	iopt_pages_unfill_xarray(pages, start_index, last_index);
2525 	kfree(access);
2526 out_unlock:
2527 	mutex_unlock(&pages->mutex);
2528 }
2529