xref: /linux/drivers/hv/mshv_root_main.c (revision 7fc2cd2e4b398c57c9cf961cfea05eadbf34c05c)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2024, Microsoft Corporation.
4  *
5  * The main part of the mshv_root module, providing APIs to create
6  * and manage guest partitions.
7  *
8  * Authors: Microsoft Linux virtualization team
9  */
10 
11 #include <linux/entry-virt.h>
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/fs.h>
15 #include <linux/miscdevice.h>
16 #include <linux/slab.h>
17 #include <linux/file.h>
18 #include <linux/anon_inodes.h>
19 #include <linux/mm.h>
20 #include <linux/io.h>
21 #include <linux/cpuhotplug.h>
22 #include <linux/random.h>
23 #include <asm/mshyperv.h>
24 #include <linux/hyperv.h>
25 #include <linux/notifier.h>
26 #include <linux/reboot.h>
27 #include <linux/kexec.h>
28 #include <linux/page-flags.h>
29 #include <linux/crash_dump.h>
30 #include <linux/panic_notifier.h>
31 #include <linux/vmalloc.h>
32 #include <linux/rseq.h>
33 
34 #include "mshv_eventfd.h"
35 #include "mshv.h"
36 #include "mshv_root.h"
37 
38 MODULE_AUTHOR("Microsoft");
39 MODULE_LICENSE("GPL");
40 MODULE_DESCRIPTION("Microsoft Hyper-V root partition VMM interface /dev/mshv");
41 
42 /* TODO move this to another file when debugfs code is added */
43 enum hv_stats_vp_counters {			/* HV_THREAD_COUNTER */
44 #if defined(CONFIG_X86)
45 	VpRootDispatchThreadBlocked			= 201,
46 #elif defined(CONFIG_ARM64)
47 	VpRootDispatchThreadBlocked			= 94,
48 #endif
49 	VpStatsMaxCounter
50 };
51 
52 struct hv_stats_page {
53 	union {
54 		u64 vp_cntrs[VpStatsMaxCounter];		/* VP counters */
55 		u8 data[HV_HYP_PAGE_SIZE];
56 	};
57 } __packed;
58 
59 struct mshv_root mshv_root;
60 
61 enum hv_scheduler_type hv_scheduler_type;
62 
63 /* Once we implement the fast extended hypercall ABI they can go away. */
64 static void * __percpu *root_scheduler_input;
65 static void * __percpu *root_scheduler_output;
66 
67 static long mshv_dev_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg);
68 static int mshv_dev_open(struct inode *inode, struct file *filp);
69 static int mshv_dev_release(struct inode *inode, struct file *filp);
70 static int mshv_vp_release(struct inode *inode, struct file *filp);
71 static long mshv_vp_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg);
72 static int mshv_partition_release(struct inode *inode, struct file *filp);
73 static long mshv_partition_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg);
74 static int mshv_vp_mmap(struct file *file, struct vm_area_struct *vma);
75 static vm_fault_t mshv_vp_fault(struct vm_fault *vmf);
76 static int mshv_init_async_handler(struct mshv_partition *partition);
77 static void mshv_async_hvcall_handler(void *data, u64 *status);
78 
79 static const union hv_input_vtl input_vtl_zero;
80 static const union hv_input_vtl input_vtl_normal = {
81 	.target_vtl = HV_NORMAL_VTL,
82 	.use_target_vtl = 1,
83 };
84 
85 static const struct vm_operations_struct mshv_vp_vm_ops = {
86 	.fault = mshv_vp_fault,
87 };
88 
89 static const struct file_operations mshv_vp_fops = {
90 	.owner = THIS_MODULE,
91 	.release = mshv_vp_release,
92 	.unlocked_ioctl = mshv_vp_ioctl,
93 	.llseek = noop_llseek,
94 	.mmap = mshv_vp_mmap,
95 };
96 
97 static const struct file_operations mshv_partition_fops = {
98 	.owner = THIS_MODULE,
99 	.release = mshv_partition_release,
100 	.unlocked_ioctl = mshv_partition_ioctl,
101 	.llseek = noop_llseek,
102 };
103 
104 static const struct file_operations mshv_dev_fops = {
105 	.owner = THIS_MODULE,
106 	.open = mshv_dev_open,
107 	.release = mshv_dev_release,
108 	.unlocked_ioctl = mshv_dev_ioctl,
109 	.llseek = noop_llseek,
110 };
111 
112 static struct miscdevice mshv_dev = {
113 	.minor = MISC_DYNAMIC_MINOR,
114 	.name = "mshv",
115 	.fops = &mshv_dev_fops,
116 	.mode = 0600,
117 };
118 
119 /*
120  * Only allow hypercalls that have a u64 partition id as the first member of
121  * the input structure.
122  * These are sorted by value.
123  */
124 static u16 mshv_passthru_hvcalls[] = {
125 	HVCALL_GET_PARTITION_PROPERTY,
126 	HVCALL_SET_PARTITION_PROPERTY,
127 	HVCALL_INSTALL_INTERCEPT,
128 	HVCALL_GET_VP_REGISTERS,
129 	HVCALL_SET_VP_REGISTERS,
130 	HVCALL_TRANSLATE_VIRTUAL_ADDRESS,
131 	HVCALL_CLEAR_VIRTUAL_INTERRUPT,
132 	HVCALL_REGISTER_INTERCEPT_RESULT,
133 	HVCALL_ASSERT_VIRTUAL_INTERRUPT,
134 	HVCALL_GET_GPA_PAGES_ACCESS_STATES,
135 	HVCALL_SIGNAL_EVENT_DIRECT,
136 	HVCALL_POST_MESSAGE_DIRECT,
137 	HVCALL_GET_VP_CPUID_VALUES,
138 };
139 
140 static bool mshv_hvcall_is_async(u16 code)
141 {
142 	switch (code) {
143 	case HVCALL_SET_PARTITION_PROPERTY:
144 		return true;
145 	default:
146 		break;
147 	}
148 	return false;
149 }
150 
151 static int mshv_ioctl_passthru_hvcall(struct mshv_partition *partition,
152 				      bool partition_locked,
153 				      void __user *user_args)
154 {
155 	u64 status;
156 	int ret = 0, i;
157 	bool is_async;
158 	struct mshv_root_hvcall args;
159 	struct page *page;
160 	unsigned int pages_order;
161 	void *input_pg = NULL;
162 	void *output_pg = NULL;
163 
164 	if (copy_from_user(&args, user_args, sizeof(args)))
165 		return -EFAULT;
166 
167 	if (args.status || !args.in_ptr || args.in_sz < sizeof(u64) ||
168 	    mshv_field_nonzero(args, rsvd) || args.in_sz > HV_HYP_PAGE_SIZE)
169 		return -EINVAL;
170 
171 	if (args.out_ptr && (!args.out_sz || args.out_sz > HV_HYP_PAGE_SIZE))
172 		return -EINVAL;
173 
174 	for (i = 0; i < ARRAY_SIZE(mshv_passthru_hvcalls); ++i)
175 		if (args.code == mshv_passthru_hvcalls[i])
176 			break;
177 
178 	if (i >= ARRAY_SIZE(mshv_passthru_hvcalls))
179 		return -EINVAL;
180 
181 	is_async = mshv_hvcall_is_async(args.code);
182 	if (is_async) {
183 		/* async hypercalls can only be called from partition fd */
184 		if (!partition_locked)
185 			return -EINVAL;
186 		ret = mshv_init_async_handler(partition);
187 		if (ret)
188 			return ret;
189 	}
190 
191 	pages_order = args.out_ptr ? 1 : 0;
192 	page = alloc_pages(GFP_KERNEL, pages_order);
193 	if (!page)
194 		return -ENOMEM;
195 	input_pg = page_address(page);
196 
197 	if (args.out_ptr)
198 		output_pg = (char *)input_pg + PAGE_SIZE;
199 	else
200 		output_pg = NULL;
201 
202 	if (copy_from_user(input_pg, (void __user *)args.in_ptr,
203 			   args.in_sz)) {
204 		ret = -EFAULT;
205 		goto free_pages_out;
206 	}
207 
208 	/*
209 	 * NOTE: This only works because all the allowed hypercalls' input
210 	 * structs begin with a u64 partition_id field.
211 	 */
212 	*(u64 *)input_pg = partition->pt_id;
213 
214 	if (args.reps)
215 		status = hv_do_rep_hypercall(args.code, args.reps, 0,
216 					     input_pg, output_pg);
217 	else
218 		status = hv_do_hypercall(args.code, input_pg, output_pg);
219 
220 	if (hv_result(status) == HV_STATUS_CALL_PENDING) {
221 		if (is_async) {
222 			mshv_async_hvcall_handler(partition, &status);
223 		} else { /* Paranoia check. This shouldn't happen! */
224 			ret = -EBADFD;
225 			goto free_pages_out;
226 		}
227 	}
228 
229 	if (hv_result(status) == HV_STATUS_INSUFFICIENT_MEMORY) {
230 		ret = hv_call_deposit_pages(NUMA_NO_NODE, partition->pt_id, 1);
231 		if (!ret)
232 			ret = -EAGAIN;
233 	} else if (!hv_result_success(status)) {
234 		ret = hv_result_to_errno(status);
235 	}
236 
237 	/*
238 	 * Always return the status and output data regardless of result.
239 	 * The VMM may need it to determine how to proceed. E.g. the status may
240 	 * contain the number of reps completed if a rep hypercall partially
241 	 * succeeded.
242 	 */
243 	args.status = hv_result(status);
244 	args.reps = args.reps ? hv_repcomp(status) : 0;
245 	if (copy_to_user(user_args, &args, sizeof(args)))
246 		ret = -EFAULT;
247 
248 	if (output_pg &&
249 	    copy_to_user((void __user *)args.out_ptr, output_pg, args.out_sz))
250 		ret = -EFAULT;
251 
252 free_pages_out:
253 	free_pages((unsigned long)input_pg, pages_order);
254 
255 	return ret;
256 }
257 
258 static inline bool is_ghcb_mapping_available(void)
259 {
260 #if IS_ENABLED(CONFIG_X86_64)
261 	return ms_hyperv.ext_features & HV_VP_GHCB_ROOT_MAPPING_AVAILABLE;
262 #else
263 	return 0;
264 #endif
265 }
266 
267 static int mshv_get_vp_registers(u32 vp_index, u64 partition_id, u16 count,
268 				 struct hv_register_assoc *registers)
269 {
270 	return hv_call_get_vp_registers(vp_index, partition_id,
271 					count, input_vtl_zero, registers);
272 }
273 
274 static int mshv_set_vp_registers(u32 vp_index, u64 partition_id, u16 count,
275 				 struct hv_register_assoc *registers)
276 {
277 	return hv_call_set_vp_registers(vp_index, partition_id,
278 					count, input_vtl_zero, registers);
279 }
280 
281 /*
282  * Explicit guest vCPU suspend is asynchronous by nature (as it is requested by
283  * dom0 vCPU for guest vCPU) and thus it can race with "intercept" suspend,
284  * done by the hypervisor.
285  * "Intercept" suspend leads to asynchronous message delivery to dom0 which
286  * should be awaited to keep the VP loop consistent (i.e. no message pending
287  * upon VP resume).
288  * VP intercept suspend can't be done when the VP is explicitly suspended
289  * already, and thus can be only two possible race scenarios:
290  *   1. implicit suspend bit set -> explicit suspend bit set -> message sent
291  *   2. implicit suspend bit set -> message sent -> explicit suspend bit set
292  * Checking for implicit suspend bit set after explicit suspend request has
293  * succeeded in either case allows us to reliably identify, if there is a
294  * message to receive and deliver to VMM.
295  */
296 static int
297 mshv_suspend_vp(const struct mshv_vp *vp, bool *message_in_flight)
298 {
299 	struct hv_register_assoc explicit_suspend = {
300 		.name = HV_REGISTER_EXPLICIT_SUSPEND
301 	};
302 	struct hv_register_assoc intercept_suspend = {
303 		.name = HV_REGISTER_INTERCEPT_SUSPEND
304 	};
305 	union hv_explicit_suspend_register *es =
306 		&explicit_suspend.value.explicit_suspend;
307 	union hv_intercept_suspend_register *is =
308 		&intercept_suspend.value.intercept_suspend;
309 	int ret;
310 
311 	es->suspended = 1;
312 
313 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
314 				    1, &explicit_suspend);
315 	if (ret) {
316 		vp_err(vp, "Failed to explicitly suspend vCPU\n");
317 		return ret;
318 	}
319 
320 	ret = mshv_get_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
321 				    1, &intercept_suspend);
322 	if (ret) {
323 		vp_err(vp, "Failed to get intercept suspend state\n");
324 		return ret;
325 	}
326 
327 	*message_in_flight = is->suspended;
328 
329 	return 0;
330 }
331 
332 /*
333  * This function is used when VPs are scheduled by the hypervisor's
334  * scheduler.
335  *
336  * Caller has to make sure the registers contain cleared
337  * HV_REGISTER_INTERCEPT_SUSPEND and HV_REGISTER_EXPLICIT_SUSPEND registers
338  * exactly in this order (the hypervisor clears them sequentially) to avoid
339  * potential invalid clearing a newly arrived HV_REGISTER_INTERCEPT_SUSPEND
340  * after VP is released from HV_REGISTER_EXPLICIT_SUSPEND in case of the
341  * opposite order.
342  */
343 static long mshv_run_vp_with_hyp_scheduler(struct mshv_vp *vp)
344 {
345 	long ret;
346 	struct hv_register_assoc suspend_regs[2] = {
347 			{ .name = HV_REGISTER_INTERCEPT_SUSPEND },
348 			{ .name = HV_REGISTER_EXPLICIT_SUSPEND }
349 	};
350 	size_t count = ARRAY_SIZE(suspend_regs);
351 
352 	/* Resume VP execution */
353 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
354 				    count, suspend_regs);
355 	if (ret) {
356 		vp_err(vp, "Failed to resume vp execution. %lx\n", ret);
357 		return ret;
358 	}
359 
360 	ret = wait_event_interruptible(vp->run.vp_suspend_queue,
361 				       vp->run.kicked_by_hv == 1);
362 	if (ret) {
363 		bool message_in_flight;
364 
365 		/*
366 		 * Otherwise the waiting was interrupted by a signal: suspend
367 		 * the vCPU explicitly and copy message in flight (if any).
368 		 */
369 		ret = mshv_suspend_vp(vp, &message_in_flight);
370 		if (ret)
371 			return ret;
372 
373 		/* Return if no message in flight */
374 		if (!message_in_flight)
375 			return -EINTR;
376 
377 		/* Wait for the message in flight. */
378 		wait_event(vp->run.vp_suspend_queue, vp->run.kicked_by_hv == 1);
379 	}
380 
381 	/*
382 	 * Reset the flag to make the wait_event call above work
383 	 * next time.
384 	 */
385 	vp->run.kicked_by_hv = 0;
386 
387 	return 0;
388 }
389 
390 static int
391 mshv_vp_dispatch(struct mshv_vp *vp, u32 flags,
392 		 struct hv_output_dispatch_vp *res)
393 {
394 	struct hv_input_dispatch_vp *input;
395 	struct hv_output_dispatch_vp *output;
396 	u64 status;
397 
398 	preempt_disable();
399 	input = *this_cpu_ptr(root_scheduler_input);
400 	output = *this_cpu_ptr(root_scheduler_output);
401 
402 	memset(input, 0, sizeof(*input));
403 	memset(output, 0, sizeof(*output));
404 
405 	input->partition_id = vp->vp_partition->pt_id;
406 	input->vp_index = vp->vp_index;
407 	input->time_slice = 0; /* Run forever until something happens */
408 	input->spec_ctrl = 0; /* TODO: set sensible flags */
409 	input->flags = flags;
410 
411 	vp->run.flags.root_sched_dispatched = 1;
412 	status = hv_do_hypercall(HVCALL_DISPATCH_VP, input, output);
413 	vp->run.flags.root_sched_dispatched = 0;
414 
415 	*res = *output;
416 	preempt_enable();
417 
418 	if (!hv_result_success(status))
419 		vp_err(vp, "%s: status %s\n", __func__,
420 		       hv_result_to_string(status));
421 
422 	return hv_result_to_errno(status);
423 }
424 
425 static int
426 mshv_vp_clear_explicit_suspend(struct mshv_vp *vp)
427 {
428 	struct hv_register_assoc explicit_suspend = {
429 		.name = HV_REGISTER_EXPLICIT_SUSPEND,
430 		.value.explicit_suspend.suspended = 0,
431 	};
432 	int ret;
433 
434 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
435 				    1, &explicit_suspend);
436 
437 	if (ret)
438 		vp_err(vp, "Failed to unsuspend\n");
439 
440 	return ret;
441 }
442 
443 #if IS_ENABLED(CONFIG_X86_64)
444 static u64 mshv_vp_interrupt_pending(struct mshv_vp *vp)
445 {
446 	if (!vp->vp_register_page)
447 		return 0;
448 	return vp->vp_register_page->interrupt_vectors.as_uint64;
449 }
450 #else
451 static u64 mshv_vp_interrupt_pending(struct mshv_vp *vp)
452 {
453 	return 0;
454 }
455 #endif
456 
457 static bool mshv_vp_dispatch_thread_blocked(struct mshv_vp *vp)
458 {
459 	struct hv_stats_page **stats = vp->vp_stats_pages;
460 	u64 *self_vp_cntrs = stats[HV_STATS_AREA_SELF]->vp_cntrs;
461 	u64 *parent_vp_cntrs = stats[HV_STATS_AREA_PARENT]->vp_cntrs;
462 
463 	if (self_vp_cntrs[VpRootDispatchThreadBlocked])
464 		return self_vp_cntrs[VpRootDispatchThreadBlocked];
465 	return parent_vp_cntrs[VpRootDispatchThreadBlocked];
466 }
467 
468 static int
469 mshv_vp_wait_for_hv_kick(struct mshv_vp *vp)
470 {
471 	int ret;
472 
473 	ret = wait_event_interruptible(vp->run.vp_suspend_queue,
474 				       (vp->run.kicked_by_hv == 1 &&
475 					!mshv_vp_dispatch_thread_blocked(vp)) ||
476 				       mshv_vp_interrupt_pending(vp));
477 	if (ret)
478 		return -EINTR;
479 
480 	vp->run.flags.root_sched_blocked = 0;
481 	vp->run.kicked_by_hv = 0;
482 
483 	return 0;
484 }
485 
486 /* Must be called with interrupts enabled */
487 static long mshv_run_vp_with_root_scheduler(struct mshv_vp *vp)
488 {
489 	long ret;
490 
491 	if (vp->run.flags.root_sched_blocked) {
492 		/*
493 		 * Dispatch state of this VP is blocked. Need to wait
494 		 * for the hypervisor to clear the blocked state before
495 		 * dispatching it.
496 		 */
497 		ret = mshv_vp_wait_for_hv_kick(vp);
498 		if (ret)
499 			return ret;
500 	}
501 
502 	do {
503 		u32 flags = 0;
504 		struct hv_output_dispatch_vp output;
505 
506 		if (__xfer_to_guest_mode_work_pending()) {
507 			ret = xfer_to_guest_mode_handle_work();
508 			if (ret)
509 				break;
510 		}
511 
512 		if (vp->run.flags.intercept_suspend)
513 			flags |= HV_DISPATCH_VP_FLAG_CLEAR_INTERCEPT_SUSPEND;
514 
515 		if (mshv_vp_interrupt_pending(vp))
516 			flags |= HV_DISPATCH_VP_FLAG_SCAN_INTERRUPT_INJECTION;
517 
518 		ret = mshv_vp_dispatch(vp, flags, &output);
519 		if (ret)
520 			break;
521 
522 		vp->run.flags.intercept_suspend = 0;
523 
524 		if (output.dispatch_state == HV_VP_DISPATCH_STATE_BLOCKED) {
525 			if (output.dispatch_event ==
526 						HV_VP_DISPATCH_EVENT_SUSPEND) {
527 				/*
528 				 * TODO: remove the warning once VP canceling
529 				 *	 is supported
530 				 */
531 				WARN_ONCE(atomic64_read(&vp->run.vp_signaled_count),
532 					  "%s: vp#%d: unexpected explicit suspend\n",
533 					  __func__, vp->vp_index);
534 				/*
535 				 * Need to clear explicit suspend before
536 				 * dispatching.
537 				 * Explicit suspend is either:
538 				 * - set right after the first VP dispatch or
539 				 * - set explicitly via hypercall
540 				 * Since the latter case is not yet supported,
541 				 * simply clear it here.
542 				 */
543 				ret = mshv_vp_clear_explicit_suspend(vp);
544 				if (ret)
545 					break;
546 
547 				ret = mshv_vp_wait_for_hv_kick(vp);
548 				if (ret)
549 					break;
550 			} else {
551 				vp->run.flags.root_sched_blocked = 1;
552 				ret = mshv_vp_wait_for_hv_kick(vp);
553 				if (ret)
554 					break;
555 			}
556 		} else {
557 			/* HV_VP_DISPATCH_STATE_READY */
558 			if (output.dispatch_event ==
559 						HV_VP_DISPATCH_EVENT_INTERCEPT)
560 				vp->run.flags.intercept_suspend = 1;
561 		}
562 	} while (!vp->run.flags.intercept_suspend);
563 
564 	rseq_virt_userspace_exit();
565 
566 	return ret;
567 }
568 
569 static_assert(sizeof(struct hv_message) <= MSHV_RUN_VP_BUF_SZ,
570 	      "sizeof(struct hv_message) must not exceed MSHV_RUN_VP_BUF_SZ");
571 
572 static long mshv_vp_ioctl_run_vp(struct mshv_vp *vp, void __user *ret_msg)
573 {
574 	long rc;
575 
576 	if (hv_scheduler_type == HV_SCHEDULER_TYPE_ROOT)
577 		rc = mshv_run_vp_with_root_scheduler(vp);
578 	else
579 		rc = mshv_run_vp_with_hyp_scheduler(vp);
580 
581 	if (rc)
582 		return rc;
583 
584 	if (copy_to_user(ret_msg, vp->vp_intercept_msg_page,
585 			 sizeof(struct hv_message)))
586 		rc = -EFAULT;
587 
588 	return rc;
589 }
590 
591 static int
592 mshv_vp_ioctl_get_set_state_pfn(struct mshv_vp *vp,
593 				struct hv_vp_state_data state_data,
594 				unsigned long user_pfn, size_t page_count,
595 				bool is_set)
596 {
597 	int completed, ret = 0;
598 	unsigned long check;
599 	struct page **pages;
600 
601 	if (page_count > INT_MAX)
602 		return -EINVAL;
603 	/*
604 	 * Check the arithmetic for wraparound/overflow.
605 	 * The last page address in the buffer is:
606 	 * (user_pfn + (page_count - 1)) * PAGE_SIZE
607 	 */
608 	if (check_add_overflow(user_pfn, (page_count - 1), &check))
609 		return -EOVERFLOW;
610 	if (check_mul_overflow(check, PAGE_SIZE, &check))
611 		return -EOVERFLOW;
612 
613 	/* Pin user pages so hypervisor can copy directly to them */
614 	pages = kcalloc(page_count, sizeof(struct page *), GFP_KERNEL);
615 	if (!pages)
616 		return -ENOMEM;
617 
618 	for (completed = 0; completed < page_count; completed += ret) {
619 		unsigned long user_addr = (user_pfn + completed) * PAGE_SIZE;
620 		int remaining = page_count - completed;
621 
622 		ret = pin_user_pages_fast(user_addr, remaining, FOLL_WRITE,
623 					  &pages[completed]);
624 		if (ret < 0) {
625 			vp_err(vp, "%s: Failed to pin user pages error %i\n",
626 			       __func__, ret);
627 			goto unpin_pages;
628 		}
629 	}
630 
631 	if (is_set)
632 		ret = hv_call_set_vp_state(vp->vp_index,
633 					   vp->vp_partition->pt_id,
634 					   state_data, page_count, pages,
635 					   0, NULL);
636 	else
637 		ret = hv_call_get_vp_state(vp->vp_index,
638 					   vp->vp_partition->pt_id,
639 					   state_data, page_count, pages,
640 					   NULL);
641 
642 unpin_pages:
643 	unpin_user_pages(pages, completed);
644 	kfree(pages);
645 	return ret;
646 }
647 
648 static long
649 mshv_vp_ioctl_get_set_state(struct mshv_vp *vp,
650 			    struct mshv_get_set_vp_state __user *user_args,
651 			    bool is_set)
652 {
653 	struct mshv_get_set_vp_state args;
654 	long ret = 0;
655 	union hv_output_get_vp_state vp_state;
656 	u32 data_sz;
657 	struct hv_vp_state_data state_data = {};
658 
659 	if (copy_from_user(&args, user_args, sizeof(args)))
660 		return -EFAULT;
661 
662 	if (args.type >= MSHV_VP_STATE_COUNT || mshv_field_nonzero(args, rsvd) ||
663 	    !args.buf_sz || !PAGE_ALIGNED(args.buf_sz) ||
664 	    !PAGE_ALIGNED(args.buf_ptr))
665 		return -EINVAL;
666 
667 	if (!access_ok((void __user *)args.buf_ptr, args.buf_sz))
668 		return -EFAULT;
669 
670 	switch (args.type) {
671 	case MSHV_VP_STATE_LAPIC:
672 		state_data.type = HV_GET_SET_VP_STATE_LAPIC_STATE;
673 		data_sz = HV_HYP_PAGE_SIZE;
674 		break;
675 	case MSHV_VP_STATE_XSAVE:
676 	{
677 		u64 data_sz_64;
678 
679 		ret = hv_call_get_partition_property(vp->vp_partition->pt_id,
680 						     HV_PARTITION_PROPERTY_XSAVE_STATES,
681 						     &state_data.xsave.states.as_uint64);
682 		if (ret)
683 			return ret;
684 
685 		ret = hv_call_get_partition_property(vp->vp_partition->pt_id,
686 						     HV_PARTITION_PROPERTY_MAX_XSAVE_DATA_SIZE,
687 						     &data_sz_64);
688 		if (ret)
689 			return ret;
690 
691 		data_sz = (u32)data_sz_64;
692 		state_data.xsave.flags = 0;
693 		/* Always request legacy states */
694 		state_data.xsave.states.legacy_x87 = 1;
695 		state_data.xsave.states.legacy_sse = 1;
696 		state_data.type = HV_GET_SET_VP_STATE_XSAVE;
697 		break;
698 	}
699 	case MSHV_VP_STATE_SIMP:
700 		state_data.type = HV_GET_SET_VP_STATE_SIM_PAGE;
701 		data_sz = HV_HYP_PAGE_SIZE;
702 		break;
703 	case MSHV_VP_STATE_SIEFP:
704 		state_data.type = HV_GET_SET_VP_STATE_SIEF_PAGE;
705 		data_sz = HV_HYP_PAGE_SIZE;
706 		break;
707 	case MSHV_VP_STATE_SYNTHETIC_TIMERS:
708 		state_data.type = HV_GET_SET_VP_STATE_SYNTHETIC_TIMERS;
709 		data_sz = sizeof(vp_state.synthetic_timers_state);
710 		break;
711 	default:
712 		return -EINVAL;
713 	}
714 
715 	if (copy_to_user(&user_args->buf_sz, &data_sz, sizeof(user_args->buf_sz)))
716 		return -EFAULT;
717 
718 	if (data_sz > args.buf_sz)
719 		return -EINVAL;
720 
721 	/* If the data is transmitted via pfns, delegate to helper */
722 	if (state_data.type & HV_GET_SET_VP_STATE_TYPE_PFN) {
723 		unsigned long user_pfn = PFN_DOWN(args.buf_ptr);
724 		size_t page_count = PFN_DOWN(args.buf_sz);
725 
726 		return mshv_vp_ioctl_get_set_state_pfn(vp, state_data, user_pfn,
727 						       page_count, is_set);
728 	}
729 
730 	/* Paranoia check - this shouldn't happen! */
731 	if (data_sz > sizeof(vp_state)) {
732 		vp_err(vp, "Invalid vp state data size!\n");
733 		return -EINVAL;
734 	}
735 
736 	if (is_set) {
737 		if (copy_from_user(&vp_state, (__user void *)args.buf_ptr, data_sz))
738 			return -EFAULT;
739 
740 		return hv_call_set_vp_state(vp->vp_index,
741 					    vp->vp_partition->pt_id,
742 					    state_data, 0, NULL,
743 					    sizeof(vp_state), (u8 *)&vp_state);
744 	}
745 
746 	ret = hv_call_get_vp_state(vp->vp_index, vp->vp_partition->pt_id,
747 				   state_data, 0, NULL, &vp_state);
748 	if (ret)
749 		return ret;
750 
751 	if (copy_to_user((void __user *)args.buf_ptr, &vp_state, data_sz))
752 		return -EFAULT;
753 
754 	return 0;
755 }
756 
757 static long
758 mshv_vp_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg)
759 {
760 	struct mshv_vp *vp = filp->private_data;
761 	long r = -ENOTTY;
762 
763 	if (mutex_lock_killable(&vp->vp_mutex))
764 		return -EINTR;
765 
766 	switch (ioctl) {
767 	case MSHV_RUN_VP:
768 		r = mshv_vp_ioctl_run_vp(vp, (void __user *)arg);
769 		break;
770 	case MSHV_GET_VP_STATE:
771 		r = mshv_vp_ioctl_get_set_state(vp, (void __user *)arg, false);
772 		break;
773 	case MSHV_SET_VP_STATE:
774 		r = mshv_vp_ioctl_get_set_state(vp, (void __user *)arg, true);
775 		break;
776 	case MSHV_ROOT_HVCALL:
777 		r = mshv_ioctl_passthru_hvcall(vp->vp_partition, false,
778 					       (void __user *)arg);
779 		break;
780 	default:
781 		vp_warn(vp, "Invalid ioctl: %#x\n", ioctl);
782 		break;
783 	}
784 	mutex_unlock(&vp->vp_mutex);
785 
786 	return r;
787 }
788 
789 static vm_fault_t mshv_vp_fault(struct vm_fault *vmf)
790 {
791 	struct mshv_vp *vp = vmf->vma->vm_file->private_data;
792 
793 	switch (vmf->vma->vm_pgoff) {
794 	case MSHV_VP_MMAP_OFFSET_REGISTERS:
795 		vmf->page = virt_to_page(vp->vp_register_page);
796 		break;
797 	case MSHV_VP_MMAP_OFFSET_INTERCEPT_MESSAGE:
798 		vmf->page = virt_to_page(vp->vp_intercept_msg_page);
799 		break;
800 	case MSHV_VP_MMAP_OFFSET_GHCB:
801 		vmf->page = virt_to_page(vp->vp_ghcb_page);
802 		break;
803 	default:
804 		return VM_FAULT_SIGBUS;
805 	}
806 
807 	get_page(vmf->page);
808 
809 	return 0;
810 }
811 
812 static int mshv_vp_mmap(struct file *file, struct vm_area_struct *vma)
813 {
814 	struct mshv_vp *vp = file->private_data;
815 
816 	switch (vma->vm_pgoff) {
817 	case MSHV_VP_MMAP_OFFSET_REGISTERS:
818 		if (!vp->vp_register_page)
819 			return -ENODEV;
820 		break;
821 	case MSHV_VP_MMAP_OFFSET_INTERCEPT_MESSAGE:
822 		if (!vp->vp_intercept_msg_page)
823 			return -ENODEV;
824 		break;
825 	case MSHV_VP_MMAP_OFFSET_GHCB:
826 		if (!vp->vp_ghcb_page)
827 			return -ENODEV;
828 		break;
829 	default:
830 		return -EINVAL;
831 	}
832 
833 	vma->vm_ops = &mshv_vp_vm_ops;
834 	return 0;
835 }
836 
837 static int
838 mshv_vp_release(struct inode *inode, struct file *filp)
839 {
840 	struct mshv_vp *vp = filp->private_data;
841 
842 	/* Rest of VP cleanup happens in destroy_partition() */
843 	mshv_partition_put(vp->vp_partition);
844 	return 0;
845 }
846 
847 static void mshv_vp_stats_unmap(u64 partition_id, u32 vp_index)
848 {
849 	union hv_stats_object_identity identity = {
850 		.vp.partition_id = partition_id,
851 		.vp.vp_index = vp_index,
852 	};
853 
854 	identity.vp.stats_area_type = HV_STATS_AREA_SELF;
855 	hv_call_unmap_stat_page(HV_STATS_OBJECT_VP, &identity);
856 
857 	identity.vp.stats_area_type = HV_STATS_AREA_PARENT;
858 	hv_call_unmap_stat_page(HV_STATS_OBJECT_VP, &identity);
859 }
860 
861 static int mshv_vp_stats_map(u64 partition_id, u32 vp_index,
862 			     void *stats_pages[])
863 {
864 	union hv_stats_object_identity identity = {
865 		.vp.partition_id = partition_id,
866 		.vp.vp_index = vp_index,
867 	};
868 	int err;
869 
870 	identity.vp.stats_area_type = HV_STATS_AREA_SELF;
871 	err = hv_call_map_stat_page(HV_STATS_OBJECT_VP, &identity,
872 				    &stats_pages[HV_STATS_AREA_SELF]);
873 	if (err)
874 		return err;
875 
876 	identity.vp.stats_area_type = HV_STATS_AREA_PARENT;
877 	err = hv_call_map_stat_page(HV_STATS_OBJECT_VP, &identity,
878 				    &stats_pages[HV_STATS_AREA_PARENT]);
879 	if (err)
880 		goto unmap_self;
881 
882 	return 0;
883 
884 unmap_self:
885 	identity.vp.stats_area_type = HV_STATS_AREA_SELF;
886 	hv_call_unmap_stat_page(HV_STATS_OBJECT_VP, &identity);
887 	return err;
888 }
889 
890 static long
891 mshv_partition_ioctl_create_vp(struct mshv_partition *partition,
892 			       void __user *arg)
893 {
894 	struct mshv_create_vp args;
895 	struct mshv_vp *vp;
896 	struct page *intercept_message_page, *register_page, *ghcb_page;
897 	void *stats_pages[2];
898 	long ret;
899 
900 	if (copy_from_user(&args, arg, sizeof(args)))
901 		return -EFAULT;
902 
903 	if (args.vp_index >= MSHV_MAX_VPS)
904 		return -EINVAL;
905 
906 	if (partition->pt_vp_array[args.vp_index])
907 		return -EEXIST;
908 
909 	ret = hv_call_create_vp(NUMA_NO_NODE, partition->pt_id, args.vp_index,
910 				0 /* Only valid for root partition VPs */);
911 	if (ret)
912 		return ret;
913 
914 	ret = hv_call_map_vp_state_page(partition->pt_id, args.vp_index,
915 					HV_VP_STATE_PAGE_INTERCEPT_MESSAGE,
916 					input_vtl_zero,
917 					&intercept_message_page);
918 	if (ret)
919 		goto destroy_vp;
920 
921 	if (!mshv_partition_encrypted(partition)) {
922 		ret = hv_call_map_vp_state_page(partition->pt_id, args.vp_index,
923 						HV_VP_STATE_PAGE_REGISTERS,
924 						input_vtl_zero,
925 						&register_page);
926 		if (ret)
927 			goto unmap_intercept_message_page;
928 	}
929 
930 	if (mshv_partition_encrypted(partition) &&
931 	    is_ghcb_mapping_available()) {
932 		ret = hv_call_map_vp_state_page(partition->pt_id, args.vp_index,
933 						HV_VP_STATE_PAGE_GHCB,
934 						input_vtl_normal,
935 						&ghcb_page);
936 		if (ret)
937 			goto unmap_register_page;
938 	}
939 
940 	if (hv_parent_partition()) {
941 		ret = mshv_vp_stats_map(partition->pt_id, args.vp_index,
942 					stats_pages);
943 		if (ret)
944 			goto unmap_ghcb_page;
945 	}
946 
947 	vp = kzalloc(sizeof(*vp), GFP_KERNEL);
948 	if (!vp)
949 		goto unmap_stats_pages;
950 
951 	vp->vp_partition = mshv_partition_get(partition);
952 	if (!vp->vp_partition) {
953 		ret = -EBADF;
954 		goto free_vp;
955 	}
956 
957 	mutex_init(&vp->vp_mutex);
958 	init_waitqueue_head(&vp->run.vp_suspend_queue);
959 	atomic64_set(&vp->run.vp_signaled_count, 0);
960 
961 	vp->vp_index = args.vp_index;
962 	vp->vp_intercept_msg_page = page_to_virt(intercept_message_page);
963 	if (!mshv_partition_encrypted(partition))
964 		vp->vp_register_page = page_to_virt(register_page);
965 
966 	if (mshv_partition_encrypted(partition) && is_ghcb_mapping_available())
967 		vp->vp_ghcb_page = page_to_virt(ghcb_page);
968 
969 	if (hv_parent_partition())
970 		memcpy(vp->vp_stats_pages, stats_pages, sizeof(stats_pages));
971 
972 	/*
973 	 * Keep anon_inode_getfd last: it installs fd in the file struct and
974 	 * thus makes the state accessible in user space.
975 	 */
976 	ret = anon_inode_getfd("mshv_vp", &mshv_vp_fops, vp,
977 			       O_RDWR | O_CLOEXEC);
978 	if (ret < 0)
979 		goto put_partition;
980 
981 	/* already exclusive with the partition mutex for all ioctls */
982 	partition->pt_vp_count++;
983 	partition->pt_vp_array[args.vp_index] = vp;
984 
985 	return ret;
986 
987 put_partition:
988 	mshv_partition_put(partition);
989 free_vp:
990 	kfree(vp);
991 unmap_stats_pages:
992 	if (hv_parent_partition())
993 		mshv_vp_stats_unmap(partition->pt_id, args.vp_index);
994 unmap_ghcb_page:
995 	if (mshv_partition_encrypted(partition) && is_ghcb_mapping_available()) {
996 		hv_call_unmap_vp_state_page(partition->pt_id, args.vp_index,
997 					    HV_VP_STATE_PAGE_GHCB,
998 					    input_vtl_normal);
999 	}
1000 unmap_register_page:
1001 	if (!mshv_partition_encrypted(partition)) {
1002 		hv_call_unmap_vp_state_page(partition->pt_id, args.vp_index,
1003 					    HV_VP_STATE_PAGE_REGISTERS,
1004 					    input_vtl_zero);
1005 	}
1006 unmap_intercept_message_page:
1007 	hv_call_unmap_vp_state_page(partition->pt_id, args.vp_index,
1008 				    HV_VP_STATE_PAGE_INTERCEPT_MESSAGE,
1009 				    input_vtl_zero);
1010 destroy_vp:
1011 	hv_call_delete_vp(partition->pt_id, args.vp_index);
1012 	return ret;
1013 }
1014 
1015 static int mshv_init_async_handler(struct mshv_partition *partition)
1016 {
1017 	if (completion_done(&partition->async_hypercall)) {
1018 		pt_err(partition,
1019 		       "Cannot issue async hypercall while another one in progress!\n");
1020 		return -EPERM;
1021 	}
1022 
1023 	reinit_completion(&partition->async_hypercall);
1024 	return 0;
1025 }
1026 
1027 static void mshv_async_hvcall_handler(void *data, u64 *status)
1028 {
1029 	struct mshv_partition *partition = data;
1030 
1031 	wait_for_completion(&partition->async_hypercall);
1032 	pt_dbg(partition, "Async hypercall completed!\n");
1033 
1034 	*status = partition->async_hypercall_status;
1035 }
1036 
1037 static int
1038 mshv_partition_region_share(struct mshv_mem_region *region)
1039 {
1040 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
1041 
1042 	if (region->flags.large_pages)
1043 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
1044 
1045 	return hv_call_modify_spa_host_access(region->partition->pt_id,
1046 			region->pages, region->nr_pages,
1047 			HV_MAP_GPA_READABLE | HV_MAP_GPA_WRITABLE,
1048 			flags, true);
1049 }
1050 
1051 static int
1052 mshv_partition_region_unshare(struct mshv_mem_region *region)
1053 {
1054 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
1055 
1056 	if (region->flags.large_pages)
1057 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
1058 
1059 	return hv_call_modify_spa_host_access(region->partition->pt_id,
1060 			region->pages, region->nr_pages,
1061 			0,
1062 			flags, false);
1063 }
1064 
1065 static int
1066 mshv_region_remap_pages(struct mshv_mem_region *region, u32 map_flags,
1067 			u64 page_offset, u64 page_count)
1068 {
1069 	if (page_offset + page_count > region->nr_pages)
1070 		return -EINVAL;
1071 
1072 	if (region->flags.large_pages)
1073 		map_flags |= HV_MAP_GPA_LARGE_PAGE;
1074 
1075 	/* ask the hypervisor to map guest ram */
1076 	return hv_call_map_gpa_pages(region->partition->pt_id,
1077 				     region->start_gfn + page_offset,
1078 				     page_count, map_flags,
1079 				     region->pages + page_offset);
1080 }
1081 
1082 static int
1083 mshv_region_map(struct mshv_mem_region *region)
1084 {
1085 	u32 map_flags = region->hv_map_flags;
1086 
1087 	return mshv_region_remap_pages(region, map_flags,
1088 				       0, region->nr_pages);
1089 }
1090 
1091 static void
1092 mshv_region_evict_pages(struct mshv_mem_region *region,
1093 			u64 page_offset, u64 page_count)
1094 {
1095 	if (region->flags.range_pinned)
1096 		unpin_user_pages(region->pages + page_offset, page_count);
1097 
1098 	memset(region->pages + page_offset, 0,
1099 	       page_count * sizeof(struct page *));
1100 }
1101 
1102 static void
1103 mshv_region_evict(struct mshv_mem_region *region)
1104 {
1105 	mshv_region_evict_pages(region, 0, region->nr_pages);
1106 }
1107 
1108 static int
1109 mshv_region_populate_pages(struct mshv_mem_region *region,
1110 			   u64 page_offset, u64 page_count)
1111 {
1112 	u64 done_count, nr_pages;
1113 	struct page **pages;
1114 	__u64 userspace_addr;
1115 	int ret;
1116 
1117 	if (page_offset + page_count > region->nr_pages)
1118 		return -EINVAL;
1119 
1120 	for (done_count = 0; done_count < page_count; done_count += ret) {
1121 		pages = region->pages + page_offset + done_count;
1122 		userspace_addr = region->start_uaddr +
1123 				(page_offset + done_count) *
1124 				HV_HYP_PAGE_SIZE;
1125 		nr_pages = min(page_count - done_count,
1126 			       MSHV_PIN_PAGES_BATCH_SIZE);
1127 
1128 		/*
1129 		 * Pinning assuming 4k pages works for large pages too.
1130 		 * All page structs within the large page are returned.
1131 		 *
1132 		 * Pin requests are batched because pin_user_pages_fast
1133 		 * with the FOLL_LONGTERM flag does a large temporary
1134 		 * allocation of contiguous memory.
1135 		 */
1136 		if (region->flags.range_pinned)
1137 			ret = pin_user_pages_fast(userspace_addr,
1138 						  nr_pages,
1139 						  FOLL_WRITE | FOLL_LONGTERM,
1140 						  pages);
1141 		else
1142 			ret = -EOPNOTSUPP;
1143 
1144 		if (ret < 0)
1145 			goto release_pages;
1146 	}
1147 
1148 	if (PageHuge(region->pages[page_offset]))
1149 		region->flags.large_pages = true;
1150 
1151 	return 0;
1152 
1153 release_pages:
1154 	mshv_region_evict_pages(region, page_offset, done_count);
1155 	return ret;
1156 }
1157 
1158 static int
1159 mshv_region_populate(struct mshv_mem_region *region)
1160 {
1161 	return mshv_region_populate_pages(region, 0, region->nr_pages);
1162 }
1163 
1164 static struct mshv_mem_region *
1165 mshv_partition_region_by_gfn(struct mshv_partition *partition, u64 gfn)
1166 {
1167 	struct mshv_mem_region *region;
1168 
1169 	hlist_for_each_entry(region, &partition->pt_mem_regions, hnode) {
1170 		if (gfn >= region->start_gfn &&
1171 		    gfn < region->start_gfn + region->nr_pages)
1172 			return region;
1173 	}
1174 
1175 	return NULL;
1176 }
1177 
1178 static struct mshv_mem_region *
1179 mshv_partition_region_by_uaddr(struct mshv_partition *partition, u64 uaddr)
1180 {
1181 	struct mshv_mem_region *region;
1182 
1183 	hlist_for_each_entry(region, &partition->pt_mem_regions, hnode) {
1184 		if (uaddr >= region->start_uaddr &&
1185 		    uaddr < region->start_uaddr +
1186 			    (region->nr_pages << HV_HYP_PAGE_SHIFT))
1187 			return region;
1188 	}
1189 
1190 	return NULL;
1191 }
1192 
1193 /*
1194  * NB: caller checks and makes sure mem->size is page aligned
1195  * Returns: 0 with regionpp updated on success, or -errno
1196  */
1197 static int mshv_partition_create_region(struct mshv_partition *partition,
1198 					struct mshv_user_mem_region *mem,
1199 					struct mshv_mem_region **regionpp,
1200 					bool is_mmio)
1201 {
1202 	struct mshv_mem_region *region;
1203 	u64 nr_pages = HVPFN_DOWN(mem->size);
1204 
1205 	/* Reject overlapping regions */
1206 	if (mshv_partition_region_by_gfn(partition, mem->guest_pfn) ||
1207 	    mshv_partition_region_by_gfn(partition, mem->guest_pfn + nr_pages - 1) ||
1208 	    mshv_partition_region_by_uaddr(partition, mem->userspace_addr) ||
1209 	    mshv_partition_region_by_uaddr(partition, mem->userspace_addr + mem->size - 1))
1210 		return -EEXIST;
1211 
1212 	region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
1213 	if (!region)
1214 		return -ENOMEM;
1215 
1216 	region->nr_pages = nr_pages;
1217 	region->start_gfn = mem->guest_pfn;
1218 	region->start_uaddr = mem->userspace_addr;
1219 	region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
1220 	if (mem->flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
1221 		region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
1222 	if (mem->flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
1223 		region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
1224 
1225 	/* Note: large_pages flag populated when we pin the pages */
1226 	if (!is_mmio)
1227 		region->flags.range_pinned = true;
1228 
1229 	region->partition = partition;
1230 
1231 	*regionpp = region;
1232 
1233 	return 0;
1234 }
1235 
1236 /*
1237  * Map guest ram. if snp, make sure to release that from the host first
1238  * Side Effects: In case of failure, pages are unpinned when feasible.
1239  */
1240 static int
1241 mshv_partition_mem_region_map(struct mshv_mem_region *region)
1242 {
1243 	struct mshv_partition *partition = region->partition;
1244 	int ret;
1245 
1246 	ret = mshv_region_populate(region);
1247 	if (ret) {
1248 		pt_err(partition, "Failed to populate memory region: %d\n",
1249 		       ret);
1250 		goto err_out;
1251 	}
1252 
1253 	/*
1254 	 * For an SNP partition it is a requirement that for every memory region
1255 	 * that we are going to map for this partition we should make sure that
1256 	 * host access to that region is released. This is ensured by doing an
1257 	 * additional hypercall which will update the SLAT to release host
1258 	 * access to guest memory regions.
1259 	 */
1260 	if (mshv_partition_encrypted(partition)) {
1261 		ret = mshv_partition_region_unshare(region);
1262 		if (ret) {
1263 			pt_err(partition,
1264 			       "Failed to unshare memory region (guest_pfn: %llu): %d\n",
1265 			       region->start_gfn, ret);
1266 			goto evict_region;
1267 		}
1268 	}
1269 
1270 	ret = mshv_region_map(region);
1271 	if (ret && mshv_partition_encrypted(partition)) {
1272 		int shrc;
1273 
1274 		shrc = mshv_partition_region_share(region);
1275 		if (!shrc)
1276 			goto evict_region;
1277 
1278 		pt_err(partition,
1279 		       "Failed to share memory region (guest_pfn: %llu): %d\n",
1280 		       region->start_gfn, shrc);
1281 		/*
1282 		 * Don't unpin if marking shared failed because pages are no
1283 		 * longer mapped in the host, ie root, anymore.
1284 		 */
1285 		goto err_out;
1286 	}
1287 
1288 	return 0;
1289 
1290 evict_region:
1291 	mshv_region_evict(region);
1292 err_out:
1293 	return ret;
1294 }
1295 
1296 /*
1297  * This maps two things: guest RAM and for pci passthru mmio space.
1298  *
1299  * mmio:
1300  *  - vfio overloads vm_pgoff to store the mmio start pfn/spa.
1301  *  - Two things need to happen for mapping mmio range:
1302  *	1. mapped in the uaddr so VMM can access it.
1303  *	2. mapped in the hwpt (gfn <-> mmio phys addr) so guest can access it.
1304  *
1305  *   This function takes care of the second. The first one is managed by vfio,
1306  *   and hence is taken care of via vfio_pci_mmap_fault().
1307  */
1308 static long
1309 mshv_map_user_memory(struct mshv_partition *partition,
1310 		     struct mshv_user_mem_region mem)
1311 {
1312 	struct mshv_mem_region *region;
1313 	struct vm_area_struct *vma;
1314 	bool is_mmio;
1315 	ulong mmio_pfn;
1316 	long ret;
1317 
1318 	if (mem.flags & BIT(MSHV_SET_MEM_BIT_UNMAP) ||
1319 	    !access_ok((const void *)mem.userspace_addr, mem.size))
1320 		return -EINVAL;
1321 
1322 	mmap_read_lock(current->mm);
1323 	vma = vma_lookup(current->mm, mem.userspace_addr);
1324 	is_mmio = vma ? !!(vma->vm_flags & (VM_IO | VM_PFNMAP)) : 0;
1325 	mmio_pfn = is_mmio ? vma->vm_pgoff : 0;
1326 	mmap_read_unlock(current->mm);
1327 
1328 	if (!vma)
1329 		return -EINVAL;
1330 
1331 	ret = mshv_partition_create_region(partition, &mem, &region,
1332 					   is_mmio);
1333 	if (ret)
1334 		return ret;
1335 
1336 	if (is_mmio)
1337 		ret = hv_call_map_mmio_pages(partition->pt_id, mem.guest_pfn,
1338 					     mmio_pfn, HVPFN_DOWN(mem.size));
1339 	else
1340 		ret = mshv_partition_mem_region_map(region);
1341 
1342 	if (ret)
1343 		goto errout;
1344 
1345 	/* Install the new region */
1346 	hlist_add_head(&region->hnode, &partition->pt_mem_regions);
1347 
1348 	return 0;
1349 
1350 errout:
1351 	vfree(region);
1352 	return ret;
1353 }
1354 
1355 /* Called for unmapping both the guest ram and the mmio space */
1356 static long
1357 mshv_unmap_user_memory(struct mshv_partition *partition,
1358 		       struct mshv_user_mem_region mem)
1359 {
1360 	struct mshv_mem_region *region;
1361 	u32 unmap_flags = 0;
1362 
1363 	if (!(mem.flags & BIT(MSHV_SET_MEM_BIT_UNMAP)))
1364 		return -EINVAL;
1365 
1366 	region = mshv_partition_region_by_gfn(partition, mem.guest_pfn);
1367 	if (!region)
1368 		return -EINVAL;
1369 
1370 	/* Paranoia check */
1371 	if (region->start_uaddr != mem.userspace_addr ||
1372 	    region->start_gfn != mem.guest_pfn ||
1373 	    region->nr_pages != HVPFN_DOWN(mem.size))
1374 		return -EINVAL;
1375 
1376 	hlist_del(&region->hnode);
1377 
1378 	if (region->flags.large_pages)
1379 		unmap_flags |= HV_UNMAP_GPA_LARGE_PAGE;
1380 
1381 	/* ignore unmap failures and continue as process may be exiting */
1382 	hv_call_unmap_gpa_pages(partition->pt_id, region->start_gfn,
1383 				region->nr_pages, unmap_flags);
1384 
1385 	mshv_region_evict(region);
1386 
1387 	vfree(region);
1388 	return 0;
1389 }
1390 
1391 static long
1392 mshv_partition_ioctl_set_memory(struct mshv_partition *partition,
1393 				struct mshv_user_mem_region __user *user_mem)
1394 {
1395 	struct mshv_user_mem_region mem;
1396 
1397 	if (copy_from_user(&mem, user_mem, sizeof(mem)))
1398 		return -EFAULT;
1399 
1400 	if (!mem.size ||
1401 	    !PAGE_ALIGNED(mem.size) ||
1402 	    !PAGE_ALIGNED(mem.userspace_addr) ||
1403 	    (mem.flags & ~MSHV_SET_MEM_FLAGS_MASK) ||
1404 	    mshv_field_nonzero(mem, rsvd))
1405 		return -EINVAL;
1406 
1407 	if (mem.flags & BIT(MSHV_SET_MEM_BIT_UNMAP))
1408 		return mshv_unmap_user_memory(partition, mem);
1409 
1410 	return mshv_map_user_memory(partition, mem);
1411 }
1412 
1413 static long
1414 mshv_partition_ioctl_ioeventfd(struct mshv_partition *partition,
1415 			       void __user *user_args)
1416 {
1417 	struct mshv_user_ioeventfd args;
1418 
1419 	if (copy_from_user(&args, user_args, sizeof(args)))
1420 		return -EFAULT;
1421 
1422 	return mshv_set_unset_ioeventfd(partition, &args);
1423 }
1424 
1425 static long
1426 mshv_partition_ioctl_irqfd(struct mshv_partition *partition,
1427 			   void __user *user_args)
1428 {
1429 	struct mshv_user_irqfd args;
1430 
1431 	if (copy_from_user(&args, user_args, sizeof(args)))
1432 		return -EFAULT;
1433 
1434 	return mshv_set_unset_irqfd(partition, &args);
1435 }
1436 
1437 static long
1438 mshv_partition_ioctl_get_gpap_access_bitmap(struct mshv_partition *partition,
1439 					    void __user *user_args)
1440 {
1441 	struct mshv_gpap_access_bitmap args;
1442 	union hv_gpa_page_access_state *states;
1443 	long ret, i;
1444 	union hv_gpa_page_access_state_flags hv_flags = {};
1445 	u8 hv_type_mask;
1446 	ulong bitmap_buf_sz, states_buf_sz;
1447 	int written = 0;
1448 
1449 	if (copy_from_user(&args, user_args, sizeof(args)))
1450 		return -EFAULT;
1451 
1452 	if (args.access_type >= MSHV_GPAP_ACCESS_TYPE_COUNT ||
1453 	    args.access_op >= MSHV_GPAP_ACCESS_OP_COUNT ||
1454 	    mshv_field_nonzero(args, rsvd) || !args.page_count ||
1455 	    !args.bitmap_ptr)
1456 		return -EINVAL;
1457 
1458 	if (check_mul_overflow(args.page_count, sizeof(*states), &states_buf_sz))
1459 		return -E2BIG;
1460 
1461 	/* Num bytes needed to store bitmap; one bit per page rounded up */
1462 	bitmap_buf_sz = DIV_ROUND_UP(args.page_count, 8);
1463 
1464 	/* Sanity check */
1465 	if (bitmap_buf_sz > states_buf_sz)
1466 		return -EBADFD;
1467 
1468 	switch (args.access_type) {
1469 	case MSHV_GPAP_ACCESS_TYPE_ACCESSED:
1470 		hv_type_mask = 1;
1471 		if (args.access_op == MSHV_GPAP_ACCESS_OP_CLEAR) {
1472 			hv_flags.clear_accessed = 1;
1473 			/* not accessed implies not dirty */
1474 			hv_flags.clear_dirty = 1;
1475 		} else { /* MSHV_GPAP_ACCESS_OP_SET */
1476 			hv_flags.set_accessed = 1;
1477 		}
1478 		break;
1479 	case MSHV_GPAP_ACCESS_TYPE_DIRTY:
1480 		hv_type_mask = 2;
1481 		if (args.access_op == MSHV_GPAP_ACCESS_OP_CLEAR) {
1482 			hv_flags.clear_dirty = 1;
1483 		} else { /* MSHV_GPAP_ACCESS_OP_SET */
1484 			hv_flags.set_dirty = 1;
1485 			/* dirty implies accessed */
1486 			hv_flags.set_accessed = 1;
1487 		}
1488 		break;
1489 	}
1490 
1491 	states = vzalloc(states_buf_sz);
1492 	if (!states)
1493 		return -ENOMEM;
1494 
1495 	ret = hv_call_get_gpa_access_states(partition->pt_id, args.page_count,
1496 					    args.gpap_base, hv_flags, &written,
1497 					    states);
1498 	if (ret)
1499 		goto free_return;
1500 
1501 	/*
1502 	 * Overwrite states buffer with bitmap - the bits in hv_type_mask
1503 	 * correspond to bitfields in hv_gpa_page_access_state
1504 	 */
1505 	for (i = 0; i < written; ++i)
1506 		__assign_bit(i, (ulong *)states,
1507 			     states[i].as_uint8 & hv_type_mask);
1508 
1509 	/* zero the unused bits in the last byte(s) of the returned bitmap */
1510 	for (i = written; i < bitmap_buf_sz * 8; ++i)
1511 		__clear_bit(i, (ulong *)states);
1512 
1513 	if (copy_to_user((void __user *)args.bitmap_ptr, states, bitmap_buf_sz))
1514 		ret = -EFAULT;
1515 
1516 free_return:
1517 	vfree(states);
1518 	return ret;
1519 }
1520 
1521 static long
1522 mshv_partition_ioctl_set_msi_routing(struct mshv_partition *partition,
1523 				     void __user *user_args)
1524 {
1525 	struct mshv_user_irq_entry *entries = NULL;
1526 	struct mshv_user_irq_table args;
1527 	long ret;
1528 
1529 	if (copy_from_user(&args, user_args, sizeof(args)))
1530 		return -EFAULT;
1531 
1532 	if (args.nr > MSHV_MAX_GUEST_IRQS ||
1533 	    mshv_field_nonzero(args, rsvd))
1534 		return -EINVAL;
1535 
1536 	if (args.nr) {
1537 		struct mshv_user_irq_table __user *urouting = user_args;
1538 
1539 		entries = vmemdup_user(urouting->entries,
1540 				       array_size(sizeof(*entries),
1541 						  args.nr));
1542 		if (IS_ERR(entries))
1543 			return PTR_ERR(entries);
1544 	}
1545 	ret = mshv_update_routing_table(partition, entries, args.nr);
1546 	kvfree(entries);
1547 
1548 	return ret;
1549 }
1550 
1551 static long
1552 mshv_partition_ioctl_initialize(struct mshv_partition *partition)
1553 {
1554 	long ret;
1555 
1556 	if (partition->pt_initialized)
1557 		return 0;
1558 
1559 	ret = hv_call_initialize_partition(partition->pt_id);
1560 	if (ret)
1561 		goto withdraw_mem;
1562 
1563 	partition->pt_initialized = true;
1564 
1565 	return 0;
1566 
1567 withdraw_mem:
1568 	hv_call_withdraw_memory(U64_MAX, NUMA_NO_NODE, partition->pt_id);
1569 
1570 	return ret;
1571 }
1572 
1573 static long
1574 mshv_partition_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg)
1575 {
1576 	struct mshv_partition *partition = filp->private_data;
1577 	long ret;
1578 	void __user *uarg = (void __user *)arg;
1579 
1580 	if (mutex_lock_killable(&partition->pt_mutex))
1581 		return -EINTR;
1582 
1583 	switch (ioctl) {
1584 	case MSHV_INITIALIZE_PARTITION:
1585 		ret = mshv_partition_ioctl_initialize(partition);
1586 		break;
1587 	case MSHV_SET_GUEST_MEMORY:
1588 		ret = mshv_partition_ioctl_set_memory(partition, uarg);
1589 		break;
1590 	case MSHV_CREATE_VP:
1591 		ret = mshv_partition_ioctl_create_vp(partition, uarg);
1592 		break;
1593 	case MSHV_IRQFD:
1594 		ret = mshv_partition_ioctl_irqfd(partition, uarg);
1595 		break;
1596 	case MSHV_IOEVENTFD:
1597 		ret = mshv_partition_ioctl_ioeventfd(partition, uarg);
1598 		break;
1599 	case MSHV_SET_MSI_ROUTING:
1600 		ret = mshv_partition_ioctl_set_msi_routing(partition, uarg);
1601 		break;
1602 	case MSHV_GET_GPAP_ACCESS_BITMAP:
1603 		ret = mshv_partition_ioctl_get_gpap_access_bitmap(partition,
1604 								  uarg);
1605 		break;
1606 	case MSHV_ROOT_HVCALL:
1607 		ret = mshv_ioctl_passthru_hvcall(partition, true, uarg);
1608 		break;
1609 	default:
1610 		ret = -ENOTTY;
1611 	}
1612 
1613 	mutex_unlock(&partition->pt_mutex);
1614 	return ret;
1615 }
1616 
1617 static int
1618 disable_vp_dispatch(struct mshv_vp *vp)
1619 {
1620 	int ret;
1621 	struct hv_register_assoc dispatch_suspend = {
1622 		.name = HV_REGISTER_DISPATCH_SUSPEND,
1623 		.value.dispatch_suspend.suspended = 1,
1624 	};
1625 
1626 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
1627 				    1, &dispatch_suspend);
1628 	if (ret)
1629 		vp_err(vp, "failed to suspend\n");
1630 
1631 	return ret;
1632 }
1633 
1634 static int
1635 get_vp_signaled_count(struct mshv_vp *vp, u64 *count)
1636 {
1637 	int ret;
1638 	struct hv_register_assoc root_signal_count = {
1639 		.name = HV_REGISTER_VP_ROOT_SIGNAL_COUNT,
1640 	};
1641 
1642 	ret = mshv_get_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
1643 				    1, &root_signal_count);
1644 
1645 	if (ret) {
1646 		vp_err(vp, "Failed to get root signal count");
1647 		*count = 0;
1648 		return ret;
1649 	}
1650 
1651 	*count = root_signal_count.value.reg64;
1652 
1653 	return ret;
1654 }
1655 
1656 static void
1657 drain_vp_signals(struct mshv_vp *vp)
1658 {
1659 	u64 hv_signal_count;
1660 	u64 vp_signal_count;
1661 
1662 	get_vp_signaled_count(vp, &hv_signal_count);
1663 
1664 	vp_signal_count = atomic64_read(&vp->run.vp_signaled_count);
1665 
1666 	/*
1667 	 * There should be at most 1 outstanding notification, but be extra
1668 	 * careful anyway.
1669 	 */
1670 	while (hv_signal_count != vp_signal_count) {
1671 		WARN_ON(hv_signal_count - vp_signal_count != 1);
1672 
1673 		if (wait_event_interruptible(vp->run.vp_suspend_queue,
1674 					     vp->run.kicked_by_hv == 1))
1675 			break;
1676 		vp->run.kicked_by_hv = 0;
1677 		vp_signal_count = atomic64_read(&vp->run.vp_signaled_count);
1678 	}
1679 }
1680 
1681 static void drain_all_vps(const struct mshv_partition *partition)
1682 {
1683 	int i;
1684 	struct mshv_vp *vp;
1685 
1686 	/*
1687 	 * VPs are reachable from ISR. It is safe to not take the partition
1688 	 * lock because nobody else can enter this function and drop the
1689 	 * partition from the list.
1690 	 */
1691 	for (i = 0; i < MSHV_MAX_VPS; i++) {
1692 		vp = partition->pt_vp_array[i];
1693 		if (!vp)
1694 			continue;
1695 		/*
1696 		 * Disable dispatching of the VP in the hypervisor. After this
1697 		 * the hypervisor guarantees it won't generate any signals for
1698 		 * the VP and the hypervisor's VP signal count won't change.
1699 		 */
1700 		disable_vp_dispatch(vp);
1701 		drain_vp_signals(vp);
1702 	}
1703 }
1704 
1705 static void
1706 remove_partition(struct mshv_partition *partition)
1707 {
1708 	spin_lock(&mshv_root.pt_ht_lock);
1709 	hlist_del_rcu(&partition->pt_hnode);
1710 	spin_unlock(&mshv_root.pt_ht_lock);
1711 
1712 	synchronize_rcu();
1713 }
1714 
1715 /*
1716  * Tear down a partition and remove it from the list.
1717  * Partition's refcount must be 0
1718  */
1719 static void destroy_partition(struct mshv_partition *partition)
1720 {
1721 	struct mshv_vp *vp;
1722 	struct mshv_mem_region *region;
1723 	int i, ret;
1724 	struct hlist_node *n;
1725 
1726 	if (refcount_read(&partition->pt_ref_count)) {
1727 		pt_err(partition,
1728 		       "Attempt to destroy partition but refcount > 0\n");
1729 		return;
1730 	}
1731 
1732 	if (partition->pt_initialized) {
1733 		/*
1734 		 * We only need to drain signals for root scheduler. This should be
1735 		 * done before removing the partition from the partition list.
1736 		 */
1737 		if (hv_scheduler_type == HV_SCHEDULER_TYPE_ROOT)
1738 			drain_all_vps(partition);
1739 
1740 		/* Remove vps */
1741 		for (i = 0; i < MSHV_MAX_VPS; ++i) {
1742 			vp = partition->pt_vp_array[i];
1743 			if (!vp)
1744 				continue;
1745 
1746 			if (hv_parent_partition())
1747 				mshv_vp_stats_unmap(partition->pt_id, vp->vp_index);
1748 
1749 			if (vp->vp_register_page) {
1750 				(void)hv_call_unmap_vp_state_page(partition->pt_id,
1751 								  vp->vp_index,
1752 								  HV_VP_STATE_PAGE_REGISTERS,
1753 								  input_vtl_zero);
1754 				vp->vp_register_page = NULL;
1755 			}
1756 
1757 			(void)hv_call_unmap_vp_state_page(partition->pt_id,
1758 							  vp->vp_index,
1759 							  HV_VP_STATE_PAGE_INTERCEPT_MESSAGE,
1760 							  input_vtl_zero);
1761 			vp->vp_intercept_msg_page = NULL;
1762 
1763 			if (vp->vp_ghcb_page) {
1764 				(void)hv_call_unmap_vp_state_page(partition->pt_id,
1765 								  vp->vp_index,
1766 								  HV_VP_STATE_PAGE_GHCB,
1767 								  input_vtl_normal);
1768 				vp->vp_ghcb_page = NULL;
1769 			}
1770 
1771 			kfree(vp);
1772 
1773 			partition->pt_vp_array[i] = NULL;
1774 		}
1775 
1776 		/* Deallocates and unmaps everything including vcpus, GPA mappings etc */
1777 		hv_call_finalize_partition(partition->pt_id);
1778 
1779 		partition->pt_initialized = false;
1780 	}
1781 
1782 	remove_partition(partition);
1783 
1784 	/* Remove regions, regain access to the memory and unpin the pages */
1785 	hlist_for_each_entry_safe(region, n, &partition->pt_mem_regions,
1786 				  hnode) {
1787 		hlist_del(&region->hnode);
1788 
1789 		if (mshv_partition_encrypted(partition)) {
1790 			ret = mshv_partition_region_share(region);
1791 			if (ret) {
1792 				pt_err(partition,
1793 				       "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
1794 				      ret);
1795 				return;
1796 			}
1797 		}
1798 
1799 		mshv_region_evict(region);
1800 
1801 		vfree(region);
1802 	}
1803 
1804 	/* Withdraw and free all pages we deposited */
1805 	hv_call_withdraw_memory(U64_MAX, NUMA_NO_NODE, partition->pt_id);
1806 	hv_call_delete_partition(partition->pt_id);
1807 
1808 	mshv_free_routing_table(partition);
1809 	kfree(partition);
1810 }
1811 
1812 struct
1813 mshv_partition *mshv_partition_get(struct mshv_partition *partition)
1814 {
1815 	if (refcount_inc_not_zero(&partition->pt_ref_count))
1816 		return partition;
1817 	return NULL;
1818 }
1819 
1820 struct
1821 mshv_partition *mshv_partition_find(u64 partition_id)
1822 	__must_hold(RCU)
1823 {
1824 	struct mshv_partition *p;
1825 
1826 	hash_for_each_possible_rcu(mshv_root.pt_htable, p, pt_hnode,
1827 				   partition_id)
1828 		if (p->pt_id == partition_id)
1829 			return p;
1830 
1831 	return NULL;
1832 }
1833 
1834 void
1835 mshv_partition_put(struct mshv_partition *partition)
1836 {
1837 	if (refcount_dec_and_test(&partition->pt_ref_count))
1838 		destroy_partition(partition);
1839 }
1840 
1841 static int
1842 mshv_partition_release(struct inode *inode, struct file *filp)
1843 {
1844 	struct mshv_partition *partition = filp->private_data;
1845 
1846 	mshv_eventfd_release(partition);
1847 
1848 	cleanup_srcu_struct(&partition->pt_irq_srcu);
1849 
1850 	mshv_partition_put(partition);
1851 
1852 	return 0;
1853 }
1854 
1855 static int
1856 add_partition(struct mshv_partition *partition)
1857 {
1858 	spin_lock(&mshv_root.pt_ht_lock);
1859 
1860 	hash_add_rcu(mshv_root.pt_htable, &partition->pt_hnode,
1861 		     partition->pt_id);
1862 
1863 	spin_unlock(&mshv_root.pt_ht_lock);
1864 
1865 	return 0;
1866 }
1867 
1868 static long
1869 mshv_ioctl_create_partition(void __user *user_arg, struct device *module_dev)
1870 {
1871 	struct mshv_create_partition args;
1872 	u64 creation_flags;
1873 	struct hv_partition_creation_properties creation_properties = {};
1874 	union hv_partition_isolation_properties isolation_properties = {};
1875 	struct mshv_partition *partition;
1876 	long ret;
1877 
1878 	if (copy_from_user(&args, user_arg, sizeof(args)))
1879 		return -EFAULT;
1880 
1881 	if ((args.pt_flags & ~MSHV_PT_FLAGS_MASK) ||
1882 	    args.pt_isolation >= MSHV_PT_ISOLATION_COUNT)
1883 		return -EINVAL;
1884 
1885 	/* Only support EXO partitions */
1886 	creation_flags = HV_PARTITION_CREATION_FLAG_EXO_PARTITION |
1887 			 HV_PARTITION_CREATION_FLAG_INTERCEPT_MESSAGE_PAGE_ENABLED;
1888 
1889 	if (args.pt_flags & BIT(MSHV_PT_BIT_LAPIC))
1890 		creation_flags |= HV_PARTITION_CREATION_FLAG_LAPIC_ENABLED;
1891 	if (args.pt_flags & BIT(MSHV_PT_BIT_X2APIC))
1892 		creation_flags |= HV_PARTITION_CREATION_FLAG_X2APIC_CAPABLE;
1893 	if (args.pt_flags & BIT(MSHV_PT_BIT_GPA_SUPER_PAGES))
1894 		creation_flags |= HV_PARTITION_CREATION_FLAG_GPA_SUPER_PAGES_ENABLED;
1895 
1896 	switch (args.pt_isolation) {
1897 	case MSHV_PT_ISOLATION_NONE:
1898 		isolation_properties.isolation_type =
1899 			HV_PARTITION_ISOLATION_TYPE_NONE;
1900 		break;
1901 	}
1902 
1903 	partition = kzalloc(sizeof(*partition), GFP_KERNEL);
1904 	if (!partition)
1905 		return -ENOMEM;
1906 
1907 	partition->pt_module_dev = module_dev;
1908 	partition->isolation_type = isolation_properties.isolation_type;
1909 
1910 	refcount_set(&partition->pt_ref_count, 1);
1911 
1912 	mutex_init(&partition->pt_mutex);
1913 
1914 	mutex_init(&partition->pt_irq_lock);
1915 
1916 	init_completion(&partition->async_hypercall);
1917 
1918 	INIT_HLIST_HEAD(&partition->irq_ack_notifier_list);
1919 
1920 	INIT_HLIST_HEAD(&partition->pt_devices);
1921 
1922 	INIT_HLIST_HEAD(&partition->pt_mem_regions);
1923 
1924 	mshv_eventfd_init(partition);
1925 
1926 	ret = init_srcu_struct(&partition->pt_irq_srcu);
1927 	if (ret)
1928 		goto free_partition;
1929 
1930 	ret = hv_call_create_partition(creation_flags,
1931 				       creation_properties,
1932 				       isolation_properties,
1933 				       &partition->pt_id);
1934 	if (ret)
1935 		goto cleanup_irq_srcu;
1936 
1937 	ret = add_partition(partition);
1938 	if (ret)
1939 		goto delete_partition;
1940 
1941 	ret = mshv_init_async_handler(partition);
1942 	if (!ret) {
1943 		ret = FD_ADD(O_CLOEXEC, anon_inode_getfile("mshv_partition",
1944 							   &mshv_partition_fops,
1945 							   partition, O_RDWR));
1946 		if (ret >= 0)
1947 			return ret;
1948 	}
1949 	remove_partition(partition);
1950 delete_partition:
1951 	hv_call_delete_partition(partition->pt_id);
1952 cleanup_irq_srcu:
1953 	cleanup_srcu_struct(&partition->pt_irq_srcu);
1954 free_partition:
1955 	kfree(partition);
1956 
1957 	return ret;
1958 }
1959 
1960 static long mshv_dev_ioctl(struct file *filp, unsigned int ioctl,
1961 			   unsigned long arg)
1962 {
1963 	struct miscdevice *misc = filp->private_data;
1964 
1965 	switch (ioctl) {
1966 	case MSHV_CREATE_PARTITION:
1967 		return mshv_ioctl_create_partition((void __user *)arg,
1968 						misc->this_device);
1969 	}
1970 
1971 	return -ENOTTY;
1972 }
1973 
1974 static int
1975 mshv_dev_open(struct inode *inode, struct file *filp)
1976 {
1977 	return 0;
1978 }
1979 
1980 static int
1981 mshv_dev_release(struct inode *inode, struct file *filp)
1982 {
1983 	return 0;
1984 }
1985 
1986 static int mshv_cpuhp_online;
1987 static int mshv_root_sched_online;
1988 
1989 static const char *scheduler_type_to_string(enum hv_scheduler_type type)
1990 {
1991 	switch (type) {
1992 	case HV_SCHEDULER_TYPE_LP:
1993 		return "classic scheduler without SMT";
1994 	case HV_SCHEDULER_TYPE_LP_SMT:
1995 		return "classic scheduler with SMT";
1996 	case HV_SCHEDULER_TYPE_CORE_SMT:
1997 		return "core scheduler";
1998 	case HV_SCHEDULER_TYPE_ROOT:
1999 		return "root scheduler";
2000 	default:
2001 		return "unknown scheduler";
2002 	};
2003 }
2004 
2005 /* TODO move this to hv_common.c when needed outside */
2006 static int __init hv_retrieve_scheduler_type(enum hv_scheduler_type *out)
2007 {
2008 	struct hv_input_get_system_property *input;
2009 	struct hv_output_get_system_property *output;
2010 	unsigned long flags;
2011 	u64 status;
2012 
2013 	local_irq_save(flags);
2014 	input = *this_cpu_ptr(hyperv_pcpu_input_arg);
2015 	output = *this_cpu_ptr(hyperv_pcpu_output_arg);
2016 
2017 	memset(input, 0, sizeof(*input));
2018 	memset(output, 0, sizeof(*output));
2019 	input->property_id = HV_SYSTEM_PROPERTY_SCHEDULER_TYPE;
2020 
2021 	status = hv_do_hypercall(HVCALL_GET_SYSTEM_PROPERTY, input, output);
2022 	if (!hv_result_success(status)) {
2023 		local_irq_restore(flags);
2024 		pr_err("%s: %s\n", __func__, hv_result_to_string(status));
2025 		return hv_result_to_errno(status);
2026 	}
2027 
2028 	*out = output->scheduler_type;
2029 	local_irq_restore(flags);
2030 
2031 	return 0;
2032 }
2033 
2034 /* Retrieve and stash the supported scheduler type */
2035 static int __init mshv_retrieve_scheduler_type(struct device *dev)
2036 {
2037 	int ret = 0;
2038 
2039 	if (hv_l1vh_partition())
2040 		hv_scheduler_type = HV_SCHEDULER_TYPE_CORE_SMT;
2041 	else
2042 		ret = hv_retrieve_scheduler_type(&hv_scheduler_type);
2043 
2044 	if (ret)
2045 		return ret;
2046 
2047 	dev_info(dev, "Hypervisor using %s\n",
2048 		 scheduler_type_to_string(hv_scheduler_type));
2049 
2050 	switch (hv_scheduler_type) {
2051 	case HV_SCHEDULER_TYPE_CORE_SMT:
2052 	case HV_SCHEDULER_TYPE_LP_SMT:
2053 	case HV_SCHEDULER_TYPE_ROOT:
2054 	case HV_SCHEDULER_TYPE_LP:
2055 		/* Supported scheduler, nothing to do */
2056 		break;
2057 	default:
2058 		dev_err(dev, "unsupported scheduler 0x%x, bailing.\n",
2059 			hv_scheduler_type);
2060 		return -EOPNOTSUPP;
2061 	}
2062 
2063 	return 0;
2064 }
2065 
2066 static int mshv_root_scheduler_init(unsigned int cpu)
2067 {
2068 	void **inputarg, **outputarg, *p;
2069 
2070 	inputarg = (void **)this_cpu_ptr(root_scheduler_input);
2071 	outputarg = (void **)this_cpu_ptr(root_scheduler_output);
2072 
2073 	/* Allocate two consecutive pages. One for input, one for output. */
2074 	p = kmalloc(2 * HV_HYP_PAGE_SIZE, GFP_KERNEL);
2075 	if (!p)
2076 		return -ENOMEM;
2077 
2078 	*inputarg = p;
2079 	*outputarg = (char *)p + HV_HYP_PAGE_SIZE;
2080 
2081 	return 0;
2082 }
2083 
2084 static int mshv_root_scheduler_cleanup(unsigned int cpu)
2085 {
2086 	void *p, **inputarg, **outputarg;
2087 
2088 	inputarg = (void **)this_cpu_ptr(root_scheduler_input);
2089 	outputarg = (void **)this_cpu_ptr(root_scheduler_output);
2090 
2091 	p = *inputarg;
2092 
2093 	*inputarg = NULL;
2094 	*outputarg = NULL;
2095 
2096 	kfree(p);
2097 
2098 	return 0;
2099 }
2100 
2101 /* Must be called after retrieving the scheduler type */
2102 static int
2103 root_scheduler_init(struct device *dev)
2104 {
2105 	int ret;
2106 
2107 	if (hv_scheduler_type != HV_SCHEDULER_TYPE_ROOT)
2108 		return 0;
2109 
2110 	root_scheduler_input = alloc_percpu(void *);
2111 	root_scheduler_output = alloc_percpu(void *);
2112 
2113 	if (!root_scheduler_input || !root_scheduler_output) {
2114 		dev_err(dev, "Failed to allocate root scheduler buffers\n");
2115 		ret = -ENOMEM;
2116 		goto out;
2117 	}
2118 
2119 	ret = cpuhp_setup_state(CPUHP_AP_ONLINE_DYN, "mshv_root_sched",
2120 				mshv_root_scheduler_init,
2121 				mshv_root_scheduler_cleanup);
2122 
2123 	if (ret < 0) {
2124 		dev_err(dev, "Failed to setup root scheduler state: %i\n", ret);
2125 		goto out;
2126 	}
2127 
2128 	mshv_root_sched_online = ret;
2129 
2130 	return 0;
2131 
2132 out:
2133 	free_percpu(root_scheduler_input);
2134 	free_percpu(root_scheduler_output);
2135 	return ret;
2136 }
2137 
2138 static void
2139 root_scheduler_deinit(void)
2140 {
2141 	if (hv_scheduler_type != HV_SCHEDULER_TYPE_ROOT)
2142 		return;
2143 
2144 	cpuhp_remove_state(mshv_root_sched_online);
2145 	free_percpu(root_scheduler_input);
2146 	free_percpu(root_scheduler_output);
2147 }
2148 
2149 static int mshv_reboot_notify(struct notifier_block *nb,
2150 			      unsigned long code, void *unused)
2151 {
2152 	cpuhp_remove_state(mshv_cpuhp_online);
2153 	return 0;
2154 }
2155 
2156 struct notifier_block mshv_reboot_nb = {
2157 	.notifier_call = mshv_reboot_notify,
2158 };
2159 
2160 static void mshv_root_partition_exit(void)
2161 {
2162 	unregister_reboot_notifier(&mshv_reboot_nb);
2163 	root_scheduler_deinit();
2164 }
2165 
2166 static int __init mshv_root_partition_init(struct device *dev)
2167 {
2168 	int err;
2169 
2170 	err = root_scheduler_init(dev);
2171 	if (err)
2172 		return err;
2173 
2174 	err = register_reboot_notifier(&mshv_reboot_nb);
2175 	if (err)
2176 		goto root_sched_deinit;
2177 
2178 	return 0;
2179 
2180 root_sched_deinit:
2181 	root_scheduler_deinit();
2182 	return err;
2183 }
2184 
2185 static int __init mshv_parent_partition_init(void)
2186 {
2187 	int ret;
2188 	struct device *dev;
2189 	union hv_hypervisor_version_info version_info;
2190 
2191 	if (!hv_parent_partition() || is_kdump_kernel())
2192 		return -ENODEV;
2193 
2194 	if (hv_get_hypervisor_version(&version_info))
2195 		return -ENODEV;
2196 
2197 	ret = misc_register(&mshv_dev);
2198 	if (ret)
2199 		return ret;
2200 
2201 	dev = mshv_dev.this_device;
2202 
2203 	if (version_info.build_number < MSHV_HV_MIN_VERSION ||
2204 	    version_info.build_number > MSHV_HV_MAX_VERSION) {
2205 		dev_err(dev, "Running on unvalidated Hyper-V version\n");
2206 		dev_err(dev, "Versions: current: %u  min: %u  max: %u\n",
2207 			version_info.build_number, MSHV_HV_MIN_VERSION,
2208 			MSHV_HV_MAX_VERSION);
2209 	}
2210 
2211 	mshv_root.synic_pages = alloc_percpu(struct hv_synic_pages);
2212 	if (!mshv_root.synic_pages) {
2213 		dev_err(dev, "Failed to allocate percpu synic page\n");
2214 		ret = -ENOMEM;
2215 		goto device_deregister;
2216 	}
2217 
2218 	ret = cpuhp_setup_state(CPUHP_AP_ONLINE_DYN, "mshv_synic",
2219 				mshv_synic_init,
2220 				mshv_synic_cleanup);
2221 	if (ret < 0) {
2222 		dev_err(dev, "Failed to setup cpu hotplug state: %i\n", ret);
2223 		goto free_synic_pages;
2224 	}
2225 
2226 	mshv_cpuhp_online = ret;
2227 
2228 	ret = mshv_retrieve_scheduler_type(dev);
2229 	if (ret)
2230 		goto remove_cpu_state;
2231 
2232 	if (hv_root_partition())
2233 		ret = mshv_root_partition_init(dev);
2234 	if (ret)
2235 		goto remove_cpu_state;
2236 
2237 	ret = mshv_irqfd_wq_init();
2238 	if (ret)
2239 		goto exit_partition;
2240 
2241 	spin_lock_init(&mshv_root.pt_ht_lock);
2242 	hash_init(mshv_root.pt_htable);
2243 
2244 	hv_setup_mshv_handler(mshv_isr);
2245 
2246 	return 0;
2247 
2248 exit_partition:
2249 	if (hv_root_partition())
2250 		mshv_root_partition_exit();
2251 remove_cpu_state:
2252 	cpuhp_remove_state(mshv_cpuhp_online);
2253 free_synic_pages:
2254 	free_percpu(mshv_root.synic_pages);
2255 device_deregister:
2256 	misc_deregister(&mshv_dev);
2257 	return ret;
2258 }
2259 
2260 static void __exit mshv_parent_partition_exit(void)
2261 {
2262 	hv_setup_mshv_handler(NULL);
2263 	mshv_port_table_fini();
2264 	misc_deregister(&mshv_dev);
2265 	mshv_irqfd_wq_cleanup();
2266 	if (hv_root_partition())
2267 		mshv_root_partition_exit();
2268 	cpuhp_remove_state(mshv_cpuhp_online);
2269 	free_percpu(mshv_root.synic_pages);
2270 }
2271 
2272 module_init(mshv_parent_partition_init);
2273 module_exit(mshv_parent_partition_exit);
2274