xref: /linux/drivers/iommu/iommufd/pages.c (revision 056daec2925dc200b22c30419bc7b9e01f7843c4)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/dma-buf.h>
49 #include <linux/dma-resv.h>
50 #include <linux/file.h>
51 #include <linux/highmem.h>
52 #include <linux/iommu.h>
53 #include <linux/iommufd.h>
54 #include <linux/kthread.h>
55 #include <linux/overflow.h>
56 #include <linux/slab.h>
57 #include <linux/sched/mm.h>
58 #include <linux/vfio_pci_core.h>
59 
60 #include "double_span.h"
61 #include "io_pagetable.h"
62 
63 #ifndef CONFIG_IOMMUFD_TEST
64 #define TEMP_MEMORY_LIMIT 65536
65 #else
66 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
67 #endif
68 #define BATCH_BACKUP_SIZE 32
69 
70 /*
71  * More memory makes pin_user_pages() and the batching more efficient, but as
72  * this is only a performance optimization don't try too hard to get it. A 64k
73  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
74  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
75  * stack memory as a backup contingency. If backup_len is given this cannot
76  * fail.
77  */
temp_kmalloc(size_t * size,void * backup,size_t backup_len)78 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
79 {
80 	void *res;
81 
82 	if (WARN_ON(*size == 0))
83 		return NULL;
84 
85 	if (*size < backup_len)
86 		return backup;
87 
88 	if (!backup && iommufd_should_fail())
89 		return NULL;
90 
91 	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
92 	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
93 	if (res)
94 		return res;
95 	*size = PAGE_SIZE;
96 	if (backup_len) {
97 		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
98 		if (res)
99 			return res;
100 		*size = backup_len;
101 		return backup;
102 	}
103 	return kmalloc(*size, GFP_KERNEL);
104 }
105 
interval_tree_double_span_iter_update(struct interval_tree_double_span_iter * iter)106 void interval_tree_double_span_iter_update(
107 	struct interval_tree_double_span_iter *iter)
108 {
109 	unsigned long last_hole = ULONG_MAX;
110 	unsigned int i;
111 
112 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
113 		if (interval_tree_span_iter_done(&iter->spans[i])) {
114 			iter->is_used = -1;
115 			return;
116 		}
117 
118 		if (iter->spans[i].is_hole) {
119 			last_hole = min(last_hole, iter->spans[i].last_hole);
120 			continue;
121 		}
122 
123 		iter->is_used = i + 1;
124 		iter->start_used = iter->spans[i].start_used;
125 		iter->last_used = min(iter->spans[i].last_used, last_hole);
126 		return;
127 	}
128 
129 	iter->is_used = 0;
130 	iter->start_hole = iter->spans[0].start_hole;
131 	iter->last_hole =
132 		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
133 }
134 
interval_tree_double_span_iter_first(struct interval_tree_double_span_iter * iter,struct rb_root_cached * itree1,struct rb_root_cached * itree2,unsigned long first_index,unsigned long last_index)135 void interval_tree_double_span_iter_first(
136 	struct interval_tree_double_span_iter *iter,
137 	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
138 	unsigned long first_index, unsigned long last_index)
139 {
140 	unsigned int i;
141 
142 	iter->itrees[0] = itree1;
143 	iter->itrees[1] = itree2;
144 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
145 		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
146 					      first_index, last_index);
147 	interval_tree_double_span_iter_update(iter);
148 }
149 
interval_tree_double_span_iter_next(struct interval_tree_double_span_iter * iter)150 void interval_tree_double_span_iter_next(
151 	struct interval_tree_double_span_iter *iter)
152 {
153 	unsigned int i;
154 
155 	if (iter->is_used == -1 ||
156 	    iter->last_hole == iter->spans[0].last_index) {
157 		iter->is_used = -1;
158 		return;
159 	}
160 
161 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
162 		interval_tree_span_iter_advance(
163 			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
164 	interval_tree_double_span_iter_update(iter);
165 }
166 
iopt_pages_add_npinned(struct iopt_pages * pages,size_t npages)167 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
168 {
169 	int rc;
170 
171 	rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
172 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
173 		WARN_ON(rc || pages->npinned > pages->npages);
174 }
175 
iopt_pages_sub_npinned(struct iopt_pages * pages,size_t npages)176 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
177 {
178 	int rc;
179 
180 	rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
181 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
182 		WARN_ON(rc || pages->npinned > pages->npages);
183 }
184 
iopt_pages_err_unpin(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** page_list)185 static void iopt_pages_err_unpin(struct iopt_pages *pages,
186 				 unsigned long start_index,
187 				 unsigned long last_index,
188 				 struct page **page_list)
189 {
190 	unsigned long npages = last_index - start_index + 1;
191 
192 	unpin_user_pages(page_list, npages);
193 	iopt_pages_sub_npinned(pages, npages);
194 }
195 
196 /*
197  * index is the number of PAGE_SIZE units from the start of the area's
198  * iopt_pages. If the iova is sub page-size then the area has an iova that
199  * covers a portion of the first and last pages in the range.
200  */
iopt_area_index_to_iova(struct iopt_area * area,unsigned long index)201 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
202 					     unsigned long index)
203 {
204 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
205 		WARN_ON(index < iopt_area_index(area) ||
206 			index > iopt_area_last_index(area));
207 	index -= iopt_area_index(area);
208 	if (index == 0)
209 		return iopt_area_iova(area);
210 	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
211 }
212 
iopt_area_index_to_iova_last(struct iopt_area * area,unsigned long index)213 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
214 						  unsigned long index)
215 {
216 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
217 		WARN_ON(index < iopt_area_index(area) ||
218 			index > iopt_area_last_index(area));
219 	if (index == iopt_area_last_index(area))
220 		return iopt_area_last_iova(area);
221 	return iopt_area_iova(area) - area->page_offset +
222 	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
223 }
224 
iommu_unmap_nofail(struct iommu_domain * domain,unsigned long iova,size_t size)225 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
226 			       size_t size)
227 {
228 	size_t ret;
229 
230 	ret = iommu_unmap(domain, iova, size);
231 	/*
232 	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
233 	 * something other than exactly as requested. This implies that the
234 	 * iommu driver may not fail unmap for reasons beyond bad agruments.
235 	 * Particularly, the iommu driver may not do a memory allocation on the
236 	 * unmap path.
237 	 */
238 	WARN_ON(ret != size);
239 }
240 
iopt_area_unmap_domain_range(struct iopt_area * area,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index)241 static void iopt_area_unmap_domain_range(struct iopt_area *area,
242 					 struct iommu_domain *domain,
243 					 unsigned long start_index,
244 					 unsigned long last_index)
245 {
246 	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
247 
248 	iommu_unmap_nofail(domain, start_iova,
249 			   iopt_area_index_to_iova_last(area, last_index) -
250 				   start_iova + 1);
251 }
252 
iopt_pages_find_domain_area(struct iopt_pages * pages,unsigned long index)253 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
254 						     unsigned long index)
255 {
256 	struct interval_tree_node *node;
257 
258 	node = interval_tree_iter_first(&pages->domains_itree, index, index);
259 	if (!node)
260 		return NULL;
261 	return container_of(node, struct iopt_area, pages_node);
262 }
263 
264 enum batch_kind {
265 	BATCH_CPU_MEMORY = 0,
266 	BATCH_MMIO,
267 };
268 
269 /*
270  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
271  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
272  * place to another. Generally everything is made more efficient if operations
273  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
274  * better cache locality, etc
275  */
276 struct pfn_batch {
277 	unsigned long *pfns;
278 	u32 *npfns;
279 	unsigned int array_size;
280 	unsigned int end;
281 	unsigned int total_pfns;
282 	enum batch_kind kind;
283 };
284 enum { MAX_NPFNS = type_max(typeof(((struct pfn_batch *)0)->npfns[0])) };
285 
batch_clear(struct pfn_batch * batch)286 static void batch_clear(struct pfn_batch *batch)
287 {
288 	batch->total_pfns = 0;
289 	batch->end = 0;
290 	batch->pfns[0] = 0;
291 	batch->npfns[0] = 0;
292 }
293 
294 /*
295  * Carry means we carry a portion of the final hugepage over to the front of the
296  * batch
297  */
batch_clear_carry(struct pfn_batch * batch,unsigned int keep_pfns)298 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
299 {
300 	if (!keep_pfns)
301 		return batch_clear(batch);
302 
303 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
304 		WARN_ON(!batch->end ||
305 			batch->npfns[batch->end - 1] < keep_pfns);
306 
307 	batch->total_pfns = keep_pfns;
308 	batch->pfns[0] = batch->pfns[batch->end - 1] +
309 			 (batch->npfns[batch->end - 1] - keep_pfns);
310 	batch->npfns[0] = keep_pfns;
311 	batch->end = 1;
312 }
313 
batch_skip_carry(struct pfn_batch * batch,unsigned int skip_pfns)314 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
315 {
316 	if (!batch->total_pfns)
317 		return;
318 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
319 		WARN_ON(batch->total_pfns != batch->npfns[0]);
320 	skip_pfns = min(batch->total_pfns, skip_pfns);
321 	batch->pfns[0] += skip_pfns;
322 	batch->npfns[0] -= skip_pfns;
323 	batch->total_pfns -= skip_pfns;
324 }
325 
__batch_init(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)326 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
327 			size_t backup_len)
328 {
329 	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
330 	size_t size = max_pages * elmsz;
331 
332 	batch->pfns = temp_kmalloc(&size, backup, backup_len);
333 	if (!batch->pfns)
334 		return -ENOMEM;
335 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
336 		return -EINVAL;
337 	batch->array_size = size / elmsz;
338 	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
339 	batch_clear(batch);
340 	return 0;
341 }
342 
batch_init(struct pfn_batch * batch,size_t max_pages)343 static int batch_init(struct pfn_batch *batch, size_t max_pages)
344 {
345 	return __batch_init(batch, max_pages, NULL, 0);
346 }
347 
batch_init_backup(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)348 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
349 			      void *backup, size_t backup_len)
350 {
351 	__batch_init(batch, max_pages, backup, backup_len);
352 }
353 
batch_destroy(struct pfn_batch * batch,void * backup)354 static void batch_destroy(struct pfn_batch *batch, void *backup)
355 {
356 	if (batch->pfns != backup)
357 		kfree(batch->pfns);
358 }
359 
batch_add_pfn_num(struct pfn_batch * batch,unsigned long pfn,u32 nr,enum batch_kind kind)360 static bool batch_add_pfn_num(struct pfn_batch *batch, unsigned long pfn,
361 			      u32 nr, enum batch_kind kind)
362 {
363 	unsigned int end = batch->end;
364 
365 	if (batch->kind != kind) {
366 		/* One kind per batch */
367 		if (batch->end != 0)
368 			return false;
369 		batch->kind = kind;
370 	}
371 
372 	if (end && pfn == batch->pfns[end - 1] + batch->npfns[end - 1] &&
373 	    nr <= MAX_NPFNS - batch->npfns[end - 1]) {
374 		batch->npfns[end - 1] += nr;
375 	} else if (end < batch->array_size) {
376 		batch->pfns[end] = pfn;
377 		batch->npfns[end] = nr;
378 		batch->end++;
379 	} else {
380 		return false;
381 	}
382 
383 	batch->total_pfns += nr;
384 	return true;
385 }
386 
batch_remove_pfn_num(struct pfn_batch * batch,unsigned long nr)387 static void batch_remove_pfn_num(struct pfn_batch *batch, unsigned long nr)
388 {
389 	batch->npfns[batch->end - 1] -= nr;
390 	if (batch->npfns[batch->end - 1] == 0)
391 		batch->end--;
392 	batch->total_pfns -= nr;
393 }
394 
395 /* true if the pfn was added, false otherwise */
batch_add_pfn(struct pfn_batch * batch,unsigned long pfn)396 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
397 {
398 	return batch_add_pfn_num(batch, pfn, 1, BATCH_CPU_MEMORY);
399 }
400 
401 /*
402  * Fill the batch with pfns from the domain. When the batch is full, or it
403  * reaches last_index, the function will return. The caller should use
404  * batch->total_pfns to determine the starting point for the next iteration.
405  */
batch_from_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)406 static void batch_from_domain(struct pfn_batch *batch,
407 			      struct iommu_domain *domain,
408 			      struct iopt_area *area, unsigned long start_index,
409 			      unsigned long last_index)
410 {
411 	unsigned int page_offset = 0;
412 	unsigned long iova;
413 	phys_addr_t phys;
414 
415 	iova = iopt_area_index_to_iova(area, start_index);
416 	if (start_index == iopt_area_index(area))
417 		page_offset = area->page_offset;
418 	while (start_index <= last_index) {
419 		/*
420 		 * This is pretty slow, it would be nice to get the page size
421 		 * back from the driver, or have the driver directly fill the
422 		 * batch.
423 		 */
424 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
425 		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
426 			return;
427 		iova += PAGE_SIZE - page_offset;
428 		page_offset = 0;
429 		start_index++;
430 	}
431 }
432 
raw_pages_from_domain(struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages)433 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
434 					   struct iopt_area *area,
435 					   unsigned long start_index,
436 					   unsigned long last_index,
437 					   struct page **out_pages)
438 {
439 	unsigned int page_offset = 0;
440 	unsigned long iova;
441 	phys_addr_t phys;
442 
443 	iova = iopt_area_index_to_iova(area, start_index);
444 	if (start_index == iopt_area_index(area))
445 		page_offset = area->page_offset;
446 	while (start_index <= last_index) {
447 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
448 		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
449 		iova += PAGE_SIZE - page_offset;
450 		page_offset = 0;
451 		start_index++;
452 	}
453 	return out_pages;
454 }
455 
456 /* Continues reading a domain until we reach a discontinuity in the pfns. */
batch_from_domain_continue(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)457 static void batch_from_domain_continue(struct pfn_batch *batch,
458 				       struct iommu_domain *domain,
459 				       struct iopt_area *area,
460 				       unsigned long start_index,
461 				       unsigned long last_index)
462 {
463 	unsigned int array_size = batch->array_size;
464 
465 	batch->array_size = batch->end;
466 	batch_from_domain(batch, domain, area, start_index, last_index);
467 	batch->array_size = array_size;
468 }
469 
470 /*
471  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
472  * mode permits splitting a mapped area up, and then one of the splits is
473  * unmapped. Doing this normally would cause us to violate our invariant of
474  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
475  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
476  * PAGE_SIZE units, not larger or smaller.
477  */
batch_iommu_map_small(struct iommu_domain * domain,unsigned long iova,phys_addr_t paddr,size_t size,int prot)478 static int batch_iommu_map_small(struct iommu_domain *domain,
479 				 unsigned long iova, phys_addr_t paddr,
480 				 size_t size, int prot)
481 {
482 	unsigned long start_iova = iova;
483 	int rc;
484 
485 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
486 		WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
487 			size % PAGE_SIZE);
488 
489 	while (size) {
490 		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
491 			       GFP_KERNEL_ACCOUNT);
492 		if (rc)
493 			goto err_unmap;
494 		iova += PAGE_SIZE;
495 		paddr += PAGE_SIZE;
496 		size -= PAGE_SIZE;
497 	}
498 	return 0;
499 
500 err_unmap:
501 	if (start_iova != iova)
502 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
503 	return rc;
504 }
505 
batch_to_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index)506 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
507 			   struct iopt_area *area, unsigned long start_index)
508 {
509 	bool disable_large_pages = area->iopt->disable_large_pages;
510 	unsigned long last_iova = iopt_area_last_iova(area);
511 	int iommu_prot = area->iommu_prot;
512 	unsigned int page_offset = 0;
513 	unsigned long start_iova;
514 	unsigned long next_iova;
515 	unsigned int cur = 0;
516 	unsigned long iova;
517 	int rc;
518 
519 	if (batch->kind == BATCH_MMIO) {
520 		iommu_prot &= ~IOMMU_CACHE;
521 		iommu_prot |= IOMMU_MMIO;
522 	}
523 
524 	/* The first index might be a partial page */
525 	if (start_index == iopt_area_index(area))
526 		page_offset = area->page_offset;
527 	next_iova = iova = start_iova =
528 		iopt_area_index_to_iova(area, start_index);
529 	while (cur < batch->end) {
530 		next_iova = min(last_iova + 1,
531 				next_iova + batch->npfns[cur] * PAGE_SIZE -
532 					page_offset);
533 		if (disable_large_pages)
534 			rc = batch_iommu_map_small(
535 				domain, iova,
536 				PFN_PHYS(batch->pfns[cur]) + page_offset,
537 				next_iova - iova, iommu_prot);
538 		else
539 			rc = iommu_map(domain, iova,
540 				       PFN_PHYS(batch->pfns[cur]) + page_offset,
541 				       next_iova - iova, iommu_prot,
542 				       GFP_KERNEL_ACCOUNT);
543 		if (rc)
544 			goto err_unmap;
545 		iova = next_iova;
546 		page_offset = 0;
547 		cur++;
548 	}
549 	return 0;
550 err_unmap:
551 	if (start_iova != iova)
552 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
553 	return rc;
554 }
555 
batch_from_xarray(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)556 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
557 			      unsigned long start_index,
558 			      unsigned long last_index)
559 {
560 	XA_STATE(xas, xa, start_index);
561 	void *entry;
562 
563 	rcu_read_lock();
564 	while (true) {
565 		entry = xas_next(&xas);
566 		if (xas_retry(&xas, entry))
567 			continue;
568 		WARN_ON(!xa_is_value(entry));
569 		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
570 		    start_index == last_index)
571 			break;
572 		start_index++;
573 	}
574 	rcu_read_unlock();
575 }
576 
batch_from_xarray_clear(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)577 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
578 				    unsigned long start_index,
579 				    unsigned long last_index)
580 {
581 	XA_STATE(xas, xa, start_index);
582 	void *entry;
583 
584 	xas_lock(&xas);
585 	while (true) {
586 		entry = xas_next(&xas);
587 		if (xas_retry(&xas, entry))
588 			continue;
589 		WARN_ON(!xa_is_value(entry));
590 		if (!batch_add_pfn(batch, xa_to_value(entry)))
591 			break;
592 		xas_store(&xas, NULL);
593 		if (start_index == last_index)
594 			break;
595 		start_index++;
596 	}
597 	xas_unlock(&xas);
598 }
599 
clear_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index)600 static void clear_xarray(struct xarray *xa, unsigned long start_index,
601 			 unsigned long last_index)
602 {
603 	XA_STATE(xas, xa, start_index);
604 	void *entry;
605 
606 	xas_lock(&xas);
607 	xas_for_each(&xas, entry, last_index)
608 		xas_store(&xas, NULL);
609 	xas_unlock(&xas);
610 }
611 
pages_to_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index,struct page ** pages)612 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
613 			   unsigned long last_index, struct page **pages)
614 {
615 	struct page **end_pages = pages + (last_index - start_index) + 1;
616 	struct page **half_pages = pages + (end_pages - pages) / 2;
617 	XA_STATE(xas, xa, start_index);
618 
619 	do {
620 		void *old;
621 
622 		xas_lock(&xas);
623 		while (pages != end_pages) {
624 			/* xarray does not participate in fault injection */
625 			if (pages == half_pages && iommufd_should_fail()) {
626 				xas_set_err(&xas, -EINVAL);
627 				xas_unlock(&xas);
628 				/* aka xas_destroy() */
629 				xas_nomem(&xas, GFP_KERNEL);
630 				goto err_clear;
631 			}
632 
633 			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
634 			if (xas_error(&xas))
635 				break;
636 			WARN_ON(old);
637 			pages++;
638 			xas_next(&xas);
639 		}
640 		xas_unlock(&xas);
641 	} while (xas_nomem(&xas, GFP_KERNEL));
642 
643 err_clear:
644 	if (xas_error(&xas)) {
645 		if (xas.xa_index != start_index)
646 			clear_xarray(xa, start_index, xas.xa_index - 1);
647 		return xas_error(&xas);
648 	}
649 	return 0;
650 }
651 
batch_from_pages(struct pfn_batch * batch,struct page ** pages,size_t npages)652 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
653 			     size_t npages)
654 {
655 	struct page **end = pages + npages;
656 
657 	for (; pages != end; pages++)
658 		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
659 			break;
660 }
661 
batch_from_folios(struct pfn_batch * batch,struct folio *** folios_p,unsigned long * offset_p,unsigned long npages)662 static int batch_from_folios(struct pfn_batch *batch, struct folio ***folios_p,
663 			     unsigned long *offset_p, unsigned long npages)
664 {
665 	int rc = 0;
666 	struct folio **folios = *folios_p;
667 	unsigned long offset = *offset_p;
668 
669 	while (npages) {
670 		struct folio *folio = *folios;
671 		unsigned long nr = folio_nr_pages(folio) - offset;
672 		unsigned long pfn = page_to_pfn(folio_page(folio, offset));
673 
674 		nr = min(nr, npages);
675 		npages -= nr;
676 
677 		if (!batch_add_pfn_num(batch, pfn, nr, BATCH_CPU_MEMORY))
678 			break;
679 		if (nr > 1) {
680 			rc = folio_add_pins(folio, nr - 1);
681 			if (rc) {
682 				batch_remove_pfn_num(batch, nr);
683 				goto out;
684 			}
685 		}
686 
687 		folios++;
688 		offset = 0;
689 	}
690 
691 out:
692 	*folios_p = folios;
693 	*offset_p = offset;
694 	return rc;
695 }
696 
batch_unpin(struct pfn_batch * batch,struct iopt_pages * pages,unsigned int first_page_off,size_t npages)697 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
698 			unsigned int first_page_off, size_t npages)
699 {
700 	unsigned int cur = 0;
701 
702 	while (first_page_off) {
703 		if (batch->npfns[cur] > first_page_off)
704 			break;
705 		first_page_off -= batch->npfns[cur];
706 		cur++;
707 	}
708 
709 	while (npages) {
710 		size_t to_unpin = min_t(size_t, npages,
711 					batch->npfns[cur] - first_page_off);
712 
713 		unpin_user_page_range_dirty_lock(
714 			pfn_to_page(batch->pfns[cur] + first_page_off),
715 			to_unpin, pages->writable);
716 		iopt_pages_sub_npinned(pages, to_unpin);
717 		cur++;
718 		first_page_off = 0;
719 		npages -= to_unpin;
720 	}
721 }
722 
copy_data_page(struct page * page,void * data,unsigned long offset,size_t length,unsigned int flags)723 static void copy_data_page(struct page *page, void *data, unsigned long offset,
724 			   size_t length, unsigned int flags)
725 {
726 	void *mem;
727 
728 	mem = kmap_local_page(page);
729 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
730 		memcpy(mem + offset, data, length);
731 		set_page_dirty_lock(page);
732 	} else {
733 		memcpy(data, mem + offset, length);
734 	}
735 	kunmap_local(mem);
736 }
737 
batch_rw(struct pfn_batch * batch,void * data,unsigned long offset,unsigned long length,unsigned int flags)738 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
739 			      unsigned long offset, unsigned long length,
740 			      unsigned int flags)
741 {
742 	unsigned long copied = 0;
743 	unsigned int npage = 0;
744 	unsigned int cur = 0;
745 
746 	while (cur < batch->end) {
747 		unsigned long bytes = min(length, PAGE_SIZE - offset);
748 
749 		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
750 			       offset, bytes, flags);
751 		offset = 0;
752 		length -= bytes;
753 		data += bytes;
754 		copied += bytes;
755 		npage++;
756 		if (npage == batch->npfns[cur]) {
757 			npage = 0;
758 			cur++;
759 		}
760 		if (!length)
761 			break;
762 	}
763 	return copied;
764 }
765 
766 /* pfn_reader_user is just the pin_user_pages() path */
767 struct pfn_reader_user {
768 	struct page **upages;
769 	size_t upages_len;
770 	unsigned long upages_start;
771 	unsigned long upages_end;
772 	unsigned int gup_flags;
773 	/*
774 	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
775 	 * neither
776 	 */
777 	int locked;
778 
779 	/* The following are only valid if file != NULL. */
780 	struct file *file;
781 	struct folio **ufolios;
782 	size_t ufolios_len;
783 	unsigned long ufolios_offset;
784 	struct folio **ufolios_next;
785 };
786 
pfn_reader_user_init(struct pfn_reader_user * user,struct iopt_pages * pages)787 static void pfn_reader_user_init(struct pfn_reader_user *user,
788 				 struct iopt_pages *pages)
789 {
790 	user->upages = NULL;
791 	user->upages_len = 0;
792 	user->upages_start = 0;
793 	user->upages_end = 0;
794 	user->locked = -1;
795 	user->gup_flags = FOLL_LONGTERM;
796 	if (pages->writable)
797 		user->gup_flags |= FOLL_WRITE;
798 
799 	user->file = (pages->type == IOPT_ADDRESS_FILE) ? pages->file : NULL;
800 	user->ufolios = NULL;
801 	user->ufolios_len = 0;
802 	user->ufolios_next = NULL;
803 	user->ufolios_offset = 0;
804 }
805 
pfn_reader_user_destroy(struct pfn_reader_user * user,struct iopt_pages * pages)806 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
807 				    struct iopt_pages *pages)
808 {
809 	if (user->locked != -1) {
810 		if (user->locked)
811 			mmap_read_unlock(pages->source_mm);
812 		if (!user->file && pages->source_mm != current->mm)
813 			mmput(pages->source_mm);
814 		user->locked = -1;
815 	}
816 
817 	kfree(user->upages);
818 	user->upages = NULL;
819 	kfree(user->ufolios);
820 	user->ufolios = NULL;
821 }
822 
pin_memfd_pages(struct pfn_reader_user * user,unsigned long start,unsigned long npages)823 static long pin_memfd_pages(struct pfn_reader_user *user, unsigned long start,
824 			    unsigned long npages)
825 {
826 	unsigned long i;
827 	unsigned long offset;
828 	unsigned long npages_out = 0;
829 	struct page **upages = user->upages;
830 	unsigned long end = start + (npages << PAGE_SHIFT) - 1;
831 	long nfolios = user->ufolios_len / sizeof(*user->ufolios);
832 
833 	/*
834 	 * todo: memfd_pin_folios should return the last pinned offset so
835 	 * we can compute npages pinned, and avoid looping over folios here
836 	 * if upages == NULL.
837 	 */
838 	nfolios = memfd_pin_folios(user->file, start, end, user->ufolios,
839 				   nfolios, &offset);
840 	if (nfolios <= 0)
841 		return nfolios;
842 
843 	offset >>= PAGE_SHIFT;
844 	user->ufolios_next = user->ufolios;
845 	user->ufolios_offset = offset;
846 
847 	for (i = 0; i < nfolios; i++) {
848 		struct folio *folio = user->ufolios[i];
849 		unsigned long nr = folio_nr_pages(folio);
850 		unsigned long npin = min(nr - offset, npages);
851 
852 		npages -= npin;
853 		npages_out += npin;
854 
855 		if (upages) {
856 			if (npin == 1) {
857 				*upages++ = folio_page(folio, offset);
858 			} else {
859 				int rc = folio_add_pins(folio, npin - 1);
860 
861 				if (rc)
862 					return rc;
863 
864 				while (npin--)
865 					*upages++ = folio_page(folio, offset++);
866 			}
867 		}
868 
869 		offset = 0;
870 	}
871 
872 	return npages_out;
873 }
874 
pfn_reader_user_pin(struct pfn_reader_user * user,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)875 static int pfn_reader_user_pin(struct pfn_reader_user *user,
876 			       struct iopt_pages *pages,
877 			       unsigned long start_index,
878 			       unsigned long last_index)
879 {
880 	bool remote_mm = pages->source_mm != current->mm;
881 	unsigned long npages = last_index - start_index + 1;
882 	unsigned long start;
883 	unsigned long unum;
884 	uintptr_t uptr;
885 	long rc;
886 
887 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
888 	    WARN_ON(last_index < start_index))
889 		return -EINVAL;
890 
891 	if (!user->file && !user->upages) {
892 		/* All undone in pfn_reader_destroy() */
893 		user->upages_len = npages * sizeof(*user->upages);
894 		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
895 		if (!user->upages)
896 			return -ENOMEM;
897 	}
898 
899 	if (user->file && !user->ufolios) {
900 		user->ufolios_len = npages * sizeof(*user->ufolios);
901 		user->ufolios = temp_kmalloc(&user->ufolios_len, NULL, 0);
902 		if (!user->ufolios)
903 			return -ENOMEM;
904 	}
905 
906 	if (user->locked == -1) {
907 		/*
908 		 * The majority of usages will run the map task within the mm
909 		 * providing the pages, so we can optimize into
910 		 * get_user_pages_fast()
911 		 */
912 		if (!user->file && remote_mm) {
913 			if (!mmget_not_zero(pages->source_mm))
914 				return -EFAULT;
915 		}
916 		user->locked = 0;
917 	}
918 
919 	unum = user->file ? user->ufolios_len / sizeof(*user->ufolios) :
920 			    user->upages_len / sizeof(*user->upages);
921 	npages = min_t(unsigned long, npages, unum);
922 
923 	if (iommufd_should_fail())
924 		return -EFAULT;
925 
926 	if (user->file) {
927 		start = pages->start + (start_index * PAGE_SIZE);
928 		rc = pin_memfd_pages(user, start, npages);
929 	} else if (!remote_mm) {
930 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
931 		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
932 					 user->upages);
933 	} else {
934 		uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
935 		if (!user->locked) {
936 			mmap_read_lock(pages->source_mm);
937 			user->locked = 1;
938 		}
939 		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
940 					   user->gup_flags, user->upages,
941 					   &user->locked);
942 	}
943 	if (rc <= 0) {
944 		if (WARN_ON(!rc))
945 			return -EFAULT;
946 		return rc;
947 	}
948 	iopt_pages_add_npinned(pages, rc);
949 	user->upages_start = start_index;
950 	user->upages_end = start_index + rc;
951 	return 0;
952 }
953 
954 /* This is the "modern" and faster accounting method used by io_uring */
incr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)955 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
956 {
957 	unsigned long lock_limit;
958 	unsigned long cur_pages;
959 	unsigned long new_pages;
960 
961 	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
962 		     PAGE_SHIFT;
963 
964 	cur_pages = atomic_long_read(&pages->source_user->locked_vm);
965 	do {
966 		new_pages = cur_pages + npages;
967 		if (new_pages > lock_limit)
968 			return -ENOMEM;
969 	} while (!atomic_long_try_cmpxchg(&pages->source_user->locked_vm,
970 					  &cur_pages, new_pages));
971 	return 0;
972 }
973 
decr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)974 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
975 {
976 	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
977 		return;
978 	atomic_long_sub(npages, &pages->source_user->locked_vm);
979 }
980 
981 /* This is the accounting method used for compatibility with VFIO */
update_mm_locked_vm(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)982 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
983 			       bool inc, struct pfn_reader_user *user)
984 {
985 	bool do_put = false;
986 	int rc;
987 
988 	if (user && user->locked) {
989 		mmap_read_unlock(pages->source_mm);
990 		user->locked = 0;
991 		/* If we had the lock then we also have a get */
992 
993 	} else if ((!user || (!user->upages && !user->ufolios)) &&
994 		   pages->source_mm != current->mm) {
995 		if (!mmget_not_zero(pages->source_mm))
996 			return -EINVAL;
997 		do_put = true;
998 	}
999 
1000 	mmap_write_lock(pages->source_mm);
1001 	rc = __account_locked_vm(pages->source_mm, npages, inc,
1002 				 pages->source_task, false);
1003 	mmap_write_unlock(pages->source_mm);
1004 
1005 	if (do_put)
1006 		mmput(pages->source_mm);
1007 	return rc;
1008 }
1009 
iopt_pages_update_pinned(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)1010 int iopt_pages_update_pinned(struct iopt_pages *pages, unsigned long npages,
1011 			     bool inc, struct pfn_reader_user *user)
1012 {
1013 	int rc = 0;
1014 
1015 	switch (pages->account_mode) {
1016 	case IOPT_PAGES_ACCOUNT_NONE:
1017 		break;
1018 	case IOPT_PAGES_ACCOUNT_USER:
1019 		if (inc)
1020 			rc = incr_user_locked_vm(pages, npages);
1021 		else
1022 			decr_user_locked_vm(pages, npages);
1023 		break;
1024 	case IOPT_PAGES_ACCOUNT_MM:
1025 		rc = update_mm_locked_vm(pages, npages, inc, user);
1026 		break;
1027 	}
1028 	if (rc)
1029 		return rc;
1030 
1031 	pages->last_npinned = pages->npinned;
1032 	if (inc)
1033 		atomic64_add(npages, &pages->source_mm->pinned_vm);
1034 	else
1035 		atomic64_sub(npages, &pages->source_mm->pinned_vm);
1036 	return 0;
1037 }
1038 
update_unpinned(struct iopt_pages * pages)1039 static void update_unpinned(struct iopt_pages *pages)
1040 {
1041 	if (WARN_ON(pages->npinned > pages->last_npinned))
1042 		return;
1043 	if (pages->npinned == pages->last_npinned)
1044 		return;
1045 	iopt_pages_update_pinned(pages, pages->last_npinned - pages->npinned,
1046 				 false, NULL);
1047 }
1048 
1049 /*
1050  * Changes in the number of pages pinned is done after the pages have been read
1051  * and processed. If the user lacked the limit then the error unwind will unpin
1052  * everything that was just pinned. This is because it is expensive to calculate
1053  * how many pages we have already pinned within a range to generate an accurate
1054  * prediction in advance of doing the work to actually pin them.
1055  */
pfn_reader_user_update_pinned(struct pfn_reader_user * user,struct iopt_pages * pages)1056 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
1057 					 struct iopt_pages *pages)
1058 {
1059 	unsigned long npages;
1060 	bool inc;
1061 
1062 	lockdep_assert_held(&pages->mutex);
1063 
1064 	if (pages->npinned == pages->last_npinned)
1065 		return 0;
1066 
1067 	if (pages->npinned < pages->last_npinned) {
1068 		npages = pages->last_npinned - pages->npinned;
1069 		inc = false;
1070 	} else {
1071 		if (iommufd_should_fail())
1072 			return -ENOMEM;
1073 		npages = pages->npinned - pages->last_npinned;
1074 		inc = true;
1075 	}
1076 	return iopt_pages_update_pinned(pages, npages, inc, user);
1077 }
1078 
1079 struct pfn_reader_dmabuf {
1080 	struct dma_buf_phys_vec phys;
1081 	unsigned long start_offset;
1082 };
1083 
pfn_reader_dmabuf_init(struct pfn_reader_dmabuf * dmabuf,struct iopt_pages * pages)1084 static int pfn_reader_dmabuf_init(struct pfn_reader_dmabuf *dmabuf,
1085 				  struct iopt_pages *pages)
1086 {
1087 	/* Callers must not get here if the dmabuf was already revoked */
1088 	if (WARN_ON(iopt_dmabuf_revoked(pages)))
1089 		return -EINVAL;
1090 
1091 	dmabuf->phys = pages->dmabuf.phys;
1092 	dmabuf->start_offset = pages->dmabuf.start;
1093 	return 0;
1094 }
1095 
pfn_reader_fill_dmabuf(struct pfn_reader_dmabuf * dmabuf,struct pfn_batch * batch,unsigned long start_index,unsigned long last_index)1096 static int pfn_reader_fill_dmabuf(struct pfn_reader_dmabuf *dmabuf,
1097 				  struct pfn_batch *batch,
1098 				  unsigned long start_index,
1099 				  unsigned long last_index)
1100 {
1101 	unsigned long start = dmabuf->start_offset + start_index * PAGE_SIZE;
1102 
1103 	/*
1104 	 * start/last_index and start are all PAGE_SIZE aligned, the batch is
1105 	 * always filled using page size aligned PFNs just like the other types.
1106 	 * If the dmabuf has been sliced on a sub page offset then the common
1107 	 * batch to domain code will adjust it before mapping to the domain.
1108 	 */
1109 	batch_add_pfn_num(batch, PHYS_PFN(dmabuf->phys.paddr + start),
1110 			  last_index - start_index + 1, BATCH_MMIO);
1111 	return 0;
1112 }
1113 
1114 /*
1115  * PFNs are stored in three places, in order of preference:
1116  * - The iopt_pages xarray. This is only populated if there is a
1117  *   iopt_pages_access
1118  * - The iommu_domain under an area
1119  * - The original PFN source, ie pages->source_mm
1120  *
1121  * This iterator reads the pfns optimizing to load according to the
1122  * above order.
1123  */
1124 struct pfn_reader {
1125 	struct iopt_pages *pages;
1126 	struct interval_tree_double_span_iter span;
1127 	struct pfn_batch batch;
1128 	unsigned long batch_start_index;
1129 	unsigned long batch_end_index;
1130 	unsigned long last_index;
1131 
1132 	union {
1133 		struct pfn_reader_user user;
1134 		struct pfn_reader_dmabuf dmabuf;
1135 	};
1136 };
1137 
pfn_reader_update_pinned(struct pfn_reader * pfns)1138 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
1139 {
1140 	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
1141 }
1142 
1143 /*
1144  * The batch can contain a mixture of pages that are still in use and pages that
1145  * need to be unpinned. Unpin only pages that are not held anywhere else.
1146  */
pfn_reader_unpin(struct pfn_reader * pfns)1147 static void pfn_reader_unpin(struct pfn_reader *pfns)
1148 {
1149 	unsigned long last = pfns->batch_end_index - 1;
1150 	unsigned long start = pfns->batch_start_index;
1151 	struct interval_tree_double_span_iter span;
1152 	struct iopt_pages *pages = pfns->pages;
1153 
1154 	lockdep_assert_held(&pages->mutex);
1155 
1156 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1157 					   &pages->domains_itree, start, last) {
1158 		if (span.is_used)
1159 			continue;
1160 
1161 		batch_unpin(&pfns->batch, pages, span.start_hole - start,
1162 			    span.last_hole - span.start_hole + 1);
1163 	}
1164 }
1165 
1166 /* Process a single span to load it from the proper storage */
pfn_reader_fill_span(struct pfn_reader * pfns)1167 static int pfn_reader_fill_span(struct pfn_reader *pfns)
1168 {
1169 	struct interval_tree_double_span_iter *span = &pfns->span;
1170 	unsigned long start_index = pfns->batch_end_index;
1171 	struct pfn_reader_user *user;
1172 	unsigned long npages;
1173 	struct iopt_area *area;
1174 	int rc;
1175 
1176 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1177 	    WARN_ON(span->last_used < start_index))
1178 		return -EINVAL;
1179 
1180 	if (span->is_used == 1) {
1181 		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
1182 				  start_index, span->last_used);
1183 		return 0;
1184 	}
1185 
1186 	if (span->is_used == 2) {
1187 		/*
1188 		 * Pull as many pages from the first domain we find in the
1189 		 * target span. If it is too small then we will be called again
1190 		 * and we'll find another area.
1191 		 */
1192 		area = iopt_pages_find_domain_area(pfns->pages, start_index);
1193 		if (WARN_ON(!area))
1194 			return -EINVAL;
1195 
1196 		/* The storage_domain cannot change without the pages mutex */
1197 		batch_from_domain(
1198 			&pfns->batch, area->storage_domain, area, start_index,
1199 			min(iopt_area_last_index(area), span->last_used));
1200 		return 0;
1201 	}
1202 
1203 	if (iopt_is_dmabuf(pfns->pages))
1204 		return pfn_reader_fill_dmabuf(&pfns->dmabuf, &pfns->batch,
1205 					      start_index, span->last_hole);
1206 
1207 	user = &pfns->user;
1208 	if (start_index >= user->upages_end) {
1209 		rc = pfn_reader_user_pin(user, pfns->pages, start_index,
1210 					 span->last_hole);
1211 		if (rc)
1212 			return rc;
1213 	}
1214 
1215 	npages = user->upages_end - start_index;
1216 	start_index -= user->upages_start;
1217 	rc = 0;
1218 
1219 	if (!user->file)
1220 		batch_from_pages(&pfns->batch, user->upages + start_index,
1221 				 npages);
1222 	else
1223 		rc = batch_from_folios(&pfns->batch, &user->ufolios_next,
1224 				       &user->ufolios_offset, npages);
1225 	return rc;
1226 }
1227 
pfn_reader_done(struct pfn_reader * pfns)1228 static bool pfn_reader_done(struct pfn_reader *pfns)
1229 {
1230 	return pfns->batch_start_index == pfns->last_index + 1;
1231 }
1232 
pfn_reader_next(struct pfn_reader * pfns)1233 static int pfn_reader_next(struct pfn_reader *pfns)
1234 {
1235 	int rc;
1236 
1237 	batch_clear(&pfns->batch);
1238 	pfns->batch_start_index = pfns->batch_end_index;
1239 
1240 	while (pfns->batch_end_index != pfns->last_index + 1) {
1241 		unsigned int npfns = pfns->batch.total_pfns;
1242 
1243 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1244 		    WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1245 			return -EINVAL;
1246 
1247 		rc = pfn_reader_fill_span(pfns);
1248 		if (rc)
1249 			return rc;
1250 
1251 		if (WARN_ON(!pfns->batch.total_pfns))
1252 			return -EINVAL;
1253 
1254 		pfns->batch_end_index =
1255 			pfns->batch_start_index + pfns->batch.total_pfns;
1256 		if (pfns->batch_end_index == pfns->span.last_used + 1)
1257 			interval_tree_double_span_iter_next(&pfns->span);
1258 
1259 		/* Batch is full */
1260 		if (npfns == pfns->batch.total_pfns)
1261 			return 0;
1262 	}
1263 	return 0;
1264 }
1265 
pfn_reader_init(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1266 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1267 			   unsigned long start_index, unsigned long last_index)
1268 {
1269 	int rc;
1270 
1271 	lockdep_assert_held(&pages->mutex);
1272 
1273 	pfns->pages = pages;
1274 	pfns->batch_start_index = start_index;
1275 	pfns->batch_end_index = start_index;
1276 	pfns->last_index = last_index;
1277 	if (iopt_is_dmabuf(pages))
1278 		pfn_reader_dmabuf_init(&pfns->dmabuf, pages);
1279 	else
1280 		pfn_reader_user_init(&pfns->user, pages);
1281 	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1282 	if (rc)
1283 		return rc;
1284 	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1285 					     &pages->domains_itree, start_index,
1286 					     last_index);
1287 	return 0;
1288 }
1289 
1290 /*
1291  * There are many assertions regarding the state of pages->npinned vs
1292  * pages->last_pinned, for instance something like unmapping a domain must only
1293  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1294  * the pins are updated. This is fine for success flows, but error flows
1295  * sometimes need to release the pins held inside the pfn_reader before going on
1296  * to complete unmapping and releasing pins held in domains.
1297  */
pfn_reader_release_pins(struct pfn_reader * pfns)1298 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1299 {
1300 	struct iopt_pages *pages = pfns->pages;
1301 	struct pfn_reader_user *user;
1302 
1303 	if (iopt_is_dmabuf(pages))
1304 		return;
1305 
1306 	user = &pfns->user;
1307 	if (user->upages_end > pfns->batch_end_index) {
1308 		/* Any pages not transferred to the batch are just unpinned */
1309 
1310 		unsigned long npages = user->upages_end - pfns->batch_end_index;
1311 		unsigned long start_index = pfns->batch_end_index -
1312 					    user->upages_start;
1313 
1314 		if (!user->file) {
1315 			unpin_user_pages(user->upages + start_index, npages);
1316 		} else {
1317 			long n = user->ufolios_len / sizeof(*user->ufolios);
1318 
1319 			unpin_folios(user->ufolios_next,
1320 				     user->ufolios + n - user->ufolios_next);
1321 		}
1322 		iopt_pages_sub_npinned(pages, npages);
1323 		user->upages_end = pfns->batch_end_index;
1324 	}
1325 	if (pfns->batch_start_index != pfns->batch_end_index) {
1326 		pfn_reader_unpin(pfns);
1327 		pfns->batch_start_index = pfns->batch_end_index;
1328 	}
1329 }
1330 
pfn_reader_destroy(struct pfn_reader * pfns)1331 static void pfn_reader_destroy(struct pfn_reader *pfns)
1332 {
1333 	struct iopt_pages *pages = pfns->pages;
1334 
1335 	pfn_reader_release_pins(pfns);
1336 	if (!iopt_is_dmabuf(pfns->pages))
1337 		pfn_reader_user_destroy(&pfns->user, pfns->pages);
1338 	batch_destroy(&pfns->batch, NULL);
1339 	WARN_ON(pages->last_npinned != pages->npinned);
1340 }
1341 
pfn_reader_first(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1342 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1343 			    unsigned long start_index, unsigned long last_index)
1344 {
1345 	int rc;
1346 
1347 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1348 	    WARN_ON(last_index < start_index))
1349 		return -EINVAL;
1350 
1351 	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1352 	if (rc)
1353 		return rc;
1354 	rc = pfn_reader_next(pfns);
1355 	if (rc) {
1356 		pfn_reader_destroy(pfns);
1357 		return rc;
1358 	}
1359 	return 0;
1360 }
1361 
iopt_alloc_pages(unsigned long start_byte,unsigned long length,bool writable)1362 static struct iopt_pages *iopt_alloc_pages(unsigned long start_byte,
1363 					   unsigned long length, bool writable)
1364 {
1365 	struct iopt_pages *pages;
1366 
1367 	/*
1368 	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1369 	 * below from overflow
1370 	 */
1371 	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1372 		return ERR_PTR(-EINVAL);
1373 
1374 	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1375 	if (!pages)
1376 		return ERR_PTR(-ENOMEM);
1377 
1378 	kref_init(&pages->kref);
1379 	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1380 	mutex_init(&pages->mutex);
1381 	pages->source_mm = current->mm;
1382 	mmgrab(pages->source_mm);
1383 	pages->npages = DIV_ROUND_UP(length + start_byte, PAGE_SIZE);
1384 	pages->access_itree = RB_ROOT_CACHED;
1385 	pages->domains_itree = RB_ROOT_CACHED;
1386 	pages->writable = writable;
1387 	if (capable(CAP_IPC_LOCK))
1388 		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1389 	else
1390 		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1391 	pages->source_task = current->group_leader;
1392 	get_task_struct(current->group_leader);
1393 	pages->source_user = get_uid(current_user());
1394 	return pages;
1395 }
1396 
iopt_alloc_user_pages(void __user * uptr,unsigned long length,bool writable)1397 struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
1398 					 unsigned long length, bool writable)
1399 {
1400 	struct iopt_pages *pages;
1401 	unsigned long end;
1402 	void __user *uptr_down =
1403 		(void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1404 
1405 	if (check_add_overflow((unsigned long)uptr, length, &end))
1406 		return ERR_PTR(-EOVERFLOW);
1407 
1408 	pages = iopt_alloc_pages(uptr - uptr_down, length, writable);
1409 	if (IS_ERR(pages))
1410 		return pages;
1411 	pages->uptr = uptr_down;
1412 	pages->type = IOPT_ADDRESS_USER;
1413 	return pages;
1414 }
1415 
iopt_alloc_file_pages(struct file * file,unsigned long start_byte,unsigned long start,unsigned long length,bool writable)1416 struct iopt_pages *iopt_alloc_file_pages(struct file *file,
1417 					 unsigned long start_byte,
1418 					 unsigned long start,
1419 					 unsigned long length, bool writable)
1420 
1421 {
1422 	struct iopt_pages *pages;
1423 
1424 	pages = iopt_alloc_pages(start_byte, length, writable);
1425 	if (IS_ERR(pages))
1426 		return pages;
1427 	pages->file = get_file(file);
1428 	pages->start = start - start_byte;
1429 	pages->type = IOPT_ADDRESS_FILE;
1430 	return pages;
1431 }
1432 
iopt_revoke_notify(struct dma_buf_attachment * attach)1433 static void iopt_revoke_notify(struct dma_buf_attachment *attach)
1434 {
1435 	struct iopt_pages *pages = attach->importer_priv;
1436 	struct iopt_pages_dmabuf_track *track;
1437 
1438 	guard(mutex)(&pages->mutex);
1439 	if (iopt_dmabuf_revoked(pages))
1440 		return;
1441 
1442 	list_for_each_entry(track, &pages->dmabuf.tracker, elm) {
1443 		struct iopt_area *area = track->area;
1444 
1445 		iopt_area_unmap_domain_range(area, track->domain,
1446 					     iopt_area_index(area),
1447 					     iopt_area_last_index(area));
1448 	}
1449 	pages->dmabuf.phys.len = 0;
1450 }
1451 
1452 static struct dma_buf_attach_ops iopt_dmabuf_attach_revoke_ops = {
1453 	.allow_peer2peer = true,
1454 	.move_notify = iopt_revoke_notify,
1455 };
1456 
1457 /*
1458  * iommufd and vfio have a circular dependency. Future work for a phys
1459  * based private interconnect will remove this.
1460  */
1461 static int
sym_vfio_pci_dma_buf_iommufd_map(struct dma_buf_attachment * attachment,struct dma_buf_phys_vec * phys)1462 sym_vfio_pci_dma_buf_iommufd_map(struct dma_buf_attachment *attachment,
1463 				 struct dma_buf_phys_vec *phys)
1464 {
1465 	typeof(&vfio_pci_dma_buf_iommufd_map) fn;
1466 	int rc;
1467 
1468 	rc = iommufd_test_dma_buf_iommufd_map(attachment, phys);
1469 	if (rc != -EOPNOTSUPP)
1470 		return rc;
1471 
1472 	if (!IS_ENABLED(CONFIG_VFIO_PCI_DMABUF))
1473 		return -EOPNOTSUPP;
1474 
1475 	fn = symbol_get(vfio_pci_dma_buf_iommufd_map);
1476 	if (!fn)
1477 		return -EOPNOTSUPP;
1478 	rc = fn(attachment, phys);
1479 	symbol_put(vfio_pci_dma_buf_iommufd_map);
1480 	return rc;
1481 }
1482 
iopt_map_dmabuf(struct iommufd_ctx * ictx,struct iopt_pages * pages,struct dma_buf * dmabuf)1483 static int iopt_map_dmabuf(struct iommufd_ctx *ictx, struct iopt_pages *pages,
1484 			   struct dma_buf *dmabuf)
1485 {
1486 	struct dma_buf_attachment *attach;
1487 	int rc;
1488 
1489 	attach = dma_buf_dynamic_attach(dmabuf, iommufd_global_device(),
1490 					&iopt_dmabuf_attach_revoke_ops, pages);
1491 	if (IS_ERR(attach))
1492 		return PTR_ERR(attach);
1493 
1494 	dma_resv_lock(dmabuf->resv, NULL);
1495 	/*
1496 	 * Lock ordering requires the mutex to be taken inside the reservation,
1497 	 * make sure lockdep sees this.
1498 	 */
1499 	if (IS_ENABLED(CONFIG_LOCKDEP)) {
1500 		mutex_lock(&pages->mutex);
1501 		mutex_unlock(&pages->mutex);
1502 	}
1503 
1504 	rc = sym_vfio_pci_dma_buf_iommufd_map(attach, &pages->dmabuf.phys);
1505 	if (rc)
1506 		goto err_detach;
1507 
1508 	dma_resv_unlock(dmabuf->resv);
1509 
1510 	/* On success iopt_release_pages() will detach and put the dmabuf. */
1511 	pages->dmabuf.attach = attach;
1512 	return 0;
1513 
1514 err_detach:
1515 	dma_resv_unlock(dmabuf->resv);
1516 	dma_buf_detach(dmabuf, attach);
1517 	return rc;
1518 }
1519 
iopt_alloc_dmabuf_pages(struct iommufd_ctx * ictx,struct dma_buf * dmabuf,unsigned long start_byte,unsigned long start,unsigned long length,bool writable)1520 struct iopt_pages *iopt_alloc_dmabuf_pages(struct iommufd_ctx *ictx,
1521 					   struct dma_buf *dmabuf,
1522 					   unsigned long start_byte,
1523 					   unsigned long start,
1524 					   unsigned long length, bool writable)
1525 {
1526 	static struct lock_class_key pages_dmabuf_mutex_key;
1527 	struct iopt_pages *pages;
1528 	int rc;
1529 
1530 	if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
1531 		return ERR_PTR(-EOPNOTSUPP);
1532 
1533 	if (dmabuf->size <= (start + length - 1) ||
1534 	    length / PAGE_SIZE >= MAX_NPFNS)
1535 		return ERR_PTR(-EINVAL);
1536 
1537 	pages = iopt_alloc_pages(start_byte, length, writable);
1538 	if (IS_ERR(pages))
1539 		return pages;
1540 
1541 	/*
1542 	 * The mmap_lock can be held when obtaining the dmabuf reservation lock
1543 	 * which creates a locking cycle with the pages mutex which is held
1544 	 * while obtaining the mmap_lock. This locking path is not present for
1545 	 * IOPT_ADDRESS_DMABUF so split the lock class.
1546 	 */
1547 	lockdep_set_class(&pages->mutex, &pages_dmabuf_mutex_key);
1548 
1549 	/* dmabuf does not use pinned page accounting. */
1550 	pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1551 	pages->type = IOPT_ADDRESS_DMABUF;
1552 	pages->dmabuf.start = start - start_byte;
1553 	INIT_LIST_HEAD(&pages->dmabuf.tracker);
1554 
1555 	rc = iopt_map_dmabuf(ictx, pages, dmabuf);
1556 	if (rc) {
1557 		iopt_put_pages(pages);
1558 		return ERR_PTR(rc);
1559 	}
1560 
1561 	return pages;
1562 }
1563 
iopt_dmabuf_track_domain(struct iopt_pages * pages,struct iopt_area * area,struct iommu_domain * domain)1564 int iopt_dmabuf_track_domain(struct iopt_pages *pages, struct iopt_area *area,
1565 			     struct iommu_domain *domain)
1566 {
1567 	struct iopt_pages_dmabuf_track *track;
1568 
1569 	lockdep_assert_held(&pages->mutex);
1570 	if (WARN_ON(!iopt_is_dmabuf(pages)))
1571 		return -EINVAL;
1572 
1573 	list_for_each_entry(track, &pages->dmabuf.tracker, elm)
1574 		if (WARN_ON(track->domain == domain && track->area == area))
1575 			return -EINVAL;
1576 
1577 	track = kzalloc(sizeof(*track), GFP_KERNEL);
1578 	if (!track)
1579 		return -ENOMEM;
1580 	track->domain = domain;
1581 	track->area = area;
1582 	list_add_tail(&track->elm, &pages->dmabuf.tracker);
1583 
1584 	return 0;
1585 }
1586 
iopt_dmabuf_untrack_domain(struct iopt_pages * pages,struct iopt_area * area,struct iommu_domain * domain)1587 void iopt_dmabuf_untrack_domain(struct iopt_pages *pages,
1588 				struct iopt_area *area,
1589 				struct iommu_domain *domain)
1590 {
1591 	struct iopt_pages_dmabuf_track *track;
1592 
1593 	lockdep_assert_held(&pages->mutex);
1594 	WARN_ON(!iopt_is_dmabuf(pages));
1595 
1596 	list_for_each_entry(track, &pages->dmabuf.tracker, elm) {
1597 		if (track->domain == domain && track->area == area) {
1598 			list_del(&track->elm);
1599 			kfree(track);
1600 			return;
1601 		}
1602 	}
1603 	WARN_ON(true);
1604 }
1605 
iopt_dmabuf_track_all_domains(struct iopt_area * area,struct iopt_pages * pages)1606 int iopt_dmabuf_track_all_domains(struct iopt_area *area,
1607 				  struct iopt_pages *pages)
1608 {
1609 	struct iopt_pages_dmabuf_track *track;
1610 	struct iommu_domain *domain;
1611 	unsigned long index;
1612 	int rc;
1613 
1614 	list_for_each_entry(track, &pages->dmabuf.tracker, elm)
1615 		if (WARN_ON(track->area == area))
1616 			return -EINVAL;
1617 
1618 	xa_for_each(&area->iopt->domains, index, domain) {
1619 		rc = iopt_dmabuf_track_domain(pages, area, domain);
1620 		if (rc)
1621 			goto err_untrack;
1622 	}
1623 	return 0;
1624 err_untrack:
1625 	iopt_dmabuf_untrack_all_domains(area, pages);
1626 	return rc;
1627 }
1628 
iopt_dmabuf_untrack_all_domains(struct iopt_area * area,struct iopt_pages * pages)1629 void iopt_dmabuf_untrack_all_domains(struct iopt_area *area,
1630 				     struct iopt_pages *pages)
1631 {
1632 	struct iopt_pages_dmabuf_track *track;
1633 	struct iopt_pages_dmabuf_track *tmp;
1634 
1635 	list_for_each_entry_safe(track, tmp, &pages->dmabuf.tracker,
1636 				 elm) {
1637 		if (track->area == area) {
1638 			list_del(&track->elm);
1639 			kfree(track);
1640 		}
1641 	}
1642 }
1643 
iopt_release_pages(struct kref * kref)1644 void iopt_release_pages(struct kref *kref)
1645 {
1646 	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1647 
1648 	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1649 	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1650 	WARN_ON(pages->npinned);
1651 	WARN_ON(!xa_empty(&pages->pinned_pfns));
1652 	mmdrop(pages->source_mm);
1653 	mutex_destroy(&pages->mutex);
1654 	put_task_struct(pages->source_task);
1655 	free_uid(pages->source_user);
1656 	if (iopt_is_dmabuf(pages) && pages->dmabuf.attach) {
1657 		struct dma_buf *dmabuf = pages->dmabuf.attach->dmabuf;
1658 
1659 		dma_buf_detach(dmabuf, pages->dmabuf.attach);
1660 		dma_buf_put(dmabuf);
1661 		WARN_ON(!list_empty(&pages->dmabuf.tracker));
1662 	} else if (pages->type == IOPT_ADDRESS_FILE) {
1663 		fput(pages->file);
1664 	}
1665 	kfree(pages);
1666 }
1667 
1668 static void
iopt_area_unpin_domain(struct pfn_batch * batch,struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index,unsigned long * unmapped_end_index,unsigned long real_last_index)1669 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1670 		       struct iopt_pages *pages, struct iommu_domain *domain,
1671 		       unsigned long start_index, unsigned long last_index,
1672 		       unsigned long *unmapped_end_index,
1673 		       unsigned long real_last_index)
1674 {
1675 	while (start_index <= last_index) {
1676 		unsigned long batch_last_index;
1677 
1678 		if (*unmapped_end_index <= last_index) {
1679 			unsigned long start =
1680 				max(start_index, *unmapped_end_index);
1681 
1682 			if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1683 			    batch->total_pfns)
1684 				WARN_ON(*unmapped_end_index -
1685 						batch->total_pfns !=
1686 					start_index);
1687 			batch_from_domain(batch, domain, area, start,
1688 					  last_index);
1689 			batch_last_index = start_index + batch->total_pfns - 1;
1690 		} else {
1691 			batch_last_index = last_index;
1692 		}
1693 
1694 		if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1695 			WARN_ON(batch_last_index > real_last_index);
1696 
1697 		/*
1698 		 * unmaps must always 'cut' at a place where the pfns are not
1699 		 * contiguous to pair with the maps that always install
1700 		 * contiguous pages. Thus, if we have to stop unpinning in the
1701 		 * middle of the domains we need to keep reading pfns until we
1702 		 * find a cut point to do the unmap. The pfns we read are
1703 		 * carried over and either skipped or integrated into the next
1704 		 * batch.
1705 		 */
1706 		if (batch_last_index == last_index &&
1707 		    last_index != real_last_index)
1708 			batch_from_domain_continue(batch, domain, area,
1709 						   last_index + 1,
1710 						   real_last_index);
1711 
1712 		if (*unmapped_end_index <= batch_last_index) {
1713 			iopt_area_unmap_domain_range(
1714 				area, domain, *unmapped_end_index,
1715 				start_index + batch->total_pfns - 1);
1716 			*unmapped_end_index = start_index + batch->total_pfns;
1717 		}
1718 
1719 		/* unpin must follow unmap */
1720 		batch_unpin(batch, pages, 0,
1721 			    batch_last_index - start_index + 1);
1722 		start_index = batch_last_index + 1;
1723 
1724 		batch_clear_carry(batch,
1725 				  *unmapped_end_index - batch_last_index - 1);
1726 	}
1727 }
1728 
__iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long last_index)1729 static void __iopt_area_unfill_domain(struct iopt_area *area,
1730 				      struct iopt_pages *pages,
1731 				      struct iommu_domain *domain,
1732 				      unsigned long last_index)
1733 {
1734 	struct interval_tree_double_span_iter span;
1735 	unsigned long start_index = iopt_area_index(area);
1736 	unsigned long unmapped_end_index = start_index;
1737 	u64 backup[BATCH_BACKUP_SIZE];
1738 	struct pfn_batch batch;
1739 
1740 	lockdep_assert_held(&pages->mutex);
1741 
1742 	if (iopt_is_dmabuf(pages)) {
1743 		if (WARN_ON(iopt_dmabuf_revoked(pages)))
1744 			return;
1745 		iopt_area_unmap_domain_range(area, domain, start_index,
1746 					     last_index);
1747 		return;
1748 	}
1749 
1750 	/*
1751 	 * For security we must not unpin something that is still DMA mapped,
1752 	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1753 	 * This creates a complexity where we need to skip over unpinning pages
1754 	 * held in the xarray, but continue to unmap from the domain.
1755 	 *
1756 	 * The domain unmap cannot stop in the middle of a contiguous range of
1757 	 * PFNs. To solve this problem the unpinning step will read ahead to the
1758 	 * end of any contiguous span, unmap that whole span, and then only
1759 	 * unpin the leading part that does not have any accesses. The residual
1760 	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1761 	 * batch as they are moved to the front of the PFN list and continue on
1762 	 * to the next iteration(s).
1763 	 */
1764 	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1765 	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1766 					   &pages->access_itree, start_index,
1767 					   last_index) {
1768 		if (span.is_used) {
1769 			batch_skip_carry(&batch,
1770 					 span.last_used - span.start_used + 1);
1771 			continue;
1772 		}
1773 		iopt_area_unpin_domain(&batch, area, pages, domain,
1774 				       span.start_hole, span.last_hole,
1775 				       &unmapped_end_index, last_index);
1776 	}
1777 	/*
1778 	 * If the range ends in a access then we do the residual unmap without
1779 	 * any unpins.
1780 	 */
1781 	if (unmapped_end_index != last_index + 1)
1782 		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1783 					     last_index);
1784 	WARN_ON(batch.total_pfns);
1785 	batch_destroy(&batch, backup);
1786 	update_unpinned(pages);
1787 }
1788 
iopt_area_unfill_partial_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long end_index)1789 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1790 					    struct iopt_pages *pages,
1791 					    struct iommu_domain *domain,
1792 					    unsigned long end_index)
1793 {
1794 	if (end_index != iopt_area_index(area))
1795 		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1796 }
1797 
1798 /**
1799  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1800  * @area: The IOVA range to unmap
1801  * @domain: The domain to unmap
1802  *
1803  * The caller must know that unpinning is not required, usually because there
1804  * are other domains in the iopt.
1805  */
iopt_area_unmap_domain(struct iopt_area * area,struct iommu_domain * domain)1806 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1807 {
1808 	iommu_unmap_nofail(domain, iopt_area_iova(area),
1809 			   iopt_area_length(area));
1810 }
1811 
1812 /**
1813  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1814  * @area: IOVA area to use
1815  * @pages: page supplier for the area (area->pages is NULL)
1816  * @domain: Domain to unmap from
1817  *
1818  * The domain should be removed from the domains_itree before calling. The
1819  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1820  * still accesses.
1821  */
iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain)1822 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1823 			     struct iommu_domain *domain)
1824 {
1825 	if (iopt_dmabuf_revoked(pages))
1826 		return;
1827 
1828 	__iopt_area_unfill_domain(area, pages, domain,
1829 				  iopt_area_last_index(area));
1830 }
1831 
1832 /**
1833  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1834  * @area: IOVA area to use
1835  * @domain: Domain to load PFNs into
1836  *
1837  * Read the pfns from the area's underlying iopt_pages and map them into the
1838  * given domain. Called when attaching a new domain to an io_pagetable.
1839  */
iopt_area_fill_domain(struct iopt_area * area,struct iommu_domain * domain)1840 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1841 {
1842 	unsigned long done_end_index;
1843 	struct pfn_reader pfns;
1844 	int rc;
1845 
1846 	lockdep_assert_held(&area->pages->mutex);
1847 
1848 	if (iopt_dmabuf_revoked(area->pages))
1849 		return 0;
1850 
1851 	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1852 			      iopt_area_last_index(area));
1853 	if (rc)
1854 		return rc;
1855 
1856 	while (!pfn_reader_done(&pfns)) {
1857 		done_end_index = pfns.batch_start_index;
1858 		rc = batch_to_domain(&pfns.batch, domain, area,
1859 				     pfns.batch_start_index);
1860 		if (rc)
1861 			goto out_unmap;
1862 		done_end_index = pfns.batch_end_index;
1863 
1864 		rc = pfn_reader_next(&pfns);
1865 		if (rc)
1866 			goto out_unmap;
1867 	}
1868 
1869 	rc = pfn_reader_update_pinned(&pfns);
1870 	if (rc)
1871 		goto out_unmap;
1872 	goto out_destroy;
1873 
1874 out_unmap:
1875 	pfn_reader_release_pins(&pfns);
1876 	iopt_area_unfill_partial_domain(area, area->pages, domain,
1877 					done_end_index);
1878 out_destroy:
1879 	pfn_reader_destroy(&pfns);
1880 	return rc;
1881 }
1882 
1883 /**
1884  * iopt_area_fill_domains() - Install PFNs into the area's domains
1885  * @area: The area to act on
1886  * @pages: The pages associated with the area (area->pages is NULL)
1887  *
1888  * Called during area creation. The area is freshly created and not inserted in
1889  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1890  * area's io_pagetable and the area is installed in the domains_itree.
1891  *
1892  * On failure all domains are left unchanged.
1893  */
iopt_area_fill_domains(struct iopt_area * area,struct iopt_pages * pages)1894 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1895 {
1896 	unsigned long done_first_end_index;
1897 	unsigned long done_all_end_index;
1898 	struct iommu_domain *domain;
1899 	unsigned long unmap_index;
1900 	struct pfn_reader pfns;
1901 	unsigned long index;
1902 	int rc;
1903 
1904 	lockdep_assert_held(&area->iopt->domains_rwsem);
1905 
1906 	if (xa_empty(&area->iopt->domains))
1907 		return 0;
1908 
1909 	mutex_lock(&pages->mutex);
1910 	if (iopt_is_dmabuf(pages)) {
1911 		rc = iopt_dmabuf_track_all_domains(area, pages);
1912 		if (rc)
1913 			goto out_unlock;
1914 	}
1915 
1916 	if (!iopt_dmabuf_revoked(pages)) {
1917 		rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1918 				      iopt_area_last_index(area));
1919 		if (rc)
1920 			goto out_untrack;
1921 
1922 		while (!pfn_reader_done(&pfns)) {
1923 			done_first_end_index = pfns.batch_end_index;
1924 			done_all_end_index = pfns.batch_start_index;
1925 			xa_for_each(&area->iopt->domains, index, domain) {
1926 				rc = batch_to_domain(&pfns.batch, domain, area,
1927 						     pfns.batch_start_index);
1928 				if (rc)
1929 					goto out_unmap;
1930 			}
1931 			done_all_end_index = done_first_end_index;
1932 
1933 			rc = pfn_reader_next(&pfns);
1934 			if (rc)
1935 				goto out_unmap;
1936 		}
1937 		rc = pfn_reader_update_pinned(&pfns);
1938 		if (rc)
1939 			goto out_unmap;
1940 
1941 		pfn_reader_destroy(&pfns);
1942 	}
1943 
1944 	area->storage_domain = xa_load(&area->iopt->domains, 0);
1945 	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1946 	mutex_unlock(&pages->mutex);
1947 	return 0;
1948 
1949 out_unmap:
1950 	pfn_reader_release_pins(&pfns);
1951 	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1952 		unsigned long end_index;
1953 
1954 		if (unmap_index < index)
1955 			end_index = done_first_end_index;
1956 		else
1957 			end_index = done_all_end_index;
1958 
1959 		/*
1960 		 * The area is not yet part of the domains_itree so we have to
1961 		 * manage the unpinning specially. The last domain does the
1962 		 * unpin, every other domain is just unmapped.
1963 		 */
1964 		if (unmap_index != area->iopt->next_domain_id - 1) {
1965 			if (end_index != iopt_area_index(area))
1966 				iopt_area_unmap_domain_range(
1967 					area, domain, iopt_area_index(area),
1968 					end_index - 1);
1969 		} else {
1970 			iopt_area_unfill_partial_domain(area, pages, domain,
1971 							end_index);
1972 		}
1973 	}
1974 	pfn_reader_destroy(&pfns);
1975 out_untrack:
1976 	if (iopt_is_dmabuf(pages))
1977 		iopt_dmabuf_untrack_all_domains(area, pages);
1978 out_unlock:
1979 	mutex_unlock(&pages->mutex);
1980 	return rc;
1981 }
1982 
1983 /**
1984  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1985  * @area: The area to act on
1986  * @pages: The pages associated with the area (area->pages is NULL)
1987  *
1988  * Called during area destruction. This unmaps the iova's covered by all the
1989  * area's domains and releases the PFNs.
1990  */
iopt_area_unfill_domains(struct iopt_area * area,struct iopt_pages * pages)1991 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1992 {
1993 	struct io_pagetable *iopt = area->iopt;
1994 	struct iommu_domain *domain;
1995 	unsigned long index;
1996 
1997 	lockdep_assert_held(&iopt->domains_rwsem);
1998 
1999 	mutex_lock(&pages->mutex);
2000 	if (!area->storage_domain)
2001 		goto out_unlock;
2002 
2003 	xa_for_each(&iopt->domains, index, domain) {
2004 		if (domain == area->storage_domain)
2005 			continue;
2006 
2007 		if (!iopt_dmabuf_revoked(pages))
2008 			iopt_area_unmap_domain_range(
2009 				area, domain, iopt_area_index(area),
2010 				iopt_area_last_index(area));
2011 	}
2012 
2013 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
2014 		WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
2015 	interval_tree_remove(&area->pages_node, &pages->domains_itree);
2016 	iopt_area_unfill_domain(area, pages, area->storage_domain);
2017 	if (iopt_is_dmabuf(pages))
2018 		iopt_dmabuf_untrack_all_domains(area, pages);
2019 	area->storage_domain = NULL;
2020 out_unlock:
2021 	mutex_unlock(&pages->mutex);
2022 }
2023 
iopt_pages_unpin_xarray(struct pfn_batch * batch,struct iopt_pages * pages,unsigned long start_index,unsigned long end_index)2024 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
2025 				    struct iopt_pages *pages,
2026 				    unsigned long start_index,
2027 				    unsigned long end_index)
2028 {
2029 	while (start_index <= end_index) {
2030 		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
2031 					end_index);
2032 		batch_unpin(batch, pages, 0, batch->total_pfns);
2033 		start_index += batch->total_pfns;
2034 		batch_clear(batch);
2035 	}
2036 }
2037 
2038 /**
2039  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
2040  * @pages: The pages to act on
2041  * @start_index: Starting PFN index
2042  * @last_index: Last PFN index
2043  *
2044  * Called when an iopt_pages_access is removed, removes pages from the itree.
2045  * The access should already be removed from the access_itree.
2046  */
iopt_pages_unfill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)2047 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
2048 			      unsigned long start_index,
2049 			      unsigned long last_index)
2050 {
2051 	struct interval_tree_double_span_iter span;
2052 	u64 backup[BATCH_BACKUP_SIZE];
2053 	struct pfn_batch batch;
2054 	bool batch_inited = false;
2055 
2056 	lockdep_assert_held(&pages->mutex);
2057 
2058 	interval_tree_for_each_double_span(&span, &pages->access_itree,
2059 					   &pages->domains_itree, start_index,
2060 					   last_index) {
2061 		if (!span.is_used) {
2062 			if (!batch_inited) {
2063 				batch_init_backup(&batch,
2064 						  last_index - start_index + 1,
2065 						  backup, sizeof(backup));
2066 				batch_inited = true;
2067 			}
2068 			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
2069 						span.last_hole);
2070 		} else if (span.is_used == 2) {
2071 			/* Covered by a domain */
2072 			clear_xarray(&pages->pinned_pfns, span.start_used,
2073 				     span.last_used);
2074 		}
2075 		/* Otherwise covered by an existing access */
2076 	}
2077 	if (batch_inited)
2078 		batch_destroy(&batch, backup);
2079 	update_unpinned(pages);
2080 }
2081 
2082 /**
2083  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
2084  * @pages: The pages to act on
2085  * @start_index: The first page index in the range
2086  * @last_index: The last page index in the range
2087  * @out_pages: The output array to return the pages
2088  *
2089  * This can be called if the caller is holding a refcount on an
2090  * iopt_pages_access that is known to have already been filled. It quickly reads
2091  * the pages directly from the xarray.
2092  *
2093  * This is part of the SW iommu interface to read pages for in-kernel use.
2094  */
iopt_pages_fill_from_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)2095 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
2096 				 unsigned long start_index,
2097 				 unsigned long last_index,
2098 				 struct page **out_pages)
2099 {
2100 	XA_STATE(xas, &pages->pinned_pfns, start_index);
2101 	void *entry;
2102 
2103 	rcu_read_lock();
2104 	while (start_index <= last_index) {
2105 		entry = xas_next(&xas);
2106 		if (xas_retry(&xas, entry))
2107 			continue;
2108 		WARN_ON(!xa_is_value(entry));
2109 		*(out_pages++) = pfn_to_page(xa_to_value(entry));
2110 		start_index++;
2111 	}
2112 	rcu_read_unlock();
2113 }
2114 
iopt_pages_fill_from_domain(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)2115 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
2116 				       unsigned long start_index,
2117 				       unsigned long last_index,
2118 				       struct page **out_pages)
2119 {
2120 	while (start_index != last_index + 1) {
2121 		unsigned long domain_last;
2122 		struct iopt_area *area;
2123 
2124 		area = iopt_pages_find_domain_area(pages, start_index);
2125 		if (WARN_ON(!area))
2126 			return -EINVAL;
2127 
2128 		domain_last = min(iopt_area_last_index(area), last_index);
2129 		out_pages = raw_pages_from_domain(area->storage_domain, area,
2130 						  start_index, domain_last,
2131 						  out_pages);
2132 		start_index = domain_last + 1;
2133 	}
2134 	return 0;
2135 }
2136 
iopt_pages_fill(struct iopt_pages * pages,struct pfn_reader_user * user,unsigned long start_index,unsigned long last_index,struct page ** out_pages)2137 static int iopt_pages_fill(struct iopt_pages *pages,
2138 			   struct pfn_reader_user *user,
2139 			   unsigned long start_index,
2140 			   unsigned long last_index,
2141 			   struct page **out_pages)
2142 {
2143 	unsigned long cur_index = start_index;
2144 	int rc;
2145 
2146 	while (cur_index != last_index + 1) {
2147 		user->upages = out_pages + (cur_index - start_index);
2148 		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
2149 		if (rc)
2150 			goto out_unpin;
2151 		cur_index = user->upages_end;
2152 	}
2153 	return 0;
2154 
2155 out_unpin:
2156 	if (start_index != cur_index)
2157 		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
2158 				     out_pages);
2159 	return rc;
2160 }
2161 
2162 /**
2163  * iopt_pages_fill_xarray() - Read PFNs
2164  * @pages: The pages to act on
2165  * @start_index: The first page index in the range
2166  * @last_index: The last page index in the range
2167  * @out_pages: The output array to return the pages, may be NULL
2168  *
2169  * This populates the xarray and returns the pages in out_pages. As the slow
2170  * path this is able to copy pages from other storage tiers into the xarray.
2171  *
2172  * On failure the xarray is left unchanged.
2173  *
2174  * This is part of the SW iommu interface to read pages for in-kernel use.
2175  */
iopt_pages_fill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)2176 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
2177 			   unsigned long last_index, struct page **out_pages)
2178 {
2179 	struct interval_tree_double_span_iter span;
2180 	unsigned long xa_end = start_index;
2181 	struct pfn_reader_user user;
2182 	int rc;
2183 
2184 	lockdep_assert_held(&pages->mutex);
2185 
2186 	pfn_reader_user_init(&user, pages);
2187 	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
2188 	interval_tree_for_each_double_span(&span, &pages->access_itree,
2189 					   &pages->domains_itree, start_index,
2190 					   last_index) {
2191 		struct page **cur_pages;
2192 
2193 		if (span.is_used == 1) {
2194 			cur_pages = out_pages + (span.start_used - start_index);
2195 			iopt_pages_fill_from_xarray(pages, span.start_used,
2196 						    span.last_used, cur_pages);
2197 			continue;
2198 		}
2199 
2200 		if (span.is_used == 2) {
2201 			cur_pages = out_pages + (span.start_used - start_index);
2202 			iopt_pages_fill_from_domain(pages, span.start_used,
2203 						    span.last_used, cur_pages);
2204 			rc = pages_to_xarray(&pages->pinned_pfns,
2205 					     span.start_used, span.last_used,
2206 					     cur_pages);
2207 			if (rc)
2208 				goto out_clean_xa;
2209 			xa_end = span.last_used + 1;
2210 			continue;
2211 		}
2212 
2213 		/* hole */
2214 		cur_pages = out_pages + (span.start_hole - start_index);
2215 		rc = iopt_pages_fill(pages, &user, span.start_hole,
2216 				     span.last_hole, cur_pages);
2217 		if (rc)
2218 			goto out_clean_xa;
2219 		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
2220 				     span.last_hole, cur_pages);
2221 		if (rc) {
2222 			iopt_pages_err_unpin(pages, span.start_hole,
2223 					     span.last_hole, cur_pages);
2224 			goto out_clean_xa;
2225 		}
2226 		xa_end = span.last_hole + 1;
2227 	}
2228 	rc = pfn_reader_user_update_pinned(&user, pages);
2229 	if (rc)
2230 		goto out_clean_xa;
2231 	user.upages = NULL;
2232 	pfn_reader_user_destroy(&user, pages);
2233 	return 0;
2234 
2235 out_clean_xa:
2236 	if (start_index != xa_end)
2237 		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
2238 	user.upages = NULL;
2239 	pfn_reader_user_destroy(&user, pages);
2240 	return rc;
2241 }
2242 
2243 /*
2244  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
2245  * do every scenario and is fully consistent with what an iommu_domain would
2246  * see.
2247  */
iopt_pages_rw_slow(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,unsigned long offset,void * data,unsigned long length,unsigned int flags)2248 static int iopt_pages_rw_slow(struct iopt_pages *pages,
2249 			      unsigned long start_index,
2250 			      unsigned long last_index, unsigned long offset,
2251 			      void *data, unsigned long length,
2252 			      unsigned int flags)
2253 {
2254 	struct pfn_reader pfns;
2255 	int rc;
2256 
2257 	mutex_lock(&pages->mutex);
2258 
2259 	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
2260 	if (rc)
2261 		goto out_unlock;
2262 
2263 	while (!pfn_reader_done(&pfns)) {
2264 		unsigned long done;
2265 
2266 		done = batch_rw(&pfns.batch, data, offset, length, flags);
2267 		data += done;
2268 		length -= done;
2269 		offset = 0;
2270 		pfn_reader_unpin(&pfns);
2271 
2272 		rc = pfn_reader_next(&pfns);
2273 		if (rc)
2274 			goto out_destroy;
2275 	}
2276 	if (WARN_ON(length != 0))
2277 		rc = -EINVAL;
2278 out_destroy:
2279 	pfn_reader_destroy(&pfns);
2280 out_unlock:
2281 	mutex_unlock(&pages->mutex);
2282 	return rc;
2283 }
2284 
2285 /*
2286  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
2287  * memory allocations or interval tree searches.
2288  */
iopt_pages_rw_page(struct iopt_pages * pages,unsigned long index,unsigned long offset,void * data,unsigned long length,unsigned int flags)2289 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
2290 			      unsigned long offset, void *data,
2291 			      unsigned long length, unsigned int flags)
2292 {
2293 	struct page *page = NULL;
2294 	int rc;
2295 
2296 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2297 	    WARN_ON(pages->type != IOPT_ADDRESS_USER))
2298 		return -EINVAL;
2299 
2300 	if (!mmget_not_zero(pages->source_mm))
2301 		return iopt_pages_rw_slow(pages, index, index, offset, data,
2302 					  length, flags);
2303 
2304 	if (iommufd_should_fail()) {
2305 		rc = -EINVAL;
2306 		goto out_mmput;
2307 	}
2308 
2309 	mmap_read_lock(pages->source_mm);
2310 	rc = pin_user_pages_remote(
2311 		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
2312 		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
2313 		NULL);
2314 	mmap_read_unlock(pages->source_mm);
2315 	if (rc != 1) {
2316 		if (WARN_ON(rc >= 0))
2317 			rc = -EINVAL;
2318 		goto out_mmput;
2319 	}
2320 	copy_data_page(page, data, offset, length, flags);
2321 	unpin_user_page(page);
2322 	rc = 0;
2323 
2324 out_mmput:
2325 	mmput(pages->source_mm);
2326 	return rc;
2327 }
2328 
2329 /**
2330  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
2331  * @pages: pages to act on
2332  * @start_byte: First byte of pages to copy to/from
2333  * @data: Kernel buffer to get/put the data
2334  * @length: Number of bytes to copy
2335  * @flags: IOMMUFD_ACCESS_RW_* flags
2336  *
2337  * This will find each page in the range, kmap it and then memcpy to/from
2338  * the given kernel buffer.
2339  */
iopt_pages_rw_access(struct iopt_pages * pages,unsigned long start_byte,void * data,unsigned long length,unsigned int flags)2340 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
2341 			 void *data, unsigned long length, unsigned int flags)
2342 {
2343 	unsigned long start_index = start_byte / PAGE_SIZE;
2344 	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
2345 	bool change_mm = current->mm != pages->source_mm;
2346 	int rc = 0;
2347 
2348 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2349 	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
2350 		change_mm = true;
2351 
2352 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2353 		return -EPERM;
2354 
2355 	if (iopt_is_dmabuf(pages))
2356 		return -EINVAL;
2357 
2358 	if (pages->type != IOPT_ADDRESS_USER)
2359 		return iopt_pages_rw_slow(pages, start_index, last_index,
2360 					  start_byte % PAGE_SIZE, data, length,
2361 					  flags);
2362 
2363 	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
2364 		if (start_index == last_index)
2365 			return iopt_pages_rw_page(pages, start_index,
2366 						  start_byte % PAGE_SIZE, data,
2367 						  length, flags);
2368 		return iopt_pages_rw_slow(pages, start_index, last_index,
2369 					  start_byte % PAGE_SIZE, data, length,
2370 					  flags);
2371 	}
2372 
2373 	/*
2374 	 * Try to copy using copy_to_user(). We do this as a fast path and
2375 	 * ignore any pinning inconsistencies, unlike a real DMA path.
2376 	 */
2377 	if (change_mm) {
2378 		if (!mmget_not_zero(pages->source_mm))
2379 			return iopt_pages_rw_slow(pages, start_index,
2380 						  last_index,
2381 						  start_byte % PAGE_SIZE, data,
2382 						  length, flags);
2383 		kthread_use_mm(pages->source_mm);
2384 	}
2385 
2386 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
2387 		if (copy_to_user(pages->uptr + start_byte, data, length))
2388 			rc = -EFAULT;
2389 	} else {
2390 		if (copy_from_user(data, pages->uptr + start_byte, length))
2391 			rc = -EFAULT;
2392 	}
2393 
2394 	if (change_mm) {
2395 		kthread_unuse_mm(pages->source_mm);
2396 		mmput(pages->source_mm);
2397 	}
2398 
2399 	return rc;
2400 }
2401 
2402 static struct iopt_pages_access *
iopt_pages_get_exact_access(struct iopt_pages * pages,unsigned long index,unsigned long last)2403 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
2404 			    unsigned long last)
2405 {
2406 	struct interval_tree_node *node;
2407 
2408 	lockdep_assert_held(&pages->mutex);
2409 
2410 	/* There can be overlapping ranges in this interval tree */
2411 	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
2412 	     node; node = interval_tree_iter_next(node, index, last))
2413 		if (node->start == index && node->last == last)
2414 			return container_of(node, struct iopt_pages_access,
2415 					    node);
2416 	return NULL;
2417 }
2418 
2419 /**
2420  * iopt_area_add_access() - Record an in-knerel access for PFNs
2421  * @area: The source of PFNs
2422  * @start_index: First page index
2423  * @last_index: Inclusive last page index
2424  * @out_pages: Output list of struct page's representing the PFNs
2425  * @flags: IOMMUFD_ACCESS_RW_* flags
2426  * @lock_area: Fail userspace munmap on this area
2427  *
2428  * Record that an in-kernel access will be accessing the pages, ensure they are
2429  * pinned, and return the PFNs as a simple list of 'struct page *'.
2430  *
2431  * This should be undone through a matching call to iopt_area_remove_access()
2432  */
iopt_area_add_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages,unsigned int flags,bool lock_area)2433 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
2434 			 unsigned long last_index, struct page **out_pages,
2435 			 unsigned int flags, bool lock_area)
2436 {
2437 	struct iopt_pages *pages = area->pages;
2438 	struct iopt_pages_access *access;
2439 	int rc;
2440 
2441 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2442 		return -EPERM;
2443 
2444 	mutex_lock(&pages->mutex);
2445 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2446 	if (access) {
2447 		area->num_accesses++;
2448 		if (lock_area)
2449 			area->num_locks++;
2450 		access->users++;
2451 		iopt_pages_fill_from_xarray(pages, start_index, last_index,
2452 					    out_pages);
2453 		mutex_unlock(&pages->mutex);
2454 		return 0;
2455 	}
2456 
2457 	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
2458 	if (!access) {
2459 		rc = -ENOMEM;
2460 		goto err_unlock;
2461 	}
2462 
2463 	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
2464 	if (rc)
2465 		goto err_free;
2466 
2467 	access->node.start = start_index;
2468 	access->node.last = last_index;
2469 	access->users = 1;
2470 	area->num_accesses++;
2471 	if (lock_area)
2472 		area->num_locks++;
2473 	interval_tree_insert(&access->node, &pages->access_itree);
2474 	mutex_unlock(&pages->mutex);
2475 	return 0;
2476 
2477 err_free:
2478 	kfree(access);
2479 err_unlock:
2480 	mutex_unlock(&pages->mutex);
2481 	return rc;
2482 }
2483 
2484 /**
2485  * iopt_area_remove_access() - Release an in-kernel access for PFNs
2486  * @area: The source of PFNs
2487  * @start_index: First page index
2488  * @last_index: Inclusive last page index
2489  * @unlock_area: Must match the matching iopt_area_add_access()'s lock_area
2490  *
2491  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
2492  * must stop using the PFNs before calling this.
2493  */
iopt_area_remove_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index,bool unlock_area)2494 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
2495 			     unsigned long last_index, bool unlock_area)
2496 {
2497 	struct iopt_pages *pages = area->pages;
2498 	struct iopt_pages_access *access;
2499 
2500 	mutex_lock(&pages->mutex);
2501 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
2502 	if (WARN_ON(!access))
2503 		goto out_unlock;
2504 
2505 	WARN_ON(area->num_accesses == 0 || access->users == 0);
2506 	if (unlock_area) {
2507 		WARN_ON(area->num_locks == 0);
2508 		area->num_locks--;
2509 	}
2510 	area->num_accesses--;
2511 	access->users--;
2512 	if (access->users)
2513 		goto out_unlock;
2514 
2515 	interval_tree_remove(&access->node, &pages->access_itree);
2516 	iopt_pages_unfill_xarray(pages, start_index, last_index);
2517 	kfree(access);
2518 out_unlock:
2519 	mutex_unlock(&pages->mutex);
2520 }
2521