xref: /linux/drivers/iommu/generic_pt/iommu_pt.h (revision dd8a3c6cd531dca5917111a94fa3074077f6ba5a)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
4  *
5  * "Templated C code" for implementing the iommu operations for page tables.
6  * This is compiled multiple times, over all the page table formats to pick up
7  * the per-format definitions.
8  */
9 #ifndef __GENERIC_PT_IOMMU_PT_H
10 #define __GENERIC_PT_IOMMU_PT_H
11 
12 #include "pt_iter.h"
13 
14 #include <linux/export.h>
15 #include <linux/iommu.h>
16 #include "../iommu-pages.h"
17 #include <linux/cleanup.h>
18 #include <linux/dma-mapping.h>
19 
20 enum {
21 	SW_BIT_CACHE_FLUSH_DONE = 0,
22 };
23 
24 static void flush_writes_range(const struct pt_state *pts,
25 			       unsigned int start_index, unsigned int end_index)
26 {
27 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT))
28 		iommu_pages_flush_incoherent(
29 			iommu_from_common(pts->range->common)->iommu_device,
30 			pts->table, start_index * PT_ITEM_WORD_SIZE,
31 			(end_index - start_index) * PT_ITEM_WORD_SIZE);
32 }
33 
34 static void flush_writes_item(const struct pt_state *pts)
35 {
36 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT))
37 		iommu_pages_flush_incoherent(
38 			iommu_from_common(pts->range->common)->iommu_device,
39 			pts->table, pts->index * PT_ITEM_WORD_SIZE,
40 			PT_ITEM_WORD_SIZE);
41 }
42 
43 struct iommupt_pending_gather {
44 	struct iommu_iotlb_gather *iotlb_gather;
45 	struct iommu_pages_list free_list;
46 	u8 leaf_levels_bitmap;
47 	u8 table_levels_bitmap;
48 };
49 
50 static void gather_add_table(struct iommupt_pending_gather *pending,
51 			     const struct pt_state *pts,
52 			     struct pt_table_p *table)
53 {
54 	iommu_pages_list_add(&pending->free_list, table);
55 	if (pts_feature(pts, PT_FEAT_DETAILED_GATHER))
56 		pending->table_levels_bitmap |= BIT(pts->level);
57 }
58 
59 static void gather_add_leaf(struct iommupt_pending_gather *pending,
60 			    const struct pt_state *pts)
61 {
62 	if (!pts_feature(pts, PT_FEAT_DETAILED_GATHER))
63 		return;
64 
65 	pending->leaf_levels_bitmap |= BIT(pts->level);
66 }
67 
68 static void gather_range_pending(struct iommupt_pending_gather *pending,
69 				 struct pt_iommu *iommu_table, pt_vaddr_t iova,
70 				 pt_vaddr_t len)
71 {
72 	struct iommu_iotlb_gather *iotlb_gather = pending->iotlb_gather;
73 	struct pt_common *common = common_from_iommu(iommu_table);
74 
75 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
76 		iommu_pages_stop_incoherent_list(&pending->free_list,
77 						 iommu_table->iommu_device);
78 
79 	/*
80 	 * If running in DMA-FQ mode then the unmap will be followed by an IOTLB
81 	 * flush all so we need to optimize by never flushing the IOTLB here.
82 	 *
83 	 * For NO_GAPS the user gets to pick if flushing all or doing micro
84 	 * flushes is better for their work load by choosing DMA vs DMA-FQ
85 	 * operation. Drivers should also see shadow_on_flush.
86 	 */
87 	if (!iommu_iotlb_gather_queued(iotlb_gather)) {
88 		if (pt_feature(common, PT_FEAT_FLUSH_RANGE_NO_GAPS) &&
89 		    iommu_iotlb_gather_is_disjoint(iotlb_gather, iova, len)) {
90 			iommu_iotlb_sync(&iommu_table->domain, iotlb_gather);
91 			/*
92 			 * Note that the sync frees the gather's free list, so
93 			 * we must not have any pages on that list that are
94 			 * covered by iova/len
95 			 */
96 		}
97 		iommu_iotlb_gather_add_range(iotlb_gather, iova, len);
98 	}
99 
100 	iommu_pages_list_splice(&pending->free_list, &iotlb_gather->freelist);
101 	INIT_LIST_HEAD(&pending->free_list.pages);
102 
103 	if (pt_feature(common, PT_FEAT_DETAILED_GATHER)) {
104 		iotlb_gather->pt.leaf_levels_bitmap |=
105 			pending->leaf_levels_bitmap;
106 		iotlb_gather->pt.table_levels_bitmap |=
107 			pending->table_levels_bitmap;
108 		pending->leaf_levels_bitmap = 0;
109 		pending->table_levels_bitmap = 0;
110 	}
111 }
112 
113 #define DOMAIN_NS(op) CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), op)
114 
115 static int make_range_ul(struct pt_common *common, struct pt_range *range,
116 			 unsigned long iova, unsigned long len)
117 {
118 	unsigned long last;
119 
120 	if (unlikely(len == 0))
121 		return -EINVAL;
122 
123 	if (check_add_overflow(iova, len - 1, &last))
124 		return -EOVERFLOW;
125 
126 	*range = pt_make_range(common, iova, last);
127 	if (sizeof(iova) > sizeof(range->va)) {
128 		if (unlikely(range->va != iova || range->last_va != last))
129 			return -EOVERFLOW;
130 	}
131 	return 0;
132 }
133 
134 static __maybe_unused int make_range_u64(struct pt_common *common,
135 					 struct pt_range *range, u64 iova,
136 					 u64 len)
137 {
138 	if (unlikely(iova > ULONG_MAX || len > ULONG_MAX))
139 		return -EOVERFLOW;
140 	return make_range_ul(common, range, iova, len);
141 }
142 
143 /*
144  * Some APIs use unsigned long, while othersuse dma_addr_t as the type. Dispatch
145  * to the correct validation based on the type.
146  */
147 #define make_range_no_check(common, range, iova, len)                   \
148 	({                                                              \
149 		int ret;                                                \
150 		if (sizeof(iova) > sizeof(unsigned long) ||             \
151 		    sizeof(len) > sizeof(unsigned long))                \
152 			ret = make_range_u64(common, range, iova, len); \
153 		else                                                    \
154 			ret = make_range_ul(common, range, iova, len);  \
155 		ret;                                                    \
156 	})
157 
158 #define make_range(common, range, iova, len)                             \
159 	({                                                               \
160 		int ret = make_range_no_check(common, range, iova, len); \
161 		if (!ret)                                                \
162 			ret = pt_check_range(range);                     \
163 		ret;                                                     \
164 	})
165 
166 static inline unsigned int compute_best_pgsize(struct pt_state *pts,
167 					       pt_oaddr_t oa)
168 {
169 	struct pt_iommu *iommu_table = iommu_from_common(pts->range->common);
170 
171 	if (!pt_can_have_leaf(pts))
172 		return 0;
173 
174 	/*
175 	 * The page size is limited by the domain's bitmap. This allows the core
176 	 * code to reduce the supported page sizes by changing the bitmap.
177 	 */
178 	return pt_compute_best_pgsize(pt_possible_sizes(pts) &
179 					      iommu_table->domain.pgsize_bitmap,
180 				      pts->range->va, pts->range->last_va, oa);
181 }
182 
183 static __always_inline int __do_iova_to_phys(struct pt_range *range, void *arg,
184 					     unsigned int level,
185 					     struct pt_table_p *table,
186 					     pt_level_fn_t descend_fn)
187 {
188 	struct pt_state pts = pt_init(range, level, table);
189 	pt_oaddr_t *res = arg;
190 
191 	switch (pt_load_single_entry(&pts)) {
192 	case PT_ENTRY_EMPTY:
193 		return -ENOENT;
194 	case PT_ENTRY_TABLE:
195 		return pt_descend(&pts, arg, descend_fn);
196 	case PT_ENTRY_OA:
197 		*res = pt_entry_oa_exact(&pts);
198 		return 0;
199 	}
200 	return -ENOENT;
201 }
202 PT_MAKE_LEVELS(__iova_to_phys, __do_iova_to_phys);
203 
204 /**
205  * iova_to_phys() - Return the output address for the given IOVA
206  * @domain: Table to query
207  * @iova: IO virtual address to query
208  *
209  * Determine the output address from the given IOVA. @iova may have any
210  * alignment, the returned physical will be adjusted with any sub page offset.
211  *
212  * Context: The caller must hold a read range lock that includes @iova.
213  *
214  * Return: 0 if there is no translation for the given iova.
215  */
216 phys_addr_t DOMAIN_NS(iova_to_phys)(struct iommu_domain *domain,
217 				    dma_addr_t iova)
218 {
219 	struct pt_iommu *iommu_table =
220 		container_of(domain, struct pt_iommu, domain);
221 	struct pt_range range;
222 	pt_oaddr_t res;
223 	int ret;
224 
225 	ret = make_range(common_from_iommu(iommu_table), &range, iova, 1);
226 	if (ret)
227 		return ret;
228 
229 	ret = pt_walk_range(&range, __iova_to_phys, &res);
230 	/* PHYS_ADDR_MAX would be a better error code */
231 	if (ret)
232 		return 0;
233 	return res;
234 }
235 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(iova_to_phys), "GENERIC_PT_IOMMU");
236 
237 struct pt_iommu_dirty_args {
238 	struct iommu_dirty_bitmap *dirty;
239 	unsigned int flags;
240 };
241 
242 static void record_dirty(struct pt_state *pts,
243 			 struct pt_iommu_dirty_args *dirty,
244 			 unsigned int num_contig_lg2)
245 {
246 	pt_vaddr_t dirty_len;
247 
248 	if (num_contig_lg2 != ilog2(1)) {
249 		unsigned int index = pts->index;
250 		unsigned int end_index = log2_set_mod_max_t(
251 			unsigned int, pts->index, num_contig_lg2);
252 
253 		/* Adjust for being contained inside a contiguous page */
254 		end_index = min(end_index, pts->end_index);
255 		dirty_len = (end_index - index) *
256 				log2_to_int(pt_table_item_lg2sz(pts));
257 	} else {
258 		dirty_len = log2_to_int(pt_table_item_lg2sz(pts));
259 	}
260 
261 	if (dirty->dirty->bitmap)
262 		iova_bitmap_set(dirty->dirty->bitmap, pts->range->va,
263 				dirty_len);
264 
265 	if (!(dirty->flags & IOMMU_DIRTY_NO_CLEAR)) {
266 		/*
267 		 * No write log required because DMA incoherence and atomic
268 		 * dirty tracking bits can't work together
269 		 */
270 		pt_entry_make_write_clean(pts);
271 		iommu_iotlb_gather_add_range(dirty->dirty->gather,
272 					     pts->range->va, dirty_len);
273 	}
274 }
275 
276 static inline int __read_and_clear_dirty(struct pt_range *range, void *arg,
277 					 unsigned int level,
278 					 struct pt_table_p *table)
279 {
280 	struct pt_state pts = pt_init(range, level, table);
281 	struct pt_iommu_dirty_args *dirty = arg;
282 	int ret;
283 
284 	for_each_pt_level_entry(&pts) {
285 		if (pts.type == PT_ENTRY_TABLE) {
286 			ret = pt_descend(&pts, arg, __read_and_clear_dirty);
287 			if (ret)
288 				return ret;
289 			continue;
290 		}
291 		if (pts.type == PT_ENTRY_OA && pt_entry_is_write_dirty(&pts))
292 			record_dirty(&pts, dirty,
293 				     pt_entry_num_contig_lg2(&pts));
294 	}
295 	return 0;
296 }
297 
298 /**
299  * read_and_clear_dirty() - Manipulate the HW set write dirty state
300  * @domain: Domain to manipulate
301  * @iova: IO virtual address to start
302  * @size: Length of the IOVA
303  * @flags: A bitmap of IOMMU_DIRTY_NO_CLEAR
304  * @dirty: Place to store the dirty bits
305  *
306  * Iterate over all the entries in the mapped range and record their write dirty
307  * status in iommu_dirty_bitmap. If IOMMU_DIRTY_NO_CLEAR is not specified then
308  * the entries will be left dirty, otherwise they are returned to being not
309  * write dirty.
310  *
311  * Context: The caller must hold a read range lock that includes @iova.
312  *
313  * Returns: -ERRNO on failure, 0 on success.
314  */
315 int DOMAIN_NS(read_and_clear_dirty)(struct iommu_domain *domain,
316 				    unsigned long iova, size_t size,
317 				    unsigned long flags,
318 				    struct iommu_dirty_bitmap *dirty)
319 {
320 	struct pt_iommu *iommu_table =
321 		container_of(domain, struct pt_iommu, domain);
322 	struct pt_iommu_dirty_args dirty_args = {
323 		.dirty = dirty,
324 		.flags = flags,
325 	};
326 	struct pt_range range;
327 	int ret;
328 
329 #if !IS_ENABLED(CONFIG_IOMMUFD_DRIVER) || !defined(pt_entry_is_write_dirty)
330 	return -EOPNOTSUPP;
331 #endif
332 
333 	ret = make_range(common_from_iommu(iommu_table), &range, iova, size);
334 	if (ret)
335 		return ret;
336 
337 	ret = pt_walk_range(&range, __read_and_clear_dirty, &dirty_args);
338 	PT_WARN_ON(ret);
339 	return ret;
340 }
341 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(read_and_clear_dirty), "GENERIC_PT_IOMMU");
342 
343 static inline int __set_dirty(struct pt_range *range, void *arg,
344 			      unsigned int level, struct pt_table_p *table)
345 {
346 	struct pt_state pts = pt_init(range, level, table);
347 
348 	switch (pt_load_single_entry(&pts)) {
349 	case PT_ENTRY_EMPTY:
350 		return -ENOENT;
351 	case PT_ENTRY_TABLE:
352 		return pt_descend(&pts, arg, __set_dirty);
353 	case PT_ENTRY_OA:
354 		if (!pt_entry_make_write_dirty(&pts))
355 			return -EAGAIN;
356 		return 0;
357 	}
358 	return -ENOENT;
359 }
360 
361 static int __maybe_unused NS(set_dirty)(struct pt_iommu *iommu_table,
362 					dma_addr_t iova)
363 {
364 	struct pt_range range;
365 	int ret;
366 
367 	ret = make_range(common_from_iommu(iommu_table), &range, iova, 1);
368 	if (ret)
369 		return ret;
370 
371 	/*
372 	 * Note: There is no locking here yet, if the test suite races this it
373 	 * can crash. It should use RCU locking eventually.
374 	 */
375 	return pt_walk_range(&range, __set_dirty, NULL);
376 }
377 
378 struct pt_iommu_collect_args {
379 	struct iommupt_pending_gather pending;
380 	/* Fail if any OAs are within the range */
381 	u8 check_mapped : 1;
382 };
383 
384 static int __collect_tables(struct pt_range *range, void *arg,
385 			    unsigned int level, struct pt_table_p *table)
386 {
387 	struct pt_state pts = pt_init(range, level, table);
388 	struct pt_iommu_collect_args *collect = arg;
389 	int ret;
390 
391 	if (!collect->check_mapped && !pt_can_have_table(&pts))
392 		return 0;
393 
394 	for_each_pt_level_entry(&pts) {
395 		if (pts.type == PT_ENTRY_TABLE) {
396 			gather_add_table(&collect->pending, &pts,
397 					 pts.table_lower);
398 			ret = pt_descend(&pts, arg, __collect_tables);
399 			if (ret)
400 				return ret;
401 			continue;
402 		}
403 		if (pts.type == PT_ENTRY_OA && collect->check_mapped)
404 			return -EADDRINUSE;
405 	}
406 	return 0;
407 }
408 
409 enum alloc_mode {ALLOC_NORMAL, ALLOC_DEFER_COHERENT_FLUSH};
410 
411 /* Allocate a table, the empty table will be ready to be installed. */
412 static inline struct pt_table_p *_table_alloc(struct pt_common *common,
413 					      size_t lg2sz, gfp_t gfp,
414 					      enum alloc_mode mode)
415 {
416 	struct pt_iommu *iommu_table = iommu_from_common(common);
417 	struct pt_table_p *table_mem;
418 
419 	table_mem = iommu_alloc_pages_node_sz(iommu_table->nid, gfp,
420 					      log2_to_int(lg2sz));
421 	if (!table_mem)
422 		return ERR_PTR(-ENOMEM);
423 
424 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) &&
425 	    mode == ALLOC_NORMAL) {
426 		int ret = iommu_pages_start_incoherent(
427 			table_mem, iommu_table->iommu_device);
428 		if (ret) {
429 			iommu_free_pages(table_mem);
430 			return ERR_PTR(ret);
431 		}
432 	}
433 	return table_mem;
434 }
435 
436 static inline struct pt_table_p *table_alloc_top(struct pt_common *common,
437 						 uintptr_t top_of_table,
438 						 gfp_t gfp,
439 						 enum alloc_mode mode)
440 {
441 	/*
442 	 * Top doesn't need the free list or otherwise, so it technically
443 	 * doesn't need to use iommu pages. Use the API anyhow as the top is
444 	 * usually not smaller than PAGE_SIZE to keep things simple.
445 	 */
446 	return _table_alloc(common, pt_top_memsize_lg2(common, top_of_table),
447 			    gfp, mode);
448 }
449 
450 /* Allocate an interior table */
451 static inline struct pt_table_p *table_alloc(const struct pt_state *parent_pts,
452 					     gfp_t gfp, enum alloc_mode mode)
453 {
454 	struct pt_state child_pts =
455 		pt_init(parent_pts->range, parent_pts->level - 1, NULL);
456 
457 	return _table_alloc(parent_pts->range->common,
458 			    pt_num_items_lg2(&child_pts) +
459 				    ilog2(PT_ITEM_WORD_SIZE),
460 			    gfp, mode);
461 }
462 
463 static inline int pt_iommu_new_table(struct pt_state *pts,
464 				     struct pt_write_attrs *attrs)
465 {
466 	struct pt_table_p *table_mem;
467 	phys_addr_t phys;
468 
469 	/* Given PA/VA/length can't be represented */
470 	if (PT_WARN_ON(!pt_can_have_table(pts)))
471 		return -ENXIO;
472 
473 	table_mem = table_alloc(pts, attrs->gfp, ALLOC_NORMAL);
474 	if (IS_ERR(table_mem))
475 		return PTR_ERR(table_mem);
476 
477 	phys = virt_to_phys(table_mem);
478 	if (!pt_install_table(pts, phys, attrs)) {
479 		iommu_pages_free_incoherent(
480 			table_mem,
481 			iommu_from_common(pts->range->common)->iommu_device);
482 		return -EAGAIN;
483 	}
484 
485 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT)) {
486 		flush_writes_item(pts);
487 		pt_set_sw_bit_release(pts, SW_BIT_CACHE_FLUSH_DONE);
488 	}
489 
490 	if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) {
491 		/*
492 		 * The underlying table can't store the physical table address.
493 		 * This happens when kunit testing tables outside their normal
494 		 * environment where a CPU might be limited.
495 		 */
496 		pt_load_single_entry(pts);
497 		if (PT_WARN_ON(pt_table_pa(pts) != phys)) {
498 			pt_clear_entries(pts, ilog2(1));
499 			iommu_pages_free_incoherent(
500 				table_mem, iommu_from_common(pts->range->common)
501 						   ->iommu_device);
502 			return -EINVAL;
503 		}
504 	}
505 
506 	pts->table_lower = table_mem;
507 	return 0;
508 }
509 
510 struct pt_iommu_map_args {
511 	struct iommu_iotlb_gather *iotlb_gather;
512 	struct pt_write_attrs attrs;
513 	pt_oaddr_t oa;
514 	unsigned int leaf_pgsize_lg2;
515 	unsigned int leaf_level;
516 	pt_vaddr_t num_leaves;
517 };
518 
519 /*
520  * This will recursively check any tables in the block to validate they are
521  * empty and then free them through the gather.
522  */
523 static int clear_contig(const struct pt_state *start_pts,
524 			struct iommu_iotlb_gather *iotlb_gather,
525 			unsigned int step, unsigned int pgsize_lg2)
526 {
527 	struct pt_iommu *iommu_table =
528 		iommu_from_common(start_pts->range->common);
529 	struct pt_range range = *start_pts->range;
530 	struct pt_state pts =
531 		pt_init(&range, start_pts->level, start_pts->table);
532 	struct pt_iommu_collect_args collect = {
533 		.check_mapped = true,
534 		.pending.iotlb_gather = iotlb_gather,
535 		.pending.free_list = IOMMU_PAGES_LIST_INIT(
536 			collect.pending.free_list),
537 	};
538 	int ret;
539 
540 	pts.index = start_pts->index;
541 	pts.end_index = start_pts->index + step;
542 	for (; _pt_iter_load(&pts); pt_next_entry(&pts)) {
543 		if (pts.type == PT_ENTRY_TABLE) {
544 			ret = pt_walk_descend_all(&pts, __collect_tables,
545 						  &collect);
546 			if (ret)
547 				return ret;
548 
549 			/*
550 			 * The table item must be cleared before we can update
551 			 * the gather
552 			 */
553 			pt_clear_entries(&pts, ilog2(1));
554 			flush_writes_item(&pts);
555 
556 			gather_add_table(&collect.pending, &pts,
557 					 pts.table_lower);
558 			gather_range_pending(
559 				&collect.pending, iommu_table, range.va,
560 				log2_to_int(pt_table_item_lg2sz(&pts)));
561 		} else if (pts.type != PT_ENTRY_EMPTY) {
562 			return -EADDRINUSE;
563 		}
564 	}
565 	return 0;
566 }
567 
568 static int __map_range_leaf(struct pt_range *range, void *arg,
569 			    unsigned int level, struct pt_table_p *table)
570 {
571 	struct pt_iommu *iommu_table = iommu_from_common(range->common);
572 	struct pt_state pts = pt_init(range, level, table);
573 	struct pt_iommu_map_args *map = arg;
574 	unsigned int leaf_pgsize_lg2 = map->leaf_pgsize_lg2;
575 	unsigned int leaves_avail;
576 	unsigned int start_index;
577 	pt_oaddr_t oa = map->oa;
578 	pt_vaddr_t num_leaves;
579 	unsigned int orig_end;
580 	unsigned int step_lg2;
581 	pt_vaddr_t last_va;
582 	unsigned int step;
583 	bool need_contig;
584 	int ret = 0;
585 
586 	PT_WARN_ON(map->leaf_level != level);
587 	PT_WARN_ON(!pt_can_have_leaf(&pts));
588 
589 	step_lg2 = leaf_pgsize_lg2 - pt_table_item_lg2sz(&pts);
590 	step = log2_to_int_t(unsigned int, step_lg2);
591 	need_contig = step_lg2 != 0;
592 
593 	_pt_iter_first(&pts);
594 	start_index = pts.index;
595 	orig_end = pts.end_index;
596 	leaves_avail =
597 		log2_div_t(unsigned int, pts.end_index - pts.index, step_lg2);
598 	if (map->num_leaves <= leaves_avail) {
599 		/* Need to stop in the middle of the table to change sizes */
600 		pts.end_index = pts.index + log2_mul(map->num_leaves, step_lg2);
601 		num_leaves = 0;
602 	} else {
603 		num_leaves = map->num_leaves - leaves_avail;
604 	}
605 
606 	PT_WARN_ON(
607 		log2_mod_t(unsigned int, pts.end_index - pts.index, step_lg2));
608 	do {
609 		pts.type = pt_load_entry_raw(&pts);
610 		if (pts.type != PT_ENTRY_EMPTY || need_contig) {
611 			if (pts.index != start_index)
612 				pt_index_to_va(&pts);
613 			ret = clear_contig(&pts, map->iotlb_gather, step,
614 					   leaf_pgsize_lg2);
615 			if (ret)
616 				break;
617 		}
618 
619 		if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) {
620 			pt_index_to_va(&pts);
621 			PT_WARN_ON(compute_best_pgsize(&pts, oa) !=
622 				   leaf_pgsize_lg2);
623 		}
624 		pt_install_leaf_entry(&pts, oa, leaf_pgsize_lg2, &map->attrs);
625 
626 		oa += log2_to_int(leaf_pgsize_lg2);
627 		pts.index += step;
628 	} while (pts.index < pts.end_index);
629 
630 	flush_writes_range(&pts, start_index, pts.index);
631 
632 	map->oa = oa;
633 	map->num_leaves = num_leaves;
634 	if (ret || num_leaves)
635 		return ret;
636 
637 	/* range->va is not valid if we reached the end of the table */
638 	pts.index -= step;
639 	pt_index_to_va(&pts);
640 	pts.index += step;
641 	last_va = range->va + log2_to_int(leaf_pgsize_lg2);
642 
643 	if (last_va - 1 == range->last_va) {
644 		PT_WARN_ON(pts.index != orig_end);
645 		return 0;
646 	}
647 
648 	/*
649 	 * Reached a point where the page size changed, compute the new
650 	 * parameters.
651 	 */
652 	map->leaf_pgsize_lg2 = pt_compute_best_pgsize(
653 		iommu_table->domain.pgsize_bitmap, last_va, range->last_va, oa);
654 	map->leaf_level =
655 		pt_pgsz_lg2_to_level(range->common, map->leaf_pgsize_lg2);
656 	map->num_leaves = pt_pgsz_count(iommu_table->domain.pgsize_bitmap,
657 					last_va, range->last_va, oa,
658 					map->leaf_pgsize_lg2);
659 
660 	/* Didn't finish this table level, caller will repeat it */
661 	if (pts.index != orig_end) {
662 		if (pts.index != start_index)
663 			pt_index_to_va(&pts);
664 		return -EAGAIN;
665 	}
666 	return 0;
667 }
668 
669 static int __map_range(struct pt_range *range, void *arg, unsigned int level,
670 		       struct pt_table_p *table)
671 {
672 	struct pt_state pts = pt_init(range, level, table);
673 	struct pt_iommu_map_args *map = arg;
674 	int ret;
675 
676 	PT_WARN_ON(map->leaf_level == level);
677 	PT_WARN_ON(!pt_can_have_table(&pts));
678 
679 	_pt_iter_first(&pts);
680 
681 	/* Descend to a child table */
682 	do {
683 		pts.type = pt_load_entry_raw(&pts);
684 
685 		if (pts.type != PT_ENTRY_TABLE) {
686 			if (pts.type != PT_ENTRY_EMPTY)
687 				return -EADDRINUSE;
688 			ret = pt_iommu_new_table(&pts, &map->attrs);
689 			/* EAGAIN on a race will loop again */
690 			if (ret)
691 				return ret;
692 		} else {
693 			pts.table_lower = pt_table_ptr(&pts);
694 			/*
695 			 * Racing with a shared pt_iommu_new_table()? The other
696 			 * thread is still flushing the cache, so we have to
697 			 * also flush it to ensure that when our thread's map
698 			 * completes all the table items leading to our mapping
699 			 * are visible.
700 			 *
701 			 * This requires the pt_set_bit_release() to be a
702 			 * release of the cache flush so that this can acquire
703 			 * visibility at the iommu.
704 			 */
705 			if (pts_feature(&pts, PT_FEAT_DMA_INCOHERENT) &&
706 			    !pt_test_sw_bit_acquire(&pts,
707 						    SW_BIT_CACHE_FLUSH_DONE))
708 				flush_writes_item(&pts);
709 		}
710 
711 		/*
712 		 * The already present table can possibly be shared with another
713 		 * concurrent map.
714 		 */
715 		do {
716 			if (map->leaf_level == level - 1)
717 				ret = pt_descend(&pts, arg, __map_range_leaf);
718 			else
719 				ret = pt_descend(&pts, arg, __map_range);
720 		} while (ret == -EAGAIN);
721 		if (ret)
722 			return ret;
723 
724 		pts.index++;
725 		pt_index_to_va(&pts);
726 		if (pts.index >= pts.end_index)
727 			break;
728 
729 		/*
730 		 * This level is currently running __map_range_leaf() which is
731 		 * not correct if the target level has been updated to this
732 		 * level. Have the caller invoke __map_range_leaf.
733 		 */
734 		if (map->leaf_level == level)
735 			return -EAGAIN;
736 	} while (true);
737 	return 0;
738 }
739 
740 /*
741  * Fast path for the easy case of mapping a 4k page to an already allocated
742  * table. This is a common workload. If it returns EAGAIN run the full algorithm
743  * instead.
744  */
745 static __always_inline int __do_map_single_page(struct pt_range *range,
746 						void *arg, unsigned int level,
747 						struct pt_table_p *table,
748 						pt_level_fn_t descend_fn)
749 {
750 	struct pt_state pts = pt_init(range, level, table);
751 	struct pt_iommu_map_args *map = arg;
752 
753 	pts.type = pt_load_single_entry(&pts);
754 	if (pts.level == 0) {
755 		if (pts.type != PT_ENTRY_EMPTY)
756 			return -EADDRINUSE;
757 		pt_install_leaf_entry(&pts, map->oa, PAGE_SHIFT,
758 				      &map->attrs);
759 		/* No flush, not used when incoherent */
760 		map->oa += PAGE_SIZE;
761 		return 0;
762 	}
763 	if (pts.type == PT_ENTRY_TABLE)
764 		return pt_descend(&pts, arg, descend_fn);
765 	/* Something else, use the slow path */
766 	return -EAGAIN;
767 }
768 PT_MAKE_LEVELS(__map_single_page, __do_map_single_page);
769 
770 /*
771  * Add a table to the top, increasing the top level as much as necessary to
772  * encompass range.
773  */
774 static int increase_top(struct pt_iommu *iommu_table, struct pt_range *range,
775 			struct pt_iommu_map_args *map)
776 {
777 	struct iommu_pages_list free_list = IOMMU_PAGES_LIST_INIT(free_list);
778 	struct pt_common *common = common_from_iommu(iommu_table);
779 	uintptr_t top_of_table = READ_ONCE(common->top_of_table);
780 	uintptr_t new_top_of_table = top_of_table;
781 	struct pt_table_p *table_mem;
782 	unsigned int new_level;
783 	spinlock_t *domain_lock;
784 	unsigned long flags;
785 	int ret;
786 
787 	while (true) {
788 		struct pt_range top_range =
789 			_pt_top_range(common, new_top_of_table);
790 		struct pt_state pts = pt_init_top(&top_range);
791 
792 		top_range.va = range->va;
793 		top_range.last_va = range->last_va;
794 
795 		if (!pt_check_range(&top_range) &&
796 		    map->leaf_level <= pts.level) {
797 			new_level = pts.level;
798 			break;
799 		}
800 
801 		pts.level++;
802 		if (pts.level > PT_MAX_TOP_LEVEL ||
803 		    pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2) {
804 			ret = -ERANGE;
805 			goto err_free;
806 		}
807 
808 		table_mem =
809 			table_alloc_top(common, _pt_top_set(NULL, pts.level),
810 					map->attrs.gfp, ALLOC_DEFER_COHERENT_FLUSH);
811 		if (IS_ERR(table_mem)) {
812 			ret = PTR_ERR(table_mem);
813 			goto err_free;
814 		}
815 		iommu_pages_list_add(&free_list, table_mem);
816 
817 		/* The new table links to the lower table always at index 0 */
818 		top_range.va = 0;
819 		top_range.top_level = pts.level;
820 		pts.table_lower = pts.table;
821 		pts.table = table_mem;
822 		pt_load_single_entry(&pts);
823 		PT_WARN_ON(pts.index != 0);
824 		pt_install_table(&pts, virt_to_phys(pts.table_lower),
825 				 &map->attrs);
826 		new_top_of_table = _pt_top_set(pts.table, pts.level);
827 	}
828 
829 	/*
830 	 * Avoid double flushing, flush it once after all pt_install_table()
831 	 */
832 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT)) {
833 		ret = iommu_pages_start_incoherent_list(
834 			&free_list, iommu_table->iommu_device);
835 		if (ret)
836 			goto err_free;
837 	}
838 
839 	/*
840 	 * top_of_table is write locked by the spinlock, but readers can use
841 	 * READ_ONCE() to get the value. Since we encode both the level and the
842 	 * pointer in one quanta the lockless reader will always see something
843 	 * valid. The HW must be updated to the new level under the spinlock
844 	 * before top_of_table is updated so that concurrent readers don't map
845 	 * into the new level until it is fully functional. If another thread
846 	 * already updated it while we were working then throw everything away
847 	 * and try again.
848 	 */
849 	domain_lock = iommu_table->driver_ops->get_top_lock(iommu_table);
850 	spin_lock_irqsave(domain_lock, flags);
851 	if (common->top_of_table != top_of_table ||
852 	    top_of_table == new_top_of_table) {
853 		spin_unlock_irqrestore(domain_lock, flags);
854 		ret = -EAGAIN;
855 		goto err_free;
856 	}
857 
858 	/*
859 	 * We do not issue any flushes for change_top on the expectation that
860 	 * any walk cache will not become a problem by adding another layer to
861 	 * the tree. Misses will rewalk from the updated top pointer, hits
862 	 * continue to be correct. Negative caching is fine too since all the
863 	 * new IOVA added by the new top is non-present.
864 	 */
865 	iommu_table->driver_ops->change_top(
866 		iommu_table, virt_to_phys(table_mem), new_level);
867 	WRITE_ONCE(common->top_of_table, new_top_of_table);
868 	spin_unlock_irqrestore(domain_lock, flags);
869 	return 0;
870 
871 err_free:
872 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
873 		iommu_pages_stop_incoherent_list(&free_list,
874 						 iommu_table->iommu_device);
875 	iommu_put_pages_list(&free_list);
876 	return ret;
877 }
878 
879 static int check_map_range(struct pt_iommu *iommu_table, struct pt_range *range,
880 			   struct pt_iommu_map_args *map)
881 {
882 	struct pt_common *common = common_from_iommu(iommu_table);
883 	int ret;
884 
885 	do {
886 		ret = pt_check_range(range);
887 		if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP))
888 			return ret;
889 
890 		if (!ret && map->leaf_level <= range->top_level)
891 			break;
892 
893 		ret = increase_top(iommu_table, range, map);
894 		if (ret && ret != -EAGAIN)
895 			return ret;
896 
897 		/* Reload the new top */
898 		*range = pt_make_range(common, range->va, range->last_va);
899 	} while (ret);
900 	PT_WARN_ON(pt_check_range(range));
901 	return 0;
902 }
903 
904 static int do_map(struct pt_range *range, struct pt_common *common,
905 		  bool single_page, struct pt_iommu_map_args *map)
906 {
907 	int ret;
908 
909 	/*
910 	 * The __map_single_page() fast path does not support DMA_INCOHERENT
911 	 * flushing to keep its .text small.
912 	 */
913 	if (single_page && !pt_feature(common, PT_FEAT_DMA_INCOHERENT)) {
914 
915 		ret = pt_walk_range(range, __map_single_page, map);
916 		if (ret != -EAGAIN)
917 			return ret;
918 		/* EAGAIN falls through to the full path */
919 	}
920 
921 	do {
922 		if (map->leaf_level == range->top_level)
923 			ret = pt_walk_range(range, __map_range_leaf, map);
924 		else
925 			ret = pt_walk_range(range, __map_range, map);
926 	} while (ret == -EAGAIN);
927 	return ret;
928 }
929 
930 static int NS(map_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
931 			 phys_addr_t paddr, dma_addr_t len, unsigned int prot,
932 			 gfp_t gfp, size_t *mapped)
933 {
934 	pt_vaddr_t pgsize_bitmap = iommu_table->domain.pgsize_bitmap;
935 	struct pt_common *common = common_from_iommu(iommu_table);
936 	struct iommu_iotlb_gather iotlb_gather;
937 	struct pt_iommu_map_args map = {
938 		.iotlb_gather = &iotlb_gather,
939 		.oa = paddr,
940 	};
941 	bool single_page = false;
942 	struct pt_range range;
943 	int ret;
944 
945 	iommu_iotlb_gather_init(&iotlb_gather);
946 
947 	if (WARN_ON(!(prot & (IOMMU_READ | IOMMU_WRITE))))
948 		return -EINVAL;
949 
950 	/* Check the paddr doesn't exceed what the table can store */
951 	if ((sizeof(pt_oaddr_t) < sizeof(paddr) &&
952 	     (pt_vaddr_t)paddr > PT_VADDR_MAX) ||
953 	    (common->max_oasz_lg2 != PT_VADDR_MAX_LG2 &&
954 	     oalog2_div(paddr, common->max_oasz_lg2)))
955 		return -ERANGE;
956 
957 	ret = pt_iommu_set_prot(common, &map.attrs, prot);
958 	if (ret)
959 		return ret;
960 	map.attrs.gfp = gfp;
961 
962 	ret = make_range_no_check(common, &range, iova, len);
963 	if (ret)
964 		return ret;
965 
966 	/* Calculate target page size and level for the leaves */
967 	if (pt_has_system_page_size(common) && len == PAGE_SIZE &&
968 		likely(pgsize_bitmap & PAGE_SIZE)) {
969 		if (log2_mod(iova | paddr, PAGE_SHIFT))
970 			return -ENXIO;
971 		map.leaf_pgsize_lg2 = PAGE_SHIFT;
972 		map.leaf_level = 0;
973 		map.num_leaves = 1;
974 		single_page = true;
975 	} else {
976 		map.leaf_pgsize_lg2 = pt_compute_best_pgsize(
977 			pgsize_bitmap, range.va, range.last_va, paddr);
978 		if (!map.leaf_pgsize_lg2)
979 			return -ENXIO;
980 		map.leaf_level =
981 			pt_pgsz_lg2_to_level(common, map.leaf_pgsize_lg2);
982 		map.num_leaves = pt_pgsz_count(pgsize_bitmap, range.va,
983 					       range.last_va, paddr,
984 					       map.leaf_pgsize_lg2);
985 	}
986 
987 	ret = check_map_range(iommu_table, &range, &map);
988 	if (ret)
989 		return ret;
990 
991 	PT_WARN_ON(map.leaf_level > range.top_level);
992 
993 	ret = do_map(&range, common, single_page, &map);
994 
995 	/*
996 	 * Table levels were freed and replaced with large items, flush any walk
997 	 * cache that may refer to the freed levels.
998 	 */
999 	if (!iommu_pages_list_empty(&iotlb_gather.freelist))
1000 		iommu_iotlb_sync(&iommu_table->domain, &iotlb_gather);
1001 
1002 	/* Bytes successfully mapped */
1003 	PT_WARN_ON(!ret && map.oa - paddr != len);
1004 	*mapped += map.oa - paddr;
1005 	return ret;
1006 }
1007 
1008 struct pt_unmap_args {
1009 	struct iommupt_pending_gather pending;
1010 	pt_vaddr_t unmapped;
1011 };
1012 
1013 static __maybe_unused int __unmap_range(struct pt_range *range, void *arg,
1014 					unsigned int level,
1015 					struct pt_table_p *table)
1016 {
1017 	struct pt_state pts = pt_init(range, level, table);
1018 	unsigned int flush_start_index = UINT_MAX;
1019 	unsigned int flush_end_index = UINT_MAX;
1020 	struct pt_unmap_args *unmap = arg;
1021 	unsigned int num_oas = 0;
1022 	unsigned int start_index;
1023 	int ret = 0;
1024 
1025 	_pt_iter_first(&pts);
1026 	start_index = pts.index;
1027 	pts.type = pt_load_entry_raw(&pts);
1028 	/*
1029 	 * A starting index is in the middle of a contiguous entry
1030 	 *
1031 	 * The IOMMU API does not require drivers to support unmapping parts of
1032 	 * large pages. Long ago VFIO would try to split maps but the current
1033 	 * version never does.
1034 	 *
1035 	 * Instead when unmap reaches a partial unmap of the start of a large
1036 	 * IOPTE it should remove the entire IOPTE and return that size to the
1037 	 * caller.
1038 	 */
1039 	if (pts.type == PT_ENTRY_OA) {
1040 		if (log2_mod(range->va, pt_entry_oa_lg2sz(&pts)))
1041 			return -EINVAL;
1042 		/* Micro optimization */
1043 		goto start_oa;
1044 	}
1045 
1046 	do {
1047 		if (pts.type != PT_ENTRY_OA) {
1048 			bool fully_covered;
1049 
1050 			if (pts.type != PT_ENTRY_TABLE) {
1051 				ret = -EINVAL;
1052 				break;
1053 			}
1054 
1055 			if (pts.index != start_index)
1056 				pt_index_to_va(&pts);
1057 			pts.table_lower = pt_table_ptr(&pts);
1058 
1059 			fully_covered = pt_entry_fully_covered(
1060 				&pts, pt_table_item_lg2sz(&pts));
1061 
1062 			ret = pt_descend(&pts, arg, __unmap_range);
1063 			if (ret)
1064 				break;
1065 
1066 			/*
1067 			 * If the unmapping range fully covers the table then we
1068 			 * can free it as well. The clear is delayed until we
1069 			 * succeed in clearing the lower table levels.
1070 			 */
1071 			if (fully_covered) {
1072 				gather_add_table(&unmap->pending, &pts,
1073 						 pts.table_lower);
1074 				pt_clear_entries(&pts, ilog2(1));
1075 				if (pts.index < flush_start_index)
1076 					flush_start_index = pts.index;
1077 				flush_end_index = pts.index + 1;
1078 			}
1079 			pts.index++;
1080 		} else {
1081 			unsigned int num_contig_lg2;
1082 start_oa:
1083 			/*
1084 			 * If the caller requested an last that falls within a
1085 			 * single entry then the entire entry is unmapped and
1086 			 * the length returned will be larger than requested.
1087 			 */
1088 			num_contig_lg2 = pt_entry_num_contig_lg2(&pts);
1089 			pt_clear_entries(&pts, num_contig_lg2);
1090 			gather_add_leaf(&unmap->pending, &pts);
1091 			num_oas += log2_to_int(num_contig_lg2);
1092 			if (pts.index < flush_start_index)
1093 				flush_start_index = pts.index;
1094 			pts.index += log2_to_int(num_contig_lg2);
1095 			flush_end_index = pts.index;
1096 		}
1097 		if (pts.index >= pts.end_index)
1098 			break;
1099 		pts.type = pt_load_entry_raw(&pts);
1100 	} while (true);
1101 
1102 	unmap->unmapped += log2_mul(num_oas, pt_table_item_lg2sz(&pts));
1103 	if (flush_start_index != flush_end_index)
1104 		flush_writes_range(&pts, flush_start_index, flush_end_index);
1105 
1106 	return ret;
1107 }
1108 
1109 static size_t NS(unmap_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
1110 			      dma_addr_t len,
1111 			      struct iommu_iotlb_gather *iotlb_gather)
1112 {
1113 	struct pt_unmap_args unmap = {
1114 		.pending.iotlb_gather = iotlb_gather,
1115 		.pending.free_list = IOMMU_PAGES_LIST_INIT(
1116 			unmap.pending.free_list),
1117 	};
1118 	struct pt_range range;
1119 	int ret;
1120 
1121 	ret = make_range(common_from_iommu(iommu_table), &range, iova, len);
1122 	if (ret)
1123 		return 0;
1124 
1125 	pt_walk_range(&range, __unmap_range, &unmap);
1126 
1127 	gather_range_pending(&unmap.pending, iommu_table, iova, unmap.unmapped);
1128 
1129 	return unmap.unmapped;
1130 }
1131 
1132 static void NS(get_info)(struct pt_iommu *iommu_table,
1133 			 struct pt_iommu_info *info)
1134 {
1135 	struct pt_common *common = common_from_iommu(iommu_table);
1136 	struct pt_range range = pt_top_range(common);
1137 	struct pt_state pts = pt_init_top(&range);
1138 	pt_vaddr_t pgsize_bitmap = 0;
1139 
1140 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP)) {
1141 		for (pts.level = 0; pts.level <= PT_MAX_TOP_LEVEL;
1142 		     pts.level++) {
1143 			if (pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2)
1144 				break;
1145 			pgsize_bitmap |= pt_possible_sizes(&pts);
1146 		}
1147 	} else {
1148 		for (pts.level = 0; pts.level <= range.top_level; pts.level++)
1149 			pgsize_bitmap |= pt_possible_sizes(&pts);
1150 	}
1151 
1152 	/*
1153 	 * Hide page sizes larger than the maximum. -1 because a whole table
1154 	 * pgsize is not allowed
1155 	 */
1156 	info->pgsize_bitmap = log2_mod(pgsize_bitmap, common->max_vasz_lg2 - 1);
1157 	info->pgsize_bitmap = oalog2_mod(info->pgsize_bitmap, common->max_oasz_lg2);
1158 }
1159 
1160 static void NS(deinit)(struct pt_iommu *iommu_table)
1161 {
1162 	struct pt_common *common = common_from_iommu(iommu_table);
1163 	struct pt_range range = pt_all_range(common);
1164 	struct pt_iommu_collect_args collect = {
1165 		.pending.free_list = IOMMU_PAGES_LIST_INIT(
1166 			collect.pending.free_list),
1167 	};
1168 
1169 	iommu_pages_list_add(&collect.pending.free_list, range.top_table);
1170 	pt_walk_range(&range, __collect_tables, &collect);
1171 
1172 	/*
1173 	 * The driver has to already have fenced the HW access to the page table
1174 	 * and invalidated any caching referring to this memory.
1175 	 */
1176 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
1177 		iommu_pages_stop_incoherent_list(&collect.pending.free_list,
1178 						 iommu_table->iommu_device);
1179 	iommu_put_pages_list(&collect.pending.free_list);
1180 }
1181 
1182 static const struct pt_iommu_ops NS(ops) = {
1183 	.map_range = NS(map_range),
1184 	.unmap_range = NS(unmap_range),
1185 #if IS_ENABLED(CONFIG_IOMMUFD_DRIVER) && defined(pt_entry_is_write_dirty) && \
1186 	IS_ENABLED(CONFIG_IOMMUFD_TEST) && defined(pt_entry_make_write_dirty)
1187 	.set_dirty = NS(set_dirty),
1188 #endif
1189 	.get_info = NS(get_info),
1190 	.deinit = NS(deinit),
1191 };
1192 
1193 static int pt_init_common(struct pt_common *common)
1194 {
1195 	struct pt_range top_range = pt_top_range(common);
1196 
1197 	if (PT_WARN_ON(top_range.top_level > PT_MAX_TOP_LEVEL))
1198 		return -EINVAL;
1199 
1200 	if (top_range.max_vasz_lg2 == PT_VADDR_MAX_LG2)
1201 		common->features |= BIT(PT_FEAT_FULL_VA);
1202 
1203 	/* Requested features must match features compiled into this format */
1204 	if ((common->features & ~(unsigned int)PT_SUPPORTED_FEATURES) ||
1205 	    (!IS_ENABLED(CONFIG_DEBUG_GENERIC_PT) &&
1206 	     (common->features & PT_FORCE_ENABLED_FEATURES) !=
1207 		     PT_FORCE_ENABLED_FEATURES))
1208 		return -EOPNOTSUPP;
1209 
1210 	/*
1211 	 * Check if the top level of the page table is too small to hold the
1212 	 * specified maxvasz.
1213 	 */
1214 	if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP) &&
1215 	    top_range.top_level != PT_MAX_TOP_LEVEL) {
1216 		struct pt_state pts = { .range = &top_range,
1217 					.level = top_range.top_level };
1218 
1219 		if (common->max_vasz_lg2 >
1220 		    pt_num_items_lg2(&pts) + pt_table_item_lg2sz(&pts))
1221 			return -EOPNOTSUPP;
1222 	}
1223 
1224 	if (common->max_oasz_lg2 == 0)
1225 		common->max_oasz_lg2 = pt_max_oa_lg2(common);
1226 	else
1227 		common->max_oasz_lg2 = min(common->max_oasz_lg2,
1228 					   pt_max_oa_lg2(common));
1229 	return 0;
1230 }
1231 
1232 static int pt_iommu_init_domain(struct pt_iommu *iommu_table,
1233 				struct iommu_domain *domain)
1234 {
1235 	struct pt_common *common = common_from_iommu(iommu_table);
1236 	struct pt_iommu_info info;
1237 	struct pt_range range;
1238 
1239 	NS(get_info)(iommu_table, &info);
1240 
1241 	domain->type = __IOMMU_DOMAIN_PAGING;
1242 	domain->pgsize_bitmap = info.pgsize_bitmap;
1243 	domain->is_iommupt = true;
1244 
1245 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP))
1246 		range = _pt_top_range(common,
1247 				      _pt_top_set(NULL, PT_MAX_TOP_LEVEL));
1248 	else
1249 		range = pt_top_range(common);
1250 
1251 	/* A 64-bit high address space table on a 32-bit system cannot work. */
1252 	domain->geometry.aperture_start = (unsigned long)range.va;
1253 	if ((pt_vaddr_t)domain->geometry.aperture_start != range.va)
1254 		return -EOVERFLOW;
1255 
1256 	/*
1257 	 * The aperture is limited to what the API can do after considering all
1258 	 * the different types dma_addr_t/unsigned long/pt_vaddr_t that are used
1259 	 * to store a VA. Set the aperture to something that is valid for all
1260 	 * cases. Saturate instead of truncate the end if the types are smaller
1261 	 * than the top range. aperture_end should be called aperture_last.
1262 	 */
1263 	domain->geometry.aperture_end = (unsigned long)range.last_va;
1264 	if ((pt_vaddr_t)domain->geometry.aperture_end != range.last_va) {
1265 		domain->geometry.aperture_end = ULONG_MAX;
1266 		domain->pgsize_bitmap &= ULONG_MAX;
1267 	}
1268 	domain->geometry.force_aperture = true;
1269 
1270 	return 0;
1271 }
1272 
1273 static void pt_iommu_zero(struct pt_iommu_table *fmt_table)
1274 {
1275 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1276 	struct pt_iommu cfg = *iommu_table;
1277 
1278 	static_assert(offsetof(struct pt_iommu_table, iommu.domain) == 0);
1279 	memset_after(fmt_table, 0, iommu.domain);
1280 
1281 	/* The caller can initialize some of these values */
1282 	iommu_table->iommu_device = cfg.iommu_device;
1283 	iommu_table->driver_ops = cfg.driver_ops;
1284 	iommu_table->nid = cfg.nid;
1285 }
1286 
1287 #define pt_iommu_table_cfg CONCATENATE(pt_iommu_table, _cfg)
1288 #define pt_iommu_init CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), init)
1289 
1290 int pt_iommu_init(struct pt_iommu_table *fmt_table,
1291 		  const struct pt_iommu_table_cfg *cfg, gfp_t gfp)
1292 {
1293 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1294 	struct pt_common *common = common_from_iommu(iommu_table);
1295 	struct pt_table_p *table_mem;
1296 	int ret;
1297 
1298 	if (cfg->common.hw_max_vasz_lg2 > PT_MAX_VA_ADDRESS_LG2 ||
1299 	    !cfg->common.hw_max_vasz_lg2 || !cfg->common.hw_max_oasz_lg2)
1300 		return -EINVAL;
1301 
1302 	pt_iommu_zero(fmt_table);
1303 	common->features = cfg->common.features;
1304 	common->max_vasz_lg2 = cfg->common.hw_max_vasz_lg2;
1305 	common->max_oasz_lg2 = cfg->common.hw_max_oasz_lg2;
1306 	ret = pt_iommu_fmt_init(fmt_table, cfg);
1307 	if (ret)
1308 		return ret;
1309 
1310 	if (cfg->common.hw_max_oasz_lg2 > pt_max_oa_lg2(common))
1311 		return -EINVAL;
1312 
1313 	ret = pt_init_common(common);
1314 	if (ret)
1315 		return ret;
1316 
1317 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP) &&
1318 	    WARN_ON(!iommu_table->driver_ops ||
1319 		    !iommu_table->driver_ops->change_top ||
1320 		    !iommu_table->driver_ops->get_top_lock))
1321 		return -EINVAL;
1322 
1323 	if (pt_feature(common, PT_FEAT_SIGN_EXTEND) &&
1324 	    (pt_feature(common, PT_FEAT_FULL_VA) ||
1325 	     pt_feature(common, PT_FEAT_DYNAMIC_TOP)))
1326 		return -EINVAL;
1327 
1328 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) &&
1329 	    WARN_ON(!iommu_table->iommu_device))
1330 		return -EINVAL;
1331 
1332 	ret = pt_iommu_init_domain(iommu_table, &iommu_table->domain);
1333 	if (ret)
1334 		return ret;
1335 
1336 	table_mem = table_alloc_top(common, common->top_of_table, gfp,
1337 				    ALLOC_NORMAL);
1338 	if (IS_ERR(table_mem))
1339 		return PTR_ERR(table_mem);
1340 	pt_top_set(common, table_mem, pt_top_get_level(common));
1341 
1342 	/* Must be last, see pt_iommu_deinit() */
1343 	iommu_table->ops = &NS(ops);
1344 	return 0;
1345 }
1346 EXPORT_SYMBOL_NS_GPL(pt_iommu_init, "GENERIC_PT_IOMMU");
1347 
1348 #ifdef pt_iommu_fmt_hw_info
1349 #define pt_iommu_table_hw_info CONCATENATE(pt_iommu_table, _hw_info)
1350 #define pt_iommu_hw_info CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), hw_info)
1351 void pt_iommu_hw_info(struct pt_iommu_table *fmt_table,
1352 		      struct pt_iommu_table_hw_info *info)
1353 {
1354 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1355 	struct pt_common *common = common_from_iommu(iommu_table);
1356 	struct pt_range top_range = pt_top_range(common);
1357 
1358 	pt_iommu_fmt_hw_info(fmt_table, &top_range, info);
1359 }
1360 EXPORT_SYMBOL_NS_GPL(pt_iommu_hw_info, "GENERIC_PT_IOMMU");
1361 #endif
1362 
1363 MODULE_LICENSE("GPL");
1364 MODULE_DESCRIPTION("IOMMU Page table implementation for " __stringify(PTPFX_RAW));
1365 MODULE_IMPORT_NS("GENERIC_PT");
1366 /* For iommu_dirty_bitmap_record() */
1367 MODULE_IMPORT_NS("IOMMUFD");
1368 
1369 #endif  /* __GENERIC_PT_IOMMU_PT_H */
1370