1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (c) 2025, Microsoft Corporation.
4 *
5 * Memory region management for mshv_root module.
6 *
7 * Authors: Microsoft Linux virtualization team
8 */
9
10 #include <linux/hmm.h>
11 #include <linux/hyperv.h>
12 #include <linux/kref.h>
13 #include <linux/mm.h>
14 #include <linux/vmalloc.h>
15
16 #include <asm/mshyperv.h>
17
18 #include "mshv_root.h"
19
20 #define MSHV_MAP_FAULT_IN_PAGES PTRS_PER_PMD
21
22 /**
23 * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
24 * in a region.
25 * @region : Pointer to the memory region structure.
26 * @flags : Flags to pass to the handler.
27 * @page_offset: Offset into the region's pages array to start processing.
28 * @page_count : Number of pages to process.
29 * @handler : Callback function to handle the chunk.
30 *
31 * This function scans the region's pages starting from @page_offset,
32 * checking for contiguous present pages of the same size (normal or huge).
33 * It invokes @handler for the chunk of contiguous pages found. Returns the
34 * number of pages handled, or a negative error code if the first page is
35 * not present or the handler fails.
36 *
37 * Note: The @handler callback must be able to handle both normal and huge
38 * pages.
39 *
40 * Return: Number of pages handled, or negative error code.
41 */
mshv_region_process_chunk(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count))42 static long mshv_region_process_chunk(struct mshv_mem_region *region,
43 u32 flags,
44 u64 page_offset, u64 page_count,
45 int (*handler)(struct mshv_mem_region *region,
46 u32 flags,
47 u64 page_offset,
48 u64 page_count))
49 {
50 u64 count, stride;
51 unsigned int page_order;
52 struct page *page;
53 int ret;
54
55 page = region->pages[page_offset];
56 if (!page)
57 return -EINVAL;
58
59 page_order = folio_order(page_folio(page));
60 /* The hypervisor only supports 4K and 2M page sizes */
61 if (page_order && page_order != HPAGE_PMD_ORDER)
62 return -EINVAL;
63
64 stride = 1 << page_order;
65
66 /* Start at stride since the first page is validated */
67 for (count = stride; count < page_count; count += stride) {
68 page = region->pages[page_offset + count];
69
70 /* Break if current page is not present */
71 if (!page)
72 break;
73
74 /* Break if page size changes */
75 if (page_order != folio_order(page_folio(page)))
76 break;
77 }
78
79 ret = handler(region, flags, page_offset, count);
80 if (ret)
81 return ret;
82
83 return count;
84 }
85
86 /**
87 * mshv_region_process_range - Processes a range of memory pages in a
88 * region.
89 * @region : Pointer to the memory region structure.
90 * @flags : Flags to pass to the handler.
91 * @page_offset: Offset into the region's pages array to start processing.
92 * @page_count : Number of pages to process.
93 * @handler : Callback function to handle each chunk of contiguous
94 * pages.
95 *
96 * Iterates over the specified range of pages in @region, skipping
97 * non-present pages. For each contiguous chunk of present pages, invokes
98 * @handler via mshv_region_process_chunk.
99 *
100 * Note: The @handler callback must be able to handle both normal and huge
101 * pages.
102 *
103 * Returns 0 on success, or a negative error code on failure.
104 */
mshv_region_process_range(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count))105 static int mshv_region_process_range(struct mshv_mem_region *region,
106 u32 flags,
107 u64 page_offset, u64 page_count,
108 int (*handler)(struct mshv_mem_region *region,
109 u32 flags,
110 u64 page_offset,
111 u64 page_count))
112 {
113 long ret;
114
115 if (page_offset + page_count > region->nr_pages)
116 return -EINVAL;
117
118 while (page_count) {
119 /* Skip non-present pages */
120 if (!region->pages[page_offset]) {
121 page_offset++;
122 page_count--;
123 continue;
124 }
125
126 ret = mshv_region_process_chunk(region, flags,
127 page_offset,
128 page_count,
129 handler);
130 if (ret < 0)
131 return ret;
132
133 page_offset += ret;
134 page_count -= ret;
135 }
136
137 return 0;
138 }
139
mshv_region_create(u64 guest_pfn,u64 nr_pages,u64 uaddr,u32 flags)140 struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
141 u64 uaddr, u32 flags)
142 {
143 struct mshv_mem_region *region;
144
145 region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
146 if (!region)
147 return ERR_PTR(-ENOMEM);
148
149 region->nr_pages = nr_pages;
150 region->start_gfn = guest_pfn;
151 region->start_uaddr = uaddr;
152 region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
153 if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
154 region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
155 if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
156 region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
157
158 kref_init(®ion->refcount);
159
160 return region;
161 }
162
mshv_region_chunk_share(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)163 static int mshv_region_chunk_share(struct mshv_mem_region *region,
164 u32 flags,
165 u64 page_offset, u64 page_count)
166 {
167 struct page *page = region->pages[page_offset];
168
169 if (PageHuge(page) || PageTransCompound(page))
170 flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
171
172 return hv_call_modify_spa_host_access(region->partition->pt_id,
173 region->pages + page_offset,
174 page_count,
175 HV_MAP_GPA_READABLE |
176 HV_MAP_GPA_WRITABLE,
177 flags, true);
178 }
179
mshv_region_share(struct mshv_mem_region * region)180 int mshv_region_share(struct mshv_mem_region *region)
181 {
182 u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
183
184 return mshv_region_process_range(region, flags,
185 0, region->nr_pages,
186 mshv_region_chunk_share);
187 }
188
mshv_region_chunk_unshare(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)189 static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
190 u32 flags,
191 u64 page_offset, u64 page_count)
192 {
193 struct page *page = region->pages[page_offset];
194
195 if (PageHuge(page) || PageTransCompound(page))
196 flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
197
198 return hv_call_modify_spa_host_access(region->partition->pt_id,
199 region->pages + page_offset,
200 page_count, 0,
201 flags, false);
202 }
203
mshv_region_unshare(struct mshv_mem_region * region)204 int mshv_region_unshare(struct mshv_mem_region *region)
205 {
206 u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
207
208 return mshv_region_process_range(region, flags,
209 0, region->nr_pages,
210 mshv_region_chunk_unshare);
211 }
212
mshv_region_chunk_remap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)213 static int mshv_region_chunk_remap(struct mshv_mem_region *region,
214 u32 flags,
215 u64 page_offset, u64 page_count)
216 {
217 struct page *page = region->pages[page_offset];
218
219 if (PageHuge(page) || PageTransCompound(page))
220 flags |= HV_MAP_GPA_LARGE_PAGE;
221
222 return hv_call_map_gpa_pages(region->partition->pt_id,
223 region->start_gfn + page_offset,
224 page_count, flags,
225 region->pages + page_offset);
226 }
227
mshv_region_remap_pages(struct mshv_mem_region * region,u32 map_flags,u64 page_offset,u64 page_count)228 static int mshv_region_remap_pages(struct mshv_mem_region *region,
229 u32 map_flags,
230 u64 page_offset, u64 page_count)
231 {
232 return mshv_region_process_range(region, map_flags,
233 page_offset, page_count,
234 mshv_region_chunk_remap);
235 }
236
mshv_region_map(struct mshv_mem_region * region)237 int mshv_region_map(struct mshv_mem_region *region)
238 {
239 u32 map_flags = region->hv_map_flags;
240
241 return mshv_region_remap_pages(region, map_flags,
242 0, region->nr_pages);
243 }
244
mshv_region_invalidate_pages(struct mshv_mem_region * region,u64 page_offset,u64 page_count)245 static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
246 u64 page_offset, u64 page_count)
247 {
248 if (region->type == MSHV_REGION_TYPE_MEM_PINNED)
249 unpin_user_pages(region->pages + page_offset, page_count);
250
251 memset(region->pages + page_offset, 0,
252 page_count * sizeof(struct page *));
253 }
254
mshv_region_invalidate(struct mshv_mem_region * region)255 void mshv_region_invalidate(struct mshv_mem_region *region)
256 {
257 mshv_region_invalidate_pages(region, 0, region->nr_pages);
258 }
259
mshv_region_pin(struct mshv_mem_region * region)260 int mshv_region_pin(struct mshv_mem_region *region)
261 {
262 u64 done_count, nr_pages;
263 struct page **pages;
264 __u64 userspace_addr;
265 int ret;
266
267 for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
268 pages = region->pages + done_count;
269 userspace_addr = region->start_uaddr +
270 done_count * HV_HYP_PAGE_SIZE;
271 nr_pages = min(region->nr_pages - done_count,
272 MSHV_PIN_PAGES_BATCH_SIZE);
273
274 /*
275 * Pinning assuming 4k pages works for large pages too.
276 * All page structs within the large page are returned.
277 *
278 * Pin requests are batched because pin_user_pages_fast
279 * with the FOLL_LONGTERM flag does a large temporary
280 * allocation of contiguous memory.
281 */
282 ret = pin_user_pages_fast(userspace_addr, nr_pages,
283 FOLL_WRITE | FOLL_LONGTERM,
284 pages);
285 if (ret < 0)
286 goto release_pages;
287 }
288
289 return 0;
290
291 release_pages:
292 mshv_region_invalidate_pages(region, 0, done_count);
293 return ret;
294 }
295
mshv_region_chunk_unmap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)296 static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
297 u32 flags,
298 u64 page_offset, u64 page_count)
299 {
300 struct page *page = region->pages[page_offset];
301
302 if (PageHuge(page) || PageTransCompound(page))
303 flags |= HV_UNMAP_GPA_LARGE_PAGE;
304
305 return hv_call_unmap_gpa_pages(region->partition->pt_id,
306 region->start_gfn + page_offset,
307 page_count, flags);
308 }
309
mshv_region_unmap(struct mshv_mem_region * region)310 static int mshv_region_unmap(struct mshv_mem_region *region)
311 {
312 return mshv_region_process_range(region, 0,
313 0, region->nr_pages,
314 mshv_region_chunk_unmap);
315 }
316
mshv_region_destroy(struct kref * ref)317 static void mshv_region_destroy(struct kref *ref)
318 {
319 struct mshv_mem_region *region =
320 container_of(ref, struct mshv_mem_region, refcount);
321 struct mshv_partition *partition = region->partition;
322 int ret;
323
324 if (region->type == MSHV_REGION_TYPE_MEM_MOVABLE)
325 mshv_region_movable_fini(region);
326
327 if (mshv_partition_encrypted(partition)) {
328 ret = mshv_region_share(region);
329 if (ret) {
330 pt_err(partition,
331 "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
332 ret);
333 return;
334 }
335 }
336
337 mshv_region_unmap(region);
338
339 mshv_region_invalidate(region);
340
341 vfree(region);
342 }
343
mshv_region_put(struct mshv_mem_region * region)344 void mshv_region_put(struct mshv_mem_region *region)
345 {
346 kref_put(®ion->refcount, mshv_region_destroy);
347 }
348
mshv_region_get(struct mshv_mem_region * region)349 int mshv_region_get(struct mshv_mem_region *region)
350 {
351 return kref_get_unless_zero(®ion->refcount);
352 }
353
354 /**
355 * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
356 * @region: Pointer to the memory region structure
357 * @range: Pointer to the HMM range structure
358 *
359 * This function performs the following steps:
360 * 1. Reads the notifier sequence for the HMM range.
361 * 2. Acquires a read lock on the memory map.
362 * 3. Handles HMM faults for the specified range.
363 * 4. Releases the read lock on the memory map.
364 * 5. If successful, locks the memory region mutex.
365 * 6. Verifies if the notifier sequence has changed during the operation.
366 * If it has, releases the mutex and returns -EBUSY to match with
367 * hmm_range_fault() return code for repeating.
368 *
369 * Return: 0 on success, a negative error code otherwise.
370 */
mshv_region_hmm_fault_and_lock(struct mshv_mem_region * region,struct hmm_range * range)371 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
372 struct hmm_range *range)
373 {
374 int ret;
375
376 range->notifier_seq = mmu_interval_read_begin(range->notifier);
377 mmap_read_lock(region->mni.mm);
378 ret = hmm_range_fault(range);
379 mmap_read_unlock(region->mni.mm);
380 if (ret)
381 return ret;
382
383 mutex_lock(®ion->mutex);
384
385 if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
386 mutex_unlock(®ion->mutex);
387 cond_resched();
388 return -EBUSY;
389 }
390
391 return 0;
392 }
393
394 /**
395 * mshv_region_range_fault - Handle memory range faults for a given region.
396 * @region: Pointer to the memory region structure.
397 * @page_offset: Offset of the page within the region.
398 * @page_count: Number of pages to handle.
399 *
400 * This function resolves memory faults for a specified range of pages
401 * within a memory region. It uses HMM (Heterogeneous Memory Management)
402 * to fault in the required pages and updates the region's page array.
403 *
404 * Return: 0 on success, negative error code on failure.
405 */
mshv_region_range_fault(struct mshv_mem_region * region,u64 page_offset,u64 page_count)406 static int mshv_region_range_fault(struct mshv_mem_region *region,
407 u64 page_offset, u64 page_count)
408 {
409 struct hmm_range range = {
410 .notifier = ®ion->mni,
411 .default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
412 };
413 unsigned long *pfns;
414 int ret;
415 u64 i;
416
417 pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
418 if (!pfns)
419 return -ENOMEM;
420
421 range.hmm_pfns = pfns;
422 range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
423 range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
424
425 do {
426 ret = mshv_region_hmm_fault_and_lock(region, &range);
427 } while (ret == -EBUSY);
428
429 if (ret)
430 goto out;
431
432 for (i = 0; i < page_count; i++)
433 region->pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
434
435 ret = mshv_region_remap_pages(region, region->hv_map_flags,
436 page_offset, page_count);
437
438 mutex_unlock(®ion->mutex);
439 out:
440 kfree(pfns);
441 return ret;
442 }
443
mshv_region_handle_gfn_fault(struct mshv_mem_region * region,u64 gfn)444 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
445 {
446 u64 page_offset, page_count;
447 int ret;
448
449 /* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
450 page_offset = ALIGN_DOWN(gfn - region->start_gfn,
451 MSHV_MAP_FAULT_IN_PAGES);
452
453 /* Map more pages than requested to reduce the number of faults. */
454 page_count = min(region->nr_pages - page_offset,
455 MSHV_MAP_FAULT_IN_PAGES);
456
457 ret = mshv_region_range_fault(region, page_offset, page_count);
458
459 WARN_ONCE(ret,
460 "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
461 region->partition->pt_id, region->start_uaddr,
462 region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
463 gfn, page_offset, page_count);
464
465 return !ret;
466 }
467
468 /**
469 * mshv_region_interval_invalidate - Invalidate a range of memory region
470 * @mni: Pointer to the mmu_interval_notifier structure
471 * @range: Pointer to the mmu_notifier_range structure
472 * @cur_seq: Current sequence number for the interval notifier
473 *
474 * This function invalidates a memory region by remapping its pages with
475 * no access permissions. It locks the region's mutex to ensure thread safety
476 * and updates the sequence number for the interval notifier. If the range
477 * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
478 * lock and returns false if unsuccessful.
479 *
480 * NOTE: Failure to invalidate a region is a serious error, as the pages will
481 * be considered freed while they are still mapped by the hypervisor.
482 * Any attempt to access such pages will likely crash the system.
483 *
484 * Return: true if the region was successfully invalidated, false otherwise.
485 */
mshv_region_interval_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)486 static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
487 const struct mmu_notifier_range *range,
488 unsigned long cur_seq)
489 {
490 struct mshv_mem_region *region = container_of(mni,
491 struct mshv_mem_region,
492 mni);
493 u64 page_offset, page_count;
494 unsigned long mstart, mend;
495 int ret = -EPERM;
496
497 if (mmu_notifier_range_blockable(range))
498 mutex_lock(®ion->mutex);
499 else if (!mutex_trylock(®ion->mutex))
500 goto out_fail;
501
502 mmu_interval_set_seq(mni, cur_seq);
503
504 mstart = max(range->start, region->start_uaddr);
505 mend = min(range->end, region->start_uaddr +
506 (region->nr_pages << HV_HYP_PAGE_SHIFT));
507
508 page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
509 page_count = HVPFN_DOWN(mend - mstart);
510
511 ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
512 page_offset, page_count);
513 if (ret)
514 goto out_fail;
515
516 mshv_region_invalidate_pages(region, page_offset, page_count);
517
518 mutex_unlock(®ion->mutex);
519
520 return true;
521
522 out_fail:
523 WARN_ONCE(ret,
524 "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
525 region->start_uaddr,
526 region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
527 range->start, range->end, range->event,
528 page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
529 return false;
530 }
531
532 static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
533 .invalidate = mshv_region_interval_invalidate,
534 };
535
mshv_region_movable_fini(struct mshv_mem_region * region)536 void mshv_region_movable_fini(struct mshv_mem_region *region)
537 {
538 mmu_interval_notifier_remove(®ion->mni);
539 }
540
mshv_region_movable_init(struct mshv_mem_region * region)541 bool mshv_region_movable_init(struct mshv_mem_region *region)
542 {
543 int ret;
544
545 ret = mmu_interval_notifier_insert(®ion->mni, current->mm,
546 region->start_uaddr,
547 region->nr_pages << HV_HYP_PAGE_SHIFT,
548 &mshv_region_mni_ops);
549 if (ret)
550 return false;
551
552 mutex_init(®ion->mutex);
553
554 return true;
555 }
556