xref: /linux/drivers/iommu/generic_pt/iommu_pt.h (revision ce5cfb0fa20dc6454da039612e34325b7b4a8243)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
4  *
5  * "Templated C code" for implementing the iommu operations for page tables.
6  * This is compiled multiple times, over all the page table formats to pick up
7  * the per-format definitions.
8  */
9 #ifndef __GENERIC_PT_IOMMU_PT_H
10 #define __GENERIC_PT_IOMMU_PT_H
11 
12 #include "pt_iter.h"
13 
14 #include <linux/export.h>
15 #include <linux/iommu.h>
16 #include "../iommu-pages.h"
17 #include <linux/cleanup.h>
18 #include <linux/dma-mapping.h>
19 
20 enum {
21 	SW_BIT_CACHE_FLUSH_DONE = 0,
22 };
23 
flush_writes_range(const struct pt_state * pts,unsigned int start_index,unsigned int end_index)24 static void flush_writes_range(const struct pt_state *pts,
25 			       unsigned int start_index, unsigned int end_index)
26 {
27 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT))
28 		iommu_pages_flush_incoherent(
29 			iommu_from_common(pts->range->common)->iommu_device,
30 			pts->table, start_index * PT_ITEM_WORD_SIZE,
31 			(end_index - start_index) * PT_ITEM_WORD_SIZE);
32 }
33 
flush_writes_item(const struct pt_state * pts)34 static void flush_writes_item(const struct pt_state *pts)
35 {
36 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT))
37 		iommu_pages_flush_incoherent(
38 			iommu_from_common(pts->range->common)->iommu_device,
39 			pts->table, pts->index * PT_ITEM_WORD_SIZE,
40 			PT_ITEM_WORD_SIZE);
41 }
42 
gather_range_pages(struct iommu_iotlb_gather * iotlb_gather,struct pt_iommu * iommu_table,pt_vaddr_t iova,pt_vaddr_t len,struct iommu_pages_list * free_list)43 static void gather_range_pages(struct iommu_iotlb_gather *iotlb_gather,
44 			       struct pt_iommu *iommu_table, pt_vaddr_t iova,
45 			       pt_vaddr_t len,
46 			       struct iommu_pages_list *free_list)
47 {
48 	struct pt_common *common = common_from_iommu(iommu_table);
49 
50 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
51 		iommu_pages_stop_incoherent_list(free_list,
52 						 iommu_table->iommu_device);
53 
54 	if (pt_feature(common, PT_FEAT_FLUSH_RANGE_NO_GAPS) &&
55 	    iommu_iotlb_gather_is_disjoint(iotlb_gather, iova, len)) {
56 		iommu_iotlb_sync(&iommu_table->domain, iotlb_gather);
57 		/*
58 		 * Note that the sync frees the gather's free list, so we must
59 		 * not have any pages on that list that are covered by iova/len
60 		 */
61 	} else if (pt_feature(common, PT_FEAT_FLUSH_RANGE)) {
62 		iommu_iotlb_gather_add_range(iotlb_gather, iova, len);
63 	}
64 
65 	iommu_pages_list_splice(free_list, &iotlb_gather->freelist);
66 }
67 
68 #define DOMAIN_NS(op) CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), op)
69 
make_range_ul(struct pt_common * common,struct pt_range * range,unsigned long iova,unsigned long len)70 static int make_range_ul(struct pt_common *common, struct pt_range *range,
71 			 unsigned long iova, unsigned long len)
72 {
73 	unsigned long last;
74 
75 	if (unlikely(len == 0))
76 		return -EINVAL;
77 
78 	if (check_add_overflow(iova, len - 1, &last))
79 		return -EOVERFLOW;
80 
81 	*range = pt_make_range(common, iova, last);
82 	if (sizeof(iova) > sizeof(range->va)) {
83 		if (unlikely(range->va != iova || range->last_va != last))
84 			return -EOVERFLOW;
85 	}
86 	return 0;
87 }
88 
make_range_u64(struct pt_common * common,struct pt_range * range,u64 iova,u64 len)89 static __maybe_unused int make_range_u64(struct pt_common *common,
90 					 struct pt_range *range, u64 iova,
91 					 u64 len)
92 {
93 	if (unlikely(iova > ULONG_MAX || len > ULONG_MAX))
94 		return -EOVERFLOW;
95 	return make_range_ul(common, range, iova, len);
96 }
97 
98 /*
99  * Some APIs use unsigned long, while othersuse dma_addr_t as the type. Dispatch
100  * to the correct validation based on the type.
101  */
102 #define make_range_no_check(common, range, iova, len)                   \
103 	({                                                              \
104 		int ret;                                                \
105 		if (sizeof(iova) > sizeof(unsigned long) ||             \
106 		    sizeof(len) > sizeof(unsigned long))                \
107 			ret = make_range_u64(common, range, iova, len); \
108 		else                                                    \
109 			ret = make_range_ul(common, range, iova, len);  \
110 		ret;                                                    \
111 	})
112 
113 #define make_range(common, range, iova, len)                             \
114 	({                                                               \
115 		int ret = make_range_no_check(common, range, iova, len); \
116 		if (!ret)                                                \
117 			ret = pt_check_range(range);                     \
118 		ret;                                                     \
119 	})
120 
compute_best_pgsize(struct pt_state * pts,pt_oaddr_t oa)121 static inline unsigned int compute_best_pgsize(struct pt_state *pts,
122 					       pt_oaddr_t oa)
123 {
124 	struct pt_iommu *iommu_table = iommu_from_common(pts->range->common);
125 
126 	if (!pt_can_have_leaf(pts))
127 		return 0;
128 
129 	/*
130 	 * The page size is limited by the domain's bitmap. This allows the core
131 	 * code to reduce the supported page sizes by changing the bitmap.
132 	 */
133 	return pt_compute_best_pgsize(pt_possible_sizes(pts) &
134 					      iommu_table->domain.pgsize_bitmap,
135 				      pts->range->va, pts->range->last_va, oa);
136 }
137 
__do_iova_to_phys(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table,pt_level_fn_t descend_fn)138 static __always_inline int __do_iova_to_phys(struct pt_range *range, void *arg,
139 					     unsigned int level,
140 					     struct pt_table_p *table,
141 					     pt_level_fn_t descend_fn)
142 {
143 	struct pt_state pts = pt_init(range, level, table);
144 	pt_oaddr_t *res = arg;
145 
146 	switch (pt_load_single_entry(&pts)) {
147 	case PT_ENTRY_EMPTY:
148 		return -ENOENT;
149 	case PT_ENTRY_TABLE:
150 		return pt_descend(&pts, arg, descend_fn);
151 	case PT_ENTRY_OA:
152 		*res = pt_entry_oa_exact(&pts);
153 		return 0;
154 	}
155 	return -ENOENT;
156 }
157 PT_MAKE_LEVELS(__iova_to_phys, __do_iova_to_phys);
158 
159 /**
160  * iova_to_phys() - Return the output address for the given IOVA
161  * @domain: Table to query
162  * @iova: IO virtual address to query
163  *
164  * Determine the output address from the given IOVA. @iova may have any
165  * alignment, the returned physical will be adjusted with any sub page offset.
166  *
167  * Context: The caller must hold a read range lock that includes @iova.
168  *
169  * Return: 0 if there is no translation for the given iova.
170  */
DOMAIN_NS(iova_to_phys)171 phys_addr_t DOMAIN_NS(iova_to_phys)(struct iommu_domain *domain,
172 				    dma_addr_t iova)
173 {
174 	struct pt_iommu *iommu_table =
175 		container_of(domain, struct pt_iommu, domain);
176 	struct pt_range range;
177 	pt_oaddr_t res;
178 	int ret;
179 
180 	ret = make_range(common_from_iommu(iommu_table), &range, iova, 1);
181 	if (ret)
182 		return ret;
183 
184 	ret = pt_walk_range(&range, __iova_to_phys, &res);
185 	/* PHYS_ADDR_MAX would be a better error code */
186 	if (ret)
187 		return 0;
188 	return res;
189 }
190 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(iova_to_phys), "GENERIC_PT_IOMMU");
191 
192 struct pt_iommu_dirty_args {
193 	struct iommu_dirty_bitmap *dirty;
194 	unsigned int flags;
195 };
196 
record_dirty(struct pt_state * pts,struct pt_iommu_dirty_args * dirty,unsigned int num_contig_lg2)197 static void record_dirty(struct pt_state *pts,
198 			 struct pt_iommu_dirty_args *dirty,
199 			 unsigned int num_contig_lg2)
200 {
201 	pt_vaddr_t dirty_len;
202 
203 	if (num_contig_lg2 != ilog2(1)) {
204 		unsigned int index = pts->index;
205 		unsigned int end_index = log2_set_mod_max_t(
206 			unsigned int, pts->index, num_contig_lg2);
207 
208 		/* Adjust for being contained inside a contiguous page */
209 		end_index = min(end_index, pts->end_index);
210 		dirty_len = (end_index - index) *
211 				log2_to_int(pt_table_item_lg2sz(pts));
212 	} else {
213 		dirty_len = log2_to_int(pt_table_item_lg2sz(pts));
214 	}
215 
216 	if (dirty->dirty->bitmap)
217 		iova_bitmap_set(dirty->dirty->bitmap, pts->range->va,
218 				dirty_len);
219 
220 	if (!(dirty->flags & IOMMU_DIRTY_NO_CLEAR)) {
221 		/*
222 		 * No write log required because DMA incoherence and atomic
223 		 * dirty tracking bits can't work together
224 		 */
225 		pt_entry_make_write_clean(pts);
226 		iommu_iotlb_gather_add_range(dirty->dirty->gather,
227 					     pts->range->va, dirty_len);
228 	}
229 }
230 
__read_and_clear_dirty(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)231 static inline int __read_and_clear_dirty(struct pt_range *range, void *arg,
232 					 unsigned int level,
233 					 struct pt_table_p *table)
234 {
235 	struct pt_state pts = pt_init(range, level, table);
236 	struct pt_iommu_dirty_args *dirty = arg;
237 	int ret;
238 
239 	for_each_pt_level_entry(&pts) {
240 		if (pts.type == PT_ENTRY_TABLE) {
241 			ret = pt_descend(&pts, arg, __read_and_clear_dirty);
242 			if (ret)
243 				return ret;
244 			continue;
245 		}
246 		if (pts.type == PT_ENTRY_OA && pt_entry_is_write_dirty(&pts))
247 			record_dirty(&pts, dirty,
248 				     pt_entry_num_contig_lg2(&pts));
249 	}
250 	return 0;
251 }
252 
253 /**
254  * read_and_clear_dirty() - Manipulate the HW set write dirty state
255  * @domain: Domain to manipulate
256  * @iova: IO virtual address to start
257  * @size: Length of the IOVA
258  * @flags: A bitmap of IOMMU_DIRTY_NO_CLEAR
259  * @dirty: Place to store the dirty bits
260  *
261  * Iterate over all the entries in the mapped range and record their write dirty
262  * status in iommu_dirty_bitmap. If IOMMU_DIRTY_NO_CLEAR is not specified then
263  * the entries will be left dirty, otherwise they are returned to being not
264  * write dirty.
265  *
266  * Context: The caller must hold a read range lock that includes @iova.
267  *
268  * Returns: -ERRNO on failure, 0 on success.
269  */
DOMAIN_NS(read_and_clear_dirty)270 int DOMAIN_NS(read_and_clear_dirty)(struct iommu_domain *domain,
271 				    unsigned long iova, size_t size,
272 				    unsigned long flags,
273 				    struct iommu_dirty_bitmap *dirty)
274 {
275 	struct pt_iommu *iommu_table =
276 		container_of(domain, struct pt_iommu, domain);
277 	struct pt_iommu_dirty_args dirty_args = {
278 		.dirty = dirty,
279 		.flags = flags,
280 	};
281 	struct pt_range range;
282 	int ret;
283 
284 #if !IS_ENABLED(CONFIG_IOMMUFD_DRIVER) || !defined(pt_entry_is_write_dirty)
285 	return -EOPNOTSUPP;
286 #endif
287 
288 	ret = make_range(common_from_iommu(iommu_table), &range, iova, size);
289 	if (ret)
290 		return ret;
291 
292 	ret = pt_walk_range(&range, __read_and_clear_dirty, &dirty_args);
293 	PT_WARN_ON(ret);
294 	return ret;
295 }
296 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(read_and_clear_dirty), "GENERIC_PT_IOMMU");
297 
__set_dirty(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)298 static inline int __set_dirty(struct pt_range *range, void *arg,
299 			      unsigned int level, struct pt_table_p *table)
300 {
301 	struct pt_state pts = pt_init(range, level, table);
302 
303 	switch (pt_load_single_entry(&pts)) {
304 	case PT_ENTRY_EMPTY:
305 		return -ENOENT;
306 	case PT_ENTRY_TABLE:
307 		return pt_descend(&pts, arg, __set_dirty);
308 	case PT_ENTRY_OA:
309 		if (!pt_entry_make_write_dirty(&pts))
310 			return -EAGAIN;
311 		return 0;
312 	}
313 	return -ENOENT;
314 }
315 
NS(set_dirty)316 static int __maybe_unused NS(set_dirty)(struct pt_iommu *iommu_table,
317 					dma_addr_t iova)
318 {
319 	struct pt_range range;
320 	int ret;
321 
322 	ret = make_range(common_from_iommu(iommu_table), &range, iova, 1);
323 	if (ret)
324 		return ret;
325 
326 	/*
327 	 * Note: There is no locking here yet, if the test suite races this it
328 	 * can crash. It should use RCU locking eventually.
329 	 */
330 	return pt_walk_range(&range, __set_dirty, NULL);
331 }
332 
333 struct pt_iommu_collect_args {
334 	struct iommu_pages_list free_list;
335 	/* Fail if any OAs are within the range */
336 	u8 check_mapped : 1;
337 };
338 
__collect_tables(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)339 static int __collect_tables(struct pt_range *range, void *arg,
340 			    unsigned int level, struct pt_table_p *table)
341 {
342 	struct pt_state pts = pt_init(range, level, table);
343 	struct pt_iommu_collect_args *collect = arg;
344 	int ret;
345 
346 	if (!collect->check_mapped && !pt_can_have_table(&pts))
347 		return 0;
348 
349 	for_each_pt_level_entry(&pts) {
350 		if (pts.type == PT_ENTRY_TABLE) {
351 			iommu_pages_list_add(&collect->free_list, pts.table_lower);
352 			ret = pt_descend(&pts, arg, __collect_tables);
353 			if (ret)
354 				return ret;
355 			continue;
356 		}
357 		if (pts.type == PT_ENTRY_OA && collect->check_mapped)
358 			return -EADDRINUSE;
359 	}
360 	return 0;
361 }
362 
363 enum alloc_mode {ALLOC_NORMAL, ALLOC_DEFER_COHERENT_FLUSH};
364 
365 /* Allocate a table, the empty table will be ready to be installed. */
_table_alloc(struct pt_common * common,size_t lg2sz,gfp_t gfp,enum alloc_mode mode)366 static inline struct pt_table_p *_table_alloc(struct pt_common *common,
367 					      size_t lg2sz, gfp_t gfp,
368 					      enum alloc_mode mode)
369 {
370 	struct pt_iommu *iommu_table = iommu_from_common(common);
371 	struct pt_table_p *table_mem;
372 
373 	table_mem = iommu_alloc_pages_node_sz(iommu_table->nid, gfp,
374 					      log2_to_int(lg2sz));
375 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) &&
376 	    mode == ALLOC_NORMAL) {
377 		int ret = iommu_pages_start_incoherent(
378 			table_mem, iommu_table->iommu_device);
379 		if (ret) {
380 			iommu_free_pages(table_mem);
381 			return ERR_PTR(ret);
382 		}
383 	}
384 	return table_mem;
385 }
386 
table_alloc_top(struct pt_common * common,uintptr_t top_of_table,gfp_t gfp,enum alloc_mode mode)387 static inline struct pt_table_p *table_alloc_top(struct pt_common *common,
388 						 uintptr_t top_of_table,
389 						 gfp_t gfp,
390 						 enum alloc_mode mode)
391 {
392 	/*
393 	 * Top doesn't need the free list or otherwise, so it technically
394 	 * doesn't need to use iommu pages. Use the API anyhow as the top is
395 	 * usually not smaller than PAGE_SIZE to keep things simple.
396 	 */
397 	return _table_alloc(common, pt_top_memsize_lg2(common, top_of_table),
398 			    gfp, mode);
399 }
400 
401 /* Allocate an interior table */
table_alloc(const struct pt_state * parent_pts,gfp_t gfp,enum alloc_mode mode)402 static inline struct pt_table_p *table_alloc(const struct pt_state *parent_pts,
403 					     gfp_t gfp, enum alloc_mode mode)
404 {
405 	struct pt_state child_pts =
406 		pt_init(parent_pts->range, parent_pts->level - 1, NULL);
407 
408 	return _table_alloc(parent_pts->range->common,
409 			    pt_num_items_lg2(&child_pts) +
410 				    ilog2(PT_ITEM_WORD_SIZE),
411 			    gfp, mode);
412 }
413 
pt_iommu_new_table(struct pt_state * pts,struct pt_write_attrs * attrs)414 static inline int pt_iommu_new_table(struct pt_state *pts,
415 				     struct pt_write_attrs *attrs)
416 {
417 	struct pt_table_p *table_mem;
418 	phys_addr_t phys;
419 
420 	/* Given PA/VA/length can't be represented */
421 	if (PT_WARN_ON(!pt_can_have_table(pts)))
422 		return -ENXIO;
423 
424 	table_mem = table_alloc(pts, attrs->gfp, ALLOC_NORMAL);
425 	if (IS_ERR(table_mem))
426 		return PTR_ERR(table_mem);
427 
428 	phys = virt_to_phys(table_mem);
429 	if (!pt_install_table(pts, phys, attrs)) {
430 		iommu_pages_free_incoherent(
431 			table_mem,
432 			iommu_from_common(pts->range->common)->iommu_device);
433 		return -EAGAIN;
434 	}
435 
436 	if (pts_feature(pts, PT_FEAT_DMA_INCOHERENT)) {
437 		flush_writes_item(pts);
438 		pt_set_sw_bit_release(pts, SW_BIT_CACHE_FLUSH_DONE);
439 	}
440 
441 	if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) {
442 		/*
443 		 * The underlying table can't store the physical table address.
444 		 * This happens when kunit testing tables outside their normal
445 		 * environment where a CPU might be limited.
446 		 */
447 		pt_load_single_entry(pts);
448 		if (PT_WARN_ON(pt_table_pa(pts) != phys)) {
449 			pt_clear_entries(pts, ilog2(1));
450 			iommu_pages_free_incoherent(
451 				table_mem, iommu_from_common(pts->range->common)
452 						   ->iommu_device);
453 			return -EINVAL;
454 		}
455 	}
456 
457 	pts->table_lower = table_mem;
458 	return 0;
459 }
460 
461 struct pt_iommu_map_args {
462 	struct iommu_iotlb_gather *iotlb_gather;
463 	struct pt_write_attrs attrs;
464 	pt_oaddr_t oa;
465 	unsigned int leaf_pgsize_lg2;
466 	unsigned int leaf_level;
467 };
468 
469 /*
470  * This will recursively check any tables in the block to validate they are
471  * empty and then free them through the gather.
472  */
clear_contig(const struct pt_state * start_pts,struct iommu_iotlb_gather * iotlb_gather,unsigned int step,unsigned int pgsize_lg2)473 static int clear_contig(const struct pt_state *start_pts,
474 			struct iommu_iotlb_gather *iotlb_gather,
475 			unsigned int step, unsigned int pgsize_lg2)
476 {
477 	struct pt_iommu *iommu_table =
478 		iommu_from_common(start_pts->range->common);
479 	struct pt_range range = *start_pts->range;
480 	struct pt_state pts =
481 		pt_init(&range, start_pts->level, start_pts->table);
482 	struct pt_iommu_collect_args collect = { .check_mapped = true };
483 	int ret;
484 
485 	pts.index = start_pts->index;
486 	pts.end_index = start_pts->index + step;
487 	for (; _pt_iter_load(&pts); pt_next_entry(&pts)) {
488 		if (pts.type == PT_ENTRY_TABLE) {
489 			collect.free_list =
490 				IOMMU_PAGES_LIST_INIT(collect.free_list);
491 			ret = pt_walk_descend_all(&pts, __collect_tables,
492 						  &collect);
493 			if (ret)
494 				return ret;
495 
496 			/*
497 			 * The table item must be cleared before we can update
498 			 * the gather
499 			 */
500 			pt_clear_entries(&pts, ilog2(1));
501 			flush_writes_item(&pts);
502 
503 			iommu_pages_list_add(&collect.free_list,
504 					     pt_table_ptr(&pts));
505 			gather_range_pages(
506 				iotlb_gather, iommu_table, range.va,
507 				log2_to_int(pt_table_item_lg2sz(&pts)),
508 				&collect.free_list);
509 		} else if (pts.type != PT_ENTRY_EMPTY) {
510 			return -EADDRINUSE;
511 		}
512 	}
513 	return 0;
514 }
515 
__map_range_leaf(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)516 static int __map_range_leaf(struct pt_range *range, void *arg,
517 			    unsigned int level, struct pt_table_p *table)
518 {
519 	struct pt_state pts = pt_init(range, level, table);
520 	struct pt_iommu_map_args *map = arg;
521 	unsigned int leaf_pgsize_lg2 = map->leaf_pgsize_lg2;
522 	unsigned int start_index;
523 	pt_oaddr_t oa = map->oa;
524 	unsigned int step;
525 	bool need_contig;
526 	int ret = 0;
527 
528 	PT_WARN_ON(map->leaf_level != level);
529 	PT_WARN_ON(!pt_can_have_leaf(&pts));
530 
531 	step = log2_to_int_t(unsigned int,
532 			     leaf_pgsize_lg2 - pt_table_item_lg2sz(&pts));
533 	need_contig = leaf_pgsize_lg2 != pt_table_item_lg2sz(&pts);
534 
535 	_pt_iter_first(&pts);
536 	start_index = pts.index;
537 	do {
538 		pts.type = pt_load_entry_raw(&pts);
539 		if (pts.type != PT_ENTRY_EMPTY || need_contig) {
540 			if (pts.index != start_index)
541 				pt_index_to_va(&pts);
542 			ret = clear_contig(&pts, map->iotlb_gather, step,
543 					   leaf_pgsize_lg2);
544 			if (ret)
545 				break;
546 		}
547 
548 		if (IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)) {
549 			pt_index_to_va(&pts);
550 			PT_WARN_ON(compute_best_pgsize(&pts, oa) !=
551 				   leaf_pgsize_lg2);
552 		}
553 		pt_install_leaf_entry(&pts, oa, leaf_pgsize_lg2, &map->attrs);
554 
555 		oa += log2_to_int(leaf_pgsize_lg2);
556 		pts.index += step;
557 	} while (pts.index < pts.end_index);
558 
559 	flush_writes_range(&pts, start_index, pts.index);
560 
561 	map->oa = oa;
562 	return ret;
563 }
564 
__map_range(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)565 static int __map_range(struct pt_range *range, void *arg, unsigned int level,
566 		       struct pt_table_p *table)
567 {
568 	struct pt_state pts = pt_init(range, level, table);
569 	struct pt_iommu_map_args *map = arg;
570 	int ret;
571 
572 	PT_WARN_ON(map->leaf_level == level);
573 	PT_WARN_ON(!pt_can_have_table(&pts));
574 
575 	_pt_iter_first(&pts);
576 
577 	/* Descend to a child table */
578 	do {
579 		pts.type = pt_load_entry_raw(&pts);
580 
581 		if (pts.type != PT_ENTRY_TABLE) {
582 			if (pts.type != PT_ENTRY_EMPTY)
583 				return -EADDRINUSE;
584 			ret = pt_iommu_new_table(&pts, &map->attrs);
585 			if (ret) {
586 				/*
587 				 * Racing with another thread installing a table
588 				 */
589 				if (ret == -EAGAIN)
590 					continue;
591 				return ret;
592 			}
593 		} else {
594 			pts.table_lower = pt_table_ptr(&pts);
595 			/*
596 			 * Racing with a shared pt_iommu_new_table()? The other
597 			 * thread is still flushing the cache, so we have to
598 			 * also flush it to ensure that when our thread's map
599 			 * completes all the table items leading to our mapping
600 			 * are visible.
601 			 *
602 			 * This requires the pt_set_bit_release() to be a
603 			 * release of the cache flush so that this can acquire
604 			 * visibility at the iommu.
605 			 */
606 			if (pts_feature(&pts, PT_FEAT_DMA_INCOHERENT) &&
607 			    !pt_test_sw_bit_acquire(&pts,
608 						    SW_BIT_CACHE_FLUSH_DONE))
609 				flush_writes_item(&pts);
610 		}
611 
612 		/*
613 		 * The already present table can possibly be shared with another
614 		 * concurrent map.
615 		 */
616 		if (map->leaf_level == level - 1)
617 			ret = pt_descend(&pts, arg, __map_range_leaf);
618 		else
619 			ret = pt_descend(&pts, arg, __map_range);
620 		if (ret)
621 			return ret;
622 
623 		pts.index++;
624 		pt_index_to_va(&pts);
625 		if (pts.index >= pts.end_index)
626 			break;
627 	} while (true);
628 	return 0;
629 }
630 
631 /*
632  * Fast path for the easy case of mapping a 4k page to an already allocated
633  * table. This is a common workload. If it returns EAGAIN run the full algorithm
634  * instead.
635  */
__do_map_single_page(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table,pt_level_fn_t descend_fn)636 static __always_inline int __do_map_single_page(struct pt_range *range,
637 						void *arg, unsigned int level,
638 						struct pt_table_p *table,
639 						pt_level_fn_t descend_fn)
640 {
641 	struct pt_state pts = pt_init(range, level, table);
642 	struct pt_iommu_map_args *map = arg;
643 
644 	pts.type = pt_load_single_entry(&pts);
645 	if (level == 0) {
646 		if (pts.type != PT_ENTRY_EMPTY)
647 			return -EADDRINUSE;
648 		pt_install_leaf_entry(&pts, map->oa, PAGE_SHIFT,
649 				      &map->attrs);
650 		/* No flush, not used when incoherent */
651 		map->oa += PAGE_SIZE;
652 		return 0;
653 	}
654 	if (pts.type == PT_ENTRY_TABLE)
655 		return pt_descend(&pts, arg, descend_fn);
656 	/* Something else, use the slow path */
657 	return -EAGAIN;
658 }
659 PT_MAKE_LEVELS(__map_single_page, __do_map_single_page);
660 
661 /*
662  * Add a table to the top, increasing the top level as much as necessary to
663  * encompass range.
664  */
increase_top(struct pt_iommu * iommu_table,struct pt_range * range,struct pt_iommu_map_args * map)665 static int increase_top(struct pt_iommu *iommu_table, struct pt_range *range,
666 			struct pt_iommu_map_args *map)
667 {
668 	struct iommu_pages_list free_list = IOMMU_PAGES_LIST_INIT(free_list);
669 	struct pt_common *common = common_from_iommu(iommu_table);
670 	uintptr_t top_of_table = READ_ONCE(common->top_of_table);
671 	uintptr_t new_top_of_table = top_of_table;
672 	struct pt_table_p *table_mem;
673 	unsigned int new_level;
674 	spinlock_t *domain_lock;
675 	unsigned long flags;
676 	int ret;
677 
678 	while (true) {
679 		struct pt_range top_range =
680 			_pt_top_range(common, new_top_of_table);
681 		struct pt_state pts = pt_init_top(&top_range);
682 
683 		top_range.va = range->va;
684 		top_range.last_va = range->last_va;
685 
686 		if (!pt_check_range(&top_range) &&
687 		    map->leaf_level <= pts.level) {
688 			new_level = pts.level;
689 			break;
690 		}
691 
692 		pts.level++;
693 		if (pts.level > PT_MAX_TOP_LEVEL ||
694 		    pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2) {
695 			ret = -ERANGE;
696 			goto err_free;
697 		}
698 
699 		table_mem =
700 			table_alloc_top(common, _pt_top_set(NULL, pts.level),
701 					map->attrs.gfp, ALLOC_DEFER_COHERENT_FLUSH);
702 		if (IS_ERR(table_mem)) {
703 			ret = PTR_ERR(table_mem);
704 			goto err_free;
705 		}
706 		iommu_pages_list_add(&free_list, table_mem);
707 
708 		/* The new table links to the lower table always at index 0 */
709 		top_range.va = 0;
710 		top_range.top_level = pts.level;
711 		pts.table_lower = pts.table;
712 		pts.table = table_mem;
713 		pt_load_single_entry(&pts);
714 		PT_WARN_ON(pts.index != 0);
715 		pt_install_table(&pts, virt_to_phys(pts.table_lower),
716 				 &map->attrs);
717 		new_top_of_table = _pt_top_set(pts.table, pts.level);
718 	}
719 
720 	/*
721 	 * Avoid double flushing, flush it once after all pt_install_table()
722 	 */
723 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT)) {
724 		ret = iommu_pages_start_incoherent_list(
725 			&free_list, iommu_table->iommu_device);
726 		if (ret)
727 			goto err_free;
728 	}
729 
730 	/*
731 	 * top_of_table is write locked by the spinlock, but readers can use
732 	 * READ_ONCE() to get the value. Since we encode both the level and the
733 	 * pointer in one quanta the lockless reader will always see something
734 	 * valid. The HW must be updated to the new level under the spinlock
735 	 * before top_of_table is updated so that concurrent readers don't map
736 	 * into the new level until it is fully functional. If another thread
737 	 * already updated it while we were working then throw everything away
738 	 * and try again.
739 	 */
740 	domain_lock = iommu_table->driver_ops->get_top_lock(iommu_table);
741 	spin_lock_irqsave(domain_lock, flags);
742 	if (common->top_of_table != top_of_table ||
743 	    top_of_table == new_top_of_table) {
744 		spin_unlock_irqrestore(domain_lock, flags);
745 		ret = -EAGAIN;
746 		goto err_free;
747 	}
748 
749 	/*
750 	 * We do not issue any flushes for change_top on the expectation that
751 	 * any walk cache will not become a problem by adding another layer to
752 	 * the tree. Misses will rewalk from the updated top pointer, hits
753 	 * continue to be correct. Negative caching is fine too since all the
754 	 * new IOVA added by the new top is non-present.
755 	 */
756 	iommu_table->driver_ops->change_top(
757 		iommu_table, virt_to_phys(table_mem), new_level);
758 	WRITE_ONCE(common->top_of_table, new_top_of_table);
759 	spin_unlock_irqrestore(domain_lock, flags);
760 	return 0;
761 
762 err_free:
763 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
764 		iommu_pages_stop_incoherent_list(&free_list,
765 						 iommu_table->iommu_device);
766 	iommu_put_pages_list(&free_list);
767 	return ret;
768 }
769 
check_map_range(struct pt_iommu * iommu_table,struct pt_range * range,struct pt_iommu_map_args * map)770 static int check_map_range(struct pt_iommu *iommu_table, struct pt_range *range,
771 			   struct pt_iommu_map_args *map)
772 {
773 	struct pt_common *common = common_from_iommu(iommu_table);
774 	int ret;
775 
776 	do {
777 		ret = pt_check_range(range);
778 		if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP))
779 			return ret;
780 
781 		if (!ret && map->leaf_level <= range->top_level)
782 			break;
783 
784 		ret = increase_top(iommu_table, range, map);
785 		if (ret && ret != -EAGAIN)
786 			return ret;
787 
788 		/* Reload the new top */
789 		*range = pt_make_range(common, range->va, range->last_va);
790 	} while (ret);
791 	PT_WARN_ON(pt_check_range(range));
792 	return 0;
793 }
794 
do_map(struct pt_range * range,struct pt_common * common,bool single_page,struct pt_iommu_map_args * map)795 static int do_map(struct pt_range *range, struct pt_common *common,
796 		  bool single_page, struct pt_iommu_map_args *map)
797 {
798 	/*
799 	 * The __map_single_page() fast path does not support DMA_INCOHERENT
800 	 * flushing to keep its .text small.
801 	 */
802 	if (single_page && !pt_feature(common, PT_FEAT_DMA_INCOHERENT)) {
803 		int ret;
804 
805 		ret = pt_walk_range(range, __map_single_page, map);
806 		if (ret != -EAGAIN)
807 			return ret;
808 		/* EAGAIN falls through to the full path */
809 	}
810 
811 	if (map->leaf_level == range->top_level)
812 		return pt_walk_range(range, __map_range_leaf, map);
813 	return pt_walk_range(range, __map_range, map);
814 }
815 
816 /**
817  * map_pages() - Install translation for an IOVA range
818  * @domain: Domain to manipulate
819  * @iova: IO virtual address to start
820  * @paddr: Physical/Output address to start
821  * @pgsize: Length of each page
822  * @pgcount: Length of the range in pgsize units starting from @iova
823  * @prot: A bitmap of IOMMU_READ/WRITE/CACHE/NOEXEC/MMIO
824  * @gfp: GFP flags for any memory allocations
825  * @mapped: Total bytes successfully mapped
826  *
827  * The range starting at IOVA will have paddr installed into it. The caller
828  * must specify a valid pgsize and pgcount to segment the range into compatible
829  * blocks.
830  *
831  * On error the caller will probably want to invoke unmap on the range from iova
832  * up to the amount indicated by @mapped to return the table back to an
833  * unchanged state.
834  *
835  * Context: The caller must hold a write range lock that includes the whole
836  * range.
837  *
838  * Returns: -ERRNO on failure, 0 on success. The number of bytes of VA that were
839  * mapped are added to @mapped, @mapped is not zerod first.
840  */
DOMAIN_NS(map_pages)841 int DOMAIN_NS(map_pages)(struct iommu_domain *domain, unsigned long iova,
842 			 phys_addr_t paddr, size_t pgsize, size_t pgcount,
843 			 int prot, gfp_t gfp, size_t *mapped)
844 {
845 	struct pt_iommu *iommu_table =
846 		container_of(domain, struct pt_iommu, domain);
847 	pt_vaddr_t pgsize_bitmap = iommu_table->domain.pgsize_bitmap;
848 	struct pt_common *common = common_from_iommu(iommu_table);
849 	struct iommu_iotlb_gather iotlb_gather;
850 	pt_vaddr_t len = pgsize * pgcount;
851 	struct pt_iommu_map_args map = {
852 		.iotlb_gather = &iotlb_gather,
853 		.oa = paddr,
854 		.leaf_pgsize_lg2 = vaffs(pgsize),
855 	};
856 	bool single_page = false;
857 	struct pt_range range;
858 	int ret;
859 
860 	iommu_iotlb_gather_init(&iotlb_gather);
861 
862 	if (WARN_ON(!(prot & (IOMMU_READ | IOMMU_WRITE))))
863 		return -EINVAL;
864 
865 	/* Check the paddr doesn't exceed what the table can store */
866 	if ((sizeof(pt_oaddr_t) < sizeof(paddr) &&
867 	     (pt_vaddr_t)paddr > PT_VADDR_MAX) ||
868 	    (common->max_oasz_lg2 != PT_VADDR_MAX_LG2 &&
869 	     oalog2_div(paddr, common->max_oasz_lg2)))
870 		return -ERANGE;
871 
872 	ret = pt_iommu_set_prot(common, &map.attrs, prot);
873 	if (ret)
874 		return ret;
875 	map.attrs.gfp = gfp;
876 
877 	ret = make_range_no_check(common, &range, iova, len);
878 	if (ret)
879 		return ret;
880 
881 	/* Calculate target page size and level for the leaves */
882 	if (pt_has_system_page_size(common) && pgsize == PAGE_SIZE &&
883 	    pgcount == 1) {
884 		PT_WARN_ON(!(pgsize_bitmap & PAGE_SIZE));
885 		if (log2_mod(iova | paddr, PAGE_SHIFT))
886 			return -ENXIO;
887 		map.leaf_pgsize_lg2 = PAGE_SHIFT;
888 		map.leaf_level = 0;
889 		single_page = true;
890 	} else {
891 		map.leaf_pgsize_lg2 = pt_compute_best_pgsize(
892 			pgsize_bitmap, range.va, range.last_va, paddr);
893 		if (!map.leaf_pgsize_lg2)
894 			return -ENXIO;
895 		map.leaf_level =
896 			pt_pgsz_lg2_to_level(common, map.leaf_pgsize_lg2);
897 	}
898 
899 	ret = check_map_range(iommu_table, &range, &map);
900 	if (ret)
901 		return ret;
902 
903 	PT_WARN_ON(map.leaf_level > range.top_level);
904 
905 	ret = do_map(&range, common, single_page, &map);
906 
907 	/*
908 	 * Table levels were freed and replaced with large items, flush any walk
909 	 * cache that may refer to the freed levels.
910 	 */
911 	if (!iommu_pages_list_empty(&iotlb_gather.freelist))
912 		iommu_iotlb_sync(&iommu_table->domain, &iotlb_gather);
913 
914 	/* Bytes successfully mapped */
915 	PT_WARN_ON(!ret && map.oa - paddr != len);
916 	*mapped += map.oa - paddr;
917 	return ret;
918 }
919 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(map_pages), "GENERIC_PT_IOMMU");
920 
921 struct pt_unmap_args {
922 	struct iommu_pages_list free_list;
923 	pt_vaddr_t unmapped;
924 };
925 
__unmap_range(struct pt_range * range,void * arg,unsigned int level,struct pt_table_p * table)926 static __maybe_unused int __unmap_range(struct pt_range *range, void *arg,
927 					unsigned int level,
928 					struct pt_table_p *table)
929 {
930 	struct pt_state pts = pt_init(range, level, table);
931 	struct pt_unmap_args *unmap = arg;
932 	unsigned int num_oas = 0;
933 	unsigned int start_index;
934 	int ret = 0;
935 
936 	_pt_iter_first(&pts);
937 	start_index = pts.index;
938 	pts.type = pt_load_entry_raw(&pts);
939 	/*
940 	 * A starting index is in the middle of a contiguous entry
941 	 *
942 	 * The IOMMU API does not require drivers to support unmapping parts of
943 	 * large pages. Long ago VFIO would try to split maps but the current
944 	 * version never does.
945 	 *
946 	 * Instead when unmap reaches a partial unmap of the start of a large
947 	 * IOPTE it should remove the entire IOPTE and return that size to the
948 	 * caller.
949 	 */
950 	if (pts.type == PT_ENTRY_OA) {
951 		if (log2_mod(range->va, pt_entry_oa_lg2sz(&pts)))
952 			return -EINVAL;
953 		/* Micro optimization */
954 		goto start_oa;
955 	}
956 
957 	do {
958 		if (pts.type != PT_ENTRY_OA) {
959 			bool fully_covered;
960 
961 			if (pts.type != PT_ENTRY_TABLE) {
962 				ret = -EINVAL;
963 				break;
964 			}
965 
966 			if (pts.index != start_index)
967 				pt_index_to_va(&pts);
968 			pts.table_lower = pt_table_ptr(&pts);
969 
970 			fully_covered = pt_entry_fully_covered(
971 				&pts, pt_table_item_lg2sz(&pts));
972 
973 			ret = pt_descend(&pts, arg, __unmap_range);
974 			if (ret)
975 				break;
976 
977 			/*
978 			 * If the unmapping range fully covers the table then we
979 			 * can free it as well. The clear is delayed until we
980 			 * succeed in clearing the lower table levels.
981 			 */
982 			if (fully_covered) {
983 				iommu_pages_list_add(&unmap->free_list,
984 						     pts.table_lower);
985 				pt_clear_entries(&pts, ilog2(1));
986 			}
987 			pts.index++;
988 		} else {
989 			unsigned int num_contig_lg2;
990 start_oa:
991 			/*
992 			 * If the caller requested an last that falls within a
993 			 * single entry then the entire entry is unmapped and
994 			 * the length returned will be larger than requested.
995 			 */
996 			num_contig_lg2 = pt_entry_num_contig_lg2(&pts);
997 			pt_clear_entries(&pts, num_contig_lg2);
998 			num_oas += log2_to_int(num_contig_lg2);
999 			pts.index += log2_to_int(num_contig_lg2);
1000 		}
1001 		if (pts.index >= pts.end_index)
1002 			break;
1003 		pts.type = pt_load_entry_raw(&pts);
1004 	} while (true);
1005 
1006 	unmap->unmapped += log2_mul(num_oas, pt_table_item_lg2sz(&pts));
1007 	flush_writes_range(&pts, start_index, pts.index);
1008 
1009 	return ret;
1010 }
1011 
1012 /**
1013  * unmap_pages() - Make a range of IOVA empty/not present
1014  * @domain: Domain to manipulate
1015  * @iova: IO virtual address to start
1016  * @pgsize: Length of each page
1017  * @pgcount: Length of the range in pgsize units starting from @iova
1018  * @iotlb_gather: Gather struct that must be flushed on return
1019  *
1020  * unmap_pages() will remove a translation created by map_pages(). It cannot
1021  * subdivide a mapping created by map_pages(), so it should be called with IOVA
1022  * ranges that match those passed to map_pages(). The IOVA range can aggregate
1023  * contiguous map_pages() calls so long as no individual range is split.
1024  *
1025  * Context: The caller must hold a write range lock that includes
1026  * the whole range.
1027  *
1028  * Returns: Number of bytes of VA unmapped. iova + res will be the point
1029  * unmapping stopped.
1030  */
DOMAIN_NS(unmap_pages)1031 size_t DOMAIN_NS(unmap_pages)(struct iommu_domain *domain, unsigned long iova,
1032 			      size_t pgsize, size_t pgcount,
1033 			      struct iommu_iotlb_gather *iotlb_gather)
1034 {
1035 	struct pt_iommu *iommu_table =
1036 		container_of(domain, struct pt_iommu, domain);
1037 	struct pt_unmap_args unmap = { .free_list = IOMMU_PAGES_LIST_INIT(
1038 					       unmap.free_list) };
1039 	pt_vaddr_t len = pgsize * pgcount;
1040 	struct pt_range range;
1041 	int ret;
1042 
1043 	ret = make_range(common_from_iommu(iommu_table), &range, iova, len);
1044 	if (ret)
1045 		return 0;
1046 
1047 	pt_walk_range(&range, __unmap_range, &unmap);
1048 
1049 	gather_range_pages(iotlb_gather, iommu_table, iova, len,
1050 			   &unmap.free_list);
1051 
1052 	return unmap.unmapped;
1053 }
1054 EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(unmap_pages), "GENERIC_PT_IOMMU");
1055 
NS(get_info)1056 static void NS(get_info)(struct pt_iommu *iommu_table,
1057 			 struct pt_iommu_info *info)
1058 {
1059 	struct pt_common *common = common_from_iommu(iommu_table);
1060 	struct pt_range range = pt_top_range(common);
1061 	struct pt_state pts = pt_init_top(&range);
1062 	pt_vaddr_t pgsize_bitmap = 0;
1063 
1064 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP)) {
1065 		for (pts.level = 0; pts.level <= PT_MAX_TOP_LEVEL;
1066 		     pts.level++) {
1067 			if (pt_table_item_lg2sz(&pts) >= common->max_vasz_lg2)
1068 				break;
1069 			pgsize_bitmap |= pt_possible_sizes(&pts);
1070 		}
1071 	} else {
1072 		for (pts.level = 0; pts.level <= range.top_level; pts.level++)
1073 			pgsize_bitmap |= pt_possible_sizes(&pts);
1074 	}
1075 
1076 	/* Hide page sizes larger than the maximum OA */
1077 	info->pgsize_bitmap = oalog2_mod(pgsize_bitmap, common->max_oasz_lg2);
1078 }
1079 
NS(deinit)1080 static void NS(deinit)(struct pt_iommu *iommu_table)
1081 {
1082 	struct pt_common *common = common_from_iommu(iommu_table);
1083 	struct pt_range range = pt_all_range(common);
1084 	struct pt_iommu_collect_args collect = {
1085 		.free_list = IOMMU_PAGES_LIST_INIT(collect.free_list),
1086 	};
1087 
1088 	iommu_pages_list_add(&collect.free_list, range.top_table);
1089 	pt_walk_range(&range, __collect_tables, &collect);
1090 
1091 	/*
1092 	 * The driver has to already have fenced the HW access to the page table
1093 	 * and invalidated any caching referring to this memory.
1094 	 */
1095 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT))
1096 		iommu_pages_stop_incoherent_list(&collect.free_list,
1097 						 iommu_table->iommu_device);
1098 	iommu_put_pages_list(&collect.free_list);
1099 }
1100 
1101 static const struct pt_iommu_ops NS(ops) = {
1102 #if IS_ENABLED(CONFIG_IOMMUFD_DRIVER) && defined(pt_entry_is_write_dirty) && \
1103 	IS_ENABLED(CONFIG_IOMMUFD_TEST) && defined(pt_entry_make_write_dirty)
1104 	.set_dirty = NS(set_dirty),
1105 #endif
1106 	.get_info = NS(get_info),
1107 	.deinit = NS(deinit),
1108 };
1109 
pt_init_common(struct pt_common * common)1110 static int pt_init_common(struct pt_common *common)
1111 {
1112 	struct pt_range top_range = pt_top_range(common);
1113 
1114 	if (PT_WARN_ON(top_range.top_level > PT_MAX_TOP_LEVEL))
1115 		return -EINVAL;
1116 
1117 	if (top_range.top_level == PT_MAX_TOP_LEVEL ||
1118 	    common->max_vasz_lg2 == top_range.max_vasz_lg2)
1119 		common->features &= ~BIT(PT_FEAT_DYNAMIC_TOP);
1120 
1121 	if (top_range.max_vasz_lg2 == PT_VADDR_MAX_LG2)
1122 		common->features |= BIT(PT_FEAT_FULL_VA);
1123 
1124 	/* Requested features must match features compiled into this format */
1125 	if ((common->features & ~(unsigned int)PT_SUPPORTED_FEATURES) ||
1126 	    (!IS_ENABLED(CONFIG_DEBUG_GENERIC_PT) &&
1127 	     (common->features & PT_FORCE_ENABLED_FEATURES) !=
1128 		     PT_FORCE_ENABLED_FEATURES))
1129 		return -EOPNOTSUPP;
1130 
1131 	/*
1132 	 * Check if the top level of the page table is too small to hold the
1133 	 * specified maxvasz.
1134 	 */
1135 	if (!pt_feature(common, PT_FEAT_DYNAMIC_TOP) &&
1136 	    top_range.top_level != PT_MAX_TOP_LEVEL) {
1137 		struct pt_state pts = { .range = &top_range,
1138 					.level = top_range.top_level };
1139 
1140 		if (common->max_vasz_lg2 >
1141 		    pt_num_items_lg2(&pts) + pt_table_item_lg2sz(&pts))
1142 			return -EOPNOTSUPP;
1143 	}
1144 
1145 	if (common->max_oasz_lg2 == 0)
1146 		common->max_oasz_lg2 = pt_max_oa_lg2(common);
1147 	else
1148 		common->max_oasz_lg2 = min(common->max_oasz_lg2,
1149 					   pt_max_oa_lg2(common));
1150 	return 0;
1151 }
1152 
pt_iommu_init_domain(struct pt_iommu * iommu_table,struct iommu_domain * domain)1153 static int pt_iommu_init_domain(struct pt_iommu *iommu_table,
1154 				struct iommu_domain *domain)
1155 {
1156 	struct pt_common *common = common_from_iommu(iommu_table);
1157 	struct pt_iommu_info info;
1158 	struct pt_range range;
1159 
1160 	NS(get_info)(iommu_table, &info);
1161 
1162 	domain->type = __IOMMU_DOMAIN_PAGING;
1163 	domain->pgsize_bitmap = info.pgsize_bitmap;
1164 
1165 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP))
1166 		range = _pt_top_range(common,
1167 				      _pt_top_set(NULL, PT_MAX_TOP_LEVEL));
1168 	else
1169 		range = pt_top_range(common);
1170 
1171 	/* A 64-bit high address space table on a 32-bit system cannot work. */
1172 	domain->geometry.aperture_start = (unsigned long)range.va;
1173 	if ((pt_vaddr_t)domain->geometry.aperture_start != range.va)
1174 		return -EOVERFLOW;
1175 
1176 	/*
1177 	 * The aperture is limited to what the API can do after considering all
1178 	 * the different types dma_addr_t/unsigned long/pt_vaddr_t that are used
1179 	 * to store a VA. Set the aperture to something that is valid for all
1180 	 * cases. Saturate instead of truncate the end if the types are smaller
1181 	 * than the top range. aperture_end should be called aperture_last.
1182 	 */
1183 	domain->geometry.aperture_end = (unsigned long)range.last_va;
1184 	if ((pt_vaddr_t)domain->geometry.aperture_end != range.last_va) {
1185 		domain->geometry.aperture_end = ULONG_MAX;
1186 		domain->pgsize_bitmap &= ULONG_MAX;
1187 	}
1188 	domain->geometry.force_aperture = true;
1189 
1190 	return 0;
1191 }
1192 
pt_iommu_zero(struct pt_iommu_table * fmt_table)1193 static void pt_iommu_zero(struct pt_iommu_table *fmt_table)
1194 {
1195 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1196 	struct pt_iommu cfg = *iommu_table;
1197 
1198 	static_assert(offsetof(struct pt_iommu_table, iommu.domain) == 0);
1199 	memset_after(fmt_table, 0, iommu.domain);
1200 
1201 	/* The caller can initialize some of these values */
1202 	iommu_table->iommu_device = cfg.iommu_device;
1203 	iommu_table->driver_ops = cfg.driver_ops;
1204 	iommu_table->nid = cfg.nid;
1205 }
1206 
1207 #define pt_iommu_table_cfg CONCATENATE(pt_iommu_table, _cfg)
1208 #define pt_iommu_init CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), init)
1209 
pt_iommu_init(struct pt_iommu_table * fmt_table,const struct pt_iommu_table_cfg * cfg,gfp_t gfp)1210 int pt_iommu_init(struct pt_iommu_table *fmt_table,
1211 		  const struct pt_iommu_table_cfg *cfg, gfp_t gfp)
1212 {
1213 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1214 	struct pt_common *common = common_from_iommu(iommu_table);
1215 	struct pt_table_p *table_mem;
1216 	int ret;
1217 
1218 	if (cfg->common.hw_max_vasz_lg2 > PT_MAX_VA_ADDRESS_LG2 ||
1219 	    !cfg->common.hw_max_vasz_lg2 || !cfg->common.hw_max_oasz_lg2)
1220 		return -EINVAL;
1221 
1222 	pt_iommu_zero(fmt_table);
1223 	common->features = cfg->common.features;
1224 	common->max_vasz_lg2 = cfg->common.hw_max_vasz_lg2;
1225 	common->max_oasz_lg2 = cfg->common.hw_max_oasz_lg2;
1226 	ret = pt_iommu_fmt_init(fmt_table, cfg);
1227 	if (ret)
1228 		return ret;
1229 
1230 	if (cfg->common.hw_max_oasz_lg2 > pt_max_oa_lg2(common))
1231 		return -EINVAL;
1232 
1233 	ret = pt_init_common(common);
1234 	if (ret)
1235 		return ret;
1236 
1237 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP) &&
1238 	    WARN_ON(!iommu_table->driver_ops ||
1239 		    !iommu_table->driver_ops->change_top ||
1240 		    !iommu_table->driver_ops->get_top_lock))
1241 		return -EINVAL;
1242 
1243 	if (pt_feature(common, PT_FEAT_SIGN_EXTEND) &&
1244 	    (pt_feature(common, PT_FEAT_FULL_VA) ||
1245 	     pt_feature(common, PT_FEAT_DYNAMIC_TOP)))
1246 		return -EINVAL;
1247 
1248 	if (pt_feature(common, PT_FEAT_DMA_INCOHERENT) &&
1249 	    WARN_ON(!iommu_table->iommu_device))
1250 		return -EINVAL;
1251 
1252 	ret = pt_iommu_init_domain(iommu_table, &iommu_table->domain);
1253 	if (ret)
1254 		return ret;
1255 
1256 	table_mem = table_alloc_top(common, common->top_of_table, gfp,
1257 				    ALLOC_NORMAL);
1258 	if (IS_ERR(table_mem))
1259 		return PTR_ERR(table_mem);
1260 	pt_top_set(common, table_mem, pt_top_get_level(common));
1261 
1262 	/* Must be last, see pt_iommu_deinit() */
1263 	iommu_table->ops = &NS(ops);
1264 	return 0;
1265 }
1266 EXPORT_SYMBOL_NS_GPL(pt_iommu_init, "GENERIC_PT_IOMMU");
1267 
1268 #ifdef pt_iommu_fmt_hw_info
1269 #define pt_iommu_table_hw_info CONCATENATE(pt_iommu_table, _hw_info)
1270 #define pt_iommu_hw_info CONCATENATE(CONCATENATE(pt_iommu_, PTPFX), hw_info)
pt_iommu_hw_info(struct pt_iommu_table * fmt_table,struct pt_iommu_table_hw_info * info)1271 void pt_iommu_hw_info(struct pt_iommu_table *fmt_table,
1272 		      struct pt_iommu_table_hw_info *info)
1273 {
1274 	struct pt_iommu *iommu_table = &fmt_table->iommu;
1275 	struct pt_common *common = common_from_iommu(iommu_table);
1276 	struct pt_range top_range = pt_top_range(common);
1277 
1278 	pt_iommu_fmt_hw_info(fmt_table, &top_range, info);
1279 }
1280 EXPORT_SYMBOL_NS_GPL(pt_iommu_hw_info, "GENERIC_PT_IOMMU");
1281 #endif
1282 
1283 MODULE_LICENSE("GPL");
1284 MODULE_DESCRIPTION("IOMMU Page table implementation for " __stringify(PTPFX_RAW));
1285 MODULE_IMPORT_NS("GENERIC_PT");
1286 /* For iommu_dirty_bitmap_record() */
1287 MODULE_IMPORT_NS("IOMMUFD");
1288 
1289 #endif  /* __GENERIC_PT_IOMMU_PT_H */
1290