xref: /linux/drivers/hv/mshv_regions.c (revision feb06d2690bb826fd33798a99ce5cff8d07b38f9)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2025, Microsoft Corporation.
4  *
5  * Memory region management for mshv_root module.
6  *
7  * Authors: Microsoft Linux virtualization team
8  */
9 
10 #include <linux/hmm.h>
11 #include <linux/hyperv.h>
12 #include <linux/kref.h>
13 #include <linux/mm.h>
14 #include <linux/vmalloc.h>
15 
16 #include <asm/mshyperv.h>
17 
18 #include "mshv_root.h"
19 
20 #define MSHV_MAP_FAULT_IN_PAGES				PTRS_PER_PMD
21 
22 /**
23  * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
24  *                             in a region.
25  * @region     : Pointer to the memory region structure.
26  * @flags      : Flags to pass to the handler.
27  * @page_offset: Offset into the region's pages array to start processing.
28  * @page_count : Number of pages to process.
29  * @handler    : Callback function to handle the chunk.
30  *
31  * This function scans the region's pages starting from @page_offset,
32  * checking for contiguous present pages of the same size (normal or huge).
33  * It invokes @handler for the chunk of contiguous pages found. Returns the
34  * number of pages handled, or a negative error code if the first page is
35  * not present or the handler fails.
36  *
37  * Note: The @handler callback must be able to handle both normal and huge
38  * pages.
39  *
40  * Return: Number of pages handled, or negative error code.
41  */
mshv_region_process_chunk(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count))42 static long mshv_region_process_chunk(struct mshv_mem_region *region,
43 				      u32 flags,
44 				      u64 page_offset, u64 page_count,
45 				      int (*handler)(struct mshv_mem_region *region,
46 						     u32 flags,
47 						     u64 page_offset,
48 						     u64 page_count))
49 {
50 	u64 count, stride;
51 	unsigned int page_order;
52 	struct page *page;
53 	int ret;
54 
55 	page = region->pages[page_offset];
56 	if (!page)
57 		return -EINVAL;
58 
59 	page_order = folio_order(page_folio(page));
60 	/* The hypervisor only supports 4K and 2M page sizes */
61 	if (page_order && page_order != HPAGE_PMD_ORDER)
62 		return -EINVAL;
63 
64 	stride = 1 << page_order;
65 
66 	/* Start at stride since the first page is validated */
67 	for (count = stride; count < page_count; count += stride) {
68 		page = region->pages[page_offset + count];
69 
70 		/* Break if current page is not present */
71 		if (!page)
72 			break;
73 
74 		/* Break if page size changes */
75 		if (page_order != folio_order(page_folio(page)))
76 			break;
77 	}
78 
79 	ret = handler(region, flags, page_offset, count);
80 	if (ret)
81 		return ret;
82 
83 	return count;
84 }
85 
86 /**
87  * mshv_region_process_range - Processes a range of memory pages in a
88  *                             region.
89  * @region     : Pointer to the memory region structure.
90  * @flags      : Flags to pass to the handler.
91  * @page_offset: Offset into the region's pages array to start processing.
92  * @page_count : Number of pages to process.
93  * @handler    : Callback function to handle each chunk of contiguous
94  *               pages.
95  *
96  * Iterates over the specified range of pages in @region, skipping
97  * non-present pages. For each contiguous chunk of present pages, invokes
98  * @handler via mshv_region_process_chunk.
99  *
100  * Note: The @handler callback must be able to handle both normal and huge
101  * pages.
102  *
103  * Returns 0 on success, or a negative error code on failure.
104  */
mshv_region_process_range(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count,int (* handler)(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count))105 static int mshv_region_process_range(struct mshv_mem_region *region,
106 				     u32 flags,
107 				     u64 page_offset, u64 page_count,
108 				     int (*handler)(struct mshv_mem_region *region,
109 						    u32 flags,
110 						    u64 page_offset,
111 						    u64 page_count))
112 {
113 	long ret;
114 
115 	if (page_offset + page_count > region->nr_pages)
116 		return -EINVAL;
117 
118 	while (page_count) {
119 		/* Skip non-present pages */
120 		if (!region->pages[page_offset]) {
121 			page_offset++;
122 			page_count--;
123 			continue;
124 		}
125 
126 		ret = mshv_region_process_chunk(region, flags,
127 						page_offset,
128 						page_count,
129 						handler);
130 		if (ret < 0)
131 			return ret;
132 
133 		page_offset += ret;
134 		page_count -= ret;
135 	}
136 
137 	return 0;
138 }
139 
mshv_region_create(u64 guest_pfn,u64 nr_pages,u64 uaddr,u32 flags)140 struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
141 					   u64 uaddr, u32 flags)
142 {
143 	struct mshv_mem_region *region;
144 
145 	region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
146 	if (!region)
147 		return ERR_PTR(-ENOMEM);
148 
149 	region->nr_pages = nr_pages;
150 	region->start_gfn = guest_pfn;
151 	region->start_uaddr = uaddr;
152 	region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
153 	if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
154 		region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
155 	if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
156 		region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
157 
158 	kref_init(&region->refcount);
159 
160 	return region;
161 }
162 
mshv_region_chunk_share(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)163 static int mshv_region_chunk_share(struct mshv_mem_region *region,
164 				   u32 flags,
165 				   u64 page_offset, u64 page_count)
166 {
167 	struct page *page = region->pages[page_offset];
168 
169 	if (PageHuge(page) || PageTransCompound(page))
170 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
171 
172 	return hv_call_modify_spa_host_access(region->partition->pt_id,
173 					      region->pages + page_offset,
174 					      page_count,
175 					      HV_MAP_GPA_READABLE |
176 					      HV_MAP_GPA_WRITABLE,
177 					      flags, true);
178 }
179 
mshv_region_share(struct mshv_mem_region * region)180 int mshv_region_share(struct mshv_mem_region *region)
181 {
182 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
183 
184 	return mshv_region_process_range(region, flags,
185 					 0, region->nr_pages,
186 					 mshv_region_chunk_share);
187 }
188 
mshv_region_chunk_unshare(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)189 static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
190 				     u32 flags,
191 				     u64 page_offset, u64 page_count)
192 {
193 	struct page *page = region->pages[page_offset];
194 
195 	if (PageHuge(page) || PageTransCompound(page))
196 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
197 
198 	return hv_call_modify_spa_host_access(region->partition->pt_id,
199 					      region->pages + page_offset,
200 					      page_count, 0,
201 					      flags, false);
202 }
203 
mshv_region_unshare(struct mshv_mem_region * region)204 int mshv_region_unshare(struct mshv_mem_region *region)
205 {
206 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
207 
208 	return mshv_region_process_range(region, flags,
209 					 0, region->nr_pages,
210 					 mshv_region_chunk_unshare);
211 }
212 
mshv_region_chunk_remap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)213 static int mshv_region_chunk_remap(struct mshv_mem_region *region,
214 				   u32 flags,
215 				   u64 page_offset, u64 page_count)
216 {
217 	struct page *page = region->pages[page_offset];
218 
219 	if (PageHuge(page) || PageTransCompound(page))
220 		flags |= HV_MAP_GPA_LARGE_PAGE;
221 
222 	return hv_call_map_gpa_pages(region->partition->pt_id,
223 				     region->start_gfn + page_offset,
224 				     page_count, flags,
225 				     region->pages + page_offset);
226 }
227 
mshv_region_remap_pages(struct mshv_mem_region * region,u32 map_flags,u64 page_offset,u64 page_count)228 static int mshv_region_remap_pages(struct mshv_mem_region *region,
229 				   u32 map_flags,
230 				   u64 page_offset, u64 page_count)
231 {
232 	return mshv_region_process_range(region, map_flags,
233 					 page_offset, page_count,
234 					 mshv_region_chunk_remap);
235 }
236 
mshv_region_map(struct mshv_mem_region * region)237 int mshv_region_map(struct mshv_mem_region *region)
238 {
239 	u32 map_flags = region->hv_map_flags;
240 
241 	return mshv_region_remap_pages(region, map_flags,
242 				       0, region->nr_pages);
243 }
244 
mshv_region_invalidate_pages(struct mshv_mem_region * region,u64 page_offset,u64 page_count)245 static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
246 					 u64 page_offset, u64 page_count)
247 {
248 	if (region->type == MSHV_REGION_TYPE_MEM_PINNED)
249 		unpin_user_pages(region->pages + page_offset, page_count);
250 
251 	memset(region->pages + page_offset, 0,
252 	       page_count * sizeof(struct page *));
253 }
254 
mshv_region_invalidate(struct mshv_mem_region * region)255 void mshv_region_invalidate(struct mshv_mem_region *region)
256 {
257 	mshv_region_invalidate_pages(region, 0, region->nr_pages);
258 }
259 
mshv_region_pin(struct mshv_mem_region * region)260 int mshv_region_pin(struct mshv_mem_region *region)
261 {
262 	u64 done_count, nr_pages;
263 	struct page **pages;
264 	__u64 userspace_addr;
265 	int ret;
266 
267 	for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
268 		pages = region->pages + done_count;
269 		userspace_addr = region->start_uaddr +
270 				 done_count * HV_HYP_PAGE_SIZE;
271 		nr_pages = min(region->nr_pages - done_count,
272 			       MSHV_PIN_PAGES_BATCH_SIZE);
273 
274 		/*
275 		 * Pinning assuming 4k pages works for large pages too.
276 		 * All page structs within the large page are returned.
277 		 *
278 		 * Pin requests are batched because pin_user_pages_fast
279 		 * with the FOLL_LONGTERM flag does a large temporary
280 		 * allocation of contiguous memory.
281 		 */
282 		ret = pin_user_pages_fast(userspace_addr, nr_pages,
283 					  FOLL_WRITE | FOLL_LONGTERM,
284 					  pages);
285 		if (ret < 0)
286 			goto release_pages;
287 	}
288 
289 	return 0;
290 
291 release_pages:
292 	mshv_region_invalidate_pages(region, 0, done_count);
293 	return ret;
294 }
295 
mshv_region_chunk_unmap(struct mshv_mem_region * region,u32 flags,u64 page_offset,u64 page_count)296 static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
297 				   u32 flags,
298 				   u64 page_offset, u64 page_count)
299 {
300 	struct page *page = region->pages[page_offset];
301 
302 	if (PageHuge(page) || PageTransCompound(page))
303 		flags |= HV_UNMAP_GPA_LARGE_PAGE;
304 
305 	return hv_call_unmap_gpa_pages(region->partition->pt_id,
306 				       region->start_gfn + page_offset,
307 				       page_count, flags);
308 }
309 
mshv_region_unmap(struct mshv_mem_region * region)310 static int mshv_region_unmap(struct mshv_mem_region *region)
311 {
312 	return mshv_region_process_range(region, 0,
313 					 0, region->nr_pages,
314 					 mshv_region_chunk_unmap);
315 }
316 
mshv_region_destroy(struct kref * ref)317 static void mshv_region_destroy(struct kref *ref)
318 {
319 	struct mshv_mem_region *region =
320 		container_of(ref, struct mshv_mem_region, refcount);
321 	struct mshv_partition *partition = region->partition;
322 	int ret;
323 
324 	if (region->type == MSHV_REGION_TYPE_MEM_MOVABLE)
325 		mshv_region_movable_fini(region);
326 
327 	if (mshv_partition_encrypted(partition)) {
328 		ret = mshv_region_share(region);
329 		if (ret) {
330 			pt_err(partition,
331 			       "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
332 			       ret);
333 			return;
334 		}
335 	}
336 
337 	mshv_region_unmap(region);
338 
339 	mshv_region_invalidate(region);
340 
341 	vfree(region);
342 }
343 
mshv_region_put(struct mshv_mem_region * region)344 void mshv_region_put(struct mshv_mem_region *region)
345 {
346 	kref_put(&region->refcount, mshv_region_destroy);
347 }
348 
mshv_region_get(struct mshv_mem_region * region)349 int mshv_region_get(struct mshv_mem_region *region)
350 {
351 	return kref_get_unless_zero(&region->refcount);
352 }
353 
354 /**
355  * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
356  * @region: Pointer to the memory region structure
357  * @range: Pointer to the HMM range structure
358  *
359  * This function performs the following steps:
360  * 1. Reads the notifier sequence for the HMM range.
361  * 2. Acquires a read lock on the memory map.
362  * 3. Handles HMM faults for the specified range.
363  * 4. Releases the read lock on the memory map.
364  * 5. If successful, locks the memory region mutex.
365  * 6. Verifies if the notifier sequence has changed during the operation.
366  *    If it has, releases the mutex and returns -EBUSY to match with
367  *    hmm_range_fault() return code for repeating.
368  *
369  * Return: 0 on success, a negative error code otherwise.
370  */
mshv_region_hmm_fault_and_lock(struct mshv_mem_region * region,struct hmm_range * range)371 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
372 					  struct hmm_range *range)
373 {
374 	int ret;
375 
376 	range->notifier_seq = mmu_interval_read_begin(range->notifier);
377 	mmap_read_lock(region->mni.mm);
378 	ret = hmm_range_fault(range);
379 	mmap_read_unlock(region->mni.mm);
380 	if (ret)
381 		return ret;
382 
383 	mutex_lock(&region->mutex);
384 
385 	if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
386 		mutex_unlock(&region->mutex);
387 		cond_resched();
388 		return -EBUSY;
389 	}
390 
391 	return 0;
392 }
393 
394 /**
395  * mshv_region_range_fault - Handle memory range faults for a given region.
396  * @region: Pointer to the memory region structure.
397  * @page_offset: Offset of the page within the region.
398  * @page_count: Number of pages to handle.
399  *
400  * This function resolves memory faults for a specified range of pages
401  * within a memory region. It uses HMM (Heterogeneous Memory Management)
402  * to fault in the required pages and updates the region's page array.
403  *
404  * Return: 0 on success, negative error code on failure.
405  */
mshv_region_range_fault(struct mshv_mem_region * region,u64 page_offset,u64 page_count)406 static int mshv_region_range_fault(struct mshv_mem_region *region,
407 				   u64 page_offset, u64 page_count)
408 {
409 	struct hmm_range range = {
410 		.notifier = &region->mni,
411 		.default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
412 	};
413 	unsigned long *pfns;
414 	int ret;
415 	u64 i;
416 
417 	pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
418 	if (!pfns)
419 		return -ENOMEM;
420 
421 	range.hmm_pfns = pfns;
422 	range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
423 	range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
424 
425 	do {
426 		ret = mshv_region_hmm_fault_and_lock(region, &range);
427 	} while (ret == -EBUSY);
428 
429 	if (ret)
430 		goto out;
431 
432 	for (i = 0; i < page_count; i++)
433 		region->pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
434 
435 	ret = mshv_region_remap_pages(region, region->hv_map_flags,
436 				      page_offset, page_count);
437 
438 	mutex_unlock(&region->mutex);
439 out:
440 	kfree(pfns);
441 	return ret;
442 }
443 
mshv_region_handle_gfn_fault(struct mshv_mem_region * region,u64 gfn)444 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
445 {
446 	u64 page_offset, page_count;
447 	int ret;
448 
449 	/* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
450 	page_offset = ALIGN_DOWN(gfn - region->start_gfn,
451 				 MSHV_MAP_FAULT_IN_PAGES);
452 
453 	/* Map more pages than requested to reduce the number of faults. */
454 	page_count = min(region->nr_pages - page_offset,
455 			 MSHV_MAP_FAULT_IN_PAGES);
456 
457 	ret = mshv_region_range_fault(region, page_offset, page_count);
458 
459 	WARN_ONCE(ret,
460 		  "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
461 		  region->partition->pt_id, region->start_uaddr,
462 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
463 		  gfn, page_offset, page_count);
464 
465 	return !ret;
466 }
467 
468 /**
469  * mshv_region_interval_invalidate - Invalidate a range of memory region
470  * @mni: Pointer to the mmu_interval_notifier structure
471  * @range: Pointer to the mmu_notifier_range structure
472  * @cur_seq: Current sequence number for the interval notifier
473  *
474  * This function invalidates a memory region by remapping its pages with
475  * no access permissions. It locks the region's mutex to ensure thread safety
476  * and updates the sequence number for the interval notifier. If the range
477  * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
478  * lock and returns false if unsuccessful.
479  *
480  * NOTE: Failure to invalidate a region is a serious error, as the pages will
481  * be considered freed while they are still mapped by the hypervisor.
482  * Any attempt to access such pages will likely crash the system.
483  *
484  * Return: true if the region was successfully invalidated, false otherwise.
485  */
mshv_region_interval_invalidate(struct mmu_interval_notifier * mni,const struct mmu_notifier_range * range,unsigned long cur_seq)486 static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
487 					    const struct mmu_notifier_range *range,
488 					    unsigned long cur_seq)
489 {
490 	struct mshv_mem_region *region = container_of(mni,
491 						      struct mshv_mem_region,
492 						      mni);
493 	u64 page_offset, page_count;
494 	unsigned long mstart, mend;
495 	int ret = -EPERM;
496 
497 	if (mmu_notifier_range_blockable(range))
498 		mutex_lock(&region->mutex);
499 	else if (!mutex_trylock(&region->mutex))
500 		goto out_fail;
501 
502 	mmu_interval_set_seq(mni, cur_seq);
503 
504 	mstart = max(range->start, region->start_uaddr);
505 	mend = min(range->end, region->start_uaddr +
506 		   (region->nr_pages << HV_HYP_PAGE_SHIFT));
507 
508 	page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
509 	page_count = HVPFN_DOWN(mend - mstart);
510 
511 	ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
512 				      page_offset, page_count);
513 	if (ret)
514 		goto out_fail;
515 
516 	mshv_region_invalidate_pages(region, page_offset, page_count);
517 
518 	mutex_unlock(&region->mutex);
519 
520 	return true;
521 
522 out_fail:
523 	WARN_ONCE(ret,
524 		  "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
525 		  region->start_uaddr,
526 		  region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
527 		  range->start, range->end, range->event,
528 		  page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
529 	return false;
530 }
531 
532 static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
533 	.invalidate = mshv_region_interval_invalidate,
534 };
535 
mshv_region_movable_fini(struct mshv_mem_region * region)536 void mshv_region_movable_fini(struct mshv_mem_region *region)
537 {
538 	mmu_interval_notifier_remove(&region->mni);
539 }
540 
mshv_region_movable_init(struct mshv_mem_region * region)541 bool mshv_region_movable_init(struct mshv_mem_region *region)
542 {
543 	int ret;
544 
545 	ret = mmu_interval_notifier_insert(&region->mni, current->mm,
546 					   region->start_uaddr,
547 					   region->nr_pages << HV_HYP_PAGE_SHIFT,
548 					   &mshv_region_mni_ops);
549 	if (ret)
550 		return false;
551 
552 	mutex_init(&region->mutex);
553 
554 	return true;
555 }
556