xref: /linux/drivers/hv/mshv_regions.c (revision af0bc3ac9a9e830cb52b718ecb237c4e76a466be)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2025, Microsoft Corporation.
4  *
5  * Memory region management for mshv_root module.
6  *
7  * Authors: Microsoft Linux virtualization team
8  */
9 
10 #include <linux/hmm.h>
11 #include <linux/hyperv.h>
12 #include <linux/kref.h>
13 #include <linux/mm.h>
14 #include <linux/vmalloc.h>
15 
16 #include <asm/mshyperv.h>
17 
18 #include "mshv_root.h"
19 
20 #define MSHV_MAP_FAULT_IN_PAGES				PTRS_PER_PMD
21 
22 /**
23  * mshv_chunk_stride - Compute stride for mapping guest memory
24  * @page      : The page to check for huge page backing
25  * @gfn       : Guest frame number for the mapping
26  * @page_count: Total number of pages in the mapping
27  *
28  * Determines the appropriate stride (in pages) for mapping guest memory.
29  * Uses huge page stride if the backing page is huge and the guest mapping
30  * is properly aligned; otherwise falls back to single page stride.
31  *
32  * Return: Stride in pages, or -EINVAL if page order is unsupported.
33  */
34 static int mshv_chunk_stride(struct page *page,
35 			     u64 gfn, u64 page_count)
36 {
37 	unsigned int page_order;
38 
39 	/*
40 	 * Use single page stride by default. For huge page stride, the
41 	 * page must be compound and point to the head of the compound
42 	 * page, and both gfn and page_count must be huge-page aligned.
43 	 */
44 	if (!PageCompound(page) || !PageHead(page) ||
45 	    !IS_ALIGNED(gfn, PTRS_PER_PMD) ||
46 	    !IS_ALIGNED(page_count, PTRS_PER_PMD))
47 		return 1;
48 
49 	page_order = folio_order(page_folio(page));
50 	/* The hypervisor only supports 2M huge page */
51 	if (page_order != PMD_ORDER)
52 		return -EINVAL;
53 
54 	return 1 << page_order;
55 }
56 
57 /**
58  * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
59  *                             in a region.
60  * @region     : Pointer to the memory region structure.
61  * @flags      : Flags to pass to the handler.
62  * @page_offset: Offset into the region's pages array to start processing.
63  * @page_count : Number of pages to process.
64  * @handler    : Callback function to handle the chunk.
65  *
66  * This function scans the region's pages starting from @page_offset,
67  * checking for contiguous present pages of the same size (normal or huge).
68  * It invokes @handler for the chunk of contiguous pages found. Returns the
69  * number of pages handled, or a negative error code if the first page is
70  * not present or the handler fails.
71  *
72  * Note: The @handler callback must be able to handle both normal and huge
73  * pages.
74  *
75  * Return: Number of pages handled, or negative error code.
76  */
77 static long mshv_region_process_chunk(struct mshv_mem_region *region,
78 				      u32 flags,
79 				      u64 page_offset, u64 page_count,
80 				      int (*handler)(struct mshv_mem_region *region,
81 						     u32 flags,
82 						     u64 page_offset,
83 						     u64 page_count,
84 						     bool huge_page))
85 {
86 	u64 gfn = region->start_gfn + page_offset;
87 	u64 count;
88 	struct page *page;
89 	int stride, ret;
90 
91 	page = region->pages[page_offset];
92 	if (!page)
93 		return -EINVAL;
94 
95 	stride = mshv_chunk_stride(page, gfn, page_count);
96 	if (stride < 0)
97 		return stride;
98 
99 	/* Start at stride since the first stride is validated */
100 	for (count = stride; count < page_count; count += stride) {
101 		page = region->pages[page_offset + count];
102 
103 		/* Break if current page is not present */
104 		if (!page)
105 			break;
106 
107 		/* Break if stride size changes */
108 		if (stride != mshv_chunk_stride(page, gfn + count,
109 						page_count - count))
110 			break;
111 	}
112 
113 	ret = handler(region, flags, page_offset, count, stride > 1);
114 	if (ret)
115 		return ret;
116 
117 	return count;
118 }
119 
120 /**
121  * mshv_region_process_range - Processes a range of memory pages in a
122  *                             region.
123  * @region     : Pointer to the memory region structure.
124  * @flags      : Flags to pass to the handler.
125  * @page_offset: Offset into the region's pages array to start processing.
126  * @page_count : Number of pages to process.
127  * @handler    : Callback function to handle each chunk of contiguous
128  *               pages.
129  *
130  * Iterates over the specified range of pages in @region, skipping
131  * non-present pages. For each contiguous chunk of present pages, invokes
132  * @handler via mshv_region_process_chunk.
133  *
134  * Note: The @handler callback must be able to handle both normal and huge
135  * pages.
136  *
137  * Returns 0 on success, or a negative error code on failure.
138  */
139 static int mshv_region_process_range(struct mshv_mem_region *region,
140 				     u32 flags,
141 				     u64 page_offset, u64 page_count,
142 				     int (*handler)(struct mshv_mem_region *region,
143 						    u32 flags,
144 						    u64 page_offset,
145 						    u64 page_count,
146 						    bool huge_page))
147 {
148 	long ret;
149 
150 	if (page_offset + page_count > region->nr_pages)
151 		return -EINVAL;
152 
153 	while (page_count) {
154 		/* Skip non-present pages */
155 		if (!region->pages[page_offset]) {
156 			page_offset++;
157 			page_count--;
158 			continue;
159 		}
160 
161 		ret = mshv_region_process_chunk(region, flags,
162 						page_offset,
163 						page_count,
164 						handler);
165 		if (ret < 0)
166 			return ret;
167 
168 		page_offset += ret;
169 		page_count -= ret;
170 	}
171 
172 	return 0;
173 }
174 
175 struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
176 					   u64 uaddr, u32 flags)
177 {
178 	struct mshv_mem_region *region;
179 
180 	region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
181 	if (!region)
182 		return ERR_PTR(-ENOMEM);
183 
184 	region->nr_pages = nr_pages;
185 	region->start_gfn = guest_pfn;
186 	region->start_uaddr = uaddr;
187 	region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
188 	if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
189 		region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
190 	if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
191 		region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
192 
193 	kref_init(&region->refcount);
194 
195 	return region;
196 }
197 
198 static int mshv_region_chunk_share(struct mshv_mem_region *region,
199 				   u32 flags,
200 				   u64 page_offset, u64 page_count,
201 				   bool huge_page)
202 {
203 	if (huge_page)
204 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
205 
206 	return hv_call_modify_spa_host_access(region->partition->pt_id,
207 					      region->pages + page_offset,
208 					      page_count,
209 					      HV_MAP_GPA_READABLE |
210 					      HV_MAP_GPA_WRITABLE,
211 					      flags, true);
212 }
213 
214 int mshv_region_share(struct mshv_mem_region *region)
215 {
216 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
217 
218 	return mshv_region_process_range(region, flags,
219 					 0, region->nr_pages,
220 					 mshv_region_chunk_share);
221 }
222 
223 static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
224 				     u32 flags,
225 				     u64 page_offset, u64 page_count,
226 				     bool huge_page)
227 {
228 	if (huge_page)
229 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
230 
231 	return hv_call_modify_spa_host_access(region->partition->pt_id,
232 					      region->pages + page_offset,
233 					      page_count, 0,
234 					      flags, false);
235 }
236 
237 int mshv_region_unshare(struct mshv_mem_region *region)
238 {
239 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
240 
241 	return mshv_region_process_range(region, flags,
242 					 0, region->nr_pages,
243 					 mshv_region_chunk_unshare);
244 }
245 
246 static int mshv_region_chunk_remap(struct mshv_mem_region *region,
247 				   u32 flags,
248 				   u64 page_offset, u64 page_count,
249 				   bool huge_page)
250 {
251 	if (huge_page)
252 		flags |= HV_MAP_GPA_LARGE_PAGE;
253 
254 	return hv_call_map_gpa_pages(region->partition->pt_id,
255 				     region->start_gfn + page_offset,
256 				     page_count, flags,
257 				     region->pages + page_offset);
258 }
259 
260 static int mshv_region_remap_pages(struct mshv_mem_region *region,
261 				   u32 map_flags,
262 				   u64 page_offset, u64 page_count)
263 {
264 	return mshv_region_process_range(region, map_flags,
265 					 page_offset, page_count,
266 					 mshv_region_chunk_remap);
267 }
268 
269 int mshv_region_map(struct mshv_mem_region *region)
270 {
271 	u32 map_flags = region->hv_map_flags;
272 
273 	return mshv_region_remap_pages(region, map_flags,
274 				       0, region->nr_pages);
275 }
276 
277 static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
278 					 u64 page_offset, u64 page_count)
279 {
280 	if (region->type == MSHV_REGION_TYPE_MEM_PINNED)
281 		unpin_user_pages(region->pages + page_offset, page_count);
282 
283 	memset(region->pages + page_offset, 0,
284 	       page_count * sizeof(struct page *));
285 }
286 
287 void mshv_region_invalidate(struct mshv_mem_region *region)
288 {
289 	mshv_region_invalidate_pages(region, 0, region->nr_pages);
290 }
291 
292 int mshv_region_pin(struct mshv_mem_region *region)
293 {
294 	u64 done_count, nr_pages;
295 	struct page **pages;
296 	__u64 userspace_addr;
297 	int ret;
298 
299 	for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
300 		pages = region->pages + done_count;
301 		userspace_addr = region->start_uaddr +
302 				 done_count * HV_HYP_PAGE_SIZE;
303 		nr_pages = min(region->nr_pages - done_count,
304 			       MSHV_PIN_PAGES_BATCH_SIZE);
305 
306 		/*
307 		 * Pinning assuming 4k pages works for large pages too.
308 		 * All page structs within the large page are returned.
309 		 *
310 		 * Pin requests are batched because pin_user_pages_fast
311 		 * with the FOLL_LONGTERM flag does a large temporary
312 		 * allocation of contiguous memory.
313 		 */
314 		ret = pin_user_pages_fast(userspace_addr, nr_pages,
315 					  FOLL_WRITE | FOLL_LONGTERM,
316 					  pages);
317 		if (ret < 0)
318 			goto release_pages;
319 	}
320 
321 	return 0;
322 
323 release_pages:
324 	mshv_region_invalidate_pages(region, 0, done_count);
325 	return ret;
326 }
327 
328 static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
329 				   u32 flags,
330 				   u64 page_offset, u64 page_count,
331 				   bool huge_page)
332 {
333 	if (huge_page)
334 		flags |= HV_UNMAP_GPA_LARGE_PAGE;
335 
336 	return hv_call_unmap_gpa_pages(region->partition->pt_id,
337 				       region->start_gfn + page_offset,
338 				       page_count, flags);
339 }
340 
341 static int mshv_region_unmap(struct mshv_mem_region *region)
342 {
343 	return mshv_region_process_range(region, 0,
344 					 0, region->nr_pages,
345 					 mshv_region_chunk_unmap);
346 }
347 
348 static void mshv_region_destroy(struct kref *ref)
349 {
350 	struct mshv_mem_region *region =
351 		container_of(ref, struct mshv_mem_region, refcount);
352 	struct mshv_partition *partition = region->partition;
353 	int ret;
354 
355 	if (region->type == MSHV_REGION_TYPE_MEM_MOVABLE)
356 		mshv_region_movable_fini(region);
357 
358 	if (mshv_partition_encrypted(partition)) {
359 		ret = mshv_region_share(region);
360 		if (ret) {
361 			pt_err(partition,
362 			       "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
363 			       ret);
364 			return;
365 		}
366 	}
367 
368 	mshv_region_unmap(region);
369 
370 	mshv_region_invalidate(region);
371 
372 	vfree(region);
373 }
374 
375 void mshv_region_put(struct mshv_mem_region *region)
376 {
377 	kref_put(&region->refcount, mshv_region_destroy);
378 }
379 
380 int mshv_region_get(struct mshv_mem_region *region)
381 {
382 	return kref_get_unless_zero(&region->refcount);
383 }
384 
385 /**
386  * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
387  * @region: Pointer to the memory region structure
388  * @range: Pointer to the HMM range structure
389  *
390  * This function performs the following steps:
391  * 1. Reads the notifier sequence for the HMM range.
392  * 2. Acquires a read lock on the memory map.
393  * 3. Handles HMM faults for the specified range.
394  * 4. Releases the read lock on the memory map.
395  * 5. If successful, locks the memory region mutex.
396  * 6. Verifies if the notifier sequence has changed during the operation.
397  *    If it has, releases the mutex and returns -EBUSY to match with
398  *    hmm_range_fault() return code for repeating.
399  *
400  * Return: 0 on success, a negative error code otherwise.
401  */
402 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
403 					  struct hmm_range *range)
404 {
405 	int ret;
406 
407 	range->notifier_seq = mmu_interval_read_begin(range->notifier);
408 	mmap_read_lock(region->mni.mm);
409 	ret = hmm_range_fault(range);
410 	mmap_read_unlock(region->mni.mm);
411 	if (ret)
412 		return ret;
413 
414 	mutex_lock(&region->mutex);
415 
416 	if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
417 		mutex_unlock(&region->mutex);
418 		cond_resched();
419 		return -EBUSY;
420 	}
421 
422 	return 0;
423 }
424 
425 /**
426  * mshv_region_range_fault - Handle memory range faults for a given region.
427  * @region: Pointer to the memory region structure.
428  * @page_offset: Offset of the page within the region.
429  * @page_count: Number of pages to handle.
430  *
431  * This function resolves memory faults for a specified range of pages
432  * within a memory region. It uses HMM (Heterogeneous Memory Management)
433  * to fault in the required pages and updates the region's page array.
434  *
435  * Return: 0 on success, negative error code on failure.
436  */
437 static int mshv_region_range_fault(struct mshv_mem_region *region,
438 				   u64 page_offset, u64 page_count)
439 {
440 	struct hmm_range range = {
441 		.notifier = &region->mni,
442 		.default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
443 	};
444 	unsigned long *pfns;
445 	int ret;
446 	u64 i;
447 
448 	pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
449 	if (!pfns)
450 		return -ENOMEM;
451 
452 	range.hmm_pfns = pfns;
453 	range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
454 	range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
455 
456 	do {
457 		ret = mshv_region_hmm_fault_and_lock(region, &range);
458 	} while (ret == -EBUSY);
459 
460 	if (ret)
461 		goto out;
462 
463 	for (i = 0; i < page_count; i++)
464 		region->pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
465 
466 	ret = mshv_region_remap_pages(region, region->hv_map_flags,
467 				      page_offset, page_count);
468 
469 	mutex_unlock(&region->mutex);
470 out:
471 	kfree(pfns);
472 	return ret;
473 }
474 
475 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
476 {
477 	u64 page_offset, page_count;
478 	int ret;
479 
480 	/* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
481 	page_offset = ALIGN_DOWN(gfn - region->start_gfn,
482 				 MSHV_MAP_FAULT_IN_PAGES);
483 
484 	/* Map more pages than requested to reduce the number of faults. */
485 	page_count = min(region->nr_pages - page_offset,
486 			 MSHV_MAP_FAULT_IN_PAGES);
487 
488 	ret = mshv_region_range_fault(region, page_offset, page_count);
489 
490 	WARN_ONCE(ret,
491 		  "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
492 		  region->partition->pt_id, region->start_uaddr,
493 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
494 		  gfn, page_offset, page_count);
495 
496 	return !ret;
497 }
498 
499 /**
500  * mshv_region_interval_invalidate - Invalidate a range of memory region
501  * @mni: Pointer to the mmu_interval_notifier structure
502  * @range: Pointer to the mmu_notifier_range structure
503  * @cur_seq: Current sequence number for the interval notifier
504  *
505  * This function invalidates a memory region by remapping its pages with
506  * no access permissions. It locks the region's mutex to ensure thread safety
507  * and updates the sequence number for the interval notifier. If the range
508  * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
509  * lock and returns false if unsuccessful.
510  *
511  * NOTE: Failure to invalidate a region is a serious error, as the pages will
512  * be considered freed while they are still mapped by the hypervisor.
513  * Any attempt to access such pages will likely crash the system.
514  *
515  * Return: true if the region was successfully invalidated, false otherwise.
516  */
517 static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
518 					    const struct mmu_notifier_range *range,
519 					    unsigned long cur_seq)
520 {
521 	struct mshv_mem_region *region = container_of(mni,
522 						      struct mshv_mem_region,
523 						      mni);
524 	u64 page_offset, page_count;
525 	unsigned long mstart, mend;
526 	int ret = -EPERM;
527 
528 	mstart = max(range->start, region->start_uaddr);
529 	mend = min(range->end, region->start_uaddr +
530 		   (region->nr_pages << HV_HYP_PAGE_SHIFT));
531 
532 	page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
533 	page_count = HVPFN_DOWN(mend - mstart);
534 
535 	if (mmu_notifier_range_blockable(range))
536 		mutex_lock(&region->mutex);
537 	else if (!mutex_trylock(&region->mutex))
538 		goto out_fail;
539 
540 	mmu_interval_set_seq(mni, cur_seq);
541 
542 	ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
543 				      page_offset, page_count);
544 	if (ret)
545 		goto out_unlock;
546 
547 	mshv_region_invalidate_pages(region, page_offset, page_count);
548 
549 	mutex_unlock(&region->mutex);
550 
551 	return true;
552 
553 out_unlock:
554 	mutex_unlock(&region->mutex);
555 out_fail:
556 	WARN_ONCE(ret,
557 		  "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
558 		  region->start_uaddr,
559 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
560 		  range->start, range->end, range->event,
561 		  page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
562 	return false;
563 }
564 
565 static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
566 	.invalidate = mshv_region_interval_invalidate,
567 };
568 
569 void mshv_region_movable_fini(struct mshv_mem_region *region)
570 {
571 	mmu_interval_notifier_remove(&region->mni);
572 }
573 
574 bool mshv_region_movable_init(struct mshv_mem_region *region)
575 {
576 	int ret;
577 
578 	ret = mmu_interval_notifier_insert(&region->mni, current->mm,
579 					   region->start_uaddr,
580 					   region->nr_pages << HV_HYP_PAGE_SHIFT,
581 					   &mshv_region_mni_ops);
582 	if (ret)
583 		return false;
584 
585 	mutex_init(&region->mutex);
586 
587 	return true;
588 }
589