xref: /linux/drivers/hv/mshv_root_main.c (revision 4f9786035f9e519db41375818e1d0b5f20da2f10)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2024, Microsoft Corporation.
4  *
5  * The main part of the mshv_root module, providing APIs to create
6  * and manage guest partitions.
7  *
8  * Authors: Microsoft Linux virtualization team
9  */
10 
11 #include <linux/kernel.h>
12 #include <linux/module.h>
13 #include <linux/fs.h>
14 #include <linux/miscdevice.h>
15 #include <linux/slab.h>
16 #include <linux/file.h>
17 #include <linux/anon_inodes.h>
18 #include <linux/mm.h>
19 #include <linux/io.h>
20 #include <linux/cpuhotplug.h>
21 #include <linux/random.h>
22 #include <asm/mshyperv.h>
23 #include <linux/hyperv.h>
24 #include <linux/notifier.h>
25 #include <linux/reboot.h>
26 #include <linux/kexec.h>
27 #include <linux/page-flags.h>
28 #include <linux/crash_dump.h>
29 #include <linux/panic_notifier.h>
30 #include <linux/vmalloc.h>
31 
32 #include "mshv_eventfd.h"
33 #include "mshv.h"
34 #include "mshv_root.h"
35 
36 MODULE_AUTHOR("Microsoft");
37 MODULE_LICENSE("GPL");
38 MODULE_DESCRIPTION("Microsoft Hyper-V root partition VMM interface /dev/mshv");
39 
40 /* TODO move this to mshyperv.h when needed outside driver */
41 static inline bool hv_parent_partition(void)
42 {
43 	return hv_root_partition();
44 }
45 
46 /* TODO move this to another file when debugfs code is added */
47 enum hv_stats_vp_counters {			/* HV_THREAD_COUNTER */
48 #if defined(CONFIG_X86)
49 	VpRootDispatchThreadBlocked			= 201,
50 #elif defined(CONFIG_ARM64)
51 	VpRootDispatchThreadBlocked			= 94,
52 #endif
53 	VpStatsMaxCounter
54 };
55 
56 struct hv_stats_page {
57 	union {
58 		u64 vp_cntrs[VpStatsMaxCounter];		/* VP counters */
59 		u8 data[HV_HYP_PAGE_SIZE];
60 	};
61 } __packed;
62 
63 struct mshv_root mshv_root;
64 
65 enum hv_scheduler_type hv_scheduler_type;
66 
67 /* Once we implement the fast extended hypercall ABI they can go away. */
68 static void * __percpu *root_scheduler_input;
69 static void * __percpu *root_scheduler_output;
70 
71 static long mshv_dev_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg);
72 static int mshv_dev_open(struct inode *inode, struct file *filp);
73 static int mshv_dev_release(struct inode *inode, struct file *filp);
74 static int mshv_vp_release(struct inode *inode, struct file *filp);
75 static long mshv_vp_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg);
76 static int mshv_partition_release(struct inode *inode, struct file *filp);
77 static long mshv_partition_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg);
78 static int mshv_vp_mmap(struct file *file, struct vm_area_struct *vma);
79 static vm_fault_t mshv_vp_fault(struct vm_fault *vmf);
80 static int mshv_init_async_handler(struct mshv_partition *partition);
81 static void mshv_async_hvcall_handler(void *data, u64 *status);
82 
83 static const union hv_input_vtl input_vtl_zero;
84 static const union hv_input_vtl input_vtl_normal = {
85 	.target_vtl = HV_NORMAL_VTL,
86 	.use_target_vtl = 1,
87 };
88 
89 static const struct vm_operations_struct mshv_vp_vm_ops = {
90 	.fault = mshv_vp_fault,
91 };
92 
93 static const struct file_operations mshv_vp_fops = {
94 	.owner = THIS_MODULE,
95 	.release = mshv_vp_release,
96 	.unlocked_ioctl = mshv_vp_ioctl,
97 	.llseek = noop_llseek,
98 	.mmap = mshv_vp_mmap,
99 };
100 
101 static const struct file_operations mshv_partition_fops = {
102 	.owner = THIS_MODULE,
103 	.release = mshv_partition_release,
104 	.unlocked_ioctl = mshv_partition_ioctl,
105 	.llseek = noop_llseek,
106 };
107 
108 static const struct file_operations mshv_dev_fops = {
109 	.owner = THIS_MODULE,
110 	.open = mshv_dev_open,
111 	.release = mshv_dev_release,
112 	.unlocked_ioctl = mshv_dev_ioctl,
113 	.llseek = noop_llseek,
114 };
115 
116 static struct miscdevice mshv_dev = {
117 	.minor = MISC_DYNAMIC_MINOR,
118 	.name = "mshv",
119 	.fops = &mshv_dev_fops,
120 	.mode = 0600,
121 };
122 
123 /*
124  * Only allow hypercalls that have a u64 partition id as the first member of
125  * the input structure.
126  * These are sorted by value.
127  */
128 static u16 mshv_passthru_hvcalls[] = {
129 	HVCALL_GET_PARTITION_PROPERTY,
130 	HVCALL_SET_PARTITION_PROPERTY,
131 	HVCALL_INSTALL_INTERCEPT,
132 	HVCALL_GET_VP_REGISTERS,
133 	HVCALL_SET_VP_REGISTERS,
134 	HVCALL_TRANSLATE_VIRTUAL_ADDRESS,
135 	HVCALL_CLEAR_VIRTUAL_INTERRUPT,
136 	HVCALL_REGISTER_INTERCEPT_RESULT,
137 	HVCALL_ASSERT_VIRTUAL_INTERRUPT,
138 	HVCALL_GET_GPA_PAGES_ACCESS_STATES,
139 	HVCALL_SIGNAL_EVENT_DIRECT,
140 	HVCALL_POST_MESSAGE_DIRECT,
141 	HVCALL_GET_VP_CPUID_VALUES,
142 };
143 
144 static bool mshv_hvcall_is_async(u16 code)
145 {
146 	switch (code) {
147 	case HVCALL_SET_PARTITION_PROPERTY:
148 		return true;
149 	default:
150 		break;
151 	}
152 	return false;
153 }
154 
155 static int mshv_ioctl_passthru_hvcall(struct mshv_partition *partition,
156 				      bool partition_locked,
157 				      void __user *user_args)
158 {
159 	u64 status;
160 	int ret = 0, i;
161 	bool is_async;
162 	struct mshv_root_hvcall args;
163 	struct page *page;
164 	unsigned int pages_order;
165 	void *input_pg = NULL;
166 	void *output_pg = NULL;
167 
168 	if (copy_from_user(&args, user_args, sizeof(args)))
169 		return -EFAULT;
170 
171 	if (args.status || !args.in_ptr || args.in_sz < sizeof(u64) ||
172 	    mshv_field_nonzero(args, rsvd) || args.in_sz > HV_HYP_PAGE_SIZE)
173 		return -EINVAL;
174 
175 	if (args.out_ptr && (!args.out_sz || args.out_sz > HV_HYP_PAGE_SIZE))
176 		return -EINVAL;
177 
178 	for (i = 0; i < ARRAY_SIZE(mshv_passthru_hvcalls); ++i)
179 		if (args.code == mshv_passthru_hvcalls[i])
180 			break;
181 
182 	if (i >= ARRAY_SIZE(mshv_passthru_hvcalls))
183 		return -EINVAL;
184 
185 	is_async = mshv_hvcall_is_async(args.code);
186 	if (is_async) {
187 		/* async hypercalls can only be called from partition fd */
188 		if (!partition_locked)
189 			return -EINVAL;
190 		ret = mshv_init_async_handler(partition);
191 		if (ret)
192 			return ret;
193 	}
194 
195 	pages_order = args.out_ptr ? 1 : 0;
196 	page = alloc_pages(GFP_KERNEL, pages_order);
197 	if (!page)
198 		return -ENOMEM;
199 	input_pg = page_address(page);
200 
201 	if (args.out_ptr)
202 		output_pg = (char *)input_pg + PAGE_SIZE;
203 	else
204 		output_pg = NULL;
205 
206 	if (copy_from_user(input_pg, (void __user *)args.in_ptr,
207 			   args.in_sz)) {
208 		ret = -EFAULT;
209 		goto free_pages_out;
210 	}
211 
212 	/*
213 	 * NOTE: This only works because all the allowed hypercalls' input
214 	 * structs begin with a u64 partition_id field.
215 	 */
216 	*(u64 *)input_pg = partition->pt_id;
217 
218 	if (args.reps)
219 		status = hv_do_rep_hypercall(args.code, args.reps, 0,
220 					     input_pg, output_pg);
221 	else
222 		status = hv_do_hypercall(args.code, input_pg, output_pg);
223 
224 	if (hv_result(status) == HV_STATUS_CALL_PENDING) {
225 		if (is_async) {
226 			mshv_async_hvcall_handler(partition, &status);
227 		} else { /* Paranoia check. This shouldn't happen! */
228 			ret = -EBADFD;
229 			goto free_pages_out;
230 		}
231 	}
232 
233 	if (hv_result(status) == HV_STATUS_INSUFFICIENT_MEMORY) {
234 		ret = hv_call_deposit_pages(NUMA_NO_NODE, partition->pt_id, 1);
235 		if (!ret)
236 			ret = -EAGAIN;
237 	} else if (!hv_result_success(status)) {
238 		ret = hv_result_to_errno(status);
239 	}
240 
241 	/*
242 	 * Always return the status and output data regardless of result.
243 	 * The VMM may need it to determine how to proceed. E.g. the status may
244 	 * contain the number of reps completed if a rep hypercall partially
245 	 * succeeded.
246 	 */
247 	args.status = hv_result(status);
248 	args.reps = args.reps ? hv_repcomp(status) : 0;
249 	if (copy_to_user(user_args, &args, sizeof(args)))
250 		ret = -EFAULT;
251 
252 	if (output_pg &&
253 	    copy_to_user((void __user *)args.out_ptr, output_pg, args.out_sz))
254 		ret = -EFAULT;
255 
256 free_pages_out:
257 	free_pages((unsigned long)input_pg, pages_order);
258 
259 	return ret;
260 }
261 
262 static inline bool is_ghcb_mapping_available(void)
263 {
264 #if IS_ENABLED(CONFIG_X86_64)
265 	return ms_hyperv.ext_features & HV_VP_GHCB_ROOT_MAPPING_AVAILABLE;
266 #else
267 	return 0;
268 #endif
269 }
270 
271 static int mshv_get_vp_registers(u32 vp_index, u64 partition_id, u16 count,
272 				 struct hv_register_assoc *registers)
273 {
274 	return hv_call_get_vp_registers(vp_index, partition_id,
275 					count, input_vtl_zero, registers);
276 }
277 
278 static int mshv_set_vp_registers(u32 vp_index, u64 partition_id, u16 count,
279 				 struct hv_register_assoc *registers)
280 {
281 	return hv_call_set_vp_registers(vp_index, partition_id,
282 					count, input_vtl_zero, registers);
283 }
284 
285 /*
286  * Explicit guest vCPU suspend is asynchronous by nature (as it is requested by
287  * dom0 vCPU for guest vCPU) and thus it can race with "intercept" suspend,
288  * done by the hypervisor.
289  * "Intercept" suspend leads to asynchronous message delivery to dom0 which
290  * should be awaited to keep the VP loop consistent (i.e. no message pending
291  * upon VP resume).
292  * VP intercept suspend can't be done when the VP is explicitly suspended
293  * already, and thus can be only two possible race scenarios:
294  *   1. implicit suspend bit set -> explicit suspend bit set -> message sent
295  *   2. implicit suspend bit set -> message sent -> explicit suspend bit set
296  * Checking for implicit suspend bit set after explicit suspend request has
297  * succeeded in either case allows us to reliably identify, if there is a
298  * message to receive and deliver to VMM.
299  */
300 static int
301 mshv_suspend_vp(const struct mshv_vp *vp, bool *message_in_flight)
302 {
303 	struct hv_register_assoc explicit_suspend = {
304 		.name = HV_REGISTER_EXPLICIT_SUSPEND
305 	};
306 	struct hv_register_assoc intercept_suspend = {
307 		.name = HV_REGISTER_INTERCEPT_SUSPEND
308 	};
309 	union hv_explicit_suspend_register *es =
310 		&explicit_suspend.value.explicit_suspend;
311 	union hv_intercept_suspend_register *is =
312 		&intercept_suspend.value.intercept_suspend;
313 	int ret;
314 
315 	es->suspended = 1;
316 
317 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
318 				    1, &explicit_suspend);
319 	if (ret) {
320 		vp_err(vp, "Failed to explicitly suspend vCPU\n");
321 		return ret;
322 	}
323 
324 	ret = mshv_get_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
325 				    1, &intercept_suspend);
326 	if (ret) {
327 		vp_err(vp, "Failed to get intercept suspend state\n");
328 		return ret;
329 	}
330 
331 	*message_in_flight = is->suspended;
332 
333 	return 0;
334 }
335 
336 /*
337  * This function is used when VPs are scheduled by the hypervisor's
338  * scheduler.
339  *
340  * Caller has to make sure the registers contain cleared
341  * HV_REGISTER_INTERCEPT_SUSPEND and HV_REGISTER_EXPLICIT_SUSPEND registers
342  * exactly in this order (the hypervisor clears them sequentially) to avoid
343  * potential invalid clearing a newly arrived HV_REGISTER_INTERCEPT_SUSPEND
344  * after VP is released from HV_REGISTER_EXPLICIT_SUSPEND in case of the
345  * opposite order.
346  */
347 static long mshv_run_vp_with_hyp_scheduler(struct mshv_vp *vp)
348 {
349 	long ret;
350 	struct hv_register_assoc suspend_regs[2] = {
351 			{ .name = HV_REGISTER_INTERCEPT_SUSPEND },
352 			{ .name = HV_REGISTER_EXPLICIT_SUSPEND }
353 	};
354 	size_t count = ARRAY_SIZE(suspend_regs);
355 
356 	/* Resume VP execution */
357 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
358 				    count, suspend_regs);
359 	if (ret) {
360 		vp_err(vp, "Failed to resume vp execution. %lx\n", ret);
361 		return ret;
362 	}
363 
364 	ret = wait_event_interruptible(vp->run.vp_suspend_queue,
365 				       vp->run.kicked_by_hv == 1);
366 	if (ret) {
367 		bool message_in_flight;
368 
369 		/*
370 		 * Otherwise the waiting was interrupted by a signal: suspend
371 		 * the vCPU explicitly and copy message in flight (if any).
372 		 */
373 		ret = mshv_suspend_vp(vp, &message_in_flight);
374 		if (ret)
375 			return ret;
376 
377 		/* Return if no message in flight */
378 		if (!message_in_flight)
379 			return -EINTR;
380 
381 		/* Wait for the message in flight. */
382 		wait_event(vp->run.vp_suspend_queue, vp->run.kicked_by_hv == 1);
383 	}
384 
385 	/*
386 	 * Reset the flag to make the wait_event call above work
387 	 * next time.
388 	 */
389 	vp->run.kicked_by_hv = 0;
390 
391 	return 0;
392 }
393 
394 static int
395 mshv_vp_dispatch(struct mshv_vp *vp, u32 flags,
396 		 struct hv_output_dispatch_vp *res)
397 {
398 	struct hv_input_dispatch_vp *input;
399 	struct hv_output_dispatch_vp *output;
400 	u64 status;
401 
402 	preempt_disable();
403 	input = *this_cpu_ptr(root_scheduler_input);
404 	output = *this_cpu_ptr(root_scheduler_output);
405 
406 	memset(input, 0, sizeof(*input));
407 	memset(output, 0, sizeof(*output));
408 
409 	input->partition_id = vp->vp_partition->pt_id;
410 	input->vp_index = vp->vp_index;
411 	input->time_slice = 0; /* Run forever until something happens */
412 	input->spec_ctrl = 0; /* TODO: set sensible flags */
413 	input->flags = flags;
414 
415 	vp->run.flags.root_sched_dispatched = 1;
416 	status = hv_do_hypercall(HVCALL_DISPATCH_VP, input, output);
417 	vp->run.flags.root_sched_dispatched = 0;
418 
419 	*res = *output;
420 	preempt_enable();
421 
422 	if (!hv_result_success(status))
423 		vp_err(vp, "%s: status %s\n", __func__,
424 		       hv_result_to_string(status));
425 
426 	return hv_result_to_errno(status);
427 }
428 
429 static int
430 mshv_vp_clear_explicit_suspend(struct mshv_vp *vp)
431 {
432 	struct hv_register_assoc explicit_suspend = {
433 		.name = HV_REGISTER_EXPLICIT_SUSPEND,
434 		.value.explicit_suspend.suspended = 0,
435 	};
436 	int ret;
437 
438 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
439 				    1, &explicit_suspend);
440 
441 	if (ret)
442 		vp_err(vp, "Failed to unsuspend\n");
443 
444 	return ret;
445 }
446 
447 #if IS_ENABLED(CONFIG_X86_64)
448 static u64 mshv_vp_interrupt_pending(struct mshv_vp *vp)
449 {
450 	if (!vp->vp_register_page)
451 		return 0;
452 	return vp->vp_register_page->interrupt_vectors.as_uint64;
453 }
454 #else
455 static u64 mshv_vp_interrupt_pending(struct mshv_vp *vp)
456 {
457 	return 0;
458 }
459 #endif
460 
461 static bool mshv_vp_dispatch_thread_blocked(struct mshv_vp *vp)
462 {
463 	struct hv_stats_page **stats = vp->vp_stats_pages;
464 	u64 *self_vp_cntrs = stats[HV_STATS_AREA_SELF]->vp_cntrs;
465 	u64 *parent_vp_cntrs = stats[HV_STATS_AREA_PARENT]->vp_cntrs;
466 
467 	if (self_vp_cntrs[VpRootDispatchThreadBlocked])
468 		return self_vp_cntrs[VpRootDispatchThreadBlocked];
469 	return parent_vp_cntrs[VpRootDispatchThreadBlocked];
470 }
471 
472 static int
473 mshv_vp_wait_for_hv_kick(struct mshv_vp *vp)
474 {
475 	int ret;
476 
477 	ret = wait_event_interruptible(vp->run.vp_suspend_queue,
478 				       (vp->run.kicked_by_hv == 1 &&
479 					!mshv_vp_dispatch_thread_blocked(vp)) ||
480 				       mshv_vp_interrupt_pending(vp));
481 	if (ret)
482 		return -EINTR;
483 
484 	vp->run.flags.root_sched_blocked = 0;
485 	vp->run.kicked_by_hv = 0;
486 
487 	return 0;
488 }
489 
490 static int mshv_pre_guest_mode_work(struct mshv_vp *vp)
491 {
492 	const ulong work_flags = _TIF_NOTIFY_SIGNAL | _TIF_SIGPENDING |
493 				 _TIF_NEED_RESCHED  | _TIF_NOTIFY_RESUME;
494 	ulong th_flags;
495 
496 	th_flags = read_thread_flags();
497 	while (th_flags & work_flags) {
498 		int ret;
499 
500 		/* nb: following will call schedule */
501 		ret = mshv_do_pre_guest_mode_work(th_flags);
502 
503 		if (ret)
504 			return ret;
505 
506 		th_flags = read_thread_flags();
507 	}
508 
509 	return 0;
510 }
511 
512 /* Must be called with interrupts enabled */
513 static long mshv_run_vp_with_root_scheduler(struct mshv_vp *vp)
514 {
515 	long ret;
516 
517 	if (vp->run.flags.root_sched_blocked) {
518 		/*
519 		 * Dispatch state of this VP is blocked. Need to wait
520 		 * for the hypervisor to clear the blocked state before
521 		 * dispatching it.
522 		 */
523 		ret = mshv_vp_wait_for_hv_kick(vp);
524 		if (ret)
525 			return ret;
526 	}
527 
528 	do {
529 		u32 flags = 0;
530 		struct hv_output_dispatch_vp output;
531 
532 		ret = mshv_pre_guest_mode_work(vp);
533 		if (ret)
534 			break;
535 
536 		if (vp->run.flags.intercept_suspend)
537 			flags |= HV_DISPATCH_VP_FLAG_CLEAR_INTERCEPT_SUSPEND;
538 
539 		if (mshv_vp_interrupt_pending(vp))
540 			flags |= HV_DISPATCH_VP_FLAG_SCAN_INTERRUPT_INJECTION;
541 
542 		ret = mshv_vp_dispatch(vp, flags, &output);
543 		if (ret)
544 			break;
545 
546 		vp->run.flags.intercept_suspend = 0;
547 
548 		if (output.dispatch_state == HV_VP_DISPATCH_STATE_BLOCKED) {
549 			if (output.dispatch_event ==
550 						HV_VP_DISPATCH_EVENT_SUSPEND) {
551 				/*
552 				 * TODO: remove the warning once VP canceling
553 				 *	 is supported
554 				 */
555 				WARN_ONCE(atomic64_read(&vp->run.vp_signaled_count),
556 					  "%s: vp#%d: unexpected explicit suspend\n",
557 					  __func__, vp->vp_index);
558 				/*
559 				 * Need to clear explicit suspend before
560 				 * dispatching.
561 				 * Explicit suspend is either:
562 				 * - set right after the first VP dispatch or
563 				 * - set explicitly via hypercall
564 				 * Since the latter case is not yet supported,
565 				 * simply clear it here.
566 				 */
567 				ret = mshv_vp_clear_explicit_suspend(vp);
568 				if (ret)
569 					break;
570 
571 				ret = mshv_vp_wait_for_hv_kick(vp);
572 				if (ret)
573 					break;
574 			} else {
575 				vp->run.flags.root_sched_blocked = 1;
576 				ret = mshv_vp_wait_for_hv_kick(vp);
577 				if (ret)
578 					break;
579 			}
580 		} else {
581 			/* HV_VP_DISPATCH_STATE_READY */
582 			if (output.dispatch_event ==
583 						HV_VP_DISPATCH_EVENT_INTERCEPT)
584 				vp->run.flags.intercept_suspend = 1;
585 		}
586 	} while (!vp->run.flags.intercept_suspend);
587 
588 	return ret;
589 }
590 
591 static_assert(sizeof(struct hv_message) <= MSHV_RUN_VP_BUF_SZ,
592 	      "sizeof(struct hv_message) must not exceed MSHV_RUN_VP_BUF_SZ");
593 
594 static long mshv_vp_ioctl_run_vp(struct mshv_vp *vp, void __user *ret_msg)
595 {
596 	long rc;
597 
598 	if (hv_scheduler_type == HV_SCHEDULER_TYPE_ROOT)
599 		rc = mshv_run_vp_with_root_scheduler(vp);
600 	else
601 		rc = mshv_run_vp_with_hyp_scheduler(vp);
602 
603 	if (rc)
604 		return rc;
605 
606 	if (copy_to_user(ret_msg, vp->vp_intercept_msg_page,
607 			 sizeof(struct hv_message)))
608 		rc = -EFAULT;
609 
610 	return rc;
611 }
612 
613 static int
614 mshv_vp_ioctl_get_set_state_pfn(struct mshv_vp *vp,
615 				struct hv_vp_state_data state_data,
616 				unsigned long user_pfn, size_t page_count,
617 				bool is_set)
618 {
619 	int completed, ret = 0;
620 	unsigned long check;
621 	struct page **pages;
622 
623 	if (page_count > INT_MAX)
624 		return -EINVAL;
625 	/*
626 	 * Check the arithmetic for wraparound/overflow.
627 	 * The last page address in the buffer is:
628 	 * (user_pfn + (page_count - 1)) * PAGE_SIZE
629 	 */
630 	if (check_add_overflow(user_pfn, (page_count - 1), &check))
631 		return -EOVERFLOW;
632 	if (check_mul_overflow(check, PAGE_SIZE, &check))
633 		return -EOVERFLOW;
634 
635 	/* Pin user pages so hypervisor can copy directly to them */
636 	pages = kcalloc(page_count, sizeof(struct page *), GFP_KERNEL);
637 	if (!pages)
638 		return -ENOMEM;
639 
640 	for (completed = 0; completed < page_count; completed += ret) {
641 		unsigned long user_addr = (user_pfn + completed) * PAGE_SIZE;
642 		int remaining = page_count - completed;
643 
644 		ret = pin_user_pages_fast(user_addr, remaining, FOLL_WRITE,
645 					  &pages[completed]);
646 		if (ret < 0) {
647 			vp_err(vp, "%s: Failed to pin user pages error %i\n",
648 			       __func__, ret);
649 			goto unpin_pages;
650 		}
651 	}
652 
653 	if (is_set)
654 		ret = hv_call_set_vp_state(vp->vp_index,
655 					   vp->vp_partition->pt_id,
656 					   state_data, page_count, pages,
657 					   0, NULL);
658 	else
659 		ret = hv_call_get_vp_state(vp->vp_index,
660 					   vp->vp_partition->pt_id,
661 					   state_data, page_count, pages,
662 					   NULL);
663 
664 unpin_pages:
665 	unpin_user_pages(pages, completed);
666 	kfree(pages);
667 	return ret;
668 }
669 
670 static long
671 mshv_vp_ioctl_get_set_state(struct mshv_vp *vp,
672 			    struct mshv_get_set_vp_state __user *user_args,
673 			    bool is_set)
674 {
675 	struct mshv_get_set_vp_state args;
676 	long ret = 0;
677 	union hv_output_get_vp_state vp_state;
678 	u32 data_sz;
679 	struct hv_vp_state_data state_data = {};
680 
681 	if (copy_from_user(&args, user_args, sizeof(args)))
682 		return -EFAULT;
683 
684 	if (args.type >= MSHV_VP_STATE_COUNT || mshv_field_nonzero(args, rsvd) ||
685 	    !args.buf_sz || !PAGE_ALIGNED(args.buf_sz) ||
686 	    !PAGE_ALIGNED(args.buf_ptr))
687 		return -EINVAL;
688 
689 	if (!access_ok((void __user *)args.buf_ptr, args.buf_sz))
690 		return -EFAULT;
691 
692 	switch (args.type) {
693 	case MSHV_VP_STATE_LAPIC:
694 		state_data.type = HV_GET_SET_VP_STATE_LAPIC_STATE;
695 		data_sz = HV_HYP_PAGE_SIZE;
696 		break;
697 	case MSHV_VP_STATE_XSAVE:
698 	{
699 		u64 data_sz_64;
700 
701 		ret = hv_call_get_partition_property(vp->vp_partition->pt_id,
702 						     HV_PARTITION_PROPERTY_XSAVE_STATES,
703 						     &state_data.xsave.states.as_uint64);
704 		if (ret)
705 			return ret;
706 
707 		ret = hv_call_get_partition_property(vp->vp_partition->pt_id,
708 						     HV_PARTITION_PROPERTY_MAX_XSAVE_DATA_SIZE,
709 						     &data_sz_64);
710 		if (ret)
711 			return ret;
712 
713 		data_sz = (u32)data_sz_64;
714 		state_data.xsave.flags = 0;
715 		/* Always request legacy states */
716 		state_data.xsave.states.legacy_x87 = 1;
717 		state_data.xsave.states.legacy_sse = 1;
718 		state_data.type = HV_GET_SET_VP_STATE_XSAVE;
719 		break;
720 	}
721 	case MSHV_VP_STATE_SIMP:
722 		state_data.type = HV_GET_SET_VP_STATE_SIM_PAGE;
723 		data_sz = HV_HYP_PAGE_SIZE;
724 		break;
725 	case MSHV_VP_STATE_SIEFP:
726 		state_data.type = HV_GET_SET_VP_STATE_SIEF_PAGE;
727 		data_sz = HV_HYP_PAGE_SIZE;
728 		break;
729 	case MSHV_VP_STATE_SYNTHETIC_TIMERS:
730 		state_data.type = HV_GET_SET_VP_STATE_SYNTHETIC_TIMERS;
731 		data_sz = sizeof(vp_state.synthetic_timers_state);
732 		break;
733 	default:
734 		return -EINVAL;
735 	}
736 
737 	if (copy_to_user(&user_args->buf_sz, &data_sz, sizeof(user_args->buf_sz)))
738 		return -EFAULT;
739 
740 	if (data_sz > args.buf_sz)
741 		return -EINVAL;
742 
743 	/* If the data is transmitted via pfns, delegate to helper */
744 	if (state_data.type & HV_GET_SET_VP_STATE_TYPE_PFN) {
745 		unsigned long user_pfn = PFN_DOWN(args.buf_ptr);
746 		size_t page_count = PFN_DOWN(args.buf_sz);
747 
748 		return mshv_vp_ioctl_get_set_state_pfn(vp, state_data, user_pfn,
749 						       page_count, is_set);
750 	}
751 
752 	/* Paranoia check - this shouldn't happen! */
753 	if (data_sz > sizeof(vp_state)) {
754 		vp_err(vp, "Invalid vp state data size!\n");
755 		return -EINVAL;
756 	}
757 
758 	if (is_set) {
759 		if (copy_from_user(&vp_state, (__user void *)args.buf_ptr, data_sz))
760 			return -EFAULT;
761 
762 		return hv_call_set_vp_state(vp->vp_index,
763 					    vp->vp_partition->pt_id,
764 					    state_data, 0, NULL,
765 					    sizeof(vp_state), (u8 *)&vp_state);
766 	}
767 
768 	ret = hv_call_get_vp_state(vp->vp_index, vp->vp_partition->pt_id,
769 				   state_data, 0, NULL, &vp_state);
770 	if (ret)
771 		return ret;
772 
773 	if (copy_to_user((void __user *)args.buf_ptr, &vp_state, data_sz))
774 		return -EFAULT;
775 
776 	return 0;
777 }
778 
779 static long
780 mshv_vp_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg)
781 {
782 	struct mshv_vp *vp = filp->private_data;
783 	long r = -ENOTTY;
784 
785 	if (mutex_lock_killable(&vp->vp_mutex))
786 		return -EINTR;
787 
788 	switch (ioctl) {
789 	case MSHV_RUN_VP:
790 		r = mshv_vp_ioctl_run_vp(vp, (void __user *)arg);
791 		break;
792 	case MSHV_GET_VP_STATE:
793 		r = mshv_vp_ioctl_get_set_state(vp, (void __user *)arg, false);
794 		break;
795 	case MSHV_SET_VP_STATE:
796 		r = mshv_vp_ioctl_get_set_state(vp, (void __user *)arg, true);
797 		break;
798 	case MSHV_ROOT_HVCALL:
799 		r = mshv_ioctl_passthru_hvcall(vp->vp_partition, false,
800 					       (void __user *)arg);
801 		break;
802 	default:
803 		vp_warn(vp, "Invalid ioctl: %#x\n", ioctl);
804 		break;
805 	}
806 	mutex_unlock(&vp->vp_mutex);
807 
808 	return r;
809 }
810 
811 static vm_fault_t mshv_vp_fault(struct vm_fault *vmf)
812 {
813 	struct mshv_vp *vp = vmf->vma->vm_file->private_data;
814 
815 	switch (vmf->vma->vm_pgoff) {
816 	case MSHV_VP_MMAP_OFFSET_REGISTERS:
817 		vmf->page = virt_to_page(vp->vp_register_page);
818 		break;
819 	case MSHV_VP_MMAP_OFFSET_INTERCEPT_MESSAGE:
820 		vmf->page = virt_to_page(vp->vp_intercept_msg_page);
821 		break;
822 	case MSHV_VP_MMAP_OFFSET_GHCB:
823 		vmf->page = virt_to_page(vp->vp_ghcb_page);
824 		break;
825 	default:
826 		return VM_FAULT_SIGBUS;
827 	}
828 
829 	get_page(vmf->page);
830 
831 	return 0;
832 }
833 
834 static int mshv_vp_mmap(struct file *file, struct vm_area_struct *vma)
835 {
836 	struct mshv_vp *vp = file->private_data;
837 
838 	switch (vma->vm_pgoff) {
839 	case MSHV_VP_MMAP_OFFSET_REGISTERS:
840 		if (!vp->vp_register_page)
841 			return -ENODEV;
842 		break;
843 	case MSHV_VP_MMAP_OFFSET_INTERCEPT_MESSAGE:
844 		if (!vp->vp_intercept_msg_page)
845 			return -ENODEV;
846 		break;
847 	case MSHV_VP_MMAP_OFFSET_GHCB:
848 		if (!vp->vp_ghcb_page)
849 			return -ENODEV;
850 		break;
851 	default:
852 		return -EINVAL;
853 	}
854 
855 	vma->vm_ops = &mshv_vp_vm_ops;
856 	return 0;
857 }
858 
859 static int
860 mshv_vp_release(struct inode *inode, struct file *filp)
861 {
862 	struct mshv_vp *vp = filp->private_data;
863 
864 	/* Rest of VP cleanup happens in destroy_partition() */
865 	mshv_partition_put(vp->vp_partition);
866 	return 0;
867 }
868 
869 static void mshv_vp_stats_unmap(u64 partition_id, u32 vp_index)
870 {
871 	union hv_stats_object_identity identity = {
872 		.vp.partition_id = partition_id,
873 		.vp.vp_index = vp_index,
874 	};
875 
876 	identity.vp.stats_area_type = HV_STATS_AREA_SELF;
877 	hv_call_unmap_stat_page(HV_STATS_OBJECT_VP, &identity);
878 
879 	identity.vp.stats_area_type = HV_STATS_AREA_PARENT;
880 	hv_call_unmap_stat_page(HV_STATS_OBJECT_VP, &identity);
881 }
882 
883 static int mshv_vp_stats_map(u64 partition_id, u32 vp_index,
884 			     void *stats_pages[])
885 {
886 	union hv_stats_object_identity identity = {
887 		.vp.partition_id = partition_id,
888 		.vp.vp_index = vp_index,
889 	};
890 	int err;
891 
892 	identity.vp.stats_area_type = HV_STATS_AREA_SELF;
893 	err = hv_call_map_stat_page(HV_STATS_OBJECT_VP, &identity,
894 				    &stats_pages[HV_STATS_AREA_SELF]);
895 	if (err)
896 		return err;
897 
898 	identity.vp.stats_area_type = HV_STATS_AREA_PARENT;
899 	err = hv_call_map_stat_page(HV_STATS_OBJECT_VP, &identity,
900 				    &stats_pages[HV_STATS_AREA_PARENT]);
901 	if (err)
902 		goto unmap_self;
903 
904 	return 0;
905 
906 unmap_self:
907 	identity.vp.stats_area_type = HV_STATS_AREA_SELF;
908 	hv_call_unmap_stat_page(HV_STATS_OBJECT_VP, &identity);
909 	return err;
910 }
911 
912 static long
913 mshv_partition_ioctl_create_vp(struct mshv_partition *partition,
914 			       void __user *arg)
915 {
916 	struct mshv_create_vp args;
917 	struct mshv_vp *vp;
918 	struct page *intercept_message_page, *register_page, *ghcb_page;
919 	void *stats_pages[2];
920 	long ret;
921 
922 	if (copy_from_user(&args, arg, sizeof(args)))
923 		return -EFAULT;
924 
925 	if (args.vp_index >= MSHV_MAX_VPS)
926 		return -EINVAL;
927 
928 	if (partition->pt_vp_array[args.vp_index])
929 		return -EEXIST;
930 
931 	ret = hv_call_create_vp(NUMA_NO_NODE, partition->pt_id, args.vp_index,
932 				0 /* Only valid for root partition VPs */);
933 	if (ret)
934 		return ret;
935 
936 	ret = hv_call_map_vp_state_page(partition->pt_id, args.vp_index,
937 					HV_VP_STATE_PAGE_INTERCEPT_MESSAGE,
938 					input_vtl_zero,
939 					&intercept_message_page);
940 	if (ret)
941 		goto destroy_vp;
942 
943 	if (!mshv_partition_encrypted(partition)) {
944 		ret = hv_call_map_vp_state_page(partition->pt_id, args.vp_index,
945 						HV_VP_STATE_PAGE_REGISTERS,
946 						input_vtl_zero,
947 						&register_page);
948 		if (ret)
949 			goto unmap_intercept_message_page;
950 	}
951 
952 	if (mshv_partition_encrypted(partition) &&
953 	    is_ghcb_mapping_available()) {
954 		ret = hv_call_map_vp_state_page(partition->pt_id, args.vp_index,
955 						HV_VP_STATE_PAGE_GHCB,
956 						input_vtl_normal,
957 						&ghcb_page);
958 		if (ret)
959 			goto unmap_register_page;
960 	}
961 
962 	if (hv_parent_partition()) {
963 		ret = mshv_vp_stats_map(partition->pt_id, args.vp_index,
964 					stats_pages);
965 		if (ret)
966 			goto unmap_ghcb_page;
967 	}
968 
969 	vp = kzalloc(sizeof(*vp), GFP_KERNEL);
970 	if (!vp)
971 		goto unmap_stats_pages;
972 
973 	vp->vp_partition = mshv_partition_get(partition);
974 	if (!vp->vp_partition) {
975 		ret = -EBADF;
976 		goto free_vp;
977 	}
978 
979 	mutex_init(&vp->vp_mutex);
980 	init_waitqueue_head(&vp->run.vp_suspend_queue);
981 	atomic64_set(&vp->run.vp_signaled_count, 0);
982 
983 	vp->vp_index = args.vp_index;
984 	vp->vp_intercept_msg_page = page_to_virt(intercept_message_page);
985 	if (!mshv_partition_encrypted(partition))
986 		vp->vp_register_page = page_to_virt(register_page);
987 
988 	if (mshv_partition_encrypted(partition) && is_ghcb_mapping_available())
989 		vp->vp_ghcb_page = page_to_virt(ghcb_page);
990 
991 	if (hv_parent_partition())
992 		memcpy(vp->vp_stats_pages, stats_pages, sizeof(stats_pages));
993 
994 	/*
995 	 * Keep anon_inode_getfd last: it installs fd in the file struct and
996 	 * thus makes the state accessible in user space.
997 	 */
998 	ret = anon_inode_getfd("mshv_vp", &mshv_vp_fops, vp,
999 			       O_RDWR | O_CLOEXEC);
1000 	if (ret < 0)
1001 		goto put_partition;
1002 
1003 	/* already exclusive with the partition mutex for all ioctls */
1004 	partition->pt_vp_count++;
1005 	partition->pt_vp_array[args.vp_index] = vp;
1006 
1007 	return ret;
1008 
1009 put_partition:
1010 	mshv_partition_put(partition);
1011 free_vp:
1012 	kfree(vp);
1013 unmap_stats_pages:
1014 	if (hv_parent_partition())
1015 		mshv_vp_stats_unmap(partition->pt_id, args.vp_index);
1016 unmap_ghcb_page:
1017 	if (mshv_partition_encrypted(partition) && is_ghcb_mapping_available()) {
1018 		hv_call_unmap_vp_state_page(partition->pt_id, args.vp_index,
1019 					    HV_VP_STATE_PAGE_GHCB,
1020 					    input_vtl_normal);
1021 	}
1022 unmap_register_page:
1023 	if (!mshv_partition_encrypted(partition)) {
1024 		hv_call_unmap_vp_state_page(partition->pt_id, args.vp_index,
1025 					    HV_VP_STATE_PAGE_REGISTERS,
1026 					    input_vtl_zero);
1027 	}
1028 unmap_intercept_message_page:
1029 	hv_call_unmap_vp_state_page(partition->pt_id, args.vp_index,
1030 				    HV_VP_STATE_PAGE_INTERCEPT_MESSAGE,
1031 				    input_vtl_zero);
1032 destroy_vp:
1033 	hv_call_delete_vp(partition->pt_id, args.vp_index);
1034 	return ret;
1035 }
1036 
1037 static int mshv_init_async_handler(struct mshv_partition *partition)
1038 {
1039 	if (completion_done(&partition->async_hypercall)) {
1040 		pt_err(partition,
1041 		       "Cannot issue async hypercall while another one in progress!\n");
1042 		return -EPERM;
1043 	}
1044 
1045 	reinit_completion(&partition->async_hypercall);
1046 	return 0;
1047 }
1048 
1049 static void mshv_async_hvcall_handler(void *data, u64 *status)
1050 {
1051 	struct mshv_partition *partition = data;
1052 
1053 	wait_for_completion(&partition->async_hypercall);
1054 	pt_dbg(partition, "Async hypercall completed!\n");
1055 
1056 	*status = partition->async_hypercall_status;
1057 }
1058 
1059 static int
1060 mshv_partition_region_share(struct mshv_mem_region *region)
1061 {
1062 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
1063 
1064 	if (region->flags.large_pages)
1065 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
1066 
1067 	return hv_call_modify_spa_host_access(region->partition->pt_id,
1068 			region->pages, region->nr_pages,
1069 			HV_MAP_GPA_READABLE | HV_MAP_GPA_WRITABLE,
1070 			flags, true);
1071 }
1072 
1073 static int
1074 mshv_partition_region_unshare(struct mshv_mem_region *region)
1075 {
1076 	u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
1077 
1078 	if (region->flags.large_pages)
1079 		flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
1080 
1081 	return hv_call_modify_spa_host_access(region->partition->pt_id,
1082 			region->pages, region->nr_pages,
1083 			0,
1084 			flags, false);
1085 }
1086 
1087 static int
1088 mshv_region_remap_pages(struct mshv_mem_region *region, u32 map_flags,
1089 			u64 page_offset, u64 page_count)
1090 {
1091 	if (page_offset + page_count > region->nr_pages)
1092 		return -EINVAL;
1093 
1094 	if (region->flags.large_pages)
1095 		map_flags |= HV_MAP_GPA_LARGE_PAGE;
1096 
1097 	/* ask the hypervisor to map guest ram */
1098 	return hv_call_map_gpa_pages(region->partition->pt_id,
1099 				     region->start_gfn + page_offset,
1100 				     page_count, map_flags,
1101 				     region->pages + page_offset);
1102 }
1103 
1104 static int
1105 mshv_region_map(struct mshv_mem_region *region)
1106 {
1107 	u32 map_flags = region->hv_map_flags;
1108 
1109 	return mshv_region_remap_pages(region, map_flags,
1110 				       0, region->nr_pages);
1111 }
1112 
1113 static void
1114 mshv_region_evict_pages(struct mshv_mem_region *region,
1115 			u64 page_offset, u64 page_count)
1116 {
1117 	if (region->flags.range_pinned)
1118 		unpin_user_pages(region->pages + page_offset, page_count);
1119 
1120 	memset(region->pages + page_offset, 0,
1121 	       page_count * sizeof(struct page *));
1122 }
1123 
1124 static void
1125 mshv_region_evict(struct mshv_mem_region *region)
1126 {
1127 	mshv_region_evict_pages(region, 0, region->nr_pages);
1128 }
1129 
1130 static int
1131 mshv_region_populate_pages(struct mshv_mem_region *region,
1132 			   u64 page_offset, u64 page_count)
1133 {
1134 	u64 done_count, nr_pages;
1135 	struct page **pages;
1136 	__u64 userspace_addr;
1137 	int ret;
1138 
1139 	if (page_offset + page_count > region->nr_pages)
1140 		return -EINVAL;
1141 
1142 	for (done_count = 0; done_count < page_count; done_count += ret) {
1143 		pages = region->pages + page_offset + done_count;
1144 		userspace_addr = region->start_uaddr +
1145 				(page_offset + done_count) *
1146 				HV_HYP_PAGE_SIZE;
1147 		nr_pages = min(page_count - done_count,
1148 			       MSHV_PIN_PAGES_BATCH_SIZE);
1149 
1150 		/*
1151 		 * Pinning assuming 4k pages works for large pages too.
1152 		 * All page structs within the large page are returned.
1153 		 *
1154 		 * Pin requests are batched because pin_user_pages_fast
1155 		 * with the FOLL_LONGTERM flag does a large temporary
1156 		 * allocation of contiguous memory.
1157 		 */
1158 		if (region->flags.range_pinned)
1159 			ret = pin_user_pages_fast(userspace_addr,
1160 						  nr_pages,
1161 						  FOLL_WRITE | FOLL_LONGTERM,
1162 						  pages);
1163 		else
1164 			ret = -EOPNOTSUPP;
1165 
1166 		if (ret < 0)
1167 			goto release_pages;
1168 	}
1169 
1170 	if (PageHuge(region->pages[page_offset]))
1171 		region->flags.large_pages = true;
1172 
1173 	return 0;
1174 
1175 release_pages:
1176 	mshv_region_evict_pages(region, page_offset, done_count);
1177 	return ret;
1178 }
1179 
1180 static int
1181 mshv_region_populate(struct mshv_mem_region *region)
1182 {
1183 	return mshv_region_populate_pages(region, 0, region->nr_pages);
1184 }
1185 
1186 static struct mshv_mem_region *
1187 mshv_partition_region_by_gfn(struct mshv_partition *partition, u64 gfn)
1188 {
1189 	struct mshv_mem_region *region;
1190 
1191 	hlist_for_each_entry(region, &partition->pt_mem_regions, hnode) {
1192 		if (gfn >= region->start_gfn &&
1193 		    gfn < region->start_gfn + region->nr_pages)
1194 			return region;
1195 	}
1196 
1197 	return NULL;
1198 }
1199 
1200 static struct mshv_mem_region *
1201 mshv_partition_region_by_uaddr(struct mshv_partition *partition, u64 uaddr)
1202 {
1203 	struct mshv_mem_region *region;
1204 
1205 	hlist_for_each_entry(region, &partition->pt_mem_regions, hnode) {
1206 		if (uaddr >= region->start_uaddr &&
1207 		    uaddr < region->start_uaddr +
1208 			    (region->nr_pages << HV_HYP_PAGE_SHIFT))
1209 			return region;
1210 	}
1211 
1212 	return NULL;
1213 }
1214 
1215 /*
1216  * NB: caller checks and makes sure mem->size is page aligned
1217  * Returns: 0 with regionpp updated on success, or -errno
1218  */
1219 static int mshv_partition_create_region(struct mshv_partition *partition,
1220 					struct mshv_user_mem_region *mem,
1221 					struct mshv_mem_region **regionpp,
1222 					bool is_mmio)
1223 {
1224 	struct mshv_mem_region *region;
1225 	u64 nr_pages = HVPFN_DOWN(mem->size);
1226 
1227 	/* Reject overlapping regions */
1228 	if (mshv_partition_region_by_gfn(partition, mem->guest_pfn) ||
1229 	    mshv_partition_region_by_gfn(partition, mem->guest_pfn + nr_pages - 1) ||
1230 	    mshv_partition_region_by_uaddr(partition, mem->userspace_addr) ||
1231 	    mshv_partition_region_by_uaddr(partition, mem->userspace_addr + mem->size - 1))
1232 		return -EEXIST;
1233 
1234 	region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
1235 	if (!region)
1236 		return -ENOMEM;
1237 
1238 	region->nr_pages = nr_pages;
1239 	region->start_gfn = mem->guest_pfn;
1240 	region->start_uaddr = mem->userspace_addr;
1241 	region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
1242 	if (mem->flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
1243 		region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
1244 	if (mem->flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
1245 		region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
1246 
1247 	/* Note: large_pages flag populated when we pin the pages */
1248 	if (!is_mmio)
1249 		region->flags.range_pinned = true;
1250 
1251 	region->partition = partition;
1252 
1253 	*regionpp = region;
1254 
1255 	return 0;
1256 }
1257 
1258 /*
1259  * Map guest ram. if snp, make sure to release that from the host first
1260  * Side Effects: In case of failure, pages are unpinned when feasible.
1261  */
1262 static int
1263 mshv_partition_mem_region_map(struct mshv_mem_region *region)
1264 {
1265 	struct mshv_partition *partition = region->partition;
1266 	int ret;
1267 
1268 	ret = mshv_region_populate(region);
1269 	if (ret) {
1270 		pt_err(partition, "Failed to populate memory region: %d\n",
1271 		       ret);
1272 		goto err_out;
1273 	}
1274 
1275 	/*
1276 	 * For an SNP partition it is a requirement that for every memory region
1277 	 * that we are going to map for this partition we should make sure that
1278 	 * host access to that region is released. This is ensured by doing an
1279 	 * additional hypercall which will update the SLAT to release host
1280 	 * access to guest memory regions.
1281 	 */
1282 	if (mshv_partition_encrypted(partition)) {
1283 		ret = mshv_partition_region_unshare(region);
1284 		if (ret) {
1285 			pt_err(partition,
1286 			       "Failed to unshare memory region (guest_pfn: %llu): %d\n",
1287 			       region->start_gfn, ret);
1288 			goto evict_region;
1289 		}
1290 	}
1291 
1292 	ret = mshv_region_map(region);
1293 	if (ret && mshv_partition_encrypted(partition)) {
1294 		int shrc;
1295 
1296 		shrc = mshv_partition_region_share(region);
1297 		if (!shrc)
1298 			goto evict_region;
1299 
1300 		pt_err(partition,
1301 		       "Failed to share memory region (guest_pfn: %llu): %d\n",
1302 		       region->start_gfn, shrc);
1303 		/*
1304 		 * Don't unpin if marking shared failed because pages are no
1305 		 * longer mapped in the host, ie root, anymore.
1306 		 */
1307 		goto err_out;
1308 	}
1309 
1310 	return 0;
1311 
1312 evict_region:
1313 	mshv_region_evict(region);
1314 err_out:
1315 	return ret;
1316 }
1317 
1318 /*
1319  * This maps two things: guest RAM and for pci passthru mmio space.
1320  *
1321  * mmio:
1322  *  - vfio overloads vm_pgoff to store the mmio start pfn/spa.
1323  *  - Two things need to happen for mapping mmio range:
1324  *	1. mapped in the uaddr so VMM can access it.
1325  *	2. mapped in the hwpt (gfn <-> mmio phys addr) so guest can access it.
1326  *
1327  *   This function takes care of the second. The first one is managed by vfio,
1328  *   and hence is taken care of via vfio_pci_mmap_fault().
1329  */
1330 static long
1331 mshv_map_user_memory(struct mshv_partition *partition,
1332 		     struct mshv_user_mem_region mem)
1333 {
1334 	struct mshv_mem_region *region;
1335 	struct vm_area_struct *vma;
1336 	bool is_mmio;
1337 	ulong mmio_pfn;
1338 	long ret;
1339 
1340 	if (mem.flags & BIT(MSHV_SET_MEM_BIT_UNMAP) ||
1341 	    !access_ok((const void *)mem.userspace_addr, mem.size))
1342 		return -EINVAL;
1343 
1344 	mmap_read_lock(current->mm);
1345 	vma = vma_lookup(current->mm, mem.userspace_addr);
1346 	is_mmio = vma ? !!(vma->vm_flags & (VM_IO | VM_PFNMAP)) : 0;
1347 	mmio_pfn = is_mmio ? vma->vm_pgoff : 0;
1348 	mmap_read_unlock(current->mm);
1349 
1350 	if (!vma)
1351 		return -EINVAL;
1352 
1353 	ret = mshv_partition_create_region(partition, &mem, &region,
1354 					   is_mmio);
1355 	if (ret)
1356 		return ret;
1357 
1358 	if (is_mmio)
1359 		ret = hv_call_map_mmio_pages(partition->pt_id, mem.guest_pfn,
1360 					     mmio_pfn, HVPFN_DOWN(mem.size));
1361 	else
1362 		ret = mshv_partition_mem_region_map(region);
1363 
1364 	if (ret)
1365 		goto errout;
1366 
1367 	/* Install the new region */
1368 	hlist_add_head(&region->hnode, &partition->pt_mem_regions);
1369 
1370 	return 0;
1371 
1372 errout:
1373 	vfree(region);
1374 	return ret;
1375 }
1376 
1377 /* Called for unmapping both the guest ram and the mmio space */
1378 static long
1379 mshv_unmap_user_memory(struct mshv_partition *partition,
1380 		       struct mshv_user_mem_region mem)
1381 {
1382 	struct mshv_mem_region *region;
1383 	u32 unmap_flags = 0;
1384 
1385 	if (!(mem.flags & BIT(MSHV_SET_MEM_BIT_UNMAP)))
1386 		return -EINVAL;
1387 
1388 	region = mshv_partition_region_by_gfn(partition, mem.guest_pfn);
1389 	if (!region)
1390 		return -EINVAL;
1391 
1392 	/* Paranoia check */
1393 	if (region->start_uaddr != mem.userspace_addr ||
1394 	    region->start_gfn != mem.guest_pfn ||
1395 	    region->nr_pages != HVPFN_DOWN(mem.size))
1396 		return -EINVAL;
1397 
1398 	hlist_del(&region->hnode);
1399 
1400 	if (region->flags.large_pages)
1401 		unmap_flags |= HV_UNMAP_GPA_LARGE_PAGE;
1402 
1403 	/* ignore unmap failures and continue as process may be exiting */
1404 	hv_call_unmap_gpa_pages(partition->pt_id, region->start_gfn,
1405 				region->nr_pages, unmap_flags);
1406 
1407 	mshv_region_evict(region);
1408 
1409 	vfree(region);
1410 	return 0;
1411 }
1412 
1413 static long
1414 mshv_partition_ioctl_set_memory(struct mshv_partition *partition,
1415 				struct mshv_user_mem_region __user *user_mem)
1416 {
1417 	struct mshv_user_mem_region mem;
1418 
1419 	if (copy_from_user(&mem, user_mem, sizeof(mem)))
1420 		return -EFAULT;
1421 
1422 	if (!mem.size ||
1423 	    !PAGE_ALIGNED(mem.size) ||
1424 	    !PAGE_ALIGNED(mem.userspace_addr) ||
1425 	    (mem.flags & ~MSHV_SET_MEM_FLAGS_MASK) ||
1426 	    mshv_field_nonzero(mem, rsvd))
1427 		return -EINVAL;
1428 
1429 	if (mem.flags & BIT(MSHV_SET_MEM_BIT_UNMAP))
1430 		return mshv_unmap_user_memory(partition, mem);
1431 
1432 	return mshv_map_user_memory(partition, mem);
1433 }
1434 
1435 static long
1436 mshv_partition_ioctl_ioeventfd(struct mshv_partition *partition,
1437 			       void __user *user_args)
1438 {
1439 	struct mshv_user_ioeventfd args;
1440 
1441 	if (copy_from_user(&args, user_args, sizeof(args)))
1442 		return -EFAULT;
1443 
1444 	return mshv_set_unset_ioeventfd(partition, &args);
1445 }
1446 
1447 static long
1448 mshv_partition_ioctl_irqfd(struct mshv_partition *partition,
1449 			   void __user *user_args)
1450 {
1451 	struct mshv_user_irqfd args;
1452 
1453 	if (copy_from_user(&args, user_args, sizeof(args)))
1454 		return -EFAULT;
1455 
1456 	return mshv_set_unset_irqfd(partition, &args);
1457 }
1458 
1459 static long
1460 mshv_partition_ioctl_get_gpap_access_bitmap(struct mshv_partition *partition,
1461 					    void __user *user_args)
1462 {
1463 	struct mshv_gpap_access_bitmap args;
1464 	union hv_gpa_page_access_state *states;
1465 	long ret, i;
1466 	union hv_gpa_page_access_state_flags hv_flags = {};
1467 	u8 hv_type_mask;
1468 	ulong bitmap_buf_sz, states_buf_sz;
1469 	int written = 0;
1470 
1471 	if (copy_from_user(&args, user_args, sizeof(args)))
1472 		return -EFAULT;
1473 
1474 	if (args.access_type >= MSHV_GPAP_ACCESS_TYPE_COUNT ||
1475 	    args.access_op >= MSHV_GPAP_ACCESS_OP_COUNT ||
1476 	    mshv_field_nonzero(args, rsvd) || !args.page_count ||
1477 	    !args.bitmap_ptr)
1478 		return -EINVAL;
1479 
1480 	if (check_mul_overflow(args.page_count, sizeof(*states), &states_buf_sz))
1481 		return -E2BIG;
1482 
1483 	/* Num bytes needed to store bitmap; one bit per page rounded up */
1484 	bitmap_buf_sz = DIV_ROUND_UP(args.page_count, 8);
1485 
1486 	/* Sanity check */
1487 	if (bitmap_buf_sz > states_buf_sz)
1488 		return -EBADFD;
1489 
1490 	switch (args.access_type) {
1491 	case MSHV_GPAP_ACCESS_TYPE_ACCESSED:
1492 		hv_type_mask = 1;
1493 		if (args.access_op == MSHV_GPAP_ACCESS_OP_CLEAR) {
1494 			hv_flags.clear_accessed = 1;
1495 			/* not accessed implies not dirty */
1496 			hv_flags.clear_dirty = 1;
1497 		} else { /* MSHV_GPAP_ACCESS_OP_SET */
1498 			hv_flags.set_accessed = 1;
1499 		}
1500 		break;
1501 	case MSHV_GPAP_ACCESS_TYPE_DIRTY:
1502 		hv_type_mask = 2;
1503 		if (args.access_op == MSHV_GPAP_ACCESS_OP_CLEAR) {
1504 			hv_flags.clear_dirty = 1;
1505 		} else { /* MSHV_GPAP_ACCESS_OP_SET */
1506 			hv_flags.set_dirty = 1;
1507 			/* dirty implies accessed */
1508 			hv_flags.set_accessed = 1;
1509 		}
1510 		break;
1511 	}
1512 
1513 	states = vzalloc(states_buf_sz);
1514 	if (!states)
1515 		return -ENOMEM;
1516 
1517 	ret = hv_call_get_gpa_access_states(partition->pt_id, args.page_count,
1518 					    args.gpap_base, hv_flags, &written,
1519 					    states);
1520 	if (ret)
1521 		goto free_return;
1522 
1523 	/*
1524 	 * Overwrite states buffer with bitmap - the bits in hv_type_mask
1525 	 * correspond to bitfields in hv_gpa_page_access_state
1526 	 */
1527 	for (i = 0; i < written; ++i)
1528 		__assign_bit(i, (ulong *)states,
1529 			     states[i].as_uint8 & hv_type_mask);
1530 
1531 	/* zero the unused bits in the last byte(s) of the returned bitmap */
1532 	for (i = written; i < bitmap_buf_sz * 8; ++i)
1533 		__clear_bit(i, (ulong *)states);
1534 
1535 	if (copy_to_user((void __user *)args.bitmap_ptr, states, bitmap_buf_sz))
1536 		ret = -EFAULT;
1537 
1538 free_return:
1539 	vfree(states);
1540 	return ret;
1541 }
1542 
1543 static long
1544 mshv_partition_ioctl_set_msi_routing(struct mshv_partition *partition,
1545 				     void __user *user_args)
1546 {
1547 	struct mshv_user_irq_entry *entries = NULL;
1548 	struct mshv_user_irq_table args;
1549 	long ret;
1550 
1551 	if (copy_from_user(&args, user_args, sizeof(args)))
1552 		return -EFAULT;
1553 
1554 	if (args.nr > MSHV_MAX_GUEST_IRQS ||
1555 	    mshv_field_nonzero(args, rsvd))
1556 		return -EINVAL;
1557 
1558 	if (args.nr) {
1559 		struct mshv_user_irq_table __user *urouting = user_args;
1560 
1561 		entries = vmemdup_user(urouting->entries,
1562 				       array_size(sizeof(*entries),
1563 						  args.nr));
1564 		if (IS_ERR(entries))
1565 			return PTR_ERR(entries);
1566 	}
1567 	ret = mshv_update_routing_table(partition, entries, args.nr);
1568 	kvfree(entries);
1569 
1570 	return ret;
1571 }
1572 
1573 static long
1574 mshv_partition_ioctl_initialize(struct mshv_partition *partition)
1575 {
1576 	long ret;
1577 
1578 	if (partition->pt_initialized)
1579 		return 0;
1580 
1581 	ret = hv_call_initialize_partition(partition->pt_id);
1582 	if (ret)
1583 		goto withdraw_mem;
1584 
1585 	partition->pt_initialized = true;
1586 
1587 	return 0;
1588 
1589 withdraw_mem:
1590 	hv_call_withdraw_memory(U64_MAX, NUMA_NO_NODE, partition->pt_id);
1591 
1592 	return ret;
1593 }
1594 
1595 static long
1596 mshv_partition_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg)
1597 {
1598 	struct mshv_partition *partition = filp->private_data;
1599 	long ret;
1600 	void __user *uarg = (void __user *)arg;
1601 
1602 	if (mutex_lock_killable(&partition->pt_mutex))
1603 		return -EINTR;
1604 
1605 	switch (ioctl) {
1606 	case MSHV_INITIALIZE_PARTITION:
1607 		ret = mshv_partition_ioctl_initialize(partition);
1608 		break;
1609 	case MSHV_SET_GUEST_MEMORY:
1610 		ret = mshv_partition_ioctl_set_memory(partition, uarg);
1611 		break;
1612 	case MSHV_CREATE_VP:
1613 		ret = mshv_partition_ioctl_create_vp(partition, uarg);
1614 		break;
1615 	case MSHV_IRQFD:
1616 		ret = mshv_partition_ioctl_irqfd(partition, uarg);
1617 		break;
1618 	case MSHV_IOEVENTFD:
1619 		ret = mshv_partition_ioctl_ioeventfd(partition, uarg);
1620 		break;
1621 	case MSHV_SET_MSI_ROUTING:
1622 		ret = mshv_partition_ioctl_set_msi_routing(partition, uarg);
1623 		break;
1624 	case MSHV_GET_GPAP_ACCESS_BITMAP:
1625 		ret = mshv_partition_ioctl_get_gpap_access_bitmap(partition,
1626 								  uarg);
1627 		break;
1628 	case MSHV_ROOT_HVCALL:
1629 		ret = mshv_ioctl_passthru_hvcall(partition, true, uarg);
1630 		break;
1631 	default:
1632 		ret = -ENOTTY;
1633 	}
1634 
1635 	mutex_unlock(&partition->pt_mutex);
1636 	return ret;
1637 }
1638 
1639 static int
1640 disable_vp_dispatch(struct mshv_vp *vp)
1641 {
1642 	int ret;
1643 	struct hv_register_assoc dispatch_suspend = {
1644 		.name = HV_REGISTER_DISPATCH_SUSPEND,
1645 		.value.dispatch_suspend.suspended = 1,
1646 	};
1647 
1648 	ret = mshv_set_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
1649 				    1, &dispatch_suspend);
1650 	if (ret)
1651 		vp_err(vp, "failed to suspend\n");
1652 
1653 	return ret;
1654 }
1655 
1656 static int
1657 get_vp_signaled_count(struct mshv_vp *vp, u64 *count)
1658 {
1659 	int ret;
1660 	struct hv_register_assoc root_signal_count = {
1661 		.name = HV_REGISTER_VP_ROOT_SIGNAL_COUNT,
1662 	};
1663 
1664 	ret = mshv_get_vp_registers(vp->vp_index, vp->vp_partition->pt_id,
1665 				    1, &root_signal_count);
1666 
1667 	if (ret) {
1668 		vp_err(vp, "Failed to get root signal count");
1669 		*count = 0;
1670 		return ret;
1671 	}
1672 
1673 	*count = root_signal_count.value.reg64;
1674 
1675 	return ret;
1676 }
1677 
1678 static void
1679 drain_vp_signals(struct mshv_vp *vp)
1680 {
1681 	u64 hv_signal_count;
1682 	u64 vp_signal_count;
1683 
1684 	get_vp_signaled_count(vp, &hv_signal_count);
1685 
1686 	vp_signal_count = atomic64_read(&vp->run.vp_signaled_count);
1687 
1688 	/*
1689 	 * There should be at most 1 outstanding notification, but be extra
1690 	 * careful anyway.
1691 	 */
1692 	while (hv_signal_count != vp_signal_count) {
1693 		WARN_ON(hv_signal_count - vp_signal_count != 1);
1694 
1695 		if (wait_event_interruptible(vp->run.vp_suspend_queue,
1696 					     vp->run.kicked_by_hv == 1))
1697 			break;
1698 		vp->run.kicked_by_hv = 0;
1699 		vp_signal_count = atomic64_read(&vp->run.vp_signaled_count);
1700 	}
1701 }
1702 
1703 static void drain_all_vps(const struct mshv_partition *partition)
1704 {
1705 	int i;
1706 	struct mshv_vp *vp;
1707 
1708 	/*
1709 	 * VPs are reachable from ISR. It is safe to not take the partition
1710 	 * lock because nobody else can enter this function and drop the
1711 	 * partition from the list.
1712 	 */
1713 	for (i = 0; i < MSHV_MAX_VPS; i++) {
1714 		vp = partition->pt_vp_array[i];
1715 		if (!vp)
1716 			continue;
1717 		/*
1718 		 * Disable dispatching of the VP in the hypervisor. After this
1719 		 * the hypervisor guarantees it won't generate any signals for
1720 		 * the VP and the hypervisor's VP signal count won't change.
1721 		 */
1722 		disable_vp_dispatch(vp);
1723 		drain_vp_signals(vp);
1724 	}
1725 }
1726 
1727 static void
1728 remove_partition(struct mshv_partition *partition)
1729 {
1730 	spin_lock(&mshv_root.pt_ht_lock);
1731 	hlist_del_rcu(&partition->pt_hnode);
1732 	spin_unlock(&mshv_root.pt_ht_lock);
1733 
1734 	synchronize_rcu();
1735 }
1736 
1737 /*
1738  * Tear down a partition and remove it from the list.
1739  * Partition's refcount must be 0
1740  */
1741 static void destroy_partition(struct mshv_partition *partition)
1742 {
1743 	struct mshv_vp *vp;
1744 	struct mshv_mem_region *region;
1745 	int i, ret;
1746 	struct hlist_node *n;
1747 
1748 	if (refcount_read(&partition->pt_ref_count)) {
1749 		pt_err(partition,
1750 		       "Attempt to destroy partition but refcount > 0\n");
1751 		return;
1752 	}
1753 
1754 	if (partition->pt_initialized) {
1755 		/*
1756 		 * We only need to drain signals for root scheduler. This should be
1757 		 * done before removing the partition from the partition list.
1758 		 */
1759 		if (hv_scheduler_type == HV_SCHEDULER_TYPE_ROOT)
1760 			drain_all_vps(partition);
1761 
1762 		/* Remove vps */
1763 		for (i = 0; i < MSHV_MAX_VPS; ++i) {
1764 			vp = partition->pt_vp_array[i];
1765 			if (!vp)
1766 				continue;
1767 
1768 			if (hv_parent_partition())
1769 				mshv_vp_stats_unmap(partition->pt_id, vp->vp_index);
1770 
1771 			if (vp->vp_register_page) {
1772 				(void)hv_call_unmap_vp_state_page(partition->pt_id,
1773 								  vp->vp_index,
1774 								  HV_VP_STATE_PAGE_REGISTERS,
1775 								  input_vtl_zero);
1776 				vp->vp_register_page = NULL;
1777 			}
1778 
1779 			(void)hv_call_unmap_vp_state_page(partition->pt_id,
1780 							  vp->vp_index,
1781 							  HV_VP_STATE_PAGE_INTERCEPT_MESSAGE,
1782 							  input_vtl_zero);
1783 			vp->vp_intercept_msg_page = NULL;
1784 
1785 			if (vp->vp_ghcb_page) {
1786 				(void)hv_call_unmap_vp_state_page(partition->pt_id,
1787 								  vp->vp_index,
1788 								  HV_VP_STATE_PAGE_GHCB,
1789 								  input_vtl_normal);
1790 				vp->vp_ghcb_page = NULL;
1791 			}
1792 
1793 			kfree(vp);
1794 
1795 			partition->pt_vp_array[i] = NULL;
1796 		}
1797 
1798 		/* Deallocates and unmaps everything including vcpus, GPA mappings etc */
1799 		hv_call_finalize_partition(partition->pt_id);
1800 
1801 		partition->pt_initialized = false;
1802 	}
1803 
1804 	remove_partition(partition);
1805 
1806 	/* Remove regions, regain access to the memory and unpin the pages */
1807 	hlist_for_each_entry_safe(region, n, &partition->pt_mem_regions,
1808 				  hnode) {
1809 		hlist_del(&region->hnode);
1810 
1811 		if (mshv_partition_encrypted(partition)) {
1812 			ret = mshv_partition_region_share(region);
1813 			if (ret) {
1814 				pt_err(partition,
1815 				       "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
1816 				      ret);
1817 				return;
1818 			}
1819 		}
1820 
1821 		mshv_region_evict(region);
1822 
1823 		vfree(region);
1824 	}
1825 
1826 	/* Withdraw and free all pages we deposited */
1827 	hv_call_withdraw_memory(U64_MAX, NUMA_NO_NODE, partition->pt_id);
1828 	hv_call_delete_partition(partition->pt_id);
1829 
1830 	mshv_free_routing_table(partition);
1831 	kfree(partition);
1832 }
1833 
1834 struct
1835 mshv_partition *mshv_partition_get(struct mshv_partition *partition)
1836 {
1837 	if (refcount_inc_not_zero(&partition->pt_ref_count))
1838 		return partition;
1839 	return NULL;
1840 }
1841 
1842 struct
1843 mshv_partition *mshv_partition_find(u64 partition_id)
1844 	__must_hold(RCU)
1845 {
1846 	struct mshv_partition *p;
1847 
1848 	hash_for_each_possible_rcu(mshv_root.pt_htable, p, pt_hnode,
1849 				   partition_id)
1850 		if (p->pt_id == partition_id)
1851 			return p;
1852 
1853 	return NULL;
1854 }
1855 
1856 void
1857 mshv_partition_put(struct mshv_partition *partition)
1858 {
1859 	if (refcount_dec_and_test(&partition->pt_ref_count))
1860 		destroy_partition(partition);
1861 }
1862 
1863 static int
1864 mshv_partition_release(struct inode *inode, struct file *filp)
1865 {
1866 	struct mshv_partition *partition = filp->private_data;
1867 
1868 	mshv_eventfd_release(partition);
1869 
1870 	cleanup_srcu_struct(&partition->pt_irq_srcu);
1871 
1872 	mshv_partition_put(partition);
1873 
1874 	return 0;
1875 }
1876 
1877 static int
1878 add_partition(struct mshv_partition *partition)
1879 {
1880 	spin_lock(&mshv_root.pt_ht_lock);
1881 
1882 	hash_add_rcu(mshv_root.pt_htable, &partition->pt_hnode,
1883 		     partition->pt_id);
1884 
1885 	spin_unlock(&mshv_root.pt_ht_lock);
1886 
1887 	return 0;
1888 }
1889 
1890 static long
1891 mshv_ioctl_create_partition(void __user *user_arg, struct device *module_dev)
1892 {
1893 	struct mshv_create_partition args;
1894 	u64 creation_flags;
1895 	struct hv_partition_creation_properties creation_properties = {};
1896 	union hv_partition_isolation_properties isolation_properties = {};
1897 	struct mshv_partition *partition;
1898 	struct file *file;
1899 	int fd;
1900 	long ret;
1901 
1902 	if (copy_from_user(&args, user_arg, sizeof(args)))
1903 		return -EFAULT;
1904 
1905 	if ((args.pt_flags & ~MSHV_PT_FLAGS_MASK) ||
1906 	    args.pt_isolation >= MSHV_PT_ISOLATION_COUNT)
1907 		return -EINVAL;
1908 
1909 	/* Only support EXO partitions */
1910 	creation_flags = HV_PARTITION_CREATION_FLAG_EXO_PARTITION |
1911 			 HV_PARTITION_CREATION_FLAG_INTERCEPT_MESSAGE_PAGE_ENABLED;
1912 
1913 	if (args.pt_flags & BIT(MSHV_PT_BIT_LAPIC))
1914 		creation_flags |= HV_PARTITION_CREATION_FLAG_LAPIC_ENABLED;
1915 	if (args.pt_flags & BIT(MSHV_PT_BIT_X2APIC))
1916 		creation_flags |= HV_PARTITION_CREATION_FLAG_X2APIC_CAPABLE;
1917 	if (args.pt_flags & BIT(MSHV_PT_BIT_GPA_SUPER_PAGES))
1918 		creation_flags |= HV_PARTITION_CREATION_FLAG_GPA_SUPER_PAGES_ENABLED;
1919 
1920 	switch (args.pt_isolation) {
1921 	case MSHV_PT_ISOLATION_NONE:
1922 		isolation_properties.isolation_type =
1923 			HV_PARTITION_ISOLATION_TYPE_NONE;
1924 		break;
1925 	}
1926 
1927 	partition = kzalloc(sizeof(*partition), GFP_KERNEL);
1928 	if (!partition)
1929 		return -ENOMEM;
1930 
1931 	partition->pt_module_dev = module_dev;
1932 	partition->isolation_type = isolation_properties.isolation_type;
1933 
1934 	refcount_set(&partition->pt_ref_count, 1);
1935 
1936 	mutex_init(&partition->pt_mutex);
1937 
1938 	mutex_init(&partition->pt_irq_lock);
1939 
1940 	init_completion(&partition->async_hypercall);
1941 
1942 	INIT_HLIST_HEAD(&partition->irq_ack_notifier_list);
1943 
1944 	INIT_HLIST_HEAD(&partition->pt_devices);
1945 
1946 	INIT_HLIST_HEAD(&partition->pt_mem_regions);
1947 
1948 	mshv_eventfd_init(partition);
1949 
1950 	ret = init_srcu_struct(&partition->pt_irq_srcu);
1951 	if (ret)
1952 		goto free_partition;
1953 
1954 	ret = hv_call_create_partition(creation_flags,
1955 				       creation_properties,
1956 				       isolation_properties,
1957 				       &partition->pt_id);
1958 	if (ret)
1959 		goto cleanup_irq_srcu;
1960 
1961 	ret = add_partition(partition);
1962 	if (ret)
1963 		goto delete_partition;
1964 
1965 	ret = mshv_init_async_handler(partition);
1966 	if (ret)
1967 		goto remove_partition;
1968 
1969 	fd = get_unused_fd_flags(O_CLOEXEC);
1970 	if (fd < 0) {
1971 		ret = fd;
1972 		goto remove_partition;
1973 	}
1974 
1975 	file = anon_inode_getfile("mshv_partition", &mshv_partition_fops,
1976 				  partition, O_RDWR);
1977 	if (IS_ERR(file)) {
1978 		ret = PTR_ERR(file);
1979 		goto put_fd;
1980 	}
1981 
1982 	fd_install(fd, file);
1983 
1984 	return fd;
1985 
1986 put_fd:
1987 	put_unused_fd(fd);
1988 remove_partition:
1989 	remove_partition(partition);
1990 delete_partition:
1991 	hv_call_delete_partition(partition->pt_id);
1992 cleanup_irq_srcu:
1993 	cleanup_srcu_struct(&partition->pt_irq_srcu);
1994 free_partition:
1995 	kfree(partition);
1996 
1997 	return ret;
1998 }
1999 
2000 static long mshv_dev_ioctl(struct file *filp, unsigned int ioctl,
2001 			   unsigned long arg)
2002 {
2003 	struct miscdevice *misc = filp->private_data;
2004 
2005 	switch (ioctl) {
2006 	case MSHV_CREATE_PARTITION:
2007 		return mshv_ioctl_create_partition((void __user *)arg,
2008 						misc->this_device);
2009 	}
2010 
2011 	return -ENOTTY;
2012 }
2013 
2014 static int
2015 mshv_dev_open(struct inode *inode, struct file *filp)
2016 {
2017 	return 0;
2018 }
2019 
2020 static int
2021 mshv_dev_release(struct inode *inode, struct file *filp)
2022 {
2023 	return 0;
2024 }
2025 
2026 static int mshv_cpuhp_online;
2027 static int mshv_root_sched_online;
2028 
2029 static const char *scheduler_type_to_string(enum hv_scheduler_type type)
2030 {
2031 	switch (type) {
2032 	case HV_SCHEDULER_TYPE_LP:
2033 		return "classic scheduler without SMT";
2034 	case HV_SCHEDULER_TYPE_LP_SMT:
2035 		return "classic scheduler with SMT";
2036 	case HV_SCHEDULER_TYPE_CORE_SMT:
2037 		return "core scheduler";
2038 	case HV_SCHEDULER_TYPE_ROOT:
2039 		return "root scheduler";
2040 	default:
2041 		return "unknown scheduler";
2042 	};
2043 }
2044 
2045 /* TODO move this to hv_common.c when needed outside */
2046 static int __init hv_retrieve_scheduler_type(enum hv_scheduler_type *out)
2047 {
2048 	struct hv_input_get_system_property *input;
2049 	struct hv_output_get_system_property *output;
2050 	unsigned long flags;
2051 	u64 status;
2052 
2053 	local_irq_save(flags);
2054 	input = *this_cpu_ptr(hyperv_pcpu_input_arg);
2055 	output = *this_cpu_ptr(hyperv_pcpu_output_arg);
2056 
2057 	memset(input, 0, sizeof(*input));
2058 	memset(output, 0, sizeof(*output));
2059 	input->property_id = HV_SYSTEM_PROPERTY_SCHEDULER_TYPE;
2060 
2061 	status = hv_do_hypercall(HVCALL_GET_SYSTEM_PROPERTY, input, output);
2062 	if (!hv_result_success(status)) {
2063 		local_irq_restore(flags);
2064 		pr_err("%s: %s\n", __func__, hv_result_to_string(status));
2065 		return hv_result_to_errno(status);
2066 	}
2067 
2068 	*out = output->scheduler_type;
2069 	local_irq_restore(flags);
2070 
2071 	return 0;
2072 }
2073 
2074 /* Retrieve and stash the supported scheduler type */
2075 static int __init mshv_retrieve_scheduler_type(struct device *dev)
2076 {
2077 	int ret;
2078 
2079 	ret = hv_retrieve_scheduler_type(&hv_scheduler_type);
2080 	if (ret)
2081 		return ret;
2082 
2083 	dev_info(dev, "Hypervisor using %s\n",
2084 		 scheduler_type_to_string(hv_scheduler_type));
2085 
2086 	switch (hv_scheduler_type) {
2087 	case HV_SCHEDULER_TYPE_CORE_SMT:
2088 	case HV_SCHEDULER_TYPE_LP_SMT:
2089 	case HV_SCHEDULER_TYPE_ROOT:
2090 	case HV_SCHEDULER_TYPE_LP:
2091 		/* Supported scheduler, nothing to do */
2092 		break;
2093 	default:
2094 		dev_err(dev, "unsupported scheduler 0x%x, bailing.\n",
2095 			hv_scheduler_type);
2096 		return -EOPNOTSUPP;
2097 	}
2098 
2099 	return 0;
2100 }
2101 
2102 static int mshv_root_scheduler_init(unsigned int cpu)
2103 {
2104 	void **inputarg, **outputarg, *p;
2105 
2106 	inputarg = (void **)this_cpu_ptr(root_scheduler_input);
2107 	outputarg = (void **)this_cpu_ptr(root_scheduler_output);
2108 
2109 	/* Allocate two consecutive pages. One for input, one for output. */
2110 	p = kmalloc(2 * HV_HYP_PAGE_SIZE, GFP_KERNEL);
2111 	if (!p)
2112 		return -ENOMEM;
2113 
2114 	*inputarg = p;
2115 	*outputarg = (char *)p + HV_HYP_PAGE_SIZE;
2116 
2117 	return 0;
2118 }
2119 
2120 static int mshv_root_scheduler_cleanup(unsigned int cpu)
2121 {
2122 	void *p, **inputarg, **outputarg;
2123 
2124 	inputarg = (void **)this_cpu_ptr(root_scheduler_input);
2125 	outputarg = (void **)this_cpu_ptr(root_scheduler_output);
2126 
2127 	p = *inputarg;
2128 
2129 	*inputarg = NULL;
2130 	*outputarg = NULL;
2131 
2132 	kfree(p);
2133 
2134 	return 0;
2135 }
2136 
2137 /* Must be called after retrieving the scheduler type */
2138 static int
2139 root_scheduler_init(struct device *dev)
2140 {
2141 	int ret;
2142 
2143 	if (hv_scheduler_type != HV_SCHEDULER_TYPE_ROOT)
2144 		return 0;
2145 
2146 	root_scheduler_input = alloc_percpu(void *);
2147 	root_scheduler_output = alloc_percpu(void *);
2148 
2149 	if (!root_scheduler_input || !root_scheduler_output) {
2150 		dev_err(dev, "Failed to allocate root scheduler buffers\n");
2151 		ret = -ENOMEM;
2152 		goto out;
2153 	}
2154 
2155 	ret = cpuhp_setup_state(CPUHP_AP_ONLINE_DYN, "mshv_root_sched",
2156 				mshv_root_scheduler_init,
2157 				mshv_root_scheduler_cleanup);
2158 
2159 	if (ret < 0) {
2160 		dev_err(dev, "Failed to setup root scheduler state: %i\n", ret);
2161 		goto out;
2162 	}
2163 
2164 	mshv_root_sched_online = ret;
2165 
2166 	return 0;
2167 
2168 out:
2169 	free_percpu(root_scheduler_input);
2170 	free_percpu(root_scheduler_output);
2171 	return ret;
2172 }
2173 
2174 static void
2175 root_scheduler_deinit(void)
2176 {
2177 	if (hv_scheduler_type != HV_SCHEDULER_TYPE_ROOT)
2178 		return;
2179 
2180 	cpuhp_remove_state(mshv_root_sched_online);
2181 	free_percpu(root_scheduler_input);
2182 	free_percpu(root_scheduler_output);
2183 }
2184 
2185 static int mshv_reboot_notify(struct notifier_block *nb,
2186 			      unsigned long code, void *unused)
2187 {
2188 	cpuhp_remove_state(mshv_cpuhp_online);
2189 	return 0;
2190 }
2191 
2192 struct notifier_block mshv_reboot_nb = {
2193 	.notifier_call = mshv_reboot_notify,
2194 };
2195 
2196 static void mshv_root_partition_exit(void)
2197 {
2198 	unregister_reboot_notifier(&mshv_reboot_nb);
2199 	root_scheduler_deinit();
2200 }
2201 
2202 static int __init mshv_root_partition_init(struct device *dev)
2203 {
2204 	int err;
2205 
2206 	if (mshv_retrieve_scheduler_type(dev))
2207 		return -ENODEV;
2208 
2209 	err = root_scheduler_init(dev);
2210 	if (err)
2211 		return err;
2212 
2213 	err = register_reboot_notifier(&mshv_reboot_nb);
2214 	if (err)
2215 		goto root_sched_deinit;
2216 
2217 	return 0;
2218 
2219 root_sched_deinit:
2220 	root_scheduler_deinit();
2221 	return err;
2222 }
2223 
2224 static int __init mshv_parent_partition_init(void)
2225 {
2226 	int ret;
2227 	struct device *dev;
2228 	union hv_hypervisor_version_info version_info;
2229 
2230 	if (!hv_root_partition() || is_kdump_kernel())
2231 		return -ENODEV;
2232 
2233 	if (hv_get_hypervisor_version(&version_info))
2234 		return -ENODEV;
2235 
2236 	ret = misc_register(&mshv_dev);
2237 	if (ret)
2238 		return ret;
2239 
2240 	dev = mshv_dev.this_device;
2241 
2242 	if (version_info.build_number < MSHV_HV_MIN_VERSION ||
2243 	    version_info.build_number > MSHV_HV_MAX_VERSION) {
2244 		dev_err(dev, "Running on unvalidated Hyper-V version\n");
2245 		dev_err(dev, "Versions: current: %u  min: %u  max: %u\n",
2246 			version_info.build_number, MSHV_HV_MIN_VERSION,
2247 			MSHV_HV_MAX_VERSION);
2248 	}
2249 
2250 	mshv_root.synic_pages = alloc_percpu(struct hv_synic_pages);
2251 	if (!mshv_root.synic_pages) {
2252 		dev_err(dev, "Failed to allocate percpu synic page\n");
2253 		ret = -ENOMEM;
2254 		goto device_deregister;
2255 	}
2256 
2257 	ret = cpuhp_setup_state(CPUHP_AP_ONLINE_DYN, "mshv_synic",
2258 				mshv_synic_init,
2259 				mshv_synic_cleanup);
2260 	if (ret < 0) {
2261 		dev_err(dev, "Failed to setup cpu hotplug state: %i\n", ret);
2262 		goto free_synic_pages;
2263 	}
2264 
2265 	mshv_cpuhp_online = ret;
2266 
2267 	ret = mshv_root_partition_init(dev);
2268 	if (ret)
2269 		goto remove_cpu_state;
2270 
2271 	ret = mshv_irqfd_wq_init();
2272 	if (ret)
2273 		goto exit_partition;
2274 
2275 	spin_lock_init(&mshv_root.pt_ht_lock);
2276 	hash_init(mshv_root.pt_htable);
2277 
2278 	hv_setup_mshv_handler(mshv_isr);
2279 
2280 	return 0;
2281 
2282 exit_partition:
2283 	if (hv_root_partition())
2284 		mshv_root_partition_exit();
2285 remove_cpu_state:
2286 	cpuhp_remove_state(mshv_cpuhp_online);
2287 free_synic_pages:
2288 	free_percpu(mshv_root.synic_pages);
2289 device_deregister:
2290 	misc_deregister(&mshv_dev);
2291 	return ret;
2292 }
2293 
2294 static void __exit mshv_parent_partition_exit(void)
2295 {
2296 	hv_setup_mshv_handler(NULL);
2297 	mshv_port_table_fini();
2298 	misc_deregister(&mshv_dev);
2299 	mshv_irqfd_wq_cleanup();
2300 	if (hv_root_partition())
2301 		mshv_root_partition_exit();
2302 	cpuhp_remove_state(mshv_cpuhp_online);
2303 	free_percpu(mshv_root.synic_pages);
2304 }
2305 
2306 module_init(mshv_parent_partition_init);
2307 module_exit(mshv_parent_partition_exit);
2308