xref: /linux/drivers/iommu/generic_pt/iommu_pt.h (revision a5210135489ae7bc1ef1cb4a8157361dd7b468cd)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
4  *
5  * "Templated C code" for implementing the iommu operations for page tables.
6  * This is compiled multiple times, over all the page table formats to pick up
7  * the per-format definitions.
8  */
9 #ifndef __GENERIC_PT_IOMMU_PT_H
10 #define __GENERIC_PT_IOMMU_PT_H
11 
12 #include "pt_iter.h"
13 
14 #include <linux/export.h>
15 #include <linux/iommu.h>
16 #include "../iommu-pages.h"
17 #include <linux/cleanup.h>
18 #include <linux/dma-mapping.h>
19 
20 enum {
21 	SW_BIT_CACHE_FLUSH_DONE = 0,
22 };
23 
flush_writes_range(const struct pt_state * pts,unsigned int start_index,unsigned int end_index)24 static void flush_writes_range(const struct pt_state *pts,
25 			       unsigned int start_index, unsigned int end_index)
26 {
27 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT))
28 		iommu_pages_flush_incoherent(
29 			iommu_from_common(pts->range->common)->iommu_device,
30 			pts->table, start_index * PT_ITEM_WORD_SIZE,
31 			(end_index - start_index) * PT_ITEM_WORD_SIZE);
32 }
33 
flush_writes_item(const struct pt_state * pts)34 static void flush_writes_item(const struct pt_state *pts)
35 {
36 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT))
37 		iommu_pages_flush_incoherent(
38 			iommu_from_common(pts->range->common)->iommu_device,
39 			pts->table, pts->index * PT_ITEM_WORD_SIZE,
40 			PT_ITEM_WORD_SIZE);
41 }
42 
gather_range_pages(struct iommu_iotlb_gather * iotlb_gather,struct pt_iommu * iommu_table,pt_vaddr_t iova,pt_vaddr_t len,struct iommu_pages_list * free_list)43 static void gather_range_pages(struct iommu_iotlb_gather *iotlb_gather,
44 			       struct pt_iommu *iommu_table, pt_vaddr_t iova,
45 			       pt_vaddr_t len,
46 			       struct iommu_pages_list *free_list)
47 {
48 	struct pt_common *common = common_from_iommu(iommu_table);
49 
50 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
51 		iommu_pages_stop_incoherent_list(free_list,
52 						 iommu_table->iommu_device);
53 
54 	/*
55 	 * If running in DMA-FQ mode then the unmap will be followed by an IOTLB
56 	 * flush all so we need to optimize by never flushing the IOTLB here.
57 	 *
58 	 * For NO_GAPS the user gets to pick if flushing all or doing micro
59 	 * flushes is better for their work load by choosing DMA vs DMA-FQ
60 	 * operation. Drivers should also see shadow_on_flush.
61 	 */
62 	if (!iommu_iotlb_gather_queued(iotlb_gather)) {
63 		if (pt_feature(common, PT_FEAT_FLUSH_RANGE_NO_GAPS) &&
64 		    iommu_iotlb_gather_is_disjoint(iotlb_gather, iova, len)) {
65 			iommu_iotlb_sync(&iommu_table->domain, iotlb_gather);
66 			/*
67 			 * Note that the sync frees the gather's free list, so
68 			 * we must not have any pages on that list that are
69 			 * covered by iova/len
70 			 */
71 		}
72 		iommu_iotlb_gather_add_range(iotlb_gather, iova, len);
73 	}
74 
75 	iommu_pages_list_splice(free_list, &iotlb_gather->freelist);
76 }
77 
78 #define DOMAIN_NS(op) CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), op)
79 
make_range_ul(struct pt_common * common,struct pt_range * range,unsigned long iova,unsigned long len)80 static int make_range_ul(struct pt_common *common, struct pt_range *range,
81 			 unsigned long iova, unsigned long len)
82 {
83 	unsigned long last;
84 
85 	if (unlikely(len == 0))
86 		return -EINVAL;
87 
88 	if (check_add_overflow(iova, len - 1, &last))
89 		return -EOVERFLOW;
90 
91 	*range = pt_make_range(common, iova, last);
92 	if (sizeof(iova) > sizeof(range->va)) {
93 		if (unlikely(range->va != iova || range->last_va != last))
94 			return -EOVERFLOW;
95 	}
96 	return 0;
97 }
98 
make_range_u64(struct pt_common * common,struct pt_range * range,u64 iova,u64 len)99 static __maybe_unused int make_range_u64(struct pt_common *common,
100 					 struct pt_range *range, u64 iova,
101 					 u64 len)
102 {
103 	if (unlikely(iova > ULONG_MAX || len > ULONG_MAX))
104 		return -EOVERFLOW;
105 	return make_range_ul(common, range, iova, len);
106 }
107 
108 /*
109  * Some APIs use unsigned long, while othersuse dma_addr_t as the type. Dispatch
110  * to the correct validation based on the type.
111  */
112 #define make_range_no_check(common, range, iova, len)                   \
113 	({                                                              \
114 		int ret;                                                \
115 		if (sizeof(iova) > sizeof(unsigned long) ||             \
116 		    sizeof(len) > sizeof(unsigned long))                \
117 			ret = make_range_u64(common, range, iova, len); \
118 		else                                                    \
119 			ret = make_range_ul(common, range, iova, len);  \
120 		ret;                                                    \
121 	})
122 
123 #define make_range(common, range, iova, len)                             \
124 	({                                                               \
125 		int ret = make_range_no_check(common, range, iova, len); \
126 		if (!ret)                                                \
127 			ret = pt_check_range(range);                     \
128 		ret;                                                     \
129 	})
130 
compute_best_pgsize(struct pt_state * pts,pt_oaddr_t oa)131 static inline unsigned int compute_best_pgsize(struct pt_state *pts,
132 					       pt_oaddr_t oa)
133 {
134 	struct pt_iommu *iommu_table = iommu_from_common(pts->range->common);
135 
136 	if (!pt_can_have_leaf(pts))
137 		return 0;
138 
139 	/*
140 	 * The page size is limited by the domain's bitmap. This allows the core
141 	 * code to reduce the supported page sizes by changing the bitmap.
142 	 */
143 	return pt_compute_best_pgsize(pt_possible_sizes(pts) &
144 					      iommu_table->domain.pgsize_bitmap,
145 				      pts->range->va, pts->range->last_va, oa);
146 }
147 
__do_iova_to_phys(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table,pt_level_fn_t descend_fn)148 static __always_inline int __do_iova_to_phys(struct pt_range *range, void *arg,
149 					     unsigned int level,
150 					     struct pt_table_p *table,
151 					     pt_level_fn_t descend_fn)
152 {
153 	struct pt_state pts = pt_init(range, level, table);
154 	pt_oaddr_t *res = arg;
155 
156 	switch (pt_load_single_entry(&pts)) {
157 	case PT_ENTRY_EMPTY:
158 		return -ENOENT;
159 	case PT_ENTRY_TABLE:
160 		return pt_descend(&pts, arg, descend_fn);
161 	case PT_ENTRY_OA:
162 		*res = pt_entry_oa_exact(&pts);
163 		return 0;
164 	}
165 	return -ENOENT;
166 }
167 PT_MAKE_LEVELS(__iova_to_phys, __do_iova_to_phys);
168 
169 /**
170  * iova_to_phys() - Return the output address for the given IOVA
171  * @domain: Table to query
172  * @iova: IO virtual address to query
173  *
174  * Determine the output address from the given IOVA. @iova may have any
175  * alignment, the returned physical will be adjusted with any sub page offset.
176  *
177  * Context: The caller must hold a read range lock that includes @iova.
178  *
179  * Return: 0 if there is no translation for the given iova.
180  */
DOMAIN_NS(iova_to_phys)181 phys_addr_t DOMAIN_NS(iova_to_phys)(struct iommu_domain *domain,
182 				    dma_addr_t iova)
183 {
184 	struct pt_iommu *iommu_table =
185 		container_of(domain, struct pt_iommu, domain);
186 	struct pt_range range;
187 	pt_oaddr_t res;
188 	int ret;
189 
190 	ret = make_range(common_from_iommu(iommu_table), &range, iova, 1);
191 	if (ret)
192 		return ret;
193 
194 	ret = pt_walk_range(&range, __iova_to_phys, &res);
195 	/* PHYS_ADDR_MAX would be a better error code */
196 	if (ret)
197 		return 0;
198 	return res;
199 }
200 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(iova_to_phys), "GENERIC_PT_IOMMU");
201 
202 struct pt_iommu_dirty_args {
203 	struct iommu_dirty_bitmap *dirty;
204 	unsigned int flags;
205 };
206 
record_dirty(struct pt_state * pts,struct pt_iommu_dirty_args * dirty,unsigned int num_contig_lg2)207 static void record_dirty(struct pt_state *pts,
208 			 struct pt_iommu_dirty_args *dirty,
209 			 unsigned int num_contig_lg2)
210 {
211 	pt_vaddr_t dirty_len;
212 
213 	if (num_contig_lg2 != ilog2(1)) {
214 		unsigned int index = pts->index;
215 		unsigned int end_index = log2_set_mod_max_t(
216 			unsigned int, pts->index, num_contig_lg2);
217 
218 		/* Adjust for being contained inside a contiguous page */
219 		end_index = min(end_index, pts->end_index);
220 		dirty_len = (end_index - index) *
221 				log2_to_int(pt_table_item_lg2sz(pts));
222 	} else {
223 		dirty_len = log2_to_int(pt_table_item_lg2sz(pts));
224 	}
225 
226 	if (dirty->dirty->bitmap)
227 		iova_bitmap_set(dirty->dirty->bitmap, pts->range->va,
228 				dirty_len);
229 
230 	if (!(dirty->flags & IOMMU_DIRTY_NO_CLEAR)) {
231 		/*
232 		 * No write log required because DMA incoherence and atomic
233 		 * dirty tracking bits can't work together
234 		 */
235 		pt_entry_make_write_clean(pts);
236 		iommu_iotlb_gather_add_range(dirty->dirty->gather,
237 					     pts->range->va, dirty_len);
238 	}
239 }
240 
__read_and_clear_dirty(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)241 static inline int __read_and_clear_dirty(struct pt_range *range, void *arg,
242 					 unsigned int level,
243 					 struct pt_table_p *table)
244 {
245 	struct pt_state pts = pt_init(range, level, table);
246 	struct pt_iommu_dirty_args *dirty = arg;
247 	int ret;
248 
249 	for_each_pt_level_entry(&pts) {
250 		if (pts.type == PT_ENTRY_TABLE) {
251 			ret = pt_descend(&pts, arg, __read_and_clear_dirty);
252 			if (ret)
253 				return ret;
254 			continue;
255 		}
256 		if (pts.type == PT_ENTRY_OA && pt_entry_is_write_dirty(&pts))
257 			record_dirty(&pts, dirty,
258 				     pt_entry_num_contig_lg2(&pts));
259 	}
260 	return 0;
261 }
262 
263 /**
264  * read_and_clear_dirty() - Manipulate the HW set write dirty state
265  * @domain: Domain to manipulate
266  * @iova: IO virtual address to start
267  * @size: Length of the IOVA
268  * @flags: A bitmap of IOMMU_DIRTY_NO_CLEAR
269  * @dirty: Place to store the dirty bits
270  *
271  * Iterate over all the entries in the mapped range and record their write dirty
272  * status in iommu_dirty_bitmap. If IOMMU_DIRTY_NO_CLEAR is not specified then
273  * the entries will be left dirty, otherwise they are returned to being not
274  * write dirty.
275  *
276  * Context: The caller must hold a read range lock that includes @iova.
277  *
278  * Returns: -ERRNO on failure, 0 on success.
279  */
DOMAIN_NS(read_and_clear_dirty)280 int DOMAIN_NS(read_and_clear_dirty)(struct iommu_domain *domain,
281 				    unsigned long iova, size_t size,
282 				    unsigned long flags,
283 				    struct iommu_dirty_bitmap *dirty)
284 {
285 	struct pt_iommu *iommu_table =
286 		container_of(domain, struct pt_iommu, domain);
287 	struct pt_iommu_dirty_args dirty_args = {
288 		.dirty = dirty,
289 		.flags = flags,
290 	};
291 	struct pt_range range;
292 	int ret;
293 
294 #if !IS_ENABLED(CONFIG_IOMMUFD_DRIVER) || !defined(pt_entry_is_write_dirty)
295 	return -EOPNOTSUPP;
296 #endif
297 
298 	ret = make_range(common_from_iommu(iommu_table), &range, iova, size);
299 	if (ret)
300 		return ret;
301 
302 	ret = pt_walk_range(&range, __read_and_clear_dirty, &dirty_args);
303 	PT_WARN_ON(ret);
304 	return ret;
305 }
306 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(read_and_clear_dirty), "GENERIC_PT_IOMMU");
307 
__set_dirty(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)308 static inline int __set_dirty(struct pt_range *range, void *arg,
309 			      unsigned int level, struct pt_table_p *table)
310 {
311 	struct pt_state pts = pt_init(range, level, table);
312 
313 	switch (pt_load_single_entry(&pts)) {
314 	case PT_ENTRY_EMPTY:
315 		return -ENOENT;
316 	case PT_ENTRY_TABLE:
317 		return pt_descend(&pts, arg, __set_dirty);
318 	case PT_ENTRY_OA:
319 		if (!pt_entry_make_write_dirty(&pts))
320 			return -EAGAIN;
321 		return 0;
322 	}
323 	return -ENOENT;
324 }
325 
NS(set_dirty)326 static int __maybe_unused NS(set_dirty)(struct pt_iommu *iommu_table,
327 					dma_addr_t iova)
328 {
329 	struct pt_range range;
330 	int ret;
331 
332 	ret = make_range(common_from_iommu(iommu_table), &range, iova, 1);
333 	if (ret)
334 		return ret;
335 
336 	/*
337 	 * Note: There is no locking here yet, if the test suite races this it
338 	 * can crash. It should use RCU locking eventually.
339 	 */
340 	return pt_walk_range(&range, __set_dirty, NULL);
341 }
342 
343 struct pt_iommu_collect_args {
344 	struct iommu_pages_list free_list;
345 	/* Fail if any OAs are within the range */
346 	u8 check_mapped : 1;
347 };
348 
__collect_tables(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)349 static int __collect_tables(struct pt_range *range, void *arg,
350 			    unsigned int level, struct pt_table_p *table)
351 {
352 	struct pt_state pts = pt_init(range, level, table);
353 	struct pt_iommu_collect_args *collect = arg;
354 	int ret;
355 
356 	if (!collect->check_mapped && !pt_can_have_table(&pts))
357 		return 0;
358 
359 	for_each_pt_level_entry(&pts) {
360 		if (pts.type == PT_ENTRY_TABLE) {
361 			iommu_pages_list_add(&collect->free_list, pts.table_lower);
362 			ret = pt_descend(&pts, arg, __collect_tables);
363 			if (ret)
364 				return ret;
365 			continue;
366 		}
367 		if (pts.type == PT_ENTRY_OA && collect->check_mapped)
368 			return -EADDRINUSE;
369 	}
370 	return 0;
371 }
372 
373 enum alloc_mode {ALLOC_NORMAL, ALLOC_DEFER_COHERENT_FLUSH};
374 
375 /* Allocate a table, the empty table will be ready to be installed. */
_table_alloc(struct pt_common * common,size_t lg2sz,gfp_t gfp,enum alloc_mode mode)376 static inline struct pt_table_p *_table_alloc(struct pt_common *common,
377 					      size_t lg2sz, gfp_t gfp,
378 					      enum alloc_mode mode)
379 {
380 	struct pt_iommu *iommu_table = iommu_from_common(common);
381 	struct pt_table_p *table_mem;
382 
383 	table_mem = iommu_alloc_pages_node_sz(iommu_table->nid, gfp,
384 					      log2_to_int(lg2sz));
385 	if (!table_mem)
386 		return ERR_PTR(-ENOMEM);
387 
388 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) &&
389 	    mode == ALLOC_NORMAL) {
390 		int ret = iommu_pages_start_incoherent(
391 			table_mem, iommu_table->iommu_device);
392 		if (ret) {
393 			iommu_free_pages(table_mem);
394 			return ERR_PTR(ret);
395 		}
396 	}
397 	return table_mem;
398 }
399 
table_alloc_top(struct pt_common * common,uintptr_t top_of_table,gfp_t gfp,enum alloc_mode mode)400 static inline struct pt_table_p *table_alloc_top(struct pt_common *common,
401 						 uintptr_t top_of_table,
402 						 gfp_t gfp,
403 						 enum alloc_mode mode)
404 {
405 	/*
406 	 * Top doesn't need the free list or otherwise, so it technically
407 	 * doesn't need to use iommu pages. Use the API anyhow as the top is
408 	 * usually not smaller than PAGE_SIZE to keep things simple.
409 	 */
410 	return _table_alloc(common, pt_top_memsize_lg2(common, top_of_table),
411 			    gfp, mode);
412 }
413 
414 /* Allocate an interior table */
table_alloc(const struct pt_state * parent_pts,gfp_t gfp,enum alloc_mode mode)415 static inline struct pt_table_p *table_alloc(const struct pt_state *parent_pts,
416 					     gfp_t gfp, enum alloc_mode mode)
417 {
418 	struct pt_state child_pts =
419 		pt_init(parent_pts->range, parent_pts->level - 1, NULL);
420 
421 	return _table_alloc(parent_pts->range->common,
422 			    pt_num_items_lg2(&child_pts) +
423 				    ilog2(PT_ITEM_WORD_SIZE),
424 			    gfp, mode);
425 }
426 
pt_iommu_new_table(struct pt_state * pts,struct pt_write_attrs * attrs)427 static inline int pt_iommu_new_table(struct pt_state *pts,
428 				     struct pt_write_attrs *attrs)
429 {
430 	struct pt_table_p *table_mem;
431 	phys_addr_t phys;
432 
433 	/* Given PA/VA/length can't be represented */
434 	if (PT_WARN_ON(!pt_can_have_table(pts)))
435 		return -ENXIO;
436 
437 	table_mem = table_alloc(pts, attrs->gfp, ALLOC_NORMAL);
438 	if (IS_ERR(table_mem))
439 		return PTR_ERR(table_mem);
440 
441 	phys = virt_to_phys(table_mem);
442 	if (!pt_install_table(pts, phys, attrs)) {
443 		iommu_pages_free_incoherent(
444 			table_mem,
445 			iommu_from_common(pts->range->common)->iommu_device);
446 		return -EAGAIN;
447 	}
448 
449 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT)) {
450 		flush_writes_item(pts);
451 		pt_set_sw_bit_release(pts, SW_BIT_CACHE_FLUSH_DONE);
452 	}
453 
454 	if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) {
455 		/*
456 		 * The underlying table can't store the physical table address.
457 		 * This happens when kunit testing tables outside their normal
458 		 * environment where a CPU might be limited.
459 		 */
460 		pt_load_single_entry(pts);
461 		if (PT_WARN_ON(pt_table_pa(pts) != phys)) {
462 			pt_clear_entries(pts, ilog2(1));
463 			iommu_pages_free_incoherent(
464 				table_mem, iommu_from_common(pts->range->common)
465 						   ->iommu_device);
466 			return -EINVAL;
467 		}
468 	}
469 
470 	pts->table_lower = table_mem;
471 	return 0;
472 }
473 
474 struct pt_iommu_map_args {
475 	struct iommu_iotlb_gather *iotlb_gather;
476 	struct pt_write_attrs attrs;
477 	pt_oaddr_t oa;
478 	unsigned int leaf_pgsize_lg2;
479 	unsigned int leaf_level;
480 	pt_vaddr_t num_leaves;
481 };
482 
483 /*
484  * This will recursively check any tables in the block to validate they are
485  * empty and then free them through the gather.
486  */
clear_contig(const struct pt_state * start_pts,struct iommu_iotlb_gather * iotlb_gather,unsigned int step,unsigned int pgsize_lg2)487 static int clear_contig(const struct pt_state *start_pts,
488 			struct iommu_iotlb_gather *iotlb_gather,
489 			unsigned int step, unsigned int pgsize_lg2)
490 {
491 	struct pt_iommu *iommu_table =
492 		iommu_from_common(start_pts->range->common);
493 	struct pt_range range = *start_pts->range;
494 	struct pt_state pts =
495 		pt_init(&range, start_pts->level, start_pts->table);
496 	struct pt_iommu_collect_args collect = { .check_mapped = true };
497 	int ret;
498 
499 	pts.index = start_pts->index;
500 	pts.end_index = start_pts->index + step;
501 	for (; _pt_iter_load(&pts); pt_next_entry(&pts)) {
502 		if (pts.type == PT_ENTRY_TABLE) {
503 			collect.free_list =
504 				IOMMU_PAGES_LIST_INIT(collect.free_list);
505 			ret = pt_walk_descend_all(&pts, __collect_tables,
506 						  &collect);
507 			if (ret)
508 				return ret;
509 
510 			/*
511 			 * The table item must be cleared before we can update
512 			 * the gather
513 			 */
514 			pt_clear_entries(&pts, ilog2(1));
515 			flush_writes_item(&pts);
516 
517 			iommu_pages_list_add(&collect.free_list,
518 					     pt_table_ptr(&pts));
519 			gather_range_pages(
520 				iotlb_gather, iommu_table, range.va,
521 				log2_to_int(pt_table_item_lg2sz(&pts)),
522 				&collect.free_list);
523 		} else if (pts.type != PT_ENTRY_EMPTY) {
524 			return -EADDRINUSE;
525 		}
526 	}
527 	return 0;
528 }
529 
__map_range_leaf(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)530 static int __map_range_leaf(struct pt_range *range, void *arg,
531 			    unsigned int level, struct pt_table_p *table)
532 {
533 	struct pt_iommu *iommu_table = iommu_from_common(range->common);
534 	struct pt_state pts = pt_init(range, level, table);
535 	struct pt_iommu_map_args *map = arg;
536 	unsigned int leaf_pgsize_lg2 = map->leaf_pgsize_lg2;
537 	unsigned int start_index;
538 	pt_oaddr_t oa = map->oa;
539 	unsigned int num_leaves;
540 	unsigned int orig_end;
541 	pt_vaddr_t last_va;
542 	unsigned int step;
543 	bool need_contig;
544 	int ret = 0;
545 
546 	PT_WARN_ON(map->leaf_level != level);
547 	PT_WARN_ON(!pt_can_have_leaf(&pts));
548 
549 	step = log2_to_int_t(unsigned int,
550 			     leaf_pgsize_lg2 - pt_table_item_lg2sz(&pts));
551 	need_contig = leaf_pgsize_lg2 != pt_table_item_lg2sz(&pts);
552 
553 	_pt_iter_first(&pts);
554 	start_index = pts.index;
555 	orig_end = pts.end_index;
556 	if (pts.index + map->num_leaves < pts.end_index) {
557 		/* Need to stop in the middle of the table to change sizes */
558 		pts.end_index = pts.index + map->num_leaves;
559 		num_leaves = 0;
560 	} else {
561 		num_leaves = map->num_leaves - (pts.end_index - pts.index);
562 	}
563 
564 	do {
565 		pts.type = pt_load_entry_raw(&pts);
566 		if (pts.type != PT_ENTRY_EMPTY || need_contig) {
567 			if (pts.index != start_index)
568 				pt_index_to_va(&pts);
569 			ret = clear_contig(&pts, map->iotlb_gather, step,
570 					   leaf_pgsize_lg2);
571 			if (ret)
572 				break;
573 		}
574 
575 		if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) {
576 			pt_index_to_va(&pts);
577 			PT_WARN_ON(compute_best_pgsize(&pts, oa) !=
578 				   leaf_pgsize_lg2);
579 		}
580 		pt_install_leaf_entry(&pts, oa, leaf_pgsize_lg2, &map->attrs);
581 
582 		oa += log2_to_int(leaf_pgsize_lg2);
583 		pts.index += step;
584 	} while (pts.index < pts.end_index);
585 
586 	flush_writes_range(&pts, start_index, pts.index);
587 
588 	map->oa = oa;
589 	map->num_leaves = num_leaves;
590 	if (ret || num_leaves)
591 		return ret;
592 
593 	/* range->va is not valid if we reached the end of the table */
594 	pts.index -= step;
595 	pt_index_to_va(&pts);
596 	pts.index += step;
597 	last_va = range->va + log2_to_int(leaf_pgsize_lg2);
598 
599 	if (last_va - 1 == range->last_va) {
600 		PT_WARN_ON(pts.index != orig_end);
601 		return 0;
602 	}
603 
604 	/*
605 	 * Reached a point where the page size changed, compute the new
606 	 * parameters.
607 	 */
608 	map->leaf_pgsize_lg2 = pt_compute_best_pgsize(
609 		iommu_table->domain.pgsize_bitmap, last_va, range->last_va, oa);
610 	map->leaf_level =
611 		pt_pgsz_lg2_to_level(range->common, map->leaf_pgsize_lg2);
612 	map->num_leaves = pt_pgsz_count(iommu_table->domain.pgsize_bitmap,
613 					last_va, range->last_va, oa,
614 					map->leaf_pgsize_lg2);
615 
616 	/* Didn't finish this table level, caller will repeat it */
617 	if (pts.index != orig_end) {
618 		if (pts.index != start_index)
619 			pt_index_to_va(&pts);
620 		return -EAGAIN;
621 	}
622 	return 0;
623 }
624 
__map_range(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)625 static int __map_range(struct pt_range *range, void *arg, unsigned int level,
626 		       struct pt_table_p *table)
627 {
628 	struct pt_state pts = pt_init(range, level, table);
629 	struct pt_iommu_map_args *map = arg;
630 	int ret;
631 
632 	PT_WARN_ON(map->leaf_level == level);
633 	PT_WARN_ON(!pt_can_have_table(&pts));
634 
635 	_pt_iter_first(&pts);
636 
637 	/* Descend to a child table */
638 	do {
639 		pts.type = pt_load_entry_raw(&pts);
640 
641 		if (pts.type != PT_ENTRY_TABLE) {
642 			if (pts.type != PT_ENTRY_EMPTY)
643 				return -EADDRINUSE;
644 			ret = pt_iommu_new_table(&pts, &map->attrs);
645 			/* EAGAIN on a race will loop again */
646 			if (ret)
647 				return ret;
648 		} else {
649 			pts.table_lower = pt_table_ptr(&pts);
650 			/*
651 			 * Racing with a shared pt_iommu_new_table()? The other
652 			 * thread is still flushing the cache, so we have to
653 			 * also flush it to ensure that when our thread's map
654 			 * completes all the table items leading to our mapping
655 			 * are visible.
656 			 *
657 			 * This requires the pt_set_bit_release() to be a
658 			 * release of the cache flush so that this can acquire
659 			 * visibility at the iommu.
660 			 */
661 			if (pts_feature(&pts, PT_FEAT_DMA_INCOHERENT) &&
662 			    !pt_test_sw_bit_acquire(&pts,
663 						    SW_BIT_CACHE_FLUSH_DONE))
664 				flush_writes_item(&pts);
665 		}
666 
667 		/*
668 		 * The already present table can possibly be shared with another
669 		 * concurrent map.
670 		 */
671 		do {
672 			if (map->leaf_level == level - 1)
673 				ret = pt_descend(&pts, arg, __map_range_leaf);
674 			else
675 				ret = pt_descend(&pts, arg, __map_range);
676 		} while (ret == -EAGAIN);
677 		if (ret)
678 			return ret;
679 
680 		pts.index++;
681 		pt_index_to_va(&pts);
682 		if (pts.index >= pts.end_index)
683 			break;
684 
685 		/*
686 		 * This level is currently running __map_range_leaf() which is
687 		 * not correct if the target level has been updated to this
688 		 * level. Have the caller invoke __map_range_leaf.
689 		 */
690 		if (map->leaf_level == level)
691 			return -EAGAIN;
692 	} while (true);
693 	return 0;
694 }
695 
696 /*
697  * Fast path for the easy case of mapping a 4k page to an already allocated
698  * table. This is a common workload. If it returns EAGAIN run the full algorithm
699  * instead.
700  */
__do_map_single_page(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table,pt_level_fn_t descend_fn)701 static __always_inline int __do_map_single_page(struct pt_range *range,
702 						void *arg, unsigned int level,
703 						struct pt_table_p *table,
704 						pt_level_fn_t descend_fn)
705 {
706 	struct pt_state pts = pt_init(range, level, table);
707 	struct pt_iommu_map_args *map = arg;
708 
709 	pts.type = pt_load_single_entry(&pts);
710 	if (pts.level == 0) {
711 		if (pts.type != PT_ENTRY_EMPTY)
712 			return -EADDRINUSE;
713 		pt_install_leaf_entry(&pts, map->oa, PAGE_SHIFT,
714 				      &map->attrs);
715 		/* No flush, not used when incoherent */
716 		map->oa += PAGE_SIZE;
717 		return 0;
718 	}
719 	if (pts.type == PT_ENTRY_TABLE)
720 		return pt_descend(&pts, arg, descend_fn);
721 	/* Something else, use the slow path */
722 	return -EAGAIN;
723 }
724 PT_MAKE_LEVELS(__map_single_page, __do_map_single_page);
725 
726 /*
727  * Add a table to the top, increasing the top level as much as necessary to
728  * encompass range.
729  */
increase_top(struct pt_iommu * iommu_table,struct pt_range * range,struct pt_iommu_map_args * map)730 static int increase_top(struct pt_iommu *iommu_table, struct pt_range *range,
731 			struct pt_iommu_map_args *map)
732 {
733 	struct iommu_pages_list free_list = IOMMU_PAGES_LIST_INIT(free_list);
734 	struct pt_common *common = common_from_iommu(iommu_table);
735 	uintptr_t top_of_table = READ_ONCE(common->top_of_table);
736 	uintptr_t new_top_of_table = top_of_table;
737 	struct pt_table_p *table_mem;
738 	unsigned int new_level;
739 	spinlock_t *domain_lock;
740 	unsigned long flags;
741 	int ret;
742 
743 	while (true) {
744 		struct pt_range top_range =
745 			_pt_top_range(common, new_top_of_table);
746 		struct pt_state pts = pt_init_top(&top_range);
747 
748 		top_range.va = range->va;
749 		top_range.last_va = range->last_va;
750 
751 		if (!pt_check_range(&top_range) &&
752 		    map->leaf_level <= pts.level) {
753 			new_level = pts.level;
754 			break;
755 		}
756 
757 		pts.level++;
758 		if (pts.level > PT_MAX_TOP_LEVEL ||
759 		    pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2) {
760 			ret = -ERANGE;
761 			goto err_free;
762 		}
763 
764 		table_mem =
765 			table_alloc_top(common, _pt_top_set(NULL, pts.level),
766 					map->attrs.gfp, ALLOC_DEFER_COHERENT_FLUSH);
767 		if (IS_ERR(table_mem)) {
768 			ret = PTR_ERR(table_mem);
769 			goto err_free;
770 		}
771 		iommu_pages_list_add(&free_list, table_mem);
772 
773 		/* The new table links to the lower table always at index 0 */
774 		top_range.va = 0;
775 		top_range.top_level = pts.level;
776 		pts.table_lower = pts.table;
777 		pts.table = table_mem;
778 		pt_load_single_entry(&pts);
779 		PT_WARN_ON(pts.index != 0);
780 		pt_install_table(&pts, virt_to_phys(pts.table_lower),
781 				 &map->attrs);
782 		new_top_of_table = _pt_top_set(pts.table, pts.level);
783 	}
784 
785 	/*
786 	 * Avoid double flushing, flush it once after all pt_install_table()
787 	 */
788 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT)) {
789 		ret = iommu_pages_start_incoherent_list(
790 			&free_list, iommu_table->iommu_device);
791 		if (ret)
792 			goto err_free;
793 	}
794 
795 	/*
796 	 * top_of_table is write locked by the spinlock, but readers can use
797 	 * READ_ONCE() to get the value. Since we encode both the level and the
798 	 * pointer in one quanta the lockless reader will always see something
799 	 * valid. The HW must be updated to the new level under the spinlock
800 	 * before top_of_table is updated so that concurrent readers don't map
801 	 * into the new level until it is fully functional. If another thread
802 	 * already updated it while we were working then throw everything away
803 	 * and try again.
804 	 */
805 	domain_lock = iommu_table->driver_ops->get_top_lock(iommu_table);
806 	spin_lock_irqsave(domain_lock, flags);
807 	if (common->top_of_table != top_of_table ||
808 	    top_of_table == new_top_of_table) {
809 		spin_unlock_irqrestore(domain_lock, flags);
810 		ret = -EAGAIN;
811 		goto err_free;
812 	}
813 
814 	/*
815 	 * We do not issue any flushes for change_top on the expectation that
816 	 * any walk cache will not become a problem by adding another layer to
817 	 * the tree. Misses will rewalk from the updated top pointer, hits
818 	 * continue to be correct. Negative caching is fine too since all the
819 	 * new IOVA added by the new top is non-present.
820 	 */
821 	iommu_table->driver_ops->change_top(
822 		iommu_table, virt_to_phys(table_mem), new_level);
823 	WRITE_ONCE(common->top_of_table, new_top_of_table);
824 	spin_unlock_irqrestore(domain_lock, flags);
825 	return 0;
826 
827 err_free:
828 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
829 		iommu_pages_stop_incoherent_list(&free_list,
830 						 iommu_table->iommu_device);
831 	iommu_put_pages_list(&free_list);
832 	return ret;
833 }
834 
check_map_range(struct pt_iommu * iommu_table,struct pt_range * range,struct pt_iommu_map_args * map)835 static int check_map_range(struct pt_iommu *iommu_table, struct pt_range *range,
836 			   struct pt_iommu_map_args *map)
837 {
838 	struct pt_common *common = common_from_iommu(iommu_table);
839 	int ret;
840 
841 	do {
842 		ret = pt_check_range(range);
843 		if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP))
844 			return ret;
845 
846 		if (!ret && map->leaf_level <= range->top_level)
847 			break;
848 
849 		ret = increase_top(iommu_table, range, map);
850 		if (ret && ret != -EAGAIN)
851 			return ret;
852 
853 		/* Reload the new top */
854 		*range = pt_make_range(common, range->va, range->last_va);
855 	} while (ret);
856 	PT_WARN_ON(pt_check_range(range));
857 	return 0;
858 }
859 
do_map(struct pt_range * range,struct pt_common * common,bool single_page,struct pt_iommu_map_args * map)860 static int do_map(struct pt_range *range, struct pt_common *common,
861 		  bool single_page, struct pt_iommu_map_args *map)
862 {
863 	int ret;
864 
865 	/*
866 	 * The __map_single_page() fast path does not support DMA_INCOHERENT
867 	 * flushing to keep its .text small.
868 	 */
869 	if (single_page && !pt_feature(common, PT_FEAT_DMA_INCOHERENT)) {
870 
871 		ret = pt_walk_range(range, __map_single_page, map);
872 		if (ret != -EAGAIN)
873 			return ret;
874 		/* EAGAIN falls through to the full path */
875 	}
876 
877 	do {
878 		if (map->leaf_level == range->top_level)
879 			ret = pt_walk_range(range, __map_range_leaf, map);
880 		else
881 			ret = pt_walk_range(range, __map_range, map);
882 	} while (ret == -EAGAIN);
883 	return ret;
884 }
885 
NS(map_range)886 static int NS(map_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
887 			 phys_addr_t paddr, dma_addr_t len, unsigned int prot,
888 			 gfp_t gfp, size_t *mapped)
889 {
890 	pt_vaddr_t pgsize_bitmap = iommu_table->domain.pgsize_bitmap;
891 	struct pt_common *common = common_from_iommu(iommu_table);
892 	struct iommu_iotlb_gather iotlb_gather;
893 	struct pt_iommu_map_args map = {
894 		.iotlb_gather = &iotlb_gather,
895 		.oa = paddr,
896 	};
897 	bool single_page = false;
898 	struct pt_range range;
899 	int ret;
900 
901 	iommu_iotlb_gather_init(&iotlb_gather);
902 
903 	if (WARN_ON(!(prot & (IOMMU_READ | IOMMU_WRITE))))
904 		return -EINVAL;
905 
906 	/* Check the paddr doesn't exceed what the table can store */
907 	if ((sizeof(pt_oaddr_t) < sizeof(paddr) &&
908 	     (pt_vaddr_t)paddr > PT_VADDR_MAX) ||
909 	    (common->max_oasz_lg2 != PT_VADDR_MAX_LG2 &&
910 	     oalog2_div(paddr, common->max_oasz_lg2)))
911 		return -ERANGE;
912 
913 	ret = pt_iommu_set_prot(common, &map.attrs, prot);
914 	if (ret)
915 		return ret;
916 	map.attrs.gfp = gfp;
917 
918 	ret = make_range_no_check(common, &range, iova, len);
919 	if (ret)
920 		return ret;
921 
922 	/* Calculate target page size and level for the leaves */
923 	if (pt_has_system_page_size(common) && len == PAGE_SIZE) {
924 		PT_WARN_ON(!(pgsize_bitmap & PAGE_SIZE));
925 		if (log2_mod(iova | paddr, PAGE_SHIFT))
926 			return -ENXIO;
927 		map.leaf_pgsize_lg2 = PAGE_SHIFT;
928 		map.leaf_level = 0;
929 		map.num_leaves = 1;
930 		single_page = true;
931 	} else {
932 		map.leaf_pgsize_lg2 = pt_compute_best_pgsize(
933 			pgsize_bitmap, range.va, range.last_va, paddr);
934 		if (!map.leaf_pgsize_lg2)
935 			return -ENXIO;
936 		map.leaf_level =
937 			pt_pgsz_lg2_to_level(common, map.leaf_pgsize_lg2);
938 		map.num_leaves = pt_pgsz_count(pgsize_bitmap, range.va,
939 					       range.last_va, paddr,
940 					       map.leaf_pgsize_lg2);
941 	}
942 
943 	ret = check_map_range(iommu_table, &range, &map);
944 	if (ret)
945 		return ret;
946 
947 	PT_WARN_ON(map.leaf_level > range.top_level);
948 
949 	ret = do_map(&range, common, single_page, &map);
950 
951 	/*
952 	 * Table levels were freed and replaced with large items, flush any walk
953 	 * cache that may refer to the freed levels.
954 	 */
955 	if (!iommu_pages_list_empty(&iotlb_gather.freelist))
956 		iommu_iotlb_sync(&iommu_table->domain, &iotlb_gather);
957 
958 	/* Bytes successfully mapped */
959 	PT_WARN_ON(!ret && map.oa - paddr != len);
960 	*mapped += map.oa - paddr;
961 	return ret;
962 }
963 
964 struct pt_unmap_args {
965 	struct iommu_pages_list free_list;
966 	pt_vaddr_t unmapped;
967 };
968 
__unmap_range(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)969 static __maybe_unused int __unmap_range(struct pt_range *range, void *arg,
970 					unsigned int level,
971 					struct pt_table_p *table)
972 {
973 	struct pt_state pts = pt_init(range, level, table);
974 	unsigned int flush_start_index = UINT_MAX;
975 	unsigned int flush_end_index = UINT_MAX;
976 	struct pt_unmap_args *unmap = arg;
977 	unsigned int num_oas = 0;
978 	unsigned int start_index;
979 	int ret = 0;
980 
981 	_pt_iter_first(&pts);
982 	start_index = pts.index;
983 	pts.type = pt_load_entry_raw(&pts);
984 	/*
985 	 * A starting index is in the middle of a contiguous entry
986 	 *
987 	 * The IOMMU API does not require drivers to support unmapping parts of
988 	 * large pages. Long ago VFIO would try to split maps but the current
989 	 * version never does.
990 	 *
991 	 * Instead when unmap reaches a partial unmap of the start of a large
992 	 * IOPTE it should remove the entire IOPTE and return that size to the
993 	 * caller.
994 	 */
995 	if (pts.type == PT_ENTRY_OA) {
996 		if (log2_mod(range->va, pt_entry_oa_lg2sz(&pts)))
997 			return -EINVAL;
998 		/* Micro optimization */
999 		goto start_oa;
1000 	}
1001 
1002 	do {
1003 		if (pts.type != PT_ENTRY_OA) {
1004 			bool fully_covered;
1005 
1006 			if (pts.type != PT_ENTRY_TABLE) {
1007 				ret = -EINVAL;
1008 				break;
1009 			}
1010 
1011 			if (pts.index != start_index)
1012 				pt_index_to_va(&pts);
1013 			pts.table_lower = pt_table_ptr(&pts);
1014 
1015 			fully_covered = pt_entry_fully_covered(
1016 				&pts, pt_table_item_lg2sz(&pts));
1017 
1018 			ret = pt_descend(&pts, arg, __unmap_range);
1019 			if (ret)
1020 				break;
1021 
1022 			/*
1023 			 * If the unmapping range fully covers the table then we
1024 			 * can free it as well. The clear is delayed until we
1025 			 * succeed in clearing the lower table levels.
1026 			 */
1027 			if (fully_covered) {
1028 				iommu_pages_list_add(&unmap->free_list,
1029 						     pts.table_lower);
1030 				pt_clear_entries(&pts, ilog2(1));
1031 				if (pts.index < flush_start_index)
1032 					flush_start_index = pts.index;
1033 				flush_end_index = pts.index + 1;
1034 			}
1035 			pts.index++;
1036 		} else {
1037 			unsigned int num_contig_lg2;
1038 start_oa:
1039 			/*
1040 			 * If the caller requested an last that falls within a
1041 			 * single entry then the entire entry is unmapped and
1042 			 * the length returned will be larger than requested.
1043 			 */
1044 			num_contig_lg2 = pt_entry_num_contig_lg2(&pts);
1045 			pt_clear_entries(&pts, num_contig_lg2);
1046 			num_oas += log2_to_int(num_contig_lg2);
1047 			if (pts.index < flush_start_index)
1048 				flush_start_index = pts.index;
1049 			pts.index += log2_to_int(num_contig_lg2);
1050 			flush_end_index = pts.index;
1051 		}
1052 		if (pts.index >= pts.end_index)
1053 			break;
1054 		pts.type = pt_load_entry_raw(&pts);
1055 	} while (true);
1056 
1057 	unmap->unmapped += log2_mul(num_oas, pt_table_item_lg2sz(&pts));
1058 	if (flush_start_index != flush_end_index)
1059 		flush_writes_range(&pts, flush_start_index, flush_end_index);
1060 
1061 	return ret;
1062 }
1063 
NS(unmap_range)1064 static size_t NS(unmap_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
1065 			      dma_addr_t len,
1066 			      struct iommu_iotlb_gather *iotlb_gather)
1067 {
1068 	struct pt_unmap_args unmap = { .free_list = IOMMU_PAGES_LIST_INIT(
1069 					       unmap.free_list) };
1070 	struct pt_range range;
1071 	int ret;
1072 
1073 	ret = make_range(common_from_iommu(iommu_table), &range, iova, len);
1074 	if (ret)
1075 		return 0;
1076 
1077 	pt_walk_range(&range, __unmap_range, &unmap);
1078 
1079 	gather_range_pages(iotlb_gather, iommu_table, iova, unmap.unmapped,
1080 			   &unmap.free_list);
1081 
1082 	return unmap.unmapped;
1083 }
1084 
NS(get_info)1085 static void NS(get_info)(struct pt_iommu *iommu_table,
1086 			 struct pt_iommu_info *info)
1087 {
1088 	struct pt_common *common = common_from_iommu(iommu_table);
1089 	struct pt_range range = pt_top_range(common);
1090 	struct pt_state pts = pt_init_top(&range);
1091 	pt_vaddr_t pgsize_bitmap = 0;
1092 
1093 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP)) {
1094 		for (pts.level = 0; pts.level <= PT_MAX_TOP_LEVEL;
1095 		     pts.level++) {
1096 			if (pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2)
1097 				break;
1098 			pgsize_bitmap |= pt_possible_sizes(&pts);
1099 		}
1100 	} else {
1101 		for (pts.level = 0; pts.level <= range.top_level; pts.level++)
1102 			pgsize_bitmap |= pt_possible_sizes(&pts);
1103 	}
1104 
1105 	/* Hide page sizes larger than the maximum OA */
1106 	info->pgsize_bitmap = oalog2_mod(pgsize_bitmap, common->max_oasz_lg2);
1107 }
1108 
NS(deinit)1109 static void NS(deinit)(struct pt_iommu *iommu_table)
1110 {
1111 	struct pt_common *common = common_from_iommu(iommu_table);
1112 	struct pt_range range = pt_all_range(common);
1113 	struct pt_iommu_collect_args collect = {
1114 		.free_list = IOMMU_PAGES_LIST_INIT(collect.free_list),
1115 	};
1116 
1117 	iommu_pages_list_add(&collect.free_list, range.top_table);
1118 	pt_walk_range(&range, __collect_tables, &collect);
1119 
1120 	/*
1121 	 * The driver has to already have fenced the HW access to the page table
1122 	 * and invalidated any caching referring to this memory.
1123 	 */
1124 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
1125 		iommu_pages_stop_incoherent_list(&collect.free_list,
1126 						 iommu_table->iommu_device);
1127 	iommu_put_pages_list(&collect.free_list);
1128 }
1129 
1130 static const struct pt_iommu_ops NS(ops) = {
1131 	.map_range = NS(map_range),
1132 	.unmap_range = NS(unmap_range),
1133 #if IS_ENABLED(CONFIG_IOMMUFD_DRIVER) && defined(pt_entry_is_write_dirty) && \
1134 	IS_ENABLED(CONFIG_IOMMUFD_TEST) && defined(pt_entry_make_write_dirty)
1135 	.set_dirty = NS(set_dirty),
1136 #endif
1137 	.get_info = NS(get_info),
1138 	.deinit = NS(deinit),
1139 };
1140 
pt_init_common(struct pt_common * common)1141 static int pt_init_common(struct pt_common *common)
1142 {
1143 	struct pt_range top_range = pt_top_range(common);
1144 
1145 	if (PT_WARN_ON(top_range.top_level > PT_MAX_TOP_LEVEL))
1146 		return -EINVAL;
1147 
1148 	if (top_range.top_level == PT_MAX_TOP_LEVEL ||
1149 	    common->max_vasz_lg2 == top_range.max_vasz_lg2)
1150 		common->features &= ~BIT(PT_FEAT_DYNAMIC_TOP);
1151 
1152 	if (top_range.max_vasz_lg2 == PT_VADDR_MAX_LG2)
1153 		common->features |= BIT(PT_FEAT_FULL_VA);
1154 
1155 	/* Requested features must match features compiled into this format */
1156 	if ((common->features & ~(unsigned int)PT_SUPPORTED_FEATURES) ||
1157 	    (!IS_ENABLED(CONFIG_DEBUG_GENERIC_PT) &&
1158 	     (common->features & PT_FORCE_ENABLED_FEATURES) !=
1159 		     PT_FORCE_ENABLED_FEATURES))
1160 		return -EOPNOTSUPP;
1161 
1162 	/*
1163 	 * Check if the top level of the page table is too small to hold the
1164 	 * specified maxvasz.
1165 	 */
1166 	if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP) &&
1167 	    top_range.top_level != PT_MAX_TOP_LEVEL) {
1168 		struct pt_state pts = { .range = &top_range,
1169 					.level = top_range.top_level };
1170 
1171 		if (common->max_vasz_lg2 >
1172 		    pt_num_items_lg2(&pts) + pt_table_item_lg2sz(&pts))
1173 			return -EOPNOTSUPP;
1174 	}
1175 
1176 	if (common->max_oasz_lg2 == 0)
1177 		common->max_oasz_lg2 = pt_max_oa_lg2(common);
1178 	else
1179 		common->max_oasz_lg2 = min(common->max_oasz_lg2,
1180 					   pt_max_oa_lg2(common));
1181 	return 0;
1182 }
1183 
pt_iommu_init_domain(struct pt_iommu * iommu_table,struct iommu_domain * domain)1184 static int pt_iommu_init_domain(struct pt_iommu *iommu_table,
1185 				struct iommu_domain *domain)
1186 {
1187 	struct pt_common *common = common_from_iommu(iommu_table);
1188 	struct pt_iommu_info info;
1189 	struct pt_range range;
1190 
1191 	NS(get_info)(iommu_table, &info);
1192 
1193 	domain->type = __IOMMU_DOMAIN_PAGING;
1194 	domain->pgsize_bitmap = info.pgsize_bitmap;
1195 	domain->is_iommupt = true;
1196 
1197 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP))
1198 		range = _pt_top_range(common,
1199 				      _pt_top_set(NULL, PT_MAX_TOP_LEVEL));
1200 	else
1201 		range = pt_top_range(common);
1202 
1203 	/* A 64-bit high address space table on a 32-bit system cannot work. */
1204 	domain->geometry.aperture_start = (unsigned long)range.va;
1205 	if ((pt_vaddr_t)domain->geometry.aperture_start != range.va)
1206 		return -EOVERFLOW;
1207 
1208 	/*
1209 	 * The aperture is limited to what the API can do after considering all
1210 	 * the different types dma_addr_t/unsigned long/pt_vaddr_t that are used
1211 	 * to store a VA. Set the aperture to something that is valid for all
1212 	 * cases. Saturate instead of truncate the end if the types are smaller
1213 	 * than the top range. aperture_end should be called aperture_last.
1214 	 */
1215 	domain->geometry.aperture_end = (unsigned long)range.last_va;
1216 	if ((pt_vaddr_t)domain->geometry.aperture_end != range.last_va) {
1217 		domain->geometry.aperture_end = ULONG_MAX;
1218 		domain->pgsize_bitmap &= ULONG_MAX;
1219 	}
1220 	domain->geometry.force_aperture = true;
1221 
1222 	return 0;
1223 }
1224 
pt_iommu_zero(struct pt_iommu_table * fmt_table)1225 static void pt_iommu_zero(struct pt_iommu_table *fmt_table)
1226 {
1227 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1228 	struct pt_iommu cfg = *iommu_table;
1229 
1230 	static_assert(offsetof(struct pt_iommu_table, iommu.domain) == 0);
1231 	memset_after(fmt_table, 0, iommu.domain);
1232 
1233 	/* The caller can initialize some of these values */
1234 	iommu_table->iommu_device = cfg.iommu_device;
1235 	iommu_table->driver_ops = cfg.driver_ops;
1236 	iommu_table->nid = cfg.nid;
1237 }
1238 
1239 #define pt_iommu_table_cfg CONCATENATE(pt_iommu_table, _cfg)
1240 #define pt_iommu_init CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), init)
1241 
pt_iommu_init(struct pt_iommu_table * fmt_table,const struct pt_iommu_table_cfg * cfg,gfp_t gfp)1242 int pt_iommu_init(struct pt_iommu_table *fmt_table,
1243 		  const struct pt_iommu_table_cfg *cfg, gfp_t gfp)
1244 {
1245 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1246 	struct pt_common *common = common_from_iommu(iommu_table);
1247 	struct pt_table_p *table_mem;
1248 	int ret;
1249 
1250 	if (cfg->common.hw_max_vasz_lg2 > PT_MAX_VA_ADDRESS_LG2 ||
1251 	    !cfg->common.hw_max_vasz_lg2 || !cfg->common.hw_max_oasz_lg2)
1252 		return -EINVAL;
1253 
1254 	pt_iommu_zero(fmt_table);
1255 	common->features = cfg->common.features;
1256 	common->max_vasz_lg2 = cfg->common.hw_max_vasz_lg2;
1257 	common->max_oasz_lg2 = cfg->common.hw_max_oasz_lg2;
1258 	ret = pt_iommu_fmt_init(fmt_table, cfg);
1259 	if (ret)
1260 		return ret;
1261 
1262 	if (cfg->common.hw_max_oasz_lg2 > pt_max_oa_lg2(common))
1263 		return -EINVAL;
1264 
1265 	ret = pt_init_common(common);
1266 	if (ret)
1267 		return ret;
1268 
1269 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP) &&
1270 	    WARN_ON(!iommu_table->driver_ops ||
1271 		    !iommu_table->driver_ops->change_top ||
1272 		    !iommu_table->driver_ops->get_top_lock))
1273 		return -EINVAL;
1274 
1275 	if (pt_feature(common, PT_FEAT_SIGN_EXTEND) &&
1276 	    (pt_feature(common, PT_FEAT_FULL_VA) ||
1277 	     pt_feature(common, PT_FEAT_DYNAMIC_TOP)))
1278 		return -EINVAL;
1279 
1280 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) &&
1281 	    WARN_ON(!iommu_table->iommu_device))
1282 		return -EINVAL;
1283 
1284 	ret = pt_iommu_init_domain(iommu_table, &iommu_table->domain);
1285 	if (ret)
1286 		return ret;
1287 
1288 	table_mem = table_alloc_top(common, common->top_of_table, gfp,
1289 				    ALLOC_NORMAL);
1290 	if (IS_ERR(table_mem))
1291 		return PTR_ERR(table_mem);
1292 	pt_top_set(common, table_mem, pt_top_get_level(common));
1293 
1294 	/* Must be last, see pt_iommu_deinit() */
1295 	iommu_table->ops = &NS(ops);
1296 	return 0;
1297 }
1298 EXPORT_SYMBOL_NS_GPL(pt_iommu_init, "GENERIC_PT_IOMMU");
1299 
1300 #ifdef pt_iommu_fmt_hw_info
1301 #define pt_iommu_table_hw_info CONCATENATE(pt_iommu_table, _hw_info)
1302 #define pt_iommu_hw_info CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), hw_info)
pt_iommu_hw_info(struct pt_iommu_table * fmt_table,struct pt_iommu_table_hw_info * info)1303 void pt_iommu_hw_info(struct pt_iommu_table *fmt_table,
1304 		      struct pt_iommu_table_hw_info *info)
1305 {
1306 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1307 	struct pt_common *common = common_from_iommu(iommu_table);
1308 	struct pt_range top_range = pt_top_range(common);
1309 
1310 	pt_iommu_fmt_hw_info(fmt_table, &top_range, info);
1311 }
1312 EXPORT_SYMBOL_NS_GPL(pt_iommu_hw_info, "GENERIC_PT_IOMMU");
1313 #endif
1314 
1315 MODULE_LICENSE("GPL");
1316 MODULE_DESCRIPTION("IOMMU Page table implementation for " __stringify(PTPFX_RAW));
1317 MODULE_IMPORT_NS("GENERIC_PT");
1318 /* For iommu_dirty_bitmap_record() */
1319 MODULE_IMPORT_NS("IOMMUFD");
1320 
1321 #endif  /* __GENERIC_PT_IOMMU_PT_H */
1322