xref: /linux/drivers/iommu/iommufd/pages.c (revision 3fd6c59042dbba50391e30862beac979491145fe)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/file.h>
49 #include <linux/highmem.h>
50 #include <linux/iommu.h>
51 #include <linux/iommufd.h>
52 #include <linux/kthread.h>
53 #include <linux/overflow.h>
54 #include <linux/slab.h>
55 #include <linux/sched/mm.h>
56 
57 #include "double_span.h"
58 #include "io_pagetable.h"
59 
60 #ifndef CONFIG_IOMMUFD_TEST
61 #define TEMP_MEMORY_LIMIT 65536
62 #else
63 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
64 #endif
65 #define BATCH_BACKUP_SIZE 32
66 
67 /*
68  * More memory makes pin_user_pages() and the batching more efficient, but as
69  * this is only a performance optimization don't try too hard to get it. A 64k
70  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
71  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
72  * stack memory as a backup contingency. If backup_len is given this cannot
73  * fail.
74  */
temp_kmalloc(size_t * size,void * backup,size_t backup_len)75 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
76 {
77 	void *res;
78 
79 	if (WARN_ON(*size == 0))
80 		return NULL;
81 
82 	if (*size < backup_len)
83 		return backup;
84 
85 	if (!backup && iommufd_should_fail())
86 		return NULL;
87 
88 	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
89 	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
90 	if (res)
91 		return res;
92 	*size = PAGE_SIZE;
93 	if (backup_len) {
94 		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
95 		if (res)
96 			return res;
97 		*size = backup_len;
98 		return backup;
99 	}
100 	return kmalloc(*size, GFP_KERNEL);
101 }
102 
interval_tree_double_span_iter_update(struct interval_tree_double_span_iter * iter)103 void interval_tree_double_span_iter_update(
104 	struct interval_tree_double_span_iter *iter)
105 {
106 	unsigned long last_hole = ULONG_MAX;
107 	unsigned int i;
108 
109 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
110 		if (interval_tree_span_iter_done(&iter->spans[i])) {
111 			iter->is_used = -1;
112 			return;
113 		}
114 
115 		if (iter->spans[i].is_hole) {
116 			last_hole = min(last_hole, iter->spans[i].last_hole);
117 			continue;
118 		}
119 
120 		iter->is_used = i + 1;
121 		iter->start_used = iter->spans[i].start_used;
122 		iter->last_used = min(iter->spans[i].last_used, last_hole);
123 		return;
124 	}
125 
126 	iter->is_used = 0;
127 	iter->start_hole = iter->spans[0].start_hole;
128 	iter->last_hole =
129 		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
130 }
131 
interval_tree_double_span_iter_first(struct interval_tree_double_span_iter * iter,struct rb_root_cached * itree1,struct rb_root_cached * itree2,unsigned long first_index,unsigned long last_index)132 void interval_tree_double_span_iter_first(
133 	struct interval_tree_double_span_iter *iter,
134 	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
135 	unsigned long first_index, unsigned long last_index)
136 {
137 	unsigned int i;
138 
139 	iter->itrees[0] = itree1;
140 	iter->itrees[1] = itree2;
141 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
142 		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
143 					      first_index, last_index);
144 	interval_tree_double_span_iter_update(iter);
145 }
146 
interval_tree_double_span_iter_next(struct interval_tree_double_span_iter * iter)147 void interval_tree_double_span_iter_next(
148 	struct interval_tree_double_span_iter *iter)
149 {
150 	unsigned int i;
151 
152 	if (iter->is_used == -1 ||
153 	    iter->last_hole == iter->spans[0].last_index) {
154 		iter->is_used = -1;
155 		return;
156 	}
157 
158 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
159 		interval_tree_span_iter_advance(
160 			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
161 	interval_tree_double_span_iter_update(iter);
162 }
163 
iopt_pages_add_npinned(struct iopt_pages * pages,size_t npages)164 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
165 {
166 	int rc;
167 
168 	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
169 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
170 		WARN_ON(rc || pages->npinned > pages->npages);
171 }
172 
iopt_pages_sub_npinned(struct iopt_pages * pages,size_t npages)173 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
174 {
175 	int rc;
176 
177 	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
178 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
179 		WARN_ON(rc || pages->npinned > pages->npages);
180 }
181 
iopt_pages_err_unpin(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** page_list)182 static void iopt_pages_err_unpin(struct iopt_pages *pages,
183 				 unsigned long start_index,
184 				 unsigned long last_index,
185 				 struct page **page_list)
186 {
187 	unsigned long npages = last_index - start_index + 1;
188 
189 	unpin_user_pages(page_list, npages);
190 	iopt_pages_sub_npinned(pages, npages);
191 }
192 
193 /*
194  * index is the number of PAGE_SIZE units from the start of the area's
195  * iopt_pages. If the iova is sub page-size then the area has an iova that
196  * covers a portion of the first and last pages in the range.
197  */
iopt_area_index_to_iova(struct iopt_area * area,unsigned long index)198 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
199 					     unsigned long index)
200 {
201 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
202 		WARN_ON(index < iopt_area_index(area) ||
203 			index > iopt_area_last_index(area));
204 	index -= iopt_area_index(area);
205 	if (index == 0)
206 		return iopt_area_iova(area);
207 	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
208 }
209 
iopt_area_index_to_iova_last(struct iopt_area * area,unsigned long index)210 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
211 						  unsigned long index)
212 {
213 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
214 		WARN_ON(index < iopt_area_index(area) ||
215 			index > iopt_area_last_index(area));
216 	if (index == iopt_area_last_index(area))
217 		return iopt_area_last_iova(area);
218 	return iopt_area_iova(area) - area->page_offset +
219 	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
220 }
221 
iommu_unmap_nofail(struct iommu_domain * domain,unsigned long iova,size_t size)222 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
223 			       size_t size)
224 {
225 	size_t ret;
226 
227 	ret = iommu_unmap(domain, iova, size);
228 	/*
229 	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
230 	 * something other than exactly as requested. This implies that the
231 	 * iommu driver may not fail unmap for reasons beyond bad agruments.
232 	 * Particularly, the iommu driver may not do a memory allocation on the
233 	 * unmap path.
234 	 */
235 	WARN_ON(ret != size);
236 }
237 
iopt_area_unmap_domain_range(struct iopt_area * area,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index)238 static void iopt_area_unmap_domain_range(struct iopt_area *area,
239 					 struct iommu_domain *domain,
240 					 unsigned long start_index,
241 					 unsigned long last_index)
242 {
243 	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
244 
245 	iommu_unmap_nofail(domain, start_iova,
246 			   iopt_area_index_to_iova_last(area, last_index) -
247 				   start_iova + 1);
248 }
249 
iopt_pages_find_domain_area(struct iopt_pages * pages,unsigned long index)250 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
251 						     unsigned long index)
252 {
253 	struct interval_tree_node *node;
254 
255 	node = interval_tree_iter_first(&pages->domains_itree, index, index);
256 	if (!node)
257 		return NULL;
258 	return container_of(node, struct iopt_area, pages_node);
259 }
260 
261 /*
262  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
263  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
264  * place to another. Generally everything is made more efficient if operations
265  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
266  * better cache locality, etc
267  */
268 struct pfn_batch {
269 	unsigned long *pfns;
270 	u32 *npfns;
271 	unsigned int array_size;
272 	unsigned int end;
273 	unsigned int total_pfns;
274 };
275 
batch_clear(struct pfn_batch * batch)276 static void batch_clear(struct pfn_batch *batch)
277 {
278 	batch->total_pfns = 0;
279 	batch->end = 0;
280 	batch->pfns[0] = 0;
281 	batch->npfns[0] = 0;
282 }
283 
284 /*
285  * Carry means we carry a portion of the final hugepage over to the front of the
286  * batch
287  */
batch_clear_carry(struct pfn_batch * batch,unsigned int keep_pfns)288 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
289 {
290 	if (!keep_pfns)
291 		return batch_clear(batch);
292 
293 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
294 		WARN_ON(!batch->end ||
295 			batch->npfns[batch->end - 1] < keep_pfns);
296 
297 	batch->total_pfns = keep_pfns;
298 	batch->pfns[0] = batch->pfns[batch->end - 1] +
299 			 (batch->npfns[batch->end - 1] - keep_pfns);
300 	batch->npfns[0] = keep_pfns;
301 	batch->end = 1;
302 }
303 
batch_skip_carry(struct pfn_batch * batch,unsigned int skip_pfns)304 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
305 {
306 	if (!batch->total_pfns)
307 		return;
308 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
309 		WARN_ON(batch->total_pfns != batch->npfns[0]);
310 	skip_pfns = min(batch->total_pfns, skip_pfns);
311 	batch->pfns[0] += skip_pfns;
312 	batch->npfns[0] -= skip_pfns;
313 	batch->total_pfns -= skip_pfns;
314 }
315 
__batch_init(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)316 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
317 			size_t backup_len)
318 {
319 	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
320 	size_t size = max_pages * elmsz;
321 
322 	batch->pfns = temp_kmalloc(&size, backup, backup_len);
323 	if (!batch->pfns)
324 		return -ENOMEM;
325 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
326 		return -EINVAL;
327 	batch->array_size = size / elmsz;
328 	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
329 	batch_clear(batch);
330 	return 0;
331 }
332 
batch_init(struct pfn_batch * batch,size_t max_pages)333 static int batch_init(struct pfn_batch *batch, size_t max_pages)
334 {
335 	return __batch_init(batch, max_pages, NULL, 0);
336 }
337 
batch_init_backup(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)338 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
339 			      void *backup, size_t backup_len)
340 {
341 	__batch_init(batch, max_pages, backup, backup_len);
342 }
343 
batch_destroy(struct pfn_batch * batch,void * backup)344 static void batch_destroy(struct pfn_batch *batch, void *backup)
345 {
346 	if (batch->pfns != backup)
347 		kfree(batch->pfns);
348 }
349 
batch_add_pfn_num(struct pfn_batch * batch,unsigned long pfn,u32 nr)350 static bool batch_add_pfn_num(struct pfn_batch *batch, unsigned long pfn,
351 			      u32 nr)
352 {
353 	const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
354 	unsigned int end = batch->end;
355 
356 	if (end && pfn == batch->pfns[end - 1] + batch->npfns[end - 1] &&
357 	    nr <= MAX_NPFNS - batch->npfns[end - 1]) {
358 		batch->npfns[end - 1] += nr;
359 	} else if (end < batch->array_size) {
360 		batch->pfns[end] = pfn;
361 		batch->npfns[end] = nr;
362 		batch->end++;
363 	} else {
364 		return false;
365 	}
366 
367 	batch->total_pfns += nr;
368 	return true;
369 }
370 
batch_remove_pfn_num(struct pfn_batch * batch,unsigned long nr)371 static void batch_remove_pfn_num(struct pfn_batch *batch, unsigned long nr)
372 {
373 	batch->npfns[batch->end - 1] -= nr;
374 	if (batch->npfns[batch->end - 1] == 0)
375 		batch->end--;
376 	batch->total_pfns -= nr;
377 }
378 
379 /* true if the pfn was added, false otherwise */
batch_add_pfn(struct pfn_batch * batch,unsigned long pfn)380 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
381 {
382 	return batch_add_pfn_num(batch, pfn, 1);
383 }
384 
385 /*
386  * Fill the batch with pfns from the domain. When the batch is full, or it
387  * reaches last_index, the function will return. The caller should use
388  * batch->total_pfns to determine the starting point for the next iteration.
389  */
batch_from_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)390 static void batch_from_domain(struct pfn_batch *batch,
391 			      struct iommu_domain *domain,
392 			      struct iopt_area *area, unsigned long start_index,
393 			      unsigned long last_index)
394 {
395 	unsigned int page_offset = 0;
396 	unsigned long iova;
397 	phys_addr_t phys;
398 
399 	iova = iopt_area_index_to_iova(area, start_index);
400 	if (start_index == iopt_area_index(area))
401 		page_offset = area->page_offset;
402 	while (start_index <= last_index) {
403 		/*
404 		 * This is pretty slow, it would be nice to get the page size
405 		 * back from the driver, or have the driver directly fill the
406 		 * batch.
407 		 */
408 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
409 		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
410 			return;
411 		iova += PAGE_SIZE - page_offset;
412 		page_offset = 0;
413 		start_index++;
414 	}
415 }
416 
raw_pages_from_domain(struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages)417 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
418 					   struct iopt_area *area,
419 					   unsigned long start_index,
420 					   unsigned long last_index,
421 					   struct page **out_pages)
422 {
423 	unsigned int page_offset = 0;
424 	unsigned long iova;
425 	phys_addr_t phys;
426 
427 	iova = iopt_area_index_to_iova(area, start_index);
428 	if (start_index == iopt_area_index(area))
429 		page_offset = area->page_offset;
430 	while (start_index <= last_index) {
431 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
432 		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
433 		iova += PAGE_SIZE - page_offset;
434 		page_offset = 0;
435 		start_index++;
436 	}
437 	return out_pages;
438 }
439 
440 /* Continues reading a domain until we reach a discontinuity in the pfns. */
batch_from_domain_continue(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)441 static void batch_from_domain_continue(struct pfn_batch *batch,
442 				       struct iommu_domain *domain,
443 				       struct iopt_area *area,
444 				       unsigned long start_index,
445 				       unsigned long last_index)
446 {
447 	unsigned int array_size = batch->array_size;
448 
449 	batch->array_size = batch->end;
450 	batch_from_domain(batch, domain, area, start_index, last_index);
451 	batch->array_size = array_size;
452 }
453 
454 /*
455  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
456  * mode permits splitting a mapped area up, and then one of the splits is
457  * unmapped. Doing this normally would cause us to violate our invariant of
458  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
459  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
460  * PAGE_SIZE units, not larger or smaller.
461  */
batch_iommu_map_small(struct iommu_domain * domain,unsigned long iova,phys_addr_t paddr,size_t size,int prot)462 static int batch_iommu_map_small(struct iommu_domain *domain,
463 				 unsigned long iova, phys_addr_t paddr,
464 				 size_t size, int prot)
465 {
466 	unsigned long start_iova = iova;
467 	int rc;
468 
469 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
470 		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
471 			size % PAGE_SIZE);
472 
473 	while (size) {
474 		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
475 			       GFP_KERNEL_ACCOUNT);
476 		if (rc)
477 			goto err_unmap;
478 		iova += PAGE_SIZE;
479 		paddr += PAGE_SIZE;
480 		size -= PAGE_SIZE;
481 	}
482 	return 0;
483 
484 err_unmap:
485 	if (start_iova != iova)
486 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
487 	return rc;
488 }
489 
batch_to_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index)490 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
491 			   struct iopt_area *area, unsigned long start_index)
492 {
493 	bool disable_large_pages = area->iopt->disable_large_pages;
494 	unsigned long last_iova = iopt_area_last_iova(area);
495 	unsigned int page_offset = 0;
496 	unsigned long start_iova;
497 	unsigned long next_iova;
498 	unsigned int cur = 0;
499 	unsigned long iova;
500 	int rc;
501 
502 	/* The first index might be a partial page */
503 	if (start_index == iopt_area_index(area))
504 		page_offset = area->page_offset;
505 	next_iova = iova = start_iova =
506 		iopt_area_index_to_iova(area, start_index);
507 	while (cur < batch->end) {
508 		next_iova = min(last_iova + 1,
509 				next_iova + batch->npfns[cur] * PAGE_SIZE -
510 					page_offset);
511 		if (disable_large_pages)
512 			rc = batch_iommu_map_small(
513 				domain, iova,
514 				PFN_PHYS(batch->pfns[cur]) + page_offset,
515 				next_iova - iova, area->iommu_prot);
516 		else
517 			rc = iommu_map(domain, iova,
518 				       PFN_PHYS(batch->pfns[cur]) + page_offset,
519 				       next_iova - iova, area->iommu_prot,
520 				       GFP_KERNEL_ACCOUNT);
521 		if (rc)
522 			goto err_unmap;
523 		iova = next_iova;
524 		page_offset = 0;
525 		cur++;
526 	}
527 	return 0;
528 err_unmap:
529 	if (start_iova != iova)
530 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
531 	return rc;
532 }
533 
batch_from_xarray(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)534 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
535 			      unsigned long start_index,
536 			      unsigned long last_index)
537 {
538 	XA_STATE(xas, xa, start_index);
539 	void *entry;
540 
541 	rcu_read_lock();
542 	while (true) {
543 		entry = xas_next(&xas);
544 		if (xas_retry(&xas, entry))
545 			continue;
546 		WARN_ON(!xa_is_value(entry));
547 		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
548 		    start_index == last_index)
549 			break;
550 		start_index++;
551 	}
552 	rcu_read_unlock();
553 }
554 
batch_from_xarray_clear(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)555 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
556 				    unsigned long start_index,
557 				    unsigned long last_index)
558 {
559 	XA_STATE(xas, xa, start_index);
560 	void *entry;
561 
562 	xas_lock(&xas);
563 	while (true) {
564 		entry = xas_next(&xas);
565 		if (xas_retry(&xas, entry))
566 			continue;
567 		WARN_ON(!xa_is_value(entry));
568 		if (!batch_add_pfn(batch, xa_to_value(entry)))
569 			break;
570 		xas_store(&xas, NULL);
571 		if (start_index == last_index)
572 			break;
573 		start_index++;
574 	}
575 	xas_unlock(&xas);
576 }
577 
clear_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index)578 static void clear_xarray(struct xarray *xa, unsigned long start_index,
579 			 unsigned long last_index)
580 {
581 	XA_STATE(xas, xa, start_index);
582 	void *entry;
583 
584 	xas_lock(&xas);
585 	xas_for_each(&xas, entry, last_index)
586 		xas_store(&xas, NULL);
587 	xas_unlock(&xas);
588 }
589 
pages_to_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index,struct page ** pages)590 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
591 			   unsigned long last_index, struct page **pages)
592 {
593 	struct page **end_pages = pages + (last_index - start_index) + 1;
594 	struct page **half_pages = pages + (end_pages - pages) / 2;
595 	XA_STATE(xas, xa, start_index);
596 
597 	do {
598 		void *old;
599 
600 		xas_lock(&xas);
601 		while (pages != end_pages) {
602 			/* xarray does not participate in fault injection */
603 			if (pages == half_pages && iommufd_should_fail()) {
604 				xas_set_err(&xas, -EINVAL);
605 				xas_unlock(&xas);
606 				/* aka xas_destroy() */
607 				xas_nomem(&xas, GFP_KERNEL);
608 				goto err_clear;
609 			}
610 
611 			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
612 			if (xas_error(&xas))
613 				break;
614 			WARN_ON(old);
615 			pages++;
616 			xas_next(&xas);
617 		}
618 		xas_unlock(&xas);
619 	} while (xas_nomem(&xas, GFP_KERNEL));
620 
621 err_clear:
622 	if (xas_error(&xas)) {
623 		if (xas.xa_index != start_index)
624 			clear_xarray(xa, start_index, xas.xa_index - 1);
625 		return xas_error(&xas);
626 	}
627 	return 0;
628 }
629 
batch_from_pages(struct pfn_batch * batch,struct page ** pages,size_t npages)630 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
631 			     size_t npages)
632 {
633 	struct page **end = pages + npages;
634 
635 	for (; pages != end; pages++)
636 		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
637 			break;
638 }
639 
batch_from_folios(struct pfn_batch * batch,struct folio *** folios_p,unsigned long * offset_p,unsigned long npages)640 static int batch_from_folios(struct pfn_batch *batch, struct folio ***folios_p,
641 			     unsigned long *offset_p, unsigned long npages)
642 {
643 	int rc = 0;
644 	struct folio **folios = *folios_p;
645 	unsigned long offset = *offset_p;
646 
647 	while (npages) {
648 		struct folio *folio = *folios;
649 		unsigned long nr = folio_nr_pages(folio) - offset;
650 		unsigned long pfn = page_to_pfn(folio_page(folio, offset));
651 
652 		nr = min(nr, npages);
653 		npages -= nr;
654 
655 		if (!batch_add_pfn_num(batch, pfn, nr))
656 			break;
657 		if (nr > 1) {
658 			rc = folio_add_pins(folio, nr - 1);
659 			if (rc) {
660 				batch_remove_pfn_num(batch, nr);
661 				goto out;
662 			}
663 		}
664 
665 		folios++;
666 		offset = 0;
667 	}
668 
669 out:
670 	*folios_p = folios;
671 	*offset_p = offset;
672 	return rc;
673 }
674 
batch_unpin(struct pfn_batch * batch,struct iopt_pages * pages,unsigned int first_page_off,size_t npages)675 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
676 			unsigned int first_page_off, size_t npages)
677 {
678 	unsigned int cur = 0;
679 
680 	while (first_page_off) {
681 		if (batch->npfns[cur] > first_page_off)
682 			break;
683 		first_page_off -= batch->npfns[cur];
684 		cur++;
685 	}
686 
687 	while (npages) {
688 		size_t to_unpin = min_t(size_t, npages,
689 					batch->npfns[cur] - first_page_off);
690 
691 		unpin_user_page_range_dirty_lock(
692 			pfn_to_page(batch->pfns[cur] + first_page_off),
693 			to_unpin, pages->writable);
694 		iopt_pages_sub_npinned(pages, to_unpin);
695 		cur++;
696 		first_page_off = 0;
697 		npages -= to_unpin;
698 	}
699 }
700 
copy_data_page(struct page * page,void * data,unsigned long offset,size_t length,unsigned int flags)701 static void copy_data_page(struct page *page, void *data, unsigned long offset,
702 			   size_t length, unsigned int flags)
703 {
704 	void *mem;
705 
706 	mem = kmap_local_page(page);
707 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
708 		memcpy(mem + offset, data, length);
709 		set_page_dirty_lock(page);
710 	} else {
711 		memcpy(data, mem + offset, length);
712 	}
713 	kunmap_local(mem);
714 }
715 
batch_rw(struct pfn_batch * batch,void * data,unsigned long offset,unsigned long length,unsigned int flags)716 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
717 			      unsigned long offset, unsigned long length,
718 			      unsigned int flags)
719 {
720 	unsigned long copied = 0;
721 	unsigned int npage = 0;
722 	unsigned int cur = 0;
723 
724 	while (cur < batch->end) {
725 		unsigned long bytes = min(length, PAGE_SIZE - offset);
726 
727 		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
728 			       offset, bytes, flags);
729 		offset = 0;
730 		length -= bytes;
731 		data += bytes;
732 		copied += bytes;
733 		npage++;
734 		if (npage == batch->npfns[cur]) {
735 			npage = 0;
736 			cur++;
737 		}
738 		if (!length)
739 			break;
740 	}
741 	return copied;
742 }
743 
744 /* pfn_reader_user is just the pin_user_pages() path */
745 struct pfn_reader_user {
746 	struct page **upages;
747 	size_t upages_len;
748 	unsigned long upages_start;
749 	unsigned long upages_end;
750 	unsigned int gup_flags;
751 	/*
752 	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
753 	 * neither
754 	 */
755 	int locked;
756 
757 	/* The following are only valid if file != NULL. */
758 	struct file *file;
759 	struct folio **ufolios;
760 	size_t ufolios_len;
761 	unsigned long ufolios_offset;
762 	struct folio **ufolios_next;
763 };
764 
pfn_reader_user_init(struct pfn_reader_user * user,struct iopt_pages * pages)765 static void pfn_reader_user_init(struct pfn_reader_user *user,
766 				 struct iopt_pages *pages)
767 {
768 	user->upages = NULL;
769 	user->upages_len = 0;
770 	user->upages_start = 0;
771 	user->upages_end = 0;
772 	user->locked = -1;
773 	user->gup_flags = FOLL_LONGTERM;
774 	if (pages->writable)
775 		user->gup_flags |= FOLL_WRITE;
776 
777 	user->file = (pages->type == IOPT_ADDRESS_FILE) ? pages->file : NULL;
778 	user->ufolios = NULL;
779 	user->ufolios_len = 0;
780 	user->ufolios_next = NULL;
781 	user->ufolios_offset = 0;
782 }
783 
pfn_reader_user_destroy(struct pfn_reader_user * user,struct iopt_pages * pages)784 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
785 				    struct iopt_pages *pages)
786 {
787 	if (user->locked != -1) {
788 		if (user->locked)
789 			mmap_read_unlock(pages->source_mm);
790 		if (!user->file && pages->source_mm != current->mm)
791 			mmput(pages->source_mm);
792 		user->locked = -1;
793 	}
794 
795 	kfree(user->upages);
796 	user->upages = NULL;
797 	kfree(user->ufolios);
798 	user->ufolios = NULL;
799 }
800 
pin_memfd_pages(struct pfn_reader_user * user,unsigned long start,unsigned long npages)801 static long pin_memfd_pages(struct pfn_reader_user *user, unsigned long start,
802 			    unsigned long npages)
803 {
804 	unsigned long i;
805 	unsigned long offset;
806 	unsigned long npages_out = 0;
807 	struct page **upages = user->upages;
808 	unsigned long end = start + (npages << PAGE_SHIFT) - 1;
809 	long nfolios = user->ufolios_len / sizeof(*user->ufolios);
810 
811 	/*
812 	 * todo: memfd_pin_folios should return the last pinned offset so
813 	 * we can compute npages pinned, and avoid looping over folios here
814 	 * if upages == NULL.
815 	 */
816 	nfolios = memfd_pin_folios(user->file, start, end, user->ufolios,
817 				   nfolios, &offset);
818 	if (nfolios <= 0)
819 		return nfolios;
820 
821 	offset >>= PAGE_SHIFT;
822 	user->ufolios_next = user->ufolios;
823 	user->ufolios_offset = offset;
824 
825 	for (i = 0; i < nfolios; i++) {
826 		struct folio *folio = user->ufolios[i];
827 		unsigned long nr = folio_nr_pages(folio);
828 		unsigned long npin = min(nr - offset, npages);
829 
830 		npages -= npin;
831 		npages_out += npin;
832 
833 		if (upages) {
834 			if (npin == 1) {
835 				*upages++ = folio_page(folio, offset);
836 			} else {
837 				int rc = folio_add_pins(folio, npin - 1);
838 
839 				if (rc)
840 					return rc;
841 
842 				while (npin--)
843 					*upages++ = folio_page(folio, offset++);
844 			}
845 		}
846 
847 		offset = 0;
848 	}
849 
850 	return npages_out;
851 }
852 
pfn_reader_user_pin(struct pfn_reader_user * user,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)853 static int pfn_reader_user_pin(struct pfn_reader_user *user,
854 			       struct iopt_pages *pages,
855 			       unsigned long start_index,
856 			       unsigned long last_index)
857 {
858 	bool remote_mm = pages->source_mm != current->mm;
859 	unsigned long npages = last_index - start_index + 1;
860 	unsigned long start;
861 	unsigned long unum;
862 	uintptr_t uptr;
863 	long rc;
864 
865 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
866 	    WARN_ON(last_index < start_index))
867 		return -EINVAL;
868 
869 	if (!user->file && !user->upages) {
870 		/* All undone in pfn_reader_destroy() */
871 		user->upages_len = npages * sizeof(*user->upages);
872 		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
873 		if (!user->upages)
874 			return -ENOMEM;
875 	}
876 
877 	if (user->file && !user->ufolios) {
878 		user->ufolios_len = npages * sizeof(*user->ufolios);
879 		user->ufolios = temp_kmalloc(&user->ufolios_len, NULL, 0);
880 		if (!user->ufolios)
881 			return -ENOMEM;
882 	}
883 
884 	if (user->locked == -1) {
885 		/*
886 		 * The majority of usages will run the map task within the mm
887 		 * providing the pages, so we can optimize into
888 		 * get_user_pages_fast()
889 		 */
890 		if (!user->file && remote_mm) {
891 			if (!mmget_not_zero(pages->source_mm))
892 				return -EFAULT;
893 		}
894 		user->locked = 0;
895 	}
896 
897 	unum = user->file ? user->ufolios_len / sizeof(*user->ufolios) :
898 			    user->upages_len / sizeof(*user->upages);
899 	npages = min_t(unsigned long, npages, unum);
900 
901 	if (iommufd_should_fail())
902 		return -EFAULT;
903 
904 	if (user->file) {
905 		start = pages->start + (start_index * PAGE_SIZE);
906 		rc = pin_memfd_pages(user, start, npages);
907 	} else if (!remote_mm) {
908 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
909 		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
910 					 user->upages);
911 	} else {
912 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
913 		if (!user->locked) {
914 			mmap_read_lock(pages->source_mm);
915 			user->locked = 1;
916 		}
917 		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
918 					   user->gup_flags, user->upages,
919 					   &user->locked);
920 	}
921 	if (rc <= 0) {
922 		if (WARN_ON(!rc))
923 			return -EFAULT;
924 		return rc;
925 	}
926 	iopt_pages_add_npinned(pages, rc);
927 	user->upages_start = start_index;
928 	user->upages_end = start_index + rc;
929 	return 0;
930 }
931 
932 /* This is the "modern" and faster accounting method used by io_uring */
incr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)933 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
934 {
935 	unsigned long lock_limit;
936 	unsigned long cur_pages;
937 	unsigned long new_pages;
938 
939 	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
940 		     PAGE_SHIFT;
941 
942 	cur_pages = atomic_long_read(&pages->source_user->locked_vm);
943 	do {
944 		new_pages = cur_pages + npages;
945 		if (new_pages > lock_limit)
946 			return -ENOMEM;
947 	} while (!atomic_long_try_cmpxchg(&pages->source_user->locked_vm,
948 					  &cur_pages, new_pages));
949 	return 0;
950 }
951 
decr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)952 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
953 {
954 	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
955 		return;
956 	atomic_long_sub(npages, &pages->source_user->locked_vm);
957 }
958 
959 /* This is the accounting method used for compatibility with VFIO */
update_mm_locked_vm(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)960 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
961 			       bool inc, struct pfn_reader_user *user)
962 {
963 	bool do_put = false;
964 	int rc;
965 
966 	if (user && user->locked) {
967 		mmap_read_unlock(pages->source_mm);
968 		user->locked = 0;
969 		/* If we had the lock then we also have a get */
970 
971 	} else if ((!user || (!user->upages && !user->ufolios)) &&
972 		   pages->source_mm != current->mm) {
973 		if (!mmget_not_zero(pages->source_mm))
974 			return -EINVAL;
975 		do_put = true;
976 	}
977 
978 	mmap_write_lock(pages->source_mm);
979 	rc = __account_locked_vm(pages->source_mm, npages, inc,
980 				 pages->source_task, false);
981 	mmap_write_unlock(pages->source_mm);
982 
983 	if (do_put)
984 		mmput(pages->source_mm);
985 	return rc;
986 }
987 
iopt_pages_update_pinned(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)988 int iopt_pages_update_pinned(struct iopt_pages *pages, unsigned long npages,
989 			     bool inc, struct pfn_reader_user *user)
990 {
991 	int rc = 0;
992 
993 	switch (pages->account_mode) {
994 	case IOPT_PAGES_ACCOUNT_NONE:
995 		break;
996 	case IOPT_PAGES_ACCOUNT_USER:
997 		if (inc)
998 			rc = incr_user_locked_vm(pages, npages);
999 		else
1000 			decr_user_locked_vm(pages, npages);
1001 		break;
1002 	case IOPT_PAGES_ACCOUNT_MM:
1003 		rc = update_mm_locked_vm(pages, npages, inc, user);
1004 		break;
1005 	}
1006 	if (rc)
1007 		return rc;
1008 
1009 	pages->last_npinned = pages->npinned;
1010 	if (inc)
1011 		atomic64_add(npages, &pages->source_mm->pinned_vm);
1012 	else
1013 		atomic64_sub(npages, &pages->source_mm->pinned_vm);
1014 	return 0;
1015 }
1016 
update_unpinned(struct iopt_pages * pages)1017 static void update_unpinned(struct iopt_pages *pages)
1018 {
1019 	if (WARN_ON(pages->npinned > pages->last_npinned))
1020 		return;
1021 	if (pages->npinned == pages->last_npinned)
1022 		return;
1023 	iopt_pages_update_pinned(pages, pages->last_npinned - pages->npinned,
1024 				 false, NULL);
1025 }
1026 
1027 /*
1028  * Changes in the number of pages pinned is done after the pages have been read
1029  * and processed. If the user lacked the limit then the error unwind will unpin
1030  * everything that was just pinned. This is because it is expensive to calculate
1031  * how many pages we have already pinned within a range to generate an accurate
1032  * prediction in advance of doing the work to actually pin them.
1033  */
pfn_reader_user_update_pinned(struct pfn_reader_user * user,struct iopt_pages * pages)1034 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
1035 					 struct iopt_pages *pages)
1036 {
1037 	unsigned long npages;
1038 	bool inc;
1039 
1040 	lockdep_assert_held(&pages->mutex);
1041 
1042 	if (pages->npinned == pages->last_npinned)
1043 		return 0;
1044 
1045 	if (pages->npinned < pages->last_npinned) {
1046 		npages = pages->last_npinned - pages->npinned;
1047 		inc = false;
1048 	} else {
1049 		if (iommufd_should_fail())
1050 			return -ENOMEM;
1051 		npages = pages->npinned - pages->last_npinned;
1052 		inc = true;
1053 	}
1054 	return iopt_pages_update_pinned(pages, npages, inc, user);
1055 }
1056 
1057 /*
1058  * PFNs are stored in three places, in order of preference:
1059  * - The iopt_pages xarray. This is only populated if there is a
1060  *   iopt_pages_access
1061  * - The iommu_domain under an area
1062  * - The original PFN source, ie pages->source_mm
1063  *
1064  * This iterator reads the pfns optimizing to load according to the
1065  * above order.
1066  */
1067 struct pfn_reader {
1068 	struct iopt_pages *pages;
1069 	struct interval_tree_double_span_iter span;
1070 	struct pfn_batch batch;
1071 	unsigned long batch_start_index;
1072 	unsigned long batch_end_index;
1073 	unsigned long last_index;
1074 
1075 	struct pfn_reader_user user;
1076 };
1077 
pfn_reader_update_pinned(struct pfn_reader * pfns)1078 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
1079 {
1080 	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
1081 }
1082 
1083 /*
1084  * The batch can contain a mixture of pages that are still in use and pages that
1085  * need to be unpinned. Unpin only pages that are not held anywhere else.
1086  */
pfn_reader_unpin(struct pfn_reader * pfns)1087 static void pfn_reader_unpin(struct pfn_reader *pfns)
1088 {
1089 	unsigned long last = pfns->batch_end_index - 1;
1090 	unsigned long start = pfns->batch_start_index;
1091 	struct interval_tree_double_span_iter span;
1092 	struct iopt_pages *pages = pfns->pages;
1093 
1094 	lockdep_assert_held(&pages->mutex);
1095 
1096 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1097 					   &pages->domains_itree, start, last) {
1098 		if (span.is_used)
1099 			continue;
1100 
1101 		batch_unpin(&pfns->batch, pages, span.start_hole - start,
1102 			    span.last_hole - span.start_hole + 1);
1103 	}
1104 }
1105 
1106 /* Process a single span to load it from the proper storage */
pfn_reader_fill_span(struct pfn_reader * pfns)1107 static int pfn_reader_fill_span(struct pfn_reader *pfns)
1108 {
1109 	struct interval_tree_double_span_iter *span = &pfns->span;
1110 	unsigned long start_index = pfns->batch_end_index;
1111 	struct pfn_reader_user *user = &pfns->user;
1112 	unsigned long npages;
1113 	struct iopt_area *area;
1114 	int rc;
1115 
1116 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1117 	    WARN_ON(span->last_used < start_index))
1118 		return -EINVAL;
1119 
1120 	if (span->is_used == 1) {
1121 		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
1122 				  start_index, span->last_used);
1123 		return 0;
1124 	}
1125 
1126 	if (span->is_used == 2) {
1127 		/*
1128 		 * Pull as many pages from the first domain we find in the
1129 		 * target span. If it is too small then we will be called again
1130 		 * and we'll find another area.
1131 		 */
1132 		area = iopt_pages_find_domain_area(pfns->pages, start_index);
1133 		if (WARN_ON(!area))
1134 			return -EINVAL;
1135 
1136 		/* The storage_domain cannot change without the pages mutex */
1137 		batch_from_domain(
1138 			&pfns->batch, area->storage_domain, area, start_index,
1139 			min(iopt_area_last_index(area), span->last_used));
1140 		return 0;
1141 	}
1142 
1143 	if (start_index >= pfns->user.upages_end) {
1144 		rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1145 					 span->last_hole);
1146 		if (rc)
1147 			return rc;
1148 	}
1149 
1150 	npages = user->upages_end - start_index;
1151 	start_index -= user->upages_start;
1152 	rc = 0;
1153 
1154 	if (!user->file)
1155 		batch_from_pages(&pfns->batch, user->upages + start_index,
1156 				 npages);
1157 	else
1158 		rc = batch_from_folios(&pfns->batch, &user->ufolios_next,
1159 				       &user->ufolios_offset, npages);
1160 	return rc;
1161 }
1162 
pfn_reader_done(struct pfn_reader * pfns)1163 static bool pfn_reader_done(struct pfn_reader *pfns)
1164 {
1165 	return pfns->batch_start_index == pfns->last_index + 1;
1166 }
1167 
pfn_reader_next(struct pfn_reader * pfns)1168 static int pfn_reader_next(struct pfn_reader *pfns)
1169 {
1170 	int rc;
1171 
1172 	batch_clear(&pfns->batch);
1173 	pfns->batch_start_index = pfns->batch_end_index;
1174 
1175 	while (pfns->batch_end_index != pfns->last_index + 1) {
1176 		unsigned int npfns = pfns->batch.total_pfns;
1177 
1178 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1179 		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1180 			return -EINVAL;
1181 
1182 		rc = pfn_reader_fill_span(pfns);
1183 		if (rc)
1184 			return rc;
1185 
1186 		if (WARN_ON(!pfns->batch.total_pfns))
1187 			return -EINVAL;
1188 
1189 		pfns->batch_end_index =
1190 			pfns->batch_start_index + pfns->batch.total_pfns;
1191 		if (pfns->batch_end_index == pfns->span.last_used + 1)
1192 			interval_tree_double_span_iter_next(&pfns->span);
1193 
1194 		/* Batch is full */
1195 		if (npfns == pfns->batch.total_pfns)
1196 			return 0;
1197 	}
1198 	return 0;
1199 }
1200 
pfn_reader_init(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1201 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1202 			   unsigned long start_index, unsigned long last_index)
1203 {
1204 	int rc;
1205 
1206 	lockdep_assert_held(&pages->mutex);
1207 
1208 	pfns->pages = pages;
1209 	pfns->batch_start_index = start_index;
1210 	pfns->batch_end_index = start_index;
1211 	pfns->last_index = last_index;
1212 	pfn_reader_user_init(&pfns->user, pages);
1213 	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1214 	if (rc)
1215 		return rc;
1216 	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1217 					     &pages->domains_itree, start_index,
1218 					     last_index);
1219 	return 0;
1220 }
1221 
1222 /*
1223  * There are many assertions regarding the state of pages->npinned vs
1224  * pages->last_pinned, for instance something like unmapping a domain must only
1225  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1226  * the pins are updated. This is fine for success flows, but error flows
1227  * sometimes need to release the pins held inside the pfn_reader before going on
1228  * to complete unmapping and releasing pins held in domains.
1229  */
pfn_reader_release_pins(struct pfn_reader * pfns)1230 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1231 {
1232 	struct iopt_pages *pages = pfns->pages;
1233 	struct pfn_reader_user *user = &pfns->user;
1234 
1235 	if (user->upages_end > pfns->batch_end_index) {
1236 		/* Any pages not transferred to the batch are just unpinned */
1237 
1238 		unsigned long npages = user->upages_end - pfns->batch_end_index;
1239 		unsigned long start_index = pfns->batch_end_index -
1240 					    user->upages_start;
1241 
1242 		if (!user->file) {
1243 			unpin_user_pages(user->upages + start_index, npages);
1244 		} else {
1245 			long n = user->ufolios_len / sizeof(*user->ufolios);
1246 
1247 			unpin_folios(user->ufolios_next,
1248 				     user->ufolios + n - user->ufolios_next);
1249 		}
1250 		iopt_pages_sub_npinned(pages, npages);
1251 		user->upages_end = pfns->batch_end_index;
1252 	}
1253 	if (pfns->batch_start_index != pfns->batch_end_index) {
1254 		pfn_reader_unpin(pfns);
1255 		pfns->batch_start_index = pfns->batch_end_index;
1256 	}
1257 }
1258 
pfn_reader_destroy(struct pfn_reader * pfns)1259 static void pfn_reader_destroy(struct pfn_reader *pfns)
1260 {
1261 	struct iopt_pages *pages = pfns->pages;
1262 
1263 	pfn_reader_release_pins(pfns);
1264 	pfn_reader_user_destroy(&pfns->user, pfns->pages);
1265 	batch_destroy(&pfns->batch, NULL);
1266 	WARN_ON(pages->last_npinned != pages->npinned);
1267 }
1268 
pfn_reader_first(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1269 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1270 			    unsigned long start_index, unsigned long last_index)
1271 {
1272 	int rc;
1273 
1274 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1275 	    WARN_ON(last_index < start_index))
1276 		return -EINVAL;
1277 
1278 	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1279 	if (rc)
1280 		return rc;
1281 	rc = pfn_reader_next(pfns);
1282 	if (rc) {
1283 		pfn_reader_destroy(pfns);
1284 		return rc;
1285 	}
1286 	return 0;
1287 }
1288 
iopt_alloc_pages(unsigned long start_byte,unsigned long length,bool writable)1289 static struct iopt_pages *iopt_alloc_pages(unsigned long start_byte,
1290 					   unsigned long length,
1291 					   bool writable)
1292 {
1293 	struct iopt_pages *pages;
1294 
1295 	/*
1296 	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1297 	 * below from overflow
1298 	 */
1299 	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1300 		return ERR_PTR(-EINVAL);
1301 
1302 	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1303 	if (!pages)
1304 		return ERR_PTR(-ENOMEM);
1305 
1306 	kref_init(&pages->kref);
1307 	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1308 	mutex_init(&pages->mutex);
1309 	pages->source_mm = current->mm;
1310 	mmgrab(pages->source_mm);
1311 	pages->npages = DIV_ROUND_UP(length + start_byte, PAGE_SIZE);
1312 	pages->access_itree = RB_ROOT_CACHED;
1313 	pages->domains_itree = RB_ROOT_CACHED;
1314 	pages->writable = writable;
1315 	if (capable(CAP_IPC_LOCK))
1316 		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1317 	else
1318 		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1319 	pages->source_task = current->group_leader;
1320 	get_task_struct(current->group_leader);
1321 	pages->source_user = get_uid(current_user());
1322 	return pages;
1323 }
1324 
iopt_alloc_user_pages(void __user * uptr,unsigned long length,bool writable)1325 struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
1326 					 unsigned long length, bool writable)
1327 {
1328 	struct iopt_pages *pages;
1329 	unsigned long end;
1330 	void __user *uptr_down =
1331 		(void __user *) ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1332 
1333 	if (check_add_overflow((unsigned long)uptr, length, &end))
1334 		return ERR_PTR(-EOVERFLOW);
1335 
1336 	pages = iopt_alloc_pages(uptr - uptr_down, length, writable);
1337 	if (IS_ERR(pages))
1338 		return pages;
1339 	pages->uptr = uptr_down;
1340 	pages->type = IOPT_ADDRESS_USER;
1341 	return pages;
1342 }
1343 
iopt_alloc_file_pages(struct file * file,unsigned long start,unsigned long length,bool writable)1344 struct iopt_pages *iopt_alloc_file_pages(struct file *file, unsigned long start,
1345 					 unsigned long length, bool writable)
1346 
1347 {
1348 	struct iopt_pages *pages;
1349 	unsigned long start_down = ALIGN_DOWN(start, PAGE_SIZE);
1350 	unsigned long end;
1351 
1352 	if (length && check_add_overflow(start, length - 1, &end))
1353 		return ERR_PTR(-EOVERFLOW);
1354 
1355 	pages = iopt_alloc_pages(start - start_down, length, writable);
1356 	if (IS_ERR(pages))
1357 		return pages;
1358 	pages->file = get_file(file);
1359 	pages->start = start_down;
1360 	pages->type = IOPT_ADDRESS_FILE;
1361 	return pages;
1362 }
1363 
iopt_release_pages(struct kref * kref)1364 void iopt_release_pages(struct kref *kref)
1365 {
1366 	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1367 
1368 	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1369 	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1370 	WARN_ON(pages->npinned);
1371 	WARN_ON(!xa_empty(&pages->pinned_pfns));
1372 	mmdrop(pages->source_mm);
1373 	mutex_destroy(&pages->mutex);
1374 	put_task_struct(pages->source_task);
1375 	free_uid(pages->source_user);
1376 	if (pages->type == IOPT_ADDRESS_FILE)
1377 		fput(pages->file);
1378 	kfree(pages);
1379 }
1380 
1381 static void
iopt_area_unpin_domain(struct pfn_batch * batch,struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index,unsigned long * unmapped_end_index,unsigned long real_last_index)1382 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1383 		       struct iopt_pages *pages, struct iommu_domain *domain,
1384 		       unsigned long start_index, unsigned long last_index,
1385 		       unsigned long *unmapped_end_index,
1386 		       unsigned long real_last_index)
1387 {
1388 	while (start_index <= last_index) {
1389 		unsigned long batch_last_index;
1390 
1391 		if (*unmapped_end_index <= last_index) {
1392 			unsigned long start =
1393 				max(start_index, *unmapped_end_index);
1394 
1395 			if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1396 			    batch->total_pfns)
1397 				WARN_ON(*unmapped_end_index -
1398 						batch->total_pfns !=
1399 					start_index);
1400 			batch_from_domain(batch, domain, area, start,
1401 					  last_index);
1402 			batch_last_index = start_index + batch->total_pfns - 1;
1403 		} else {
1404 			batch_last_index = last_index;
1405 		}
1406 
1407 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1408 			WARN_ON(batch_last_index > real_last_index);
1409 
1410 		/*
1411 		 * unmaps must always 'cut' at a place where the pfns are not
1412 		 * contiguous to pair with the maps that always install
1413 		 * contiguous pages. Thus, if we have to stop unpinning in the
1414 		 * middle of the domains we need to keep reading pfns until we
1415 		 * find a cut point to do the unmap. The pfns we read are
1416 		 * carried over and either skipped or integrated into the next
1417 		 * batch.
1418 		 */
1419 		if (batch_last_index == last_index &&
1420 		    last_index != real_last_index)
1421 			batch_from_domain_continue(batch, domain, area,
1422 						   last_index + 1,
1423 						   real_last_index);
1424 
1425 		if (*unmapped_end_index <= batch_last_index) {
1426 			iopt_area_unmap_domain_range(
1427 				area, domain, *unmapped_end_index,
1428 				start_index + batch->total_pfns - 1);
1429 			*unmapped_end_index = start_index + batch->total_pfns;
1430 		}
1431 
1432 		/* unpin must follow unmap */
1433 		batch_unpin(batch, pages, 0,
1434 			    batch_last_index - start_index + 1);
1435 		start_index = batch_last_index + 1;
1436 
1437 		batch_clear_carry(batch,
1438 				  *unmapped_end_index - batch_last_index - 1);
1439 	}
1440 }
1441 
__iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long last_index)1442 static void __iopt_area_unfill_domain(struct iopt_area *area,
1443 				      struct iopt_pages *pages,
1444 				      struct iommu_domain *domain,
1445 				      unsigned long last_index)
1446 {
1447 	struct interval_tree_double_span_iter span;
1448 	unsigned long start_index = iopt_area_index(area);
1449 	unsigned long unmapped_end_index = start_index;
1450 	u64 backup[BATCH_BACKUP_SIZE];
1451 	struct pfn_batch batch;
1452 
1453 	lockdep_assert_held(&pages->mutex);
1454 
1455 	/*
1456 	 * For security we must not unpin something that is still DMA mapped,
1457 	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1458 	 * This creates a complexity where we need to skip over unpinning pages
1459 	 * held in the xarray, but continue to unmap from the domain.
1460 	 *
1461 	 * The domain unmap cannot stop in the middle of a contiguous range of
1462 	 * PFNs. To solve this problem the unpinning step will read ahead to the
1463 	 * end of any contiguous span, unmap that whole span, and then only
1464 	 * unpin the leading part that does not have any accesses. The residual
1465 	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1466 	 * batch as they are moved to the front of the PFN list and continue on
1467 	 * to the next iteration(s).
1468 	 */
1469 	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1470 	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1471 					   &pages->access_itree, start_index,
1472 					   last_index) {
1473 		if (span.is_used) {
1474 			batch_skip_carry(&batch,
1475 					 span.last_used - span.start_used + 1);
1476 			continue;
1477 		}
1478 		iopt_area_unpin_domain(&batch, area, pages, domain,
1479 				       span.start_hole, span.last_hole,
1480 				       &unmapped_end_index, last_index);
1481 	}
1482 	/*
1483 	 * If the range ends in a access then we do the residual unmap without
1484 	 * any unpins.
1485 	 */
1486 	if (unmapped_end_index != last_index + 1)
1487 		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1488 					     last_index);
1489 	WARN_ON(batch.total_pfns);
1490 	batch_destroy(&batch, backup);
1491 	update_unpinned(pages);
1492 }
1493 
iopt_area_unfill_partial_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long end_index)1494 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1495 					    struct iopt_pages *pages,
1496 					    struct iommu_domain *domain,
1497 					    unsigned long end_index)
1498 {
1499 	if (end_index != iopt_area_index(area))
1500 		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1501 }
1502 
1503 /**
1504  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1505  * @area: The IOVA range to unmap
1506  * @domain: The domain to unmap
1507  *
1508  * The caller must know that unpinning is not required, usually because there
1509  * are other domains in the iopt.
1510  */
iopt_area_unmap_domain(struct iopt_area * area,struct iommu_domain * domain)1511 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1512 {
1513 	iommu_unmap_nofail(domain, iopt_area_iova(area),
1514 			   iopt_area_length(area));
1515 }
1516 
1517 /**
1518  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1519  * @area: IOVA area to use
1520  * @pages: page supplier for the area (area->pages is NULL)
1521  * @domain: Domain to unmap from
1522  *
1523  * The domain should be removed from the domains_itree before calling. The
1524  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1525  * still accesses.
1526  */
iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain)1527 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1528 			     struct iommu_domain *domain)
1529 {
1530 	__iopt_area_unfill_domain(area, pages, domain,
1531 				  iopt_area_last_index(area));
1532 }
1533 
1534 /**
1535  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1536  * @area: IOVA area to use
1537  * @domain: Domain to load PFNs into
1538  *
1539  * Read the pfns from the area's underlying iopt_pages and map them into the
1540  * given domain. Called when attaching a new domain to an io_pagetable.
1541  */
iopt_area_fill_domain(struct iopt_area * area,struct iommu_domain * domain)1542 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1543 {
1544 	unsigned long done_end_index;
1545 	struct pfn_reader pfns;
1546 	int rc;
1547 
1548 	lockdep_assert_held(&area->pages->mutex);
1549 
1550 	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1551 			      iopt_area_last_index(area));
1552 	if (rc)
1553 		return rc;
1554 
1555 	while (!pfn_reader_done(&pfns)) {
1556 		done_end_index = pfns.batch_start_index;
1557 		rc = batch_to_domain(&pfns.batch, domain, area,
1558 				     pfns.batch_start_index);
1559 		if (rc)
1560 			goto out_unmap;
1561 		done_end_index = pfns.batch_end_index;
1562 
1563 		rc = pfn_reader_next(&pfns);
1564 		if (rc)
1565 			goto out_unmap;
1566 	}
1567 
1568 	rc = pfn_reader_update_pinned(&pfns);
1569 	if (rc)
1570 		goto out_unmap;
1571 	goto out_destroy;
1572 
1573 out_unmap:
1574 	pfn_reader_release_pins(&pfns);
1575 	iopt_area_unfill_partial_domain(area, area->pages, domain,
1576 					done_end_index);
1577 out_destroy:
1578 	pfn_reader_destroy(&pfns);
1579 	return rc;
1580 }
1581 
1582 /**
1583  * iopt_area_fill_domains() - Install PFNs into the area's domains
1584  * @area: The area to act on
1585  * @pages: The pages associated with the area (area->pages is NULL)
1586  *
1587  * Called during area creation. The area is freshly created and not inserted in
1588  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1589  * area's io_pagetable and the area is installed in the domains_itree.
1590  *
1591  * On failure all domains are left unchanged.
1592  */
iopt_area_fill_domains(struct iopt_area * area,struct iopt_pages * pages)1593 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1594 {
1595 	unsigned long done_first_end_index;
1596 	unsigned long done_all_end_index;
1597 	struct iommu_domain *domain;
1598 	unsigned long unmap_index;
1599 	struct pfn_reader pfns;
1600 	unsigned long index;
1601 	int rc;
1602 
1603 	lockdep_assert_held(&area->iopt->domains_rwsem);
1604 
1605 	if (xa_empty(&area->iopt->domains))
1606 		return 0;
1607 
1608 	mutex_lock(&pages->mutex);
1609 	rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1610 			      iopt_area_last_index(area));
1611 	if (rc)
1612 		goto out_unlock;
1613 
1614 	while (!pfn_reader_done(&pfns)) {
1615 		done_first_end_index = pfns.batch_end_index;
1616 		done_all_end_index = pfns.batch_start_index;
1617 		xa_for_each(&area->iopt->domains, index, domain) {
1618 			rc = batch_to_domain(&pfns.batch, domain, area,
1619 					     pfns.batch_start_index);
1620 			if (rc)
1621 				goto out_unmap;
1622 		}
1623 		done_all_end_index = done_first_end_index;
1624 
1625 		rc = pfn_reader_next(&pfns);
1626 		if (rc)
1627 			goto out_unmap;
1628 	}
1629 	rc = pfn_reader_update_pinned(&pfns);
1630 	if (rc)
1631 		goto out_unmap;
1632 
1633 	area->storage_domain = xa_load(&area->iopt->domains, 0);
1634 	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1635 	goto out_destroy;
1636 
1637 out_unmap:
1638 	pfn_reader_release_pins(&pfns);
1639 	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1640 		unsigned long end_index;
1641 
1642 		if (unmap_index < index)
1643 			end_index = done_first_end_index;
1644 		else
1645 			end_index = done_all_end_index;
1646 
1647 		/*
1648 		 * The area is not yet part of the domains_itree so we have to
1649 		 * manage the unpinning specially. The last domain does the
1650 		 * unpin, every other domain is just unmapped.
1651 		 */
1652 		if (unmap_index != area->iopt->next_domain_id - 1) {
1653 			if (end_index != iopt_area_index(area))
1654 				iopt_area_unmap_domain_range(
1655 					area, domain, iopt_area_index(area),
1656 					end_index - 1);
1657 		} else {
1658 			iopt_area_unfill_partial_domain(area, pages, domain,
1659 							end_index);
1660 		}
1661 	}
1662 out_destroy:
1663 	pfn_reader_destroy(&pfns);
1664 out_unlock:
1665 	mutex_unlock(&pages->mutex);
1666 	return rc;
1667 }
1668 
1669 /**
1670  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1671  * @area: The area to act on
1672  * @pages: The pages associated with the area (area->pages is NULL)
1673  *
1674  * Called during area destruction. This unmaps the iova's covered by all the
1675  * area's domains and releases the PFNs.
1676  */
iopt_area_unfill_domains(struct iopt_area * area,struct iopt_pages * pages)1677 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1678 {
1679 	struct io_pagetable *iopt = area->iopt;
1680 	struct iommu_domain *domain;
1681 	unsigned long index;
1682 
1683 	lockdep_assert_held(&iopt->domains_rwsem);
1684 
1685 	mutex_lock(&pages->mutex);
1686 	if (!area->storage_domain)
1687 		goto out_unlock;
1688 
1689 	xa_for_each(&iopt->domains, index, domain)
1690 		if (domain != area->storage_domain)
1691 			iopt_area_unmap_domain_range(
1692 				area, domain, iopt_area_index(area),
1693 				iopt_area_last_index(area));
1694 
1695 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1696 		WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
1697 	interval_tree_remove(&area->pages_node, &pages->domains_itree);
1698 	iopt_area_unfill_domain(area, pages, area->storage_domain);
1699 	area->storage_domain = NULL;
1700 out_unlock:
1701 	mutex_unlock(&pages->mutex);
1702 }
1703 
iopt_pages_unpin_xarray(struct pfn_batch * batch,struct iopt_pages * pages,unsigned long start_index,unsigned long end_index)1704 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1705 				    struct iopt_pages *pages,
1706 				    unsigned long start_index,
1707 				    unsigned long end_index)
1708 {
1709 	while (start_index <= end_index) {
1710 		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1711 					end_index);
1712 		batch_unpin(batch, pages, 0, batch->total_pfns);
1713 		start_index += batch->total_pfns;
1714 		batch_clear(batch);
1715 	}
1716 }
1717 
1718 /**
1719  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1720  * @pages: The pages to act on
1721  * @start_index: Starting PFN index
1722  * @last_index: Last PFN index
1723  *
1724  * Called when an iopt_pages_access is removed, removes pages from the itree.
1725  * The access should already be removed from the access_itree.
1726  */
iopt_pages_unfill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1727 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1728 			      unsigned long start_index,
1729 			      unsigned long last_index)
1730 {
1731 	struct interval_tree_double_span_iter span;
1732 	u64 backup[BATCH_BACKUP_SIZE];
1733 	struct pfn_batch batch;
1734 	bool batch_inited = false;
1735 
1736 	lockdep_assert_held(&pages->mutex);
1737 
1738 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1739 					   &pages->domains_itree, start_index,
1740 					   last_index) {
1741 		if (!span.is_used) {
1742 			if (!batch_inited) {
1743 				batch_init_backup(&batch,
1744 						  last_index - start_index + 1,
1745 						  backup, sizeof(backup));
1746 				batch_inited = true;
1747 			}
1748 			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1749 						span.last_hole);
1750 		} else if (span.is_used == 2) {
1751 			/* Covered by a domain */
1752 			clear_xarray(&pages->pinned_pfns, span.start_used,
1753 				     span.last_used);
1754 		}
1755 		/* Otherwise covered by an existing access */
1756 	}
1757 	if (batch_inited)
1758 		batch_destroy(&batch, backup);
1759 	update_unpinned(pages);
1760 }
1761 
1762 /**
1763  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1764  * @pages: The pages to act on
1765  * @start_index: The first page index in the range
1766  * @last_index: The last page index in the range
1767  * @out_pages: The output array to return the pages
1768  *
1769  * This can be called if the caller is holding a refcount on an
1770  * iopt_pages_access that is known to have already been filled. It quickly reads
1771  * the pages directly from the xarray.
1772  *
1773  * This is part of the SW iommu interface to read pages for in-kernel use.
1774  */
iopt_pages_fill_from_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1775 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1776 				 unsigned long start_index,
1777 				 unsigned long last_index,
1778 				 struct page **out_pages)
1779 {
1780 	XA_STATE(xas, &pages->pinned_pfns, start_index);
1781 	void *entry;
1782 
1783 	rcu_read_lock();
1784 	while (start_index <= last_index) {
1785 		entry = xas_next(&xas);
1786 		if (xas_retry(&xas, entry))
1787 			continue;
1788 		WARN_ON(!xa_is_value(entry));
1789 		*(out_pages++) = pfn_to_page(xa_to_value(entry));
1790 		start_index++;
1791 	}
1792 	rcu_read_unlock();
1793 }
1794 
iopt_pages_fill_from_domain(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1795 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1796 				       unsigned long start_index,
1797 				       unsigned long last_index,
1798 				       struct page **out_pages)
1799 {
1800 	while (start_index != last_index + 1) {
1801 		unsigned long domain_last;
1802 		struct iopt_area *area;
1803 
1804 		area = iopt_pages_find_domain_area(pages, start_index);
1805 		if (WARN_ON(!area))
1806 			return -EINVAL;
1807 
1808 		domain_last = min(iopt_area_last_index(area), last_index);
1809 		out_pages = raw_pages_from_domain(area->storage_domain, area,
1810 						  start_index, domain_last,
1811 						  out_pages);
1812 		start_index = domain_last + 1;
1813 	}
1814 	return 0;
1815 }
1816 
iopt_pages_fill(struct iopt_pages * pages,struct pfn_reader_user * user,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1817 static int iopt_pages_fill(struct iopt_pages *pages,
1818 			   struct pfn_reader_user *user,
1819 			   unsigned long start_index,
1820 			   unsigned long last_index,
1821 			   struct page **out_pages)
1822 {
1823 	unsigned long cur_index = start_index;
1824 	int rc;
1825 
1826 	while (cur_index != last_index + 1) {
1827 		user->upages = out_pages + (cur_index - start_index);
1828 		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1829 		if (rc)
1830 			goto out_unpin;
1831 		cur_index = user->upages_end;
1832 	}
1833 	return 0;
1834 
1835 out_unpin:
1836 	if (start_index != cur_index)
1837 		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1838 				     out_pages);
1839 	return rc;
1840 }
1841 
1842 /**
1843  * iopt_pages_fill_xarray() - Read PFNs
1844  * @pages: The pages to act on
1845  * @start_index: The first page index in the range
1846  * @last_index: The last page index in the range
1847  * @out_pages: The output array to return the pages, may be NULL
1848  *
1849  * This populates the xarray and returns the pages in out_pages. As the slow
1850  * path this is able to copy pages from other storage tiers into the xarray.
1851  *
1852  * On failure the xarray is left unchanged.
1853  *
1854  * This is part of the SW iommu interface to read pages for in-kernel use.
1855  */
iopt_pages_fill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1856 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1857 			   unsigned long last_index, struct page **out_pages)
1858 {
1859 	struct interval_tree_double_span_iter span;
1860 	unsigned long xa_end = start_index;
1861 	struct pfn_reader_user user;
1862 	int rc;
1863 
1864 	lockdep_assert_held(&pages->mutex);
1865 
1866 	pfn_reader_user_init(&user, pages);
1867 	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1868 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1869 					   &pages->domains_itree, start_index,
1870 					   last_index) {
1871 		struct page **cur_pages;
1872 
1873 		if (span.is_used == 1) {
1874 			cur_pages = out_pages + (span.start_used - start_index);
1875 			iopt_pages_fill_from_xarray(pages, span.start_used,
1876 						    span.last_used, cur_pages);
1877 			continue;
1878 		}
1879 
1880 		if (span.is_used == 2) {
1881 			cur_pages = out_pages + (span.start_used - start_index);
1882 			iopt_pages_fill_from_domain(pages, span.start_used,
1883 						    span.last_used, cur_pages);
1884 			rc = pages_to_xarray(&pages->pinned_pfns,
1885 					     span.start_used, span.last_used,
1886 					     cur_pages);
1887 			if (rc)
1888 				goto out_clean_xa;
1889 			xa_end = span.last_used + 1;
1890 			continue;
1891 		}
1892 
1893 		/* hole */
1894 		cur_pages = out_pages + (span.start_hole - start_index);
1895 		rc = iopt_pages_fill(pages, &user, span.start_hole,
1896 				     span.last_hole, cur_pages);
1897 		if (rc)
1898 			goto out_clean_xa;
1899 		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1900 				     span.last_hole, cur_pages);
1901 		if (rc) {
1902 			iopt_pages_err_unpin(pages, span.start_hole,
1903 					     span.last_hole, cur_pages);
1904 			goto out_clean_xa;
1905 		}
1906 		xa_end = span.last_hole + 1;
1907 	}
1908 	rc = pfn_reader_user_update_pinned(&user, pages);
1909 	if (rc)
1910 		goto out_clean_xa;
1911 	user.upages = NULL;
1912 	pfn_reader_user_destroy(&user, pages);
1913 	return 0;
1914 
1915 out_clean_xa:
1916 	if (start_index != xa_end)
1917 		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1918 	user.upages = NULL;
1919 	pfn_reader_user_destroy(&user, pages);
1920 	return rc;
1921 }
1922 
1923 /*
1924  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1925  * do every scenario and is fully consistent with what an iommu_domain would
1926  * see.
1927  */
iopt_pages_rw_slow(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1928 static int iopt_pages_rw_slow(struct iopt_pages *pages,
1929 			      unsigned long start_index,
1930 			      unsigned long last_index, unsigned long offset,
1931 			      void *data, unsigned long length,
1932 			      unsigned int flags)
1933 {
1934 	struct pfn_reader pfns;
1935 	int rc;
1936 
1937 	mutex_lock(&pages->mutex);
1938 
1939 	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1940 	if (rc)
1941 		goto out_unlock;
1942 
1943 	while (!pfn_reader_done(&pfns)) {
1944 		unsigned long done;
1945 
1946 		done = batch_rw(&pfns.batch, data, offset, length, flags);
1947 		data += done;
1948 		length -= done;
1949 		offset = 0;
1950 		pfn_reader_unpin(&pfns);
1951 
1952 		rc = pfn_reader_next(&pfns);
1953 		if (rc)
1954 			goto out_destroy;
1955 	}
1956 	if (WARN_ON(length != 0))
1957 		rc = -EINVAL;
1958 out_destroy:
1959 	pfn_reader_destroy(&pfns);
1960 out_unlock:
1961 	mutex_unlock(&pages->mutex);
1962 	return rc;
1963 }
1964 
1965 /*
1966  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1967  * memory allocations or interval tree searches.
1968  */
iopt_pages_rw_page(struct iopt_pages * pages,unsigned long index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1969 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1970 			      unsigned long offset, void *data,
1971 			      unsigned long length, unsigned int flags)
1972 {
1973 	struct page *page = NULL;
1974 	int rc;
1975 
1976 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1977 	    WARN_ON(pages->type != IOPT_ADDRESS_USER))
1978 		return -EINVAL;
1979 
1980 	if (!mmget_not_zero(pages->source_mm))
1981 		return iopt_pages_rw_slow(pages, index, index, offset, data,
1982 					  length, flags);
1983 
1984 	if (iommufd_should_fail()) {
1985 		rc = -EINVAL;
1986 		goto out_mmput;
1987 	}
1988 
1989 	mmap_read_lock(pages->source_mm);
1990 	rc = pin_user_pages_remote(
1991 		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1992 		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1993 		NULL);
1994 	mmap_read_unlock(pages->source_mm);
1995 	if (rc != 1) {
1996 		if (WARN_ON(rc >= 0))
1997 			rc = -EINVAL;
1998 		goto out_mmput;
1999 	}
2000 	copy_data_page(page, data, offset, length, flags);
2001 	unpin_user_page(page);
2002 	rc = 0;
2003 
2004 out_mmput:
2005 	mmput(pages->source_mm);
2006 	return rc;
2007 }
2008 
2009 /**
2010  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
2011  * @pages: pages to act on
2012  * @start_byte: First byte of pages to copy to/from
2013  * @data: Kernel buffer to get/put the data
2014  * @length: Number of bytes to copy
2015  * @flags: IOMMUFD_ACCESS_RW_* flags
2016  *
2017  * This will find each page in the range, kmap it and then memcpy to/from
2018  * the given kernel buffer.
2019  */
iopt_pages_rw_access(struct iopt_pages * pages,unsigned long start_byte,void * data,unsigned long length,unsigned int flags)2020 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
2021 			 void *data, unsigned long length, unsigned int flags)
2022 {
2023 	unsigned long start_index = start_byte / PAGE_SIZE;
2024 	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
2025 	bool change_mm = current->mm != pages->source_mm;
2026 	int rc = 0;
2027 
2028 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2029 	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
2030 		change_mm = true;
2031 
2032 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2033 		return -EPERM;
2034 
2035 	if (pages->type == IOPT_ADDRESS_FILE)
2036 		return iopt_pages_rw_slow(pages, start_index, last_index,
2037 					  start_byte % PAGE_SIZE, data, length,
2038 					  flags);
2039 
2040 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2041 	    WARN_ON(pages->type != IOPT_ADDRESS_USER))
2042 		return -EINVAL;
2043 
2044 	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
2045 		if (start_index == last_index)
2046 			return iopt_pages_rw_page(pages, start_index,
2047 						  start_byte % PAGE_SIZE, data,
2048 						  length, flags);
2049 		return iopt_pages_rw_slow(pages, start_index, last_index,
2050 					  start_byte % PAGE_SIZE, data, length,
2051 					  flags);
2052 	}
2053 
2054 	/*
2055 	 * Try to copy using copy_to_user(). We do this as a fast path and
2056 	 * ignore any pinning inconsistencies, unlike a real DMA path.
2057 	 */
2058 	if (change_mm) {
2059 		if (!mmget_not_zero(pages->source_mm))
2060 			return iopt_pages_rw_slow(pages, start_index,
2061 						  last_index,
2062 						  start_byte % PAGE_SIZE, data,
2063 						  length, flags);
2064 		kthread_use_mm(pages->source_mm);
2065 	}
2066 
2067 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
2068 		if (copy_to_user(pages->uptr + start_byte, data, length))
2069 			rc = -EFAULT;
2070 	} else {
2071 		if (copy_from_user(data, pages->uptr + start_byte, length))
2072 			rc = -EFAULT;
2073 	}
2074 
2075 	if (change_mm) {
2076 		kthread_unuse_mm(pages->source_mm);
2077 		mmput(pages->source_mm);
2078 	}
2079 
2080 	return rc;
2081 }
2082 
2083 static struct iopt_pages_access *
iopt_pages_get_exact_access(struct iopt_pages * pages,unsigned long index,unsigned long last)2084 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
2085 			    unsigned long last)
2086 {
2087 	struct interval_tree_node *node;
2088 
2089 	lockdep_assert_held(&pages->mutex);
2090 
2091 	/* There can be overlapping ranges in this interval tree */
2092 	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
2093 	     node; node = interval_tree_iter_next(node, index, last))
2094 		if (node->start == index && node->last == last)
2095 			return container_of(node, struct iopt_pages_access,
2096 					    node);
2097 	return NULL;
2098 }
2099 
2100 /**
2101  * iopt_area_add_access() - Record an in-knerel access for PFNs
2102  * @area: The source of PFNs
2103  * @start_index: First page index
2104  * @last_index: Inclusive last page index
2105  * @out_pages: Output list of struct page's representing the PFNs
2106  * @flags: IOMMUFD_ACCESS_RW_* flags
2107  *
2108  * Record that an in-kernel access will be accessing the pages, ensure they are
2109  * pinned, and return the PFNs as a simple list of 'struct page *'.
2110  *
2111  * This should be undone through a matching call to iopt_area_remove_access()
2112  */
iopt_area_add_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages,unsigned int flags)2113 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
2114 			  unsigned long last_index, struct page **out_pages,
2115 			  unsigned int flags)
2116 {
2117 	struct iopt_pages *pages = area->pages;
2118 	struct iopt_pages_access *access;
2119 	int rc;
2120 
2121 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2122 		return -EPERM;
2123 
2124 	mutex_lock(&pages->mutex);
2125 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2126 	if (access) {
2127 		area->num_accesses++;
2128 		access->users++;
2129 		iopt_pages_fill_from_xarray(pages, start_index, last_index,
2130 					    out_pages);
2131 		mutex_unlock(&pages->mutex);
2132 		return 0;
2133 	}
2134 
2135 	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
2136 	if (!access) {
2137 		rc = -ENOMEM;
2138 		goto err_unlock;
2139 	}
2140 
2141 	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
2142 	if (rc)
2143 		goto err_free;
2144 
2145 	access->node.start = start_index;
2146 	access->node.last = last_index;
2147 	access->users = 1;
2148 	area->num_accesses++;
2149 	interval_tree_insert(&access->node, &pages->access_itree);
2150 	mutex_unlock(&pages->mutex);
2151 	return 0;
2152 
2153 err_free:
2154 	kfree(access);
2155 err_unlock:
2156 	mutex_unlock(&pages->mutex);
2157 	return rc;
2158 }
2159 
2160 /**
2161  * iopt_area_remove_access() - Release an in-kernel access for PFNs
2162  * @area: The source of PFNs
2163  * @start_index: First page index
2164  * @last_index: Inclusive last page index
2165  *
2166  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
2167  * must stop using the PFNs before calling this.
2168  */
iopt_area_remove_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index)2169 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
2170 			     unsigned long last_index)
2171 {
2172 	struct iopt_pages *pages = area->pages;
2173 	struct iopt_pages_access *access;
2174 
2175 	mutex_lock(&pages->mutex);
2176 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2177 	if (WARN_ON(!access))
2178 		goto out_unlock;
2179 
2180 	WARN_ON(area->num_accesses == 0 || access->users == 0);
2181 	area->num_accesses--;
2182 	access->users--;
2183 	if (access->users)
2184 		goto out_unlock;
2185 
2186 	interval_tree_remove(&access->node, &pages->access_itree);
2187 	iopt_pages_unfill_xarray(pages, start_index, last_index);
2188 	kfree(access);
2189 out_unlock:
2190 	mutex_unlock(&pages->mutex);
2191 }
2192