1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (c) 2025, Microsoft Corporation.
4 *
5 * Memory region management for mshv_root module.
6 *
7 * Authors: Microsoft Linux virtualization team
8 */
9
10 #include <linux/hmm.h>
11 #include <linux/hyperv.h>
12 #include <linux/kref.h>
13 #include <linux/mm.h>
14 #include <linux/vmalloc.h>
15
16 #include <asm/mshyperv.h>
17
18 #include "mshv_root.h"
19
20 #define MSHV_MAP_FAULT_IN_PAGES PTRS_PER_PMD
21
22 /**
23 * mshv_chunk_stride - Compute stride for mapping guest memory
24 * @page : The page to check for huge page backing
25 * @gfn : Guest frame number for the mapping
26 * @page_count: Total number of pages in the mapping
27 *
28 * Determines the appropriate stride (in pages) for mapping guest memory.
29 * Uses huge page stride if the backing page is huge and the guest mapping
30 * is properly aligned; otherwise falls back to single page stride.
31 *
32 * Return: Stride in pages, or -EINVAL if page order is unsupported.
33 */
mshv_chunk_stride(struct page * page,u64 gfn,u64 page_count)34 static int mshv_chunk_stride(struct page *page,
35 u64 gfn, u64 page_count)
36 {
37 unsigned int page_order;
38
39 /*
40 * Use single page stride by default. For huge page stride, the
41 * page must be compound and point to the head of the compound
42 * page, and both gfn and page_count must be huge-page aligned.
43 */
44 if (!PageCompound(page) || !PageHead(page) ||
45 !IS_ALIGNED(gfn, PTRS_PER_PMD) ||
46 !IS_ALIGNED(page_count, PTRS_PER_PMD))
47 return 1;
48
49 page_order = folio_order(page_folio(page));
50 /* The hypervisor only supports 2M huge page */
51 if (page_order != PMD_ORDER)
52 return -EINVAL;
53
54 return 1 << page_order;
55 }
56
57 /**
58 * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
59 * in a region.
60 * @region : Pointer to the memory region structure.
61 * @flags : Flags to pass to the handler.
62 * @page_offset: Offset into the region's pages array to start processing.
63 * @page_count : Number of pages to process.
64 * @handler : Callback function to handle the chunk.
65 *
66 * This function scans the region's pages starting from @page_offset,
67 * checking for contiguous present pages of the same size (normal or huge).
68 * It invokes @handler for the chunk of contiguous pages found. Returns the
69 * number of pages handled, or a negative error code if the first page is
70 * not present or the handler fails.
71 *
72 * Note: The @handler callback must be able to handle both normal and huge
73 * pages.
74 *
75 * Return: Number of pages handled, or negative error code.
76 */
mshv_region_process_chunk(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page))77 static long mshv_region_process_chunk(struct mshv_mem_region *region,
78 u32 flags,
79 u64 page_offset, u64 page_count,
80 int (*handler)(struct mshv_mem_region *region,
81 u32 flags,
82 u64 page_offset,
83 u64 page_count,
84 bool huge_page))
85 {
86 u64 gfn = region->start_gfn + page_offset;
87 u64 count;
88 struct page *page;
89 int stride, ret;
90
91 page = region->pages[page_offset];
92 if (!page)
93 return -EINVAL;
94
95 stride = mshv_chunk_stride(page, gfn, page_count);
96 if (stride < 0)
97 return stride;
98
99 /* Start at stride since the first stride is validated */
100 for (count = stride; count < page_count; count += stride) {
101 page = region->pages[page_offset + count];
102
103 /* Break if current page is not present */
104 if (!page)
105 break;
106
107 /* Break if stride size changes */
108 if (stride != mshv_chunk_stride(page, gfn + count,
109 page_count - count))
110 break;
111 }
112
113 ret = handler(region, flags, page_offset, count, stride > 1);
114 if (ret)
115 return ret;
116
117 return count;
118 }
119
120 /**
121 * mshv_region_process_range - Processes a range of memory pages in a
122 * region.
123 * @region : Pointer to the memory region structure.
124 * @flags : Flags to pass to the handler.
125 * @page_offset: Offset into the region's pages array to start processing.
126 * @page_count : Number of pages to process.
127 * @handler : Callback function to handle each chunk of contiguous
128 * pages.
129 *
130 * Iterates over the specified range of pages in @region, skipping
131 * non-present pages. For each contiguous chunk of present pages, invokes
132 * @handler via mshv_region_process_chunk.
133 *
134 * Note: The @handler callback must be able to handle both normal and huge
135 * pages.
136 *
137 * Returns 0 on success, or a negative error code on failure.
138 */
mshv_region_process_range(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page))139 static int mshv_region_process_range(struct mshv_mem_region *region,
140 u32 flags,
141 u64 page_offset, u64 page_count,
142 int (*handler)(struct mshv_mem_region *region,
143 u32 flags,
144 u64 page_offset,
145 u64 page_count,
146 bool huge_page))
147 {
148 long ret;
149
150 if (page_offset + page_count > region->nr_pages)
151 return -EINVAL;
152
153 while (page_count) {
154 /* Skip non-present pages */
155 if (!region->pages[page_offset]) {
156 page_offset++;
157 page_count--;
158 continue;
159 }
160
161 ret = mshv_region_process_chunk(region, flags,
162 page_offset,
163 page_count,
164 handler);
165 if (ret < 0)
166 return ret;
167
168 page_offset += ret;
169 page_count -= ret;
170 }
171
172 return 0;
173 }
174
mshv_region_create(u64 guest_pfn,u64 nr_pages,u64 uaddr,u32 flags)175 struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
176 u64 uaddr, u32 flags)
177 {
178 struct mshv_mem_region *region;
179
180 region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
181 if (!region)
182 return ERR_PTR(-ENOMEM);
183
184 region->nr_pages = nr_pages;
185 region->start_gfn = guest_pfn;
186 region->start_uaddr = uaddr;
187 region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
188 if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
189 region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
190 if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
191 region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
192
193 kref_init(®ion->refcount);
194
195 return region;
196 }
197
mshv_region_chunk_share(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)198 static int mshv_region_chunk_share(struct mshv_mem_region *region,
199 u32 flags,
200 u64 page_offset, u64 page_count,
201 bool huge_page)
202 {
203 if (huge_page)
204 flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
205
206 return hv_call_modify_spa_host_access(region->partition->pt_id,
207 region->pages + page_offset,
208 page_count,
209 HV_MAP_GPA_READABLE |
210 HV_MAP_GPA_WRITABLE,
211 flags, true);
212 }
213
mshv_region_share(struct mshv_mem_region * region)214 int mshv_region_share(struct mshv_mem_region *region)
215 {
216 u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
217
218 return mshv_region_process_range(region, flags,
219 0, region->nr_pages,
220 mshv_region_chunk_share);
221 }
222
mshv_region_chunk_unshare(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)223 static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
224 u32 flags,
225 u64 page_offset, u64 page_count,
226 bool huge_page)
227 {
228 if (huge_page)
229 flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
230
231 return hv_call_modify_spa_host_access(region->partition->pt_id,
232 region->pages + page_offset,
233 page_count, 0,
234 flags, false);
235 }
236
mshv_region_unshare(struct mshv_mem_region * region)237 int mshv_region_unshare(struct mshv_mem_region *region)
238 {
239 u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
240
241 return mshv_region_process_range(region, flags,
242 0, region->nr_pages,
243 mshv_region_chunk_unshare);
244 }
245
mshv_region_chunk_remap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)246 static int mshv_region_chunk_remap(struct mshv_mem_region *region,
247 u32 flags,
248 u64 page_offset, u64 page_count,
249 bool huge_page)
250 {
251 if (huge_page)
252 flags |= HV_MAP_GPA_LARGE_PAGE;
253
254 return hv_call_map_gpa_pages(region->partition->pt_id,
255 region->start_gfn + page_offset,
256 page_count, flags,
257 region->pages + page_offset);
258 }
259
mshv_region_remap_pages(struct mshv_mem_region * region,u32 map_flags,u64 page_offset,u64 page_count)260 static int mshv_region_remap_pages(struct mshv_mem_region *region,
261 u32 map_flags,
262 u64 page_offset, u64 page_count)
263 {
264 return mshv_region_process_range(region, map_flags,
265 page_offset, page_count,
266 mshv_region_chunk_remap);
267 }
268
mshv_region_map(struct mshv_mem_region * region)269 int mshv_region_map(struct mshv_mem_region *region)
270 {
271 u32 map_flags = region->hv_map_flags;
272
273 return mshv_region_remap_pages(region, map_flags,
274 0, region->nr_pages);
275 }
276
mshv_region_invalidate_pages(struct mshv_mem_region * region,u64 page_offset,u64 page_count)277 static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
278 u64 page_offset, u64 page_count)
279 {
280 if (region->type == MSHV_REGION_TYPE_MEM_PINNED)
281 unpin_user_pages(region->pages + page_offset, page_count);
282
283 memset(region->pages + page_offset, 0,
284 page_count * sizeof(struct page *));
285 }
286
mshv_region_invalidate(struct mshv_mem_region * region)287 void mshv_region_invalidate(struct mshv_mem_region *region)
288 {
289 mshv_region_invalidate_pages(region, 0, region->nr_pages);
290 }
291
mshv_region_pin(struct mshv_mem_region * region)292 int mshv_region_pin(struct mshv_mem_region *region)
293 {
294 u64 done_count, nr_pages;
295 struct page **pages;
296 __u64 userspace_addr;
297 int ret;
298
299 for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
300 pages = region->pages + done_count;
301 userspace_addr = region->start_uaddr +
302 done_count * HV_HYP_PAGE_SIZE;
303 nr_pages = min(region->nr_pages - done_count,
304 MSHV_PIN_PAGES_BATCH_SIZE);
305
306 /*
307 * Pinning assuming 4k pages works for large pages too.
308 * All page structs within the large page are returned.
309 *
310 * Pin requests are batched because pin_user_pages_fast
311 * with the FOLL_LONGTERM flag does a large temporary
312 * allocation of contiguous memory.
313 */
314 ret = pin_user_pages_fast(userspace_addr, nr_pages,
315 FOLL_WRITE | FOLL_LONGTERM,
316 pages);
317 if (ret < 0)
318 goto release_pages;
319 }
320
321 return 0;
322
323 release_pages:
324 mshv_region_invalidate_pages(region, 0, done_count);
325 return ret;
326 }
327
mshv_region_chunk_unmap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,bool huge_page)328 static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
329 u32 flags,
330 u64 page_offset, u64 page_count,
331 bool huge_page)
332 {
333 if (huge_page)
334 flags |= HV_UNMAP_GPA_LARGE_PAGE;
335
336 return hv_call_unmap_gpa_pages(region->partition->pt_id,
337 region->start_gfn + page_offset,
338 page_count, flags);
339 }
340
mshv_region_unmap(struct mshv_mem_region * region)341 static int mshv_region_unmap(struct mshv_mem_region *region)
342 {
343 return mshv_region_process_range(region, 0,
344 0, region->nr_pages,
345 mshv_region_chunk_unmap);
346 }
347
mshv_region_destroy(struct kref * ref)348 static void mshv_region_destroy(struct kref *ref)
349 {
350 struct mshv_mem_region *region =
351 container_of(ref, struct mshv_mem_region, refcount);
352 struct mshv_partition *partition = region->partition;
353 int ret;
354
355 if (region->type == MSHV_REGION_TYPE_MEM_MOVABLE)
356 mshv_region_movable_fini(region);
357
358 if (mshv_partition_encrypted(partition)) {
359 ret = mshv_region_share(region);
360 if (ret) {
361 pt_err(partition,
362 "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
363 ret);
364 return;
365 }
366 }
367
368 mshv_region_unmap(region);
369
370 mshv_region_invalidate(region);
371
372 vfree(region);
373 }
374
mshv_region_put(struct mshv_mem_region * region)375 void mshv_region_put(struct mshv_mem_region *region)
376 {
377 kref_put(®ion->refcount, mshv_region_destroy);
378 }
379
mshv_region_get(struct mshv_mem_region * region)380 int mshv_region_get(struct mshv_mem_region *region)
381 {
382 return kref_get_unless_zero(®ion->refcount);
383 }
384
385 /**
386 * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
387 * @region: Pointer to the memory region structure
388 * @range: Pointer to the HMM range structure
389 *
390 * This function performs the following steps:
391 * 1. Reads the notifier sequence for the HMM range.
392 * 2. Acquires a read lock on the memory map.
393 * 3. Handles HMM faults for the specified range.
394 * 4. Releases the read lock on the memory map.
395 * 5. If successful, locks the memory region mutex.
396 * 6. Verifies if the notifier sequence has changed during the operation.
397 * If it has, releases the mutex and returns -EBUSY to match with
398 * hmm_range_fault() return code for repeating.
399 *
400 * Return: 0 on success, a negative error code otherwise.
401 */
mshv_region_hmm_fault_and_lock(struct mshv_mem_region * region,struct hmm_range * range)402 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
403 struct hmm_range *range)
404 {
405 int ret;
406
407 range->notifier_seq = mmu_interval_read_begin(range->notifier);
408 mmap_read_lock(region->mni.mm);
409 ret = hmm_range_fault(range);
410 mmap_read_unlock(region->mni.mm);
411 if (ret)
412 return ret;
413
414 mutex_lock(®ion->mutex);
415
416 if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
417 mutex_unlock(®ion->mutex);
418 cond_resched();
419 return -EBUSY;
420 }
421
422 return 0;
423 }
424
425 /**
426 * mshv_region_range_fault - Handle memory range faults for a given region.
427 * @region: Pointer to the memory region structure.
428 * @page_offset: Offset of the page within the region.
429 * @page_count: Number of pages to handle.
430 *
431 * This function resolves memory faults for a specified range of pages
432 * within a memory region. It uses HMM (Heterogeneous Memory Management)
433 * to fault in the required pages and updates the region's page array.
434 *
435 * Return: 0 on success, negative error code on failure.
436 */
mshv_region_range_fault(struct mshv_mem_region * region,u64 page_offset,u64 page_count)437 static int mshv_region_range_fault(struct mshv_mem_region *region,
438 u64 page_offset, u64 page_count)
439 {
440 struct hmm_range range = {
441 .notifier = ®ion->mni,
442 .default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
443 };
444 unsigned long *pfns;
445 int ret;
446 u64 i;
447
448 pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
449 if (!pfns)
450 return -ENOMEM;
451
452 range.hmm_pfns = pfns;
453 range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
454 range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
455
456 do {
457 ret = mshv_region_hmm_fault_and_lock(region, &range);
458 } while (ret == -EBUSY);
459
460 if (ret)
461 goto out;
462
463 for (i = 0; i < page_count; i++)
464 region->pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
465
466 ret = mshv_region_remap_pages(region, region->hv_map_flags,
467 page_offset, page_count);
468
469 mutex_unlock(®ion->mutex);
470 out:
471 kfree(pfns);
472 return ret;
473 }
474
mshv_region_handle_gfn_fault(struct mshv_mem_region * region,u64 gfn)475 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
476 {
477 u64 page_offset, page_count;
478 int ret;
479
480 /* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
481 page_offset = ALIGN_DOWN(gfn - region->start_gfn,
482 MSHV_MAP_FAULT_IN_PAGES);
483
484 /* Map more pages than requested to reduce the number of faults. */
485 page_count = min(region->nr_pages - page_offset,
486 MSHV_MAP_FAULT_IN_PAGES);
487
488 ret = mshv_region_range_fault(region, page_offset, page_count);
489
490 WARN_ONCE(ret,
491 "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
492 region->partition->pt_id, region->start_uaddr,
493 region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
494 gfn, page_offset, page_count);
495
496 return !ret;
497 }
498
499 /**
500 * mshv_region_interval_invalidate - Invalidate a range of memory region
501 * @mni: Pointer to the mmu_interval_notifier structure
502 * @range: Pointer to the mmu_notifier_range structure
503 * @cur_seq: Current sequence number for the interval notifier
504 *
505 * This function invalidates a memory region by remapping its pages with
506 * no access permissions. It locks the region's mutex to ensure thread safety
507 * and updates the sequence number for the interval notifier. If the range
508 * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
509 * lock and returns false if unsuccessful.
510 *
511 * NOTE: Failure to invalidate a region is a serious error, as the pages will
512 * be considered freed while they are still mapped by the hypervisor.
513 * Any attempt to access such pages will likely crash the system.
514 *
515 * Return: true if the region was successfully invalidated, false otherwise.
516 */
mshv_region_interval_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)517 static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
518 const struct mmu_notifier_range *range,
519 unsigned long cur_seq)
520 {
521 struct mshv_mem_region *region = container_of(mni,
522 struct mshv_mem_region,
523 mni);
524 u64 page_offset, page_count;
525 unsigned long mstart, mend;
526 int ret = -EPERM;
527
528 mstart = max(range->start, region->start_uaddr);
529 mend = min(range->end, region->start_uaddr +
530 (region->nr_pages << HV_HYP_PAGE_SHIFT));
531
532 page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
533 page_count = HVPFN_DOWN(mend - mstart);
534
535 if (mmu_notifier_range_blockable(range))
536 mutex_lock(®ion->mutex);
537 else if (!mutex_trylock(®ion->mutex))
538 goto out_fail;
539
540 mmu_interval_set_seq(mni, cur_seq);
541
542 ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
543 page_offset, page_count);
544 if (ret)
545 goto out_unlock;
546
547 mshv_region_invalidate_pages(region, page_offset, page_count);
548
549 mutex_unlock(®ion->mutex);
550
551 return true;
552
553 out_unlock:
554 mutex_unlock(®ion->mutex);
555 out_fail:
556 WARN_ONCE(ret,
557 "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
558 region->start_uaddr,
559 region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
560 range->start, range->end, range->event,
561 page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
562 return false;
563 }
564
565 static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
566 .invalidate = mshv_region_interval_invalidate,
567 };
568
mshv_region_movable_fini(struct mshv_mem_region * region)569 void mshv_region_movable_fini(struct mshv_mem_region *region)
570 {
571 mmu_interval_notifier_remove(®ion->mni);
572 }
573
mshv_region_movable_init(struct mshv_mem_region * region)574 bool mshv_region_movable_init(struct mshv_mem_region *region)
575 {
576 int ret;
577
578 ret = mmu_interval_notifier_insert(®ion->mni, current->mm,
579 region->start_uaddr,
580 region->nr_pages << HV_HYP_PAGE_SHIFT,
581 &mshv_region_mni_ops);
582 if (ret)
583 return false;
584
585 mutex_init(®ion->mutex);
586
587 return true;
588 }
589