1 /* SPDX-License-Identifier: GPL-2.0-only */ 2 /* 3 * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES 4 * 5 * Iterators for Generic Page Table 6 */ 7 #ifndef __GENERIC_PT_PT_ITER_H 8 #define __GENERIC_PT_PT_ITER_H 9 10 #include "pt_common.h" 11 12 #include <linux/errno.h> 13 14 /* 15 * Use to mangle symbols so that backtraces and the symbol table are 16 * understandable. Any non-inlined function should get mangled like this. 17 */ 18 #define NS(fn) CONCATENATE(PTPFX, fn) 19 20 /** 21 * pt_check_range() - Validate the range can be iterated 22 * @range: Range to validate 23 * 24 * Check that VA and last_va fall within the permitted range of VAs. If the 25 * format is using PT_FEAT_SIGN_EXTEND then this also checks the sign extension 26 * is correct. 27 */ 28 static inline int pt_check_range(struct pt_range *range) 29 { 30 pt_vaddr_t prefix; 31 32 PT_WARN_ON(!range->max_vasz_lg2); 33 34 if (pt_feature(range->common, PT_FEAT_SIGN_EXTEND)) { 35 PT_WARN_ON(range->common->max_vasz_lg2 != range->max_vasz_lg2); 36 prefix = fvalog2_div(range->va, range->max_vasz_lg2 - 1) ? 37 PT_VADDR_MAX : 38 0; 39 } else { 40 prefix = pt_full_va_prefix(range->common); 41 } 42 43 if (!fvalog2_div_eq(range->va, prefix, range->max_vasz_lg2) || 44 !fvalog2_div_eq(range->last_va, prefix, range->max_vasz_lg2)) 45 return -ERANGE; 46 return 0; 47 } 48 49 /** 50 * pt_index_to_va() - Update range->va to the current pts->index 51 * @pts: Iteration State 52 * 53 * Adjust range->va to match the current index. This is done in a lazy manner 54 * since computing the VA takes several instructions and is rarely required. 55 */ 56 static inline void pt_index_to_va(struct pt_state *pts) 57 { 58 pt_vaddr_t lower_va; 59 60 lower_va = log2_mul(pts->index, pt_table_item_lg2sz(pts)); 61 pts->range->va = fvalog2_set_mod(pts->range->va, lower_va, 62 pt_table_oa_lg2sz(pts)); 63 } 64 65 /* 66 * Add index_count_lg2 number of entries to pts's VA and index. The VA will be 67 * adjusted to the end of the contiguous block if it is currently in the middle. 68 */ 69 static inline void _pt_advance(struct pt_state *pts, 70 unsigned int index_count_lg2) 71 { 72 pts->index = log2_set_mod(pts->index + log2_to_int(index_count_lg2), 0, 73 index_count_lg2); 74 } 75 76 /** 77 * pt_entry_fully_covered() - Check if the item or entry is entirely contained 78 * within pts->range 79 * @pts: Iteration State 80 * @oasz_lg2: The size of the item to check, pt_table_item_lg2sz() or 81 * pt_entry_oa_lg2sz() 82 * 83 * Returns: true if the item is fully enclosed by the pts->range. 84 */ 85 static inline bool pt_entry_fully_covered(const struct pt_state *pts, 86 unsigned int oasz_lg2) 87 { 88 struct pt_range *range = pts->range; 89 90 /* Range begins at the start of the entry */ 91 if (log2_mod(pts->range->va, oasz_lg2)) 92 return false; 93 94 /* Range ends past the end of the entry */ 95 if (!log2_div_eq(range->va, range->last_va, oasz_lg2)) 96 return true; 97 98 /* Range ends at the end of the entry */ 99 return log2_mod_eq_max(range->last_va, oasz_lg2); 100 } 101 102 /** 103 * pt_range_to_index() - Starting index for an iteration 104 * @pts: Iteration State 105 * 106 * Return: the starting index for the iteration in pts. 107 */ 108 static inline unsigned int pt_range_to_index(const struct pt_state *pts) 109 { 110 unsigned int isz_lg2 = pt_table_item_lg2sz(pts); 111 112 PT_WARN_ON(pts->level > pts->range->top_level); 113 if (pts->range->top_level == pts->level) 114 return log2_div(fvalog2_mod(pts->range->va, 115 pts->range->max_vasz_lg2), 116 isz_lg2); 117 return log2_mod(log2_div(pts->range->va, isz_lg2), 118 pt_num_items_lg2(pts)); 119 } 120 121 /** 122 * pt_range_to_end_index() - Ending index iteration 123 * @pts: Iteration State 124 * 125 * Return: the last index for the iteration in pts. 126 */ 127 static inline unsigned int pt_range_to_end_index(const struct pt_state *pts) 128 { 129 unsigned int isz_lg2 = pt_table_item_lg2sz(pts); 130 struct pt_range *range = pts->range; 131 unsigned int num_entries_lg2; 132 133 if (range->va == range->last_va) 134 return pts->index + 1; 135 136 if (pts->range->top_level == pts->level) 137 return log2_div(fvalog2_mod(pts->range->last_va, 138 pts->range->max_vasz_lg2), 139 isz_lg2) + 140 1; 141 142 num_entries_lg2 = pt_num_items_lg2(pts); 143 144 /* last_va falls within this table */ 145 if (log2_div_eq(range->va, range->last_va, num_entries_lg2 + isz_lg2)) 146 return log2_mod(log2_div(pts->range->last_va, isz_lg2), 147 num_entries_lg2) + 148 1; 149 150 return log2_to_int(num_entries_lg2); 151 } 152 153 static inline void _pt_iter_first(struct pt_state *pts) 154 { 155 pts->index = pt_range_to_index(pts); 156 pts->end_index = pt_range_to_end_index(pts); 157 PT_WARN_ON(pts->index > pts->end_index); 158 } 159 160 static inline bool _pt_iter_load(struct pt_state *pts) 161 { 162 if (pts->index >= pts->end_index) 163 return false; 164 pt_load_entry(pts); 165 return true; 166 } 167 168 /** 169 * pt_next_entry() - Advance pts to the next entry 170 * @pts: Iteration State 171 * 172 * Update pts to go to the next index at this level. If pts is pointing at a 173 * contiguous entry then the index may advance my more than one. 174 */ 175 static inline void pt_next_entry(struct pt_state *pts) 176 { 177 if (pts->type == PT_ENTRY_OA && 178 !__builtin_constant_p(pt_entry_num_contig_lg2(pts) == 0)) 179 _pt_advance(pts, pt_entry_num_contig_lg2(pts)); 180 else 181 pts->index++; 182 pt_index_to_va(pts); 183 } 184 185 /** 186 * for_each_pt_level_entry() - For loop wrapper over entries in the range 187 * @pts: Iteration State 188 * 189 * This is the basic iteration primitive. It iterates over all the entries in 190 * pts->range that fall within the pts's current table level. Each step does 191 * pt_load_entry(pts). 192 */ 193 #define for_each_pt_level_entry(pts) \ 194 for (_pt_iter_first(pts); _pt_iter_load(pts); pt_next_entry(pts)) 195 196 /** 197 * pt_load_single_entry() - Version of pt_load_entry() usable within a walker 198 * @pts: Iteration State 199 * 200 * Alternative to for_each_pt_level_entry() if the walker function uses only a 201 * single entry. 202 */ 203 static inline enum pt_entry_type pt_load_single_entry(struct pt_state *pts) 204 { 205 pts->index = pt_range_to_index(pts); 206 pt_load_entry(pts); 207 return pts->type; 208 } 209 210 static __always_inline struct pt_range _pt_top_range(struct pt_common *common, 211 uintptr_t top_of_table) 212 { 213 struct pt_range range = { 214 .common = common, 215 .top_table = 216 (struct pt_table_p *)(top_of_table & 217 ~(uintptr_t)PT_TOP_LEVEL_MASK), 218 .top_level = top_of_table % (1 << PT_TOP_LEVEL_BITS), 219 }; 220 struct pt_state pts = { .range = &range, .level = range.top_level }; 221 unsigned int max_vasz_lg2; 222 223 max_vasz_lg2 = common->max_vasz_lg2; 224 if (pt_feature(common, PT_FEAT_DYNAMIC_TOP) && 225 pts.level != PT_MAX_TOP_LEVEL) 226 max_vasz_lg2 = min_t(unsigned int, common->max_vasz_lg2, 227 pt_num_items_lg2(&pts) + 228 pt_table_item_lg2sz(&pts)); 229 230 /* 231 * The top range will default to the lower region only with sign extend. 232 */ 233 range.max_vasz_lg2 = max_vasz_lg2; 234 if (pt_feature(common, PT_FEAT_SIGN_EXTEND)) 235 max_vasz_lg2--; 236 237 range.va = fvalog2_set_mod(pt_full_va_prefix(common), 0, max_vasz_lg2); 238 range.last_va = 239 fvalog2_set_mod_max(pt_full_va_prefix(common), max_vasz_lg2); 240 return range; 241 } 242 243 /** 244 * pt_top_range() - Return a range that spans part of the top level 245 * @common: Table 246 * 247 * For PT_FEAT_SIGN_EXTEND this will return the lower range, and cover half the 248 * total page table. Otherwise it returns the entire page table. 249 */ 250 static __always_inline struct pt_range pt_top_range(struct pt_common *common) 251 { 252 /* 253 * The top pointer can change without locking. We capture the value and 254 * it's level here and are safe to walk it so long as both values are 255 * captured without tearing. 256 */ 257 return _pt_top_range(common, READ_ONCE(common->top_of_table)); 258 } 259 260 /** 261 * pt_all_range() - Return a range that spans the entire page table 262 * @common: Table 263 * 264 * The returned range spans the whole page table. Due to how PT_FEAT_SIGN_EXTEND 265 * is supported range->va and range->last_va will be incorrect during the 266 * iteration and must not be accessed. 267 */ 268 static inline struct pt_range pt_all_range(struct pt_common *common) 269 { 270 struct pt_range range = pt_top_range(common); 271 272 if (!pt_feature(common, PT_FEAT_SIGN_EXTEND)) 273 return range; 274 275 /* 276 * Pretend the table is linear from 0 without a sign extension. This 277 * generates the correct indexes for iteration. 278 */ 279 range.last_va = fvalog2_set_mod_max(0, range.max_vasz_lg2); 280 return range; 281 } 282 283 /** 284 * pt_upper_range() - Return a range that spans part of the top level 285 * @common: Table 286 * 287 * For PT_FEAT_SIGN_EXTEND this will return the upper range, and cover half the 288 * total page table. Otherwise it returns the entire page table. 289 */ 290 static inline struct pt_range pt_upper_range(struct pt_common *common) 291 { 292 struct pt_range range = pt_top_range(common); 293 294 if (!pt_feature(common, PT_FEAT_SIGN_EXTEND)) 295 return range; 296 297 range.va = fvalog2_set_mod(PT_VADDR_MAX, 0, range.max_vasz_lg2 - 1); 298 range.last_va = PT_VADDR_MAX; 299 return range; 300 } 301 302 /** 303 * pt_make_range() - Return a range that spans part of the table 304 * @common: Table 305 * @va: Start address 306 * @last_va: Last address 307 * 308 * The caller must validate the range with pt_check_range() before using it. 309 */ 310 static __always_inline struct pt_range 311 pt_make_range(struct pt_common *common, pt_vaddr_t va, pt_vaddr_t last_va) 312 { 313 struct pt_range range = 314 _pt_top_range(common, READ_ONCE(common->top_of_table)); 315 316 range.va = va; 317 range.last_va = last_va; 318 319 return range; 320 } 321 322 /* 323 * Span a slice of the table starting at a lower table level from an active 324 * walk. 325 */ 326 static __always_inline struct pt_range 327 pt_make_child_range(const struct pt_range *parent, pt_vaddr_t va, 328 pt_vaddr_t last_va) 329 { 330 struct pt_range range = *parent; 331 332 range.va = va; 333 range.last_va = last_va; 334 335 PT_WARN_ON(last_va < va); 336 PT_WARN_ON(pt_check_range(&range)); 337 338 return range; 339 } 340 341 /** 342 * pt_init() - Initialize a pt_state on the stack 343 * @range: Range pointer to embed in the state 344 * @level: Table level for the state 345 * @table: Pointer to the table memory at level 346 * 347 * Helper to initialize the on-stack pt_state from walker arguments. 348 */ 349 static __always_inline struct pt_state 350 pt_init(struct pt_range *range, unsigned int level, struct pt_table_p *table) 351 { 352 struct pt_state pts = { 353 .range = range, 354 .table = table, 355 .level = level, 356 }; 357 return pts; 358 } 359 360 /** 361 * pt_init_top() - Initialize a pt_state on the stack 362 * @range: Range pointer to embed in the state 363 * 364 * The pt_state points to the top most level. 365 */ 366 static __always_inline struct pt_state pt_init_top(struct pt_range *range) 367 { 368 return pt_init(range, range->top_level, range->top_table); 369 } 370 371 typedef int (*pt_level_fn_t)(struct pt_range *range, void *arg, 372 unsigned int level, struct pt_table_p *table); 373 374 /** 375 * pt_descend() - Recursively invoke the walker for the lower level 376 * @pts: Iteration State 377 * @arg: Value to pass to the function 378 * @fn: Walker function to call 379 * 380 * pts must point to a table item. Invoke fn as a walker on the table 381 * pts points to. 382 */ 383 static __always_inline int pt_descend(struct pt_state *pts, void *arg, 384 pt_level_fn_t fn) 385 { 386 int ret; 387 388 if (PT_WARN_ON(!pts->table_lower)) 389 return -EINVAL; 390 391 ret = (*fn)(pts->range, arg, pts->level - 1, pts->table_lower); 392 return ret; 393 } 394 395 /** 396 * pt_walk_range() - Walk over a VA range 397 * @range: Range pointer 398 * @fn: Walker function to call 399 * @arg: Value to pass to the function 400 * 401 * Walk over a VA range. The caller should have done a validity check, at 402 * least calling pt_check_range(), when building range. The walk will 403 * start at the top most table. 404 */ 405 static __always_inline int pt_walk_range(struct pt_range *range, 406 pt_level_fn_t fn, void *arg) 407 { 408 return fn(range, arg, range->top_level, range->top_table); 409 } 410 411 /* 412 * pt_walk_descend() - Recursively invoke the walker for a slice of a lower 413 * level 414 * @pts: Iteration State 415 * @va: Start address 416 * @last_va: Last address 417 * @fn: Walker function to call 418 * @arg: Value to pass to the function 419 * 420 * With pts pointing at a table item this will descend and over a slice of the 421 * lower table. The caller must ensure that va/last_va are within the table 422 * item. This creates a new walk and does not alter pts or pts->range. 423 */ 424 static __always_inline int pt_walk_descend(const struct pt_state *pts, 425 pt_vaddr_t va, pt_vaddr_t last_va, 426 pt_level_fn_t fn, void *arg) 427 { 428 struct pt_range range = pt_make_child_range(pts->range, va, last_va); 429 430 if (PT_WARN_ON(!pt_can_have_table(pts)) || 431 PT_WARN_ON(!pts->table_lower)) 432 return -EINVAL; 433 434 return fn(&range, arg, pts->level - 1, pts->table_lower); 435 } 436 437 /* 438 * pt_walk_descend_all() - Recursively invoke the walker for a table item 439 * @parent_pts: Iteration State 440 * @fn: Walker function to call 441 * @arg: Value to pass to the function 442 * 443 * With pts pointing at a table item this will descend and over the entire lower 444 * table. This creates a new walk and does not alter pts or pts->range. 445 */ 446 static __always_inline int 447 pt_walk_descend_all(const struct pt_state *parent_pts, pt_level_fn_t fn, 448 void *arg) 449 { 450 unsigned int isz_lg2 = pt_table_item_lg2sz(parent_pts); 451 452 return pt_walk_descend(parent_pts, 453 log2_set_mod(parent_pts->range->va, 0, isz_lg2), 454 log2_set_mod_max(parent_pts->range->va, isz_lg2), 455 fn, arg); 456 } 457 458 /** 459 * pt_range_slice() - Return a range that spans indexes 460 * @pts: Iteration State 461 * @start_index: Starting index within pts 462 * @end_index: Ending index within pts 463 * 464 * Create a range than spans an index range of the current table level 465 * pt_state points at. 466 */ 467 static inline struct pt_range pt_range_slice(const struct pt_state *pts, 468 unsigned int start_index, 469 unsigned int end_index) 470 { 471 unsigned int table_lg2sz = pt_table_oa_lg2sz(pts); 472 pt_vaddr_t last_va; 473 pt_vaddr_t va; 474 475 va = fvalog2_set_mod(pts->range->va, 476 log2_mul(start_index, pt_table_item_lg2sz(pts)), 477 table_lg2sz); 478 last_va = fvalog2_set_mod( 479 pts->range->va, 480 log2_mul(end_index, pt_table_item_lg2sz(pts)) - 1, table_lg2sz); 481 return pt_make_child_range(pts->range, va, last_va); 482 } 483 484 /** 485 * pt_top_memsize_lg2() 486 * @common: Table 487 * @top_of_table: Top of table value from _pt_top_set() 488 * 489 * Compute the allocation size of the top table. For PT_FEAT_DYNAMIC_TOP this 490 * will compute the top size assuming the table will grow. 491 */ 492 static inline unsigned int pt_top_memsize_lg2(struct pt_common *common, 493 uintptr_t top_of_table) 494 { 495 struct pt_range range = _pt_top_range(common, top_of_table); 496 struct pt_state pts = pt_init_top(&range); 497 unsigned int num_items_lg2; 498 499 num_items_lg2 = common->max_vasz_lg2 - pt_table_item_lg2sz(&pts); 500 if (range.top_level != PT_MAX_TOP_LEVEL && 501 pt_feature(common, PT_FEAT_DYNAMIC_TOP)) 502 num_items_lg2 = min(num_items_lg2, pt_num_items_lg2(&pts)); 503 504 /* Round up the allocation size to the minimum alignment */ 505 return max(ffs_t(u64, PT_TOP_PHYS_MASK), 506 num_items_lg2 + ilog2(PT_ITEM_WORD_SIZE)); 507 } 508 509 /** 510 * pt_compute_best_pgsize() - Determine the best page size for leaf entries 511 * @pgsz_bitmap: Permitted page sizes 512 * @va: Starting virtual address for the leaf entry 513 * @last_va: Last virtual address for the leaf entry, sets the max page size 514 * @oa: Starting output address for the leaf entry 515 * 516 * Compute the largest page size for va, last_va, and oa together and return it 517 * in lg2. The largest page size depends on the format's supported page sizes at 518 * this level, and the relative alignment of the VA and OA addresses. 0 means 519 * the OA cannot be stored with the provided pgsz_bitmap. 520 */ 521 static inline unsigned int pt_compute_best_pgsize(pt_vaddr_t pgsz_bitmap, 522 pt_vaddr_t va, 523 pt_vaddr_t last_va, 524 pt_oaddr_t oa) 525 { 526 unsigned int best_pgsz_lg2; 527 unsigned int pgsz_lg2; 528 pt_vaddr_t len = last_va - va + 1; 529 pt_vaddr_t mask; 530 531 if (PT_WARN_ON(va >= last_va)) 532 return 0; 533 534 /* 535 * Given a VA/OA pair the best page size is the largest page size 536 * where: 537 * 538 * 1) VA and OA start at the page. Bitwise this is the count of least 539 * significant 0 bits. 540 * This also implies that last_va/oa has the same prefix as va/oa. 541 */ 542 mask = va | oa; 543 544 /* 545 * 2) The page size is not larger than the last_va (length). Since page 546 * sizes are always power of two this can't be larger than the 547 * largest power of two factor of the length. 548 */ 549 mask |= log2_to_int(vafls(len) - 1); 550 551 best_pgsz_lg2 = vaffs(mask); 552 553 /* Choose the highest bit <= best_pgsz_lg2 */ 554 if (best_pgsz_lg2 < PT_VADDR_MAX_LG2 - 1) 555 pgsz_bitmap = log2_mod(pgsz_bitmap, best_pgsz_lg2 + 1); 556 557 pgsz_lg2 = vafls(pgsz_bitmap); 558 if (!pgsz_lg2) 559 return 0; 560 561 pgsz_lg2--; 562 563 PT_WARN_ON(log2_mod(va, pgsz_lg2) != 0); 564 PT_WARN_ON(oalog2_mod(oa, pgsz_lg2) != 0); 565 PT_WARN_ON(va + log2_to_int(pgsz_lg2) - 1 > last_va); 566 PT_WARN_ON(!log2_div_eq(va, va + log2_to_int(pgsz_lg2) - 1, pgsz_lg2)); 567 PT_WARN_ON( 568 !oalog2_div_eq(oa, oa + log2_to_int(pgsz_lg2) - 1, pgsz_lg2)); 569 return pgsz_lg2; 570 } 571 572 #define _PT_MAKE_CALL_LEVEL(fn) \ 573 static __always_inline int fn(struct pt_range *range, void *arg, \ 574 unsigned int level, \ 575 struct pt_table_p *table) \ 576 { \ 577 static_assert(PT_MAX_TOP_LEVEL <= 5); \ 578 if (level == 0) \ 579 return CONCATENATE(fn, 0)(range, arg, 0, table); \ 580 if (level == 1 || PT_MAX_TOP_LEVEL == 1) \ 581 return CONCATENATE(fn, 1)(range, arg, 1, table); \ 582 if (level == 2 || PT_MAX_TOP_LEVEL == 2) \ 583 return CONCATENATE(fn, 2)(range, arg, 2, table); \ 584 if (level == 3 || PT_MAX_TOP_LEVEL == 3) \ 585 return CONCATENATE(fn, 3)(range, arg, 3, table); \ 586 if (level == 4 || PT_MAX_TOP_LEVEL == 4) \ 587 return CONCATENATE(fn, 4)(range, arg, 4, table); \ 588 return CONCATENATE(fn, 5)(range, arg, 5, table); \ 589 } 590 591 static inline int __pt_make_level_fn_err(struct pt_range *range, void *arg, 592 unsigned int unused_level, 593 struct pt_table_p *table) 594 { 595 static_assert(PT_MAX_TOP_LEVEL <= 5); 596 return -EPROTOTYPE; 597 } 598 599 #define __PT_MAKE_LEVEL_FN(fn, level, descend_fn, do_fn) \ 600 static inline int fn(struct pt_range *range, void *arg, \ 601 unsigned int unused_level, \ 602 struct pt_table_p *table) \ 603 { \ 604 return do_fn(range, arg, level, table, descend_fn); \ 605 } 606 607 /** 608 * PT_MAKE_LEVELS() - Build an unwound walker 609 * @fn: Name of the walker function 610 * @do_fn: Function to call at each level 611 * 612 * This builds a function call tree that can be fully inlined. 613 * The caller must provide a function body in an __always_inline function:: 614 * 615 * static __always_inline int do_fn(struct pt_range *range, void *arg, 616 * unsigned int level, struct pt_table_p *table, 617 * pt_level_fn_t descend_fn) 618 * 619 * An inline function will be created for each table level that calls do_fn with 620 * a compile time constant for level and a pointer to the next lower function. 621 * This generates an optimally inlined walk where each of the functions sees a 622 * constant level and can codegen the exact constants/etc for that level. 623 * 624 * Note this can produce a lot of code! 625 */ 626 #define PT_MAKE_LEVELS(fn, do_fn) \ 627 __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 0), 0, __pt_make_level_fn_err, \ 628 do_fn); \ 629 __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 1), 1, CONCATENATE(fn, 0), do_fn); \ 630 __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 2), 2, CONCATENATE(fn, 1), do_fn); \ 631 __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 3), 3, CONCATENATE(fn, 2), do_fn); \ 632 __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 4), 4, CONCATENATE(fn, 3), do_fn); \ 633 __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 5), 5, CONCATENATE(fn, 4), do_fn); \ 634 _PT_MAKE_CALL_LEVEL(fn) 635 636 #endif 637