xref: /linux/drivers/hv/mshv_regions.c (revision 98e0fc32e53dd62cd38a0d67eaf5846ae20078cc)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2025, Microsoft Corporation.
4  *
5  * Memory region management for mshv_root module.
6  *
7  * Authors: Microsoft Linux virtualization team
8  */
9 
10 #include <linux/hmm.h>
11 #include <linux/hyperv.h>
12 #include <linux/kref.h>
13 #include <linux/mm.h>
14 #include <linux/vmalloc.h>
15 
16 #include <asm/mshyperv.h>
17 
18 #include "mshv_root.h"
19 
20 #define MSHV_MAP_FAULT_IN_PAGES				PTRS_PER_PMD
21 
22 /**
23  * mshv_chunk_stride - Compute stride for mapping guest memory
24  * @page      : The page to check for huge page backing
25  * @gfn       : Guest frame number for the mapping
26  * @page_count: Total number of pages in the mapping
27  *
28  * Determines the appropriate stride (in pages) for mapping guest memory.
29  * Uses huge page stride if the backing page is huge and the guest mapping
30  * is properly aligned; otherwise falls back to single page stride.
31  *
32  * Return: Stride in pages.
33  */
34 static unsigned int mshv_chunk_stride(struct page *page, u64 gfn,
35 				      u64 page_count)
36 {
37 	unsigned int page_order = folio_order(page_folio(page));
38 
39 	/*
40 	 * Use single page stride by default. For huge page stride, the
41 	 * folio order must be at least PMD_ORDER, the page's PFN must be
42 	 * 2M-aligned (so that a 2M-aligned tail page of a larger folio is
43 	 * acceptable), and both gfn and page_count must be 2M-aligned.
44 	 */
45 	if (page_order < PMD_ORDER ||
46 	    !IS_ALIGNED(page_to_pfn(page), PTRS_PER_PMD) ||
47 	    !IS_ALIGNED(gfn, PTRS_PER_PMD) ||
48 	    !IS_ALIGNED(page_count, PTRS_PER_PMD))
49 		return 1;
50 
51 	/* Use 2M stride always i.e. process 1G folios as 2M chunks */
52 	return 1 << PMD_ORDER;
53 }
54 
55 /**
56  * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
57  *                             in a region.
58  * @region     : Pointer to the memory region structure.
59  * @flags      : Flags to pass to the handler.
60  * @page_offset: Offset into the region's pages array to start processing.
61  * @page_count : Number of pages to process.
62  * @handler    : Callback function to handle the chunk.
63  *
64  * This function scans the region's pages starting from @page_offset,
65  * checking for contiguous present pages of the same size (normal or huge).
66  * It invokes @handler for the chunk of contiguous pages found. Returns the
67  * number of pages handled, or a negative error code if the first page is
68  * not present or the handler fails.
69  *
70  * Note: The @handler callback must be able to handle both normal and huge
71  * pages.
72  *
73  * Return: Number of pages handled, or negative error code.
74  */
75 static long mshv_region_process_chunk(struct mshv_mem_region *region,
76 				      u32 flags,
77 				      u64 page_offset, u64 page_count,
78 				      int (*handler)(struct mshv_mem_region *region,
79 						     u32 flags,
80 						     u64 page_offset,
81 						     u64 page_count,
82 						     bool huge_page))
83 {
84 	u64 gfn = region->start_gfn + page_offset;
85 	u64 count;
86 	struct page *page;
87 	unsigned int stride;
88 	int ret;
89 
90 	page = region->mreg_pages[page_offset];
91 	if (!page)
92 		return -EINVAL;
93 
94 	stride = mshv_chunk_stride(page, gfn, page_count);
95 
96 	/* Start at stride since the first stride is validated */
97 	for (count = stride; count < page_count; count += stride) {
98 		page = region->mreg_pages[page_offset + count];
99 
100 		/* Break if current page is not present */
101 		if (!page)
102 			break;
103 
104 		/* Break if stride size changes */
105 		if (stride != mshv_chunk_stride(page, gfn + count,
106 						page_count - count))
107 			break;
108 	}
109 
110 	ret = handler(region, flags, page_offset, count, stride > 1);
111 	if (ret)
112 		return ret;
113 
114 	return count;
115 }
116 
117 /**
118  * mshv_region_process_range - Processes a range of memory pages in a
119  *                             region.
120  * @region     : Pointer to the memory region structure.
121  * @flags      : Flags to pass to the handler.
122  * @page_offset: Offset into the region's pages array to start processing.
123  * @page_count : Number of pages to process.
124  * @handler    : Callback function to handle each chunk of contiguous
125  *               pages.
126  *
127  * Iterates over the specified range of pages in @region, skipping
128  * non-present pages. For each contiguous chunk of present pages, invokes
129  * @handler via mshv_region_process_chunk.
130  *
131  * Note: The @handler callback must be able to handle both normal and huge
132  * pages.
133  *
134  * Returns 0 on success, or a negative error code on failure.
135  */
136 static int mshv_region_process_range(struct mshv_mem_region *region,
137 				     u32 flags,
138 				     u64 page_offset, u64 page_count,
139 				     int (*handler)(struct mshv_mem_region *region,
140 						    u32 flags,
141 						    u64 page_offset,
142 						    u64 page_count,
143 						    bool huge_page))
144 {
145 	long ret;
146 
147 	if (page_offset + page_count > region->nr_pages)
148 		return -EINVAL;
149 
150 	while (page_count) {
151 		/* Skip non-present pages */
152 		if (!region->mreg_pages[page_offset]) {
153 			page_offset++;
154 			page_count--;
155 			continue;
156 		}
157 
158 		ret = mshv_region_process_chunk(region, flags,
159 						page_offset,
160 						page_count,
161 						handler);
162 		if (ret < 0)
163 			return ret;
164 
165 		page_offset += ret;
166 		page_count -= ret;
167 	}
168 
169 	return 0;
170 }
171 
172 struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
173 					   u64 uaddr, u32 flags)
174 {
175 	struct mshv_mem_region *region;
176 
177 	region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
178 	if (!region)
179 		return ERR_PTR(-ENOMEM);
180 
181 	region->nr_pages = nr_pages;
182 	region->start_gfn = guest_pfn;
183 	region->start_uaddr = uaddr;
184 	region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
185 	if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
186 		region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
187 	if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
188 		region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
189 
190 	kref_init(&region->mreg_refcount);
191 
192 	return region;
193 }
194 
195 static int mshv_region_chunk_share(struct mshv_mem_region *region,
196 				   u32 flags,
197 				   u64 page_offset, u64 page_count,
198 				   bool huge_page)
199 {
200 	if (huge_page)
201 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
202 
203 	return hv_call_modify_spa_host_access(region->partition->pt_id,
204 					      region->mreg_pages + page_offset,
205 					      page_count,
206 					      HV_MAP_GPA_READABLE |
207 					      HV_MAP_GPA_WRITABLE,
208 					      flags, true);
209 }
210 
211 int mshv_region_share(struct mshv_mem_region *region)
212 {
213 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
214 
215 	return mshv_region_process_range(region, flags,
216 					 0, region->nr_pages,
217 					 mshv_region_chunk_share);
218 }
219 
220 static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
221 				     u32 flags,
222 				     u64 page_offset, u64 page_count,
223 				     bool huge_page)
224 {
225 	if (huge_page)
226 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
227 
228 	return hv_call_modify_spa_host_access(region->partition->pt_id,
229 					      region->mreg_pages + page_offset,
230 					      page_count, 0,
231 					      flags, false);
232 }
233 
234 int mshv_region_unshare(struct mshv_mem_region *region)
235 {
236 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
237 
238 	return mshv_region_process_range(region, flags,
239 					 0, region->nr_pages,
240 					 mshv_region_chunk_unshare);
241 }
242 
243 static int mshv_region_chunk_remap(struct mshv_mem_region *region,
244 				   u32 flags,
245 				   u64 page_offset, u64 page_count,
246 				   bool huge_page)
247 {
248 	if (huge_page)
249 		flags |= HV_MAP_GPA_LARGE_PAGE;
250 
251 	return hv_call_map_gpa_pages(region->partition->pt_id,
252 				     region->start_gfn + page_offset,
253 				     page_count, flags,
254 				     region->mreg_pages + page_offset);
255 }
256 
257 static int mshv_region_remap_pages(struct mshv_mem_region *region,
258 				   u32 map_flags,
259 				   u64 page_offset, u64 page_count)
260 {
261 	return mshv_region_process_range(region, map_flags,
262 					 page_offset, page_count,
263 					 mshv_region_chunk_remap);
264 }
265 
266 int mshv_region_map(struct mshv_mem_region *region)
267 {
268 	u32 map_flags = region->hv_map_flags;
269 
270 	return mshv_region_remap_pages(region, map_flags,
271 				       0, region->nr_pages);
272 }
273 
274 static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
275 					 u64 page_offset, u64 page_count)
276 {
277 	if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED)
278 		unpin_user_pages(region->mreg_pages + page_offset, page_count);
279 
280 	memset(region->mreg_pages + page_offset, 0,
281 	       page_count * sizeof(struct page *));
282 }
283 
284 void mshv_region_invalidate(struct mshv_mem_region *region)
285 {
286 	mshv_region_invalidate_pages(region, 0, region->nr_pages);
287 }
288 
289 int mshv_region_pin(struct mshv_mem_region *region)
290 {
291 	u64 done_count, nr_pages;
292 	struct page **pages;
293 	__u64 userspace_addr;
294 	int ret;
295 
296 	for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
297 		pages = region->mreg_pages + done_count;
298 		userspace_addr = region->start_uaddr +
299 				 done_count * HV_HYP_PAGE_SIZE;
300 		nr_pages = min(region->nr_pages - done_count,
301 			       MSHV_PIN_PAGES_BATCH_SIZE);
302 
303 		/*
304 		 * Pinning assuming 4k pages works for large pages too.
305 		 * All page structs within the large page are returned.
306 		 *
307 		 * Pin requests are batched because pin_user_pages_fast
308 		 * with the FOLL_LONGTERM flag does a large temporary
309 		 * allocation of contiguous memory.
310 		 */
311 		ret = pin_user_pages_fast(userspace_addr, nr_pages,
312 					  FOLL_WRITE | FOLL_LONGTERM,
313 					  pages);
314 		if (ret != nr_pages)
315 			goto release_pages;
316 	}
317 
318 	return 0;
319 
320 release_pages:
321 	if (ret > 0)
322 		done_count += ret;
323 	mshv_region_invalidate_pages(region, 0, done_count);
324 	return ret < 0 ? ret : -ENOMEM;
325 }
326 
327 static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
328 				   u32 flags,
329 				   u64 page_offset, u64 page_count,
330 				   bool huge_page)
331 {
332 	if (huge_page)
333 		flags |= HV_UNMAP_GPA_LARGE_PAGE;
334 
335 	return hv_call_unmap_gpa_pages(region->partition->pt_id,
336 				       region->start_gfn + page_offset,
337 				       page_count, flags);
338 }
339 
340 static int mshv_region_unmap(struct mshv_mem_region *region)
341 {
342 	return mshv_region_process_range(region, 0,
343 					 0, region->nr_pages,
344 					 mshv_region_chunk_unmap);
345 }
346 
347 static void mshv_region_destroy(struct kref *ref)
348 {
349 	struct mshv_mem_region *region =
350 		container_of(ref, struct mshv_mem_region, mreg_refcount);
351 	struct mshv_partition *partition = region->partition;
352 	int ret;
353 
354 	if (region->mreg_type == MSHV_REGION_TYPE_MEM_MOVABLE)
355 		mshv_region_movable_fini(region);
356 
357 	if (mshv_partition_encrypted(partition)) {
358 		ret = mshv_region_share(region);
359 		if (ret) {
360 			pt_err(partition,
361 			       "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
362 			       ret);
363 			return;
364 		}
365 	}
366 
367 	mshv_region_unmap(region);
368 
369 	mshv_region_invalidate(region);
370 
371 	vfree(region);
372 }
373 
374 void mshv_region_put(struct mshv_mem_region *region)
375 {
376 	kref_put(&region->mreg_refcount, mshv_region_destroy);
377 }
378 
379 int mshv_region_get(struct mshv_mem_region *region)
380 {
381 	return kref_get_unless_zero(&region->mreg_refcount);
382 }
383 
384 /**
385  * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
386  * @region: Pointer to the memory region structure
387  * @range: Pointer to the HMM range structure
388  *
389  * This function performs the following steps:
390  * 1. Reads the notifier sequence for the HMM range.
391  * 2. Acquires a read lock on the memory map.
392  * 3. Handles HMM faults for the specified range.
393  * 4. Releases the read lock on the memory map.
394  * 5. If successful, locks the memory region mutex.
395  * 6. Verifies if the notifier sequence has changed during the operation.
396  *    If it has, releases the mutex and returns -EBUSY to match with
397  *    hmm_range_fault() return code for repeating.
398  *
399  * Return: 0 on success, a negative error code otherwise.
400  */
401 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
402 					  struct hmm_range *range)
403 {
404 	int ret;
405 
406 	range->notifier_seq = mmu_interval_read_begin(range->notifier);
407 	mmap_read_lock(region->mreg_mni.mm);
408 	ret = hmm_range_fault(range);
409 	mmap_read_unlock(region->mreg_mni.mm);
410 	if (ret)
411 		return ret;
412 
413 	mutex_lock(&region->mreg_mutex);
414 
415 	if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
416 		mutex_unlock(&region->mreg_mutex);
417 		cond_resched();
418 		return -EBUSY;
419 	}
420 
421 	return 0;
422 }
423 
424 /**
425  * mshv_region_range_fault - Handle memory range faults for a given region.
426  * @region: Pointer to the memory region structure.
427  * @page_offset: Offset of the page within the region.
428  * @page_count: Number of pages to handle.
429  *
430  * This function resolves memory faults for a specified range of pages
431  * within a memory region. It uses HMM (Heterogeneous Memory Management)
432  * to fault in the required pages and updates the region's page array.
433  *
434  * Return: 0 on success, negative error code on failure.
435  */
436 static int mshv_region_range_fault(struct mshv_mem_region *region,
437 				   u64 page_offset, u64 page_count)
438 {
439 	struct hmm_range range = {
440 		.notifier = &region->mreg_mni,
441 		.default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
442 	};
443 	unsigned long *pfns;
444 	int ret;
445 	u64 i;
446 
447 	pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
448 	if (!pfns)
449 		return -ENOMEM;
450 
451 	range.hmm_pfns = pfns;
452 	range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
453 	range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
454 
455 	do {
456 		ret = mshv_region_hmm_fault_and_lock(region, &range);
457 	} while (ret == -EBUSY);
458 
459 	if (ret)
460 		goto out;
461 
462 	for (i = 0; i < page_count; i++)
463 		region->mreg_pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
464 
465 	ret = mshv_region_remap_pages(region, region->hv_map_flags,
466 				      page_offset, page_count);
467 
468 	mutex_unlock(&region->mreg_mutex);
469 out:
470 	kfree(pfns);
471 	return ret;
472 }
473 
474 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
475 {
476 	u64 page_offset, page_count;
477 	int ret;
478 
479 	/* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
480 	page_offset = ALIGN_DOWN(gfn - region->start_gfn,
481 				 MSHV_MAP_FAULT_IN_PAGES);
482 
483 	/* Map more pages than requested to reduce the number of faults. */
484 	page_count = min(region->nr_pages - page_offset,
485 			 MSHV_MAP_FAULT_IN_PAGES);
486 
487 	ret = mshv_region_range_fault(region, page_offset, page_count);
488 
489 	WARN_ONCE(ret,
490 		  "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
491 		  region->partition->pt_id, region->start_uaddr,
492 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
493 		  gfn, page_offset, page_count);
494 
495 	return !ret;
496 }
497 
498 /**
499  * mshv_region_interval_invalidate - Invalidate a range of memory region
500  * @mni: Pointer to the mmu_interval_notifier structure
501  * @range: Pointer to the mmu_notifier_range structure
502  * @cur_seq: Current sequence number for the interval notifier
503  *
504  * This function invalidates a memory region by remapping its pages with
505  * no access permissions. It locks the region's mutex to ensure thread safety
506  * and updates the sequence number for the interval notifier. If the range
507  * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
508  * lock and returns false if unsuccessful.
509  *
510  * NOTE: Failure to invalidate a region is a serious error, as the pages will
511  * be considered freed while they are still mapped by the hypervisor.
512  * Any attempt to access such pages will likely crash the system.
513  *
514  * Return: true if the region was successfully invalidated, false otherwise.
515  */
516 static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
517 					    const struct mmu_notifier_range *range,
518 					    unsigned long cur_seq)
519 {
520 	struct mshv_mem_region *region = container_of(mni,
521 						      struct mshv_mem_region,
522 						      mreg_mni);
523 	u64 page_offset, page_count;
524 	unsigned long mstart, mend;
525 	int ret = -EPERM;
526 
527 	mstart = max(range->start, region->start_uaddr);
528 	mend = min(range->end, region->start_uaddr +
529 		   (region->nr_pages << HV_HYP_PAGE_SHIFT));
530 
531 	page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
532 	page_count = HVPFN_DOWN(mend - mstart);
533 
534 	if (mmu_notifier_range_blockable(range))
535 		mutex_lock(&region->mreg_mutex);
536 	else if (!mutex_trylock(&region->mreg_mutex))
537 		goto out_fail;
538 
539 	mmu_interval_set_seq(mni, cur_seq);
540 
541 	ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
542 				      page_offset, page_count);
543 	if (ret)
544 		goto out_unlock;
545 
546 	mshv_region_invalidate_pages(region, page_offset, page_count);
547 
548 	mutex_unlock(&region->mreg_mutex);
549 
550 	return true;
551 
552 out_unlock:
553 	mutex_unlock(&region->mreg_mutex);
554 out_fail:
555 	WARN_ONCE(ret,
556 		  "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
557 		  region->start_uaddr,
558 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
559 		  range->start, range->end, range->event,
560 		  page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
561 	return false;
562 }
563 
564 static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
565 	.invalidate = mshv_region_interval_invalidate,
566 };
567 
568 void mshv_region_movable_fini(struct mshv_mem_region *region)
569 {
570 	mmu_interval_notifier_remove(&region->mreg_mni);
571 }
572 
573 bool mshv_region_movable_init(struct mshv_mem_region *region)
574 {
575 	int ret;
576 
577 	ret = mmu_interval_notifier_insert(&region->mreg_mni, current->mm,
578 					   region->start_uaddr,
579 					   region->nr_pages << HV_HYP_PAGE_SHIFT,
580 					   &mshv_region_mni_ops);
581 	if (ret)
582 		return false;
583 
584 	mutex_init(&region->mreg_mutex);
585 
586 	return true;
587 }
588