xref: /linux/drivers/vfio/pci/virtio/migrate.c (revision 2eff01ee2881becc9daaa0d53477ec202136b1f4)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved
4  */
5 
6 #include <linux/device.h>
7 #include <linux/module.h>
8 #include <linux/mutex.h>
9 #include <linux/pci.h>
10 #include <linux/pm_runtime.h>
11 #include <linux/types.h>
12 #include <linux/uaccess.h>
13 #include <linux/vfio.h>
14 #include <linux/vfio_pci_core.h>
15 #include <linux/virtio_pci.h>
16 #include <linux/virtio_net.h>
17 #include <linux/virtio_pci_admin.h>
18 #include <linux/anon_inodes.h>
19 
20 #include "common.h"
21 
22 /* Device specification max parts size */
23 #define MAX_LOAD_SIZE (BIT_ULL(BITS_PER_TYPE \
24 	(((struct virtio_admin_cmd_dev_parts_metadata_result *)0)->parts_size.size)) - 1)
25 
26 /* Initial target buffer size */
27 #define VIRTIOVF_TARGET_INITIAL_BUF_SIZE SZ_1M
28 
29 static int
30 virtiovf_read_device_context_chunk(struct virtiovf_migration_file *migf,
31 				   u32 ctx_size);
32 
33 static struct page *
34 virtiovf_get_migration_page(struct virtiovf_data_buffer *buf,
35 			    unsigned long offset)
36 {
37 	unsigned long cur_offset = 0;
38 	struct scatterlist *sg;
39 	unsigned int i;
40 
41 	/* All accesses are sequential */
42 	if (offset < buf->last_offset || !buf->last_offset_sg) {
43 		buf->last_offset = 0;
44 		buf->last_offset_sg = buf->table.sgt.sgl;
45 		buf->sg_last_entry = 0;
46 	}
47 
48 	cur_offset = buf->last_offset;
49 
50 	for_each_sg(buf->last_offset_sg, sg,
51 		    buf->table.sgt.orig_nents - buf->sg_last_entry, i) {
52 		if (offset < sg->length + cur_offset) {
53 			buf->last_offset_sg = sg;
54 			buf->sg_last_entry += i;
55 			buf->last_offset = cur_offset;
56 			return nth_page(sg_page(sg),
57 					(offset - cur_offset) / PAGE_SIZE);
58 		}
59 		cur_offset += sg->length;
60 	}
61 	return NULL;
62 }
63 
64 static int virtiovf_add_migration_pages(struct virtiovf_data_buffer *buf,
65 					unsigned int npages)
66 {
67 	unsigned int to_alloc = npages;
68 	struct page **page_list;
69 	unsigned long filled;
70 	unsigned int to_fill;
71 	int ret;
72 	int i;
73 
74 	to_fill = min_t(unsigned int, npages, PAGE_SIZE / sizeof(*page_list));
75 	page_list = kvcalloc(to_fill, sizeof(*page_list), GFP_KERNEL_ACCOUNT);
76 	if (!page_list)
77 		return -ENOMEM;
78 
79 	do {
80 		filled = alloc_pages_bulk_array(GFP_KERNEL_ACCOUNT, to_fill,
81 						page_list);
82 		if (!filled) {
83 			ret = -ENOMEM;
84 			goto err;
85 		}
86 		to_alloc -= filled;
87 		ret = sg_alloc_append_table_from_pages(&buf->table, page_list,
88 			filled, 0, filled << PAGE_SHIFT, UINT_MAX,
89 			SG_MAX_SINGLE_ALLOC, GFP_KERNEL_ACCOUNT);
90 
91 		if (ret)
92 			goto err_append;
93 		buf->allocated_length += filled * PAGE_SIZE;
94 		/* clean input for another bulk allocation */
95 		memset(page_list, 0, filled * sizeof(*page_list));
96 		to_fill = min_t(unsigned int, to_alloc,
97 				PAGE_SIZE / sizeof(*page_list));
98 	} while (to_alloc > 0);
99 
100 	kvfree(page_list);
101 	return 0;
102 
103 err_append:
104 	for (i = filled - 1; i >= 0; i--)
105 		__free_page(page_list[i]);
106 err:
107 	kvfree(page_list);
108 	return ret;
109 }
110 
111 static void virtiovf_free_data_buffer(struct virtiovf_data_buffer *buf)
112 {
113 	struct sg_page_iter sg_iter;
114 
115 	/* Undo alloc_pages_bulk_array() */
116 	for_each_sgtable_page(&buf->table.sgt, &sg_iter, 0)
117 		__free_page(sg_page_iter_page(&sg_iter));
118 	sg_free_append_table(&buf->table);
119 	kfree(buf);
120 }
121 
122 static struct virtiovf_data_buffer *
123 virtiovf_alloc_data_buffer(struct virtiovf_migration_file *migf, size_t length)
124 {
125 	struct virtiovf_data_buffer *buf;
126 	int ret;
127 
128 	buf = kzalloc(sizeof(*buf), GFP_KERNEL_ACCOUNT);
129 	if (!buf)
130 		return ERR_PTR(-ENOMEM);
131 
132 	ret = virtiovf_add_migration_pages(buf,
133 				DIV_ROUND_UP_ULL(length, PAGE_SIZE));
134 	if (ret)
135 		goto end;
136 
137 	buf->migf = migf;
138 	return buf;
139 end:
140 	virtiovf_free_data_buffer(buf);
141 	return ERR_PTR(ret);
142 }
143 
144 static void virtiovf_put_data_buffer(struct virtiovf_data_buffer *buf)
145 {
146 	spin_lock_irq(&buf->migf->list_lock);
147 	list_add_tail(&buf->buf_elm, &buf->migf->avail_list);
148 	spin_unlock_irq(&buf->migf->list_lock);
149 }
150 
151 static int
152 virtiovf_pci_alloc_obj_id(struct virtiovf_pci_core_device *virtvdev, u8 type,
153 			  u32 *obj_id)
154 {
155 	return virtio_pci_admin_obj_create(virtvdev->core_device.pdev,
156 					   VIRTIO_RESOURCE_OBJ_DEV_PARTS, type, obj_id);
157 }
158 
159 static void
160 virtiovf_pci_free_obj_id(struct virtiovf_pci_core_device *virtvdev, u32 obj_id)
161 {
162 	virtio_pci_admin_obj_destroy(virtvdev->core_device.pdev,
163 			VIRTIO_RESOURCE_OBJ_DEV_PARTS, obj_id);
164 }
165 
166 static struct virtiovf_data_buffer *
167 virtiovf_get_data_buffer(struct virtiovf_migration_file *migf, size_t length)
168 {
169 	struct virtiovf_data_buffer *buf, *temp_buf;
170 	struct list_head free_list;
171 
172 	INIT_LIST_HEAD(&free_list);
173 
174 	spin_lock_irq(&migf->list_lock);
175 	list_for_each_entry_safe(buf, temp_buf, &migf->avail_list, buf_elm) {
176 		list_del_init(&buf->buf_elm);
177 		if (buf->allocated_length >= length) {
178 			spin_unlock_irq(&migf->list_lock);
179 			goto found;
180 		}
181 		/*
182 		 * Prevent holding redundant buffers. Put in a free
183 		 * list and call at the end not under the spin lock
184 		 * (&migf->list_lock) to minimize its scope usage.
185 		 */
186 		list_add(&buf->buf_elm, &free_list);
187 	}
188 	spin_unlock_irq(&migf->list_lock);
189 	buf = virtiovf_alloc_data_buffer(migf, length);
190 
191 found:
192 	while ((temp_buf = list_first_entry_or_null(&free_list,
193 				struct virtiovf_data_buffer, buf_elm))) {
194 		list_del(&temp_buf->buf_elm);
195 		virtiovf_free_data_buffer(temp_buf);
196 	}
197 
198 	return buf;
199 }
200 
201 static void virtiovf_clean_migf_resources(struct virtiovf_migration_file *migf)
202 {
203 	struct virtiovf_data_buffer *entry;
204 
205 	if (migf->buf) {
206 		virtiovf_free_data_buffer(migf->buf);
207 		migf->buf = NULL;
208 	}
209 
210 	if (migf->buf_header) {
211 		virtiovf_free_data_buffer(migf->buf_header);
212 		migf->buf_header = NULL;
213 	}
214 
215 	list_splice(&migf->avail_list, &migf->buf_list);
216 
217 	while ((entry = list_first_entry_or_null(&migf->buf_list,
218 				struct virtiovf_data_buffer, buf_elm))) {
219 		list_del(&entry->buf_elm);
220 		virtiovf_free_data_buffer(entry);
221 	}
222 
223 	if (migf->has_obj_id)
224 		virtiovf_pci_free_obj_id(migf->virtvdev, migf->obj_id);
225 }
226 
227 static void virtiovf_disable_fd(struct virtiovf_migration_file *migf)
228 {
229 	mutex_lock(&migf->lock);
230 	migf->state = VIRTIOVF_MIGF_STATE_ERROR;
231 	migf->filp->f_pos = 0;
232 	mutex_unlock(&migf->lock);
233 }
234 
235 static void virtiovf_disable_fds(struct virtiovf_pci_core_device *virtvdev)
236 {
237 	if (virtvdev->resuming_migf) {
238 		virtiovf_disable_fd(virtvdev->resuming_migf);
239 		virtiovf_clean_migf_resources(virtvdev->resuming_migf);
240 		fput(virtvdev->resuming_migf->filp);
241 		virtvdev->resuming_migf = NULL;
242 	}
243 	if (virtvdev->saving_migf) {
244 		virtiovf_disable_fd(virtvdev->saving_migf);
245 		virtiovf_clean_migf_resources(virtvdev->saving_migf);
246 		fput(virtvdev->saving_migf->filp);
247 		virtvdev->saving_migf = NULL;
248 	}
249 }
250 
251 /*
252  * This function is called in all state_mutex unlock cases to
253  * handle a 'deferred_reset' if exists.
254  */
255 static void virtiovf_state_mutex_unlock(struct virtiovf_pci_core_device *virtvdev)
256 {
257 again:
258 	spin_lock(&virtvdev->reset_lock);
259 	if (virtvdev->deferred_reset) {
260 		virtvdev->deferred_reset = false;
261 		spin_unlock(&virtvdev->reset_lock);
262 		virtvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
263 		virtiovf_disable_fds(virtvdev);
264 		goto again;
265 	}
266 	mutex_unlock(&virtvdev->state_mutex);
267 	spin_unlock(&virtvdev->reset_lock);
268 }
269 
270 void virtiovf_migration_reset_done(struct pci_dev *pdev)
271 {
272 	struct virtiovf_pci_core_device *virtvdev = dev_get_drvdata(&pdev->dev);
273 
274 	if (!virtvdev->migrate_cap)
275 		return;
276 
277 	/*
278 	 * As the higher VFIO layers are holding locks across reset and using
279 	 * those same locks with the mm_lock we need to prevent ABBA deadlock
280 	 * with the state_mutex and mm_lock.
281 	 * In case the state_mutex was taken already we defer the cleanup work
282 	 * to the unlock flow of the other running context.
283 	 */
284 	spin_lock(&virtvdev->reset_lock);
285 	virtvdev->deferred_reset = true;
286 	if (!mutex_trylock(&virtvdev->state_mutex)) {
287 		spin_unlock(&virtvdev->reset_lock);
288 		return;
289 	}
290 	spin_unlock(&virtvdev->reset_lock);
291 	virtiovf_state_mutex_unlock(virtvdev);
292 }
293 
294 static int virtiovf_release_file(struct inode *inode, struct file *filp)
295 {
296 	struct virtiovf_migration_file *migf = filp->private_data;
297 
298 	virtiovf_disable_fd(migf);
299 	mutex_destroy(&migf->lock);
300 	kfree(migf);
301 	return 0;
302 }
303 
304 static struct virtiovf_data_buffer *
305 virtiovf_get_data_buff_from_pos(struct virtiovf_migration_file *migf,
306 				loff_t pos, bool *end_of_data)
307 {
308 	struct virtiovf_data_buffer *buf;
309 	bool found = false;
310 
311 	*end_of_data = false;
312 	spin_lock_irq(&migf->list_lock);
313 	if (list_empty(&migf->buf_list)) {
314 		*end_of_data = true;
315 		goto end;
316 	}
317 
318 	buf = list_first_entry(&migf->buf_list, struct virtiovf_data_buffer,
319 			       buf_elm);
320 	if (pos >= buf->start_pos &&
321 	    pos < buf->start_pos + buf->length) {
322 		found = true;
323 		goto end;
324 	}
325 
326 	/*
327 	 * As we use a stream based FD we may expect having the data always
328 	 * on first chunk
329 	 */
330 	migf->state = VIRTIOVF_MIGF_STATE_ERROR;
331 
332 end:
333 	spin_unlock_irq(&migf->list_lock);
334 	return found ? buf : NULL;
335 }
336 
337 static ssize_t virtiovf_buf_read(struct virtiovf_data_buffer *vhca_buf,
338 				 char __user **buf, size_t *len, loff_t *pos)
339 {
340 	unsigned long offset;
341 	ssize_t done = 0;
342 	size_t copy_len;
343 
344 	copy_len = min_t(size_t,
345 			 vhca_buf->start_pos + vhca_buf->length - *pos, *len);
346 	while (copy_len) {
347 		size_t page_offset;
348 		struct page *page;
349 		size_t page_len;
350 		u8 *from_buff;
351 		int ret;
352 
353 		offset = *pos - vhca_buf->start_pos;
354 		page_offset = offset % PAGE_SIZE;
355 		offset -= page_offset;
356 		page = virtiovf_get_migration_page(vhca_buf, offset);
357 		if (!page)
358 			return -EINVAL;
359 		page_len = min_t(size_t, copy_len, PAGE_SIZE - page_offset);
360 		from_buff = kmap_local_page(page);
361 		ret = copy_to_user(*buf, from_buff + page_offset, page_len);
362 		kunmap_local(from_buff);
363 		if (ret)
364 			return -EFAULT;
365 		*pos += page_len;
366 		*len -= page_len;
367 		*buf += page_len;
368 		done += page_len;
369 		copy_len -= page_len;
370 	}
371 
372 	if (*pos >= vhca_buf->start_pos + vhca_buf->length) {
373 		spin_lock_irq(&vhca_buf->migf->list_lock);
374 		list_del_init(&vhca_buf->buf_elm);
375 		list_add_tail(&vhca_buf->buf_elm, &vhca_buf->migf->avail_list);
376 		spin_unlock_irq(&vhca_buf->migf->list_lock);
377 	}
378 
379 	return done;
380 }
381 
382 static ssize_t virtiovf_save_read(struct file *filp, char __user *buf, size_t len,
383 				  loff_t *pos)
384 {
385 	struct virtiovf_migration_file *migf = filp->private_data;
386 	struct virtiovf_data_buffer *vhca_buf;
387 	bool first_loop_call = true;
388 	bool end_of_data;
389 	ssize_t done = 0;
390 
391 	if (pos)
392 		return -ESPIPE;
393 	pos = &filp->f_pos;
394 
395 	mutex_lock(&migf->lock);
396 	if (migf->state == VIRTIOVF_MIGF_STATE_ERROR) {
397 		done = -ENODEV;
398 		goto out_unlock;
399 	}
400 
401 	while (len) {
402 		ssize_t count;
403 
404 		vhca_buf = virtiovf_get_data_buff_from_pos(migf, *pos, &end_of_data);
405 		if (first_loop_call) {
406 			first_loop_call = false;
407 			/* Temporary end of file as part of PRE_COPY */
408 			if (end_of_data && migf->state == VIRTIOVF_MIGF_STATE_PRECOPY) {
409 				done = -ENOMSG;
410 				goto out_unlock;
411 			}
412 			if (end_of_data && migf->state != VIRTIOVF_MIGF_STATE_COMPLETE) {
413 				done = -EINVAL;
414 				goto out_unlock;
415 			}
416 		}
417 
418 		if (end_of_data)
419 			goto out_unlock;
420 
421 		if (!vhca_buf) {
422 			done = -EINVAL;
423 			goto out_unlock;
424 		}
425 
426 		count = virtiovf_buf_read(vhca_buf, &buf, &len, pos);
427 		if (count < 0) {
428 			done = count;
429 			goto out_unlock;
430 		}
431 		done += count;
432 	}
433 
434 out_unlock:
435 	mutex_unlock(&migf->lock);
436 	return done;
437 }
438 
439 static long virtiovf_precopy_ioctl(struct file *filp, unsigned int cmd,
440 				   unsigned long arg)
441 {
442 	struct virtiovf_migration_file *migf = filp->private_data;
443 	struct virtiovf_pci_core_device *virtvdev = migf->virtvdev;
444 	struct vfio_precopy_info info = {};
445 	loff_t *pos = &filp->f_pos;
446 	bool end_of_data = false;
447 	unsigned long minsz;
448 	u32 ctx_size = 0;
449 	int ret;
450 
451 	if (cmd != VFIO_MIG_GET_PRECOPY_INFO)
452 		return -ENOTTY;
453 
454 	minsz = offsetofend(struct vfio_precopy_info, dirty_bytes);
455 	if (copy_from_user(&info, (void __user *)arg, minsz))
456 		return -EFAULT;
457 
458 	if (info.argsz < minsz)
459 		return -EINVAL;
460 
461 	mutex_lock(&virtvdev->state_mutex);
462 	if (virtvdev->mig_state != VFIO_DEVICE_STATE_PRE_COPY &&
463 	    virtvdev->mig_state != VFIO_DEVICE_STATE_PRE_COPY_P2P) {
464 		ret = -EINVAL;
465 		goto err_state_unlock;
466 	}
467 
468 	/*
469 	 * The virtio specification does not include a PRE_COPY concept.
470 	 * Since we can expect the data to remain the same for a certain period,
471 	 * we use a rate limiter mechanism before making a call to the device.
472 	 */
473 	if (__ratelimit(&migf->pre_copy_rl_state)) {
474 
475 		ret = virtio_pci_admin_dev_parts_metadata_get(virtvdev->core_device.pdev,
476 					VIRTIO_RESOURCE_OBJ_DEV_PARTS, migf->obj_id,
477 					VIRTIO_ADMIN_CMD_DEV_PARTS_METADATA_TYPE_SIZE,
478 					&ctx_size);
479 		if (ret)
480 			goto err_state_unlock;
481 	}
482 
483 	mutex_lock(&migf->lock);
484 	if (migf->state == VIRTIOVF_MIGF_STATE_ERROR) {
485 		ret = -ENODEV;
486 		goto err_migf_unlock;
487 	}
488 
489 	if (migf->pre_copy_initial_bytes > *pos) {
490 		info.initial_bytes = migf->pre_copy_initial_bytes - *pos;
491 	} else {
492 		info.dirty_bytes = migf->max_pos - *pos;
493 		if (!info.dirty_bytes)
494 			end_of_data = true;
495 		info.dirty_bytes += ctx_size;
496 	}
497 
498 	if (!end_of_data || !ctx_size) {
499 		mutex_unlock(&migf->lock);
500 		goto done;
501 	}
502 
503 	mutex_unlock(&migf->lock);
504 	/*
505 	 * We finished transferring the current state and the device has a
506 	 * dirty state, read a new state.
507 	 */
508 	ret = virtiovf_read_device_context_chunk(migf, ctx_size);
509 	if (ret)
510 		/*
511 		 * The machine is running, and context size could be grow, so no reason to mark
512 		 * the device state as VIRTIOVF_MIGF_STATE_ERROR.
513 		 */
514 		goto err_state_unlock;
515 
516 done:
517 	virtiovf_state_mutex_unlock(virtvdev);
518 	if (copy_to_user((void __user *)arg, &info, minsz))
519 		return -EFAULT;
520 	return 0;
521 
522 err_migf_unlock:
523 	mutex_unlock(&migf->lock);
524 err_state_unlock:
525 	virtiovf_state_mutex_unlock(virtvdev);
526 	return ret;
527 }
528 
529 static const struct file_operations virtiovf_save_fops = {
530 	.owner = THIS_MODULE,
531 	.read = virtiovf_save_read,
532 	.unlocked_ioctl = virtiovf_precopy_ioctl,
533 	.compat_ioctl = compat_ptr_ioctl,
534 	.release = virtiovf_release_file,
535 };
536 
537 static int
538 virtiovf_add_buf_header(struct virtiovf_data_buffer *header_buf,
539 			u32 data_size)
540 {
541 	struct virtiovf_migration_file *migf = header_buf->migf;
542 	struct virtiovf_migration_header header = {};
543 	struct page *page;
544 	u8 *to_buff;
545 
546 	header.record_size = cpu_to_le64(data_size);
547 	header.flags = cpu_to_le32(VIRTIOVF_MIGF_HEADER_FLAGS_TAG_MANDATORY);
548 	header.tag = cpu_to_le32(VIRTIOVF_MIGF_HEADER_TAG_DEVICE_DATA);
549 	page = virtiovf_get_migration_page(header_buf, 0);
550 	if (!page)
551 		return -EINVAL;
552 	to_buff = kmap_local_page(page);
553 	memcpy(to_buff, &header, sizeof(header));
554 	kunmap_local(to_buff);
555 	header_buf->length = sizeof(header);
556 	header_buf->start_pos = header_buf->migf->max_pos;
557 	migf->max_pos += header_buf->length;
558 	spin_lock_irq(&migf->list_lock);
559 	list_add_tail(&header_buf->buf_elm, &migf->buf_list);
560 	spin_unlock_irq(&migf->list_lock);
561 	return 0;
562 }
563 
564 static int
565 virtiovf_read_device_context_chunk(struct virtiovf_migration_file *migf,
566 				   u32 ctx_size)
567 {
568 	struct virtiovf_data_buffer *header_buf;
569 	struct virtiovf_data_buffer *buf;
570 	bool unmark_end = false;
571 	struct scatterlist *sg;
572 	unsigned int i;
573 	u32 res_size;
574 	int nent;
575 	int ret;
576 
577 	buf = virtiovf_get_data_buffer(migf, ctx_size);
578 	if (IS_ERR(buf))
579 		return PTR_ERR(buf);
580 
581 	/* Find the total count of SG entries which satisfies the size */
582 	nent = sg_nents_for_len(buf->table.sgt.sgl, ctx_size);
583 	if (nent <= 0) {
584 		ret = -EINVAL;
585 		goto out;
586 	}
587 
588 	/*
589 	 * Iterate to that SG entry and mark it as last (if it's not already)
590 	 * to let underlay layers iterate only till that entry.
591 	 */
592 	for_each_sg(buf->table.sgt.sgl, sg, nent - 1, i)
593 		;
594 
595 	if (!sg_is_last(sg)) {
596 		unmark_end = true;
597 		sg_mark_end(sg);
598 	}
599 
600 	ret = virtio_pci_admin_dev_parts_get(migf->virtvdev->core_device.pdev,
601 					     VIRTIO_RESOURCE_OBJ_DEV_PARTS,
602 					     migf->obj_id,
603 					     VIRTIO_ADMIN_CMD_DEV_PARTS_GET_TYPE_ALL,
604 					     buf->table.sgt.sgl, &res_size);
605 	/* Restore the original SG mark end */
606 	if (unmark_end)
607 		sg_unmark_end(sg);
608 	if (ret)
609 		goto out;
610 
611 	buf->length = res_size;
612 	header_buf = virtiovf_get_data_buffer(migf,
613 				sizeof(struct virtiovf_migration_header));
614 	if (IS_ERR(header_buf)) {
615 		ret = PTR_ERR(header_buf);
616 		goto out;
617 	}
618 
619 	ret = virtiovf_add_buf_header(header_buf, res_size);
620 	if (ret)
621 		goto out_header;
622 
623 	buf->start_pos = buf->migf->max_pos;
624 	migf->max_pos += buf->length;
625 	spin_lock(&migf->list_lock);
626 	list_add_tail(&buf->buf_elm, &migf->buf_list);
627 	spin_unlock_irq(&migf->list_lock);
628 	return 0;
629 
630 out_header:
631 	virtiovf_put_data_buffer(header_buf);
632 out:
633 	virtiovf_put_data_buffer(buf);
634 	return ret;
635 }
636 
637 static int
638 virtiovf_pci_save_device_final_data(struct virtiovf_pci_core_device *virtvdev)
639 {
640 	struct virtiovf_migration_file *migf = virtvdev->saving_migf;
641 	u32 ctx_size;
642 	int ret;
643 
644 	if (migf->state == VIRTIOVF_MIGF_STATE_ERROR)
645 		return -ENODEV;
646 
647 	ret = virtio_pci_admin_dev_parts_metadata_get(virtvdev->core_device.pdev,
648 				VIRTIO_RESOURCE_OBJ_DEV_PARTS, migf->obj_id,
649 				VIRTIO_ADMIN_CMD_DEV_PARTS_METADATA_TYPE_SIZE,
650 				&ctx_size);
651 	if (ret)
652 		goto err;
653 
654 	if (!ctx_size) {
655 		ret = -EINVAL;
656 		goto err;
657 	}
658 
659 	ret = virtiovf_read_device_context_chunk(migf, ctx_size);
660 	if (ret)
661 		goto err;
662 
663 	migf->state = VIRTIOVF_MIGF_STATE_COMPLETE;
664 	return 0;
665 
666 err:
667 	migf->state = VIRTIOVF_MIGF_STATE_ERROR;
668 	return ret;
669 }
670 
671 static struct virtiovf_migration_file *
672 virtiovf_pci_save_device_data(struct virtiovf_pci_core_device *virtvdev,
673 			      bool pre_copy)
674 {
675 	struct virtiovf_migration_file *migf;
676 	u32 ctx_size;
677 	u32 obj_id;
678 	int ret;
679 
680 	migf = kzalloc(sizeof(*migf), GFP_KERNEL_ACCOUNT);
681 	if (!migf)
682 		return ERR_PTR(-ENOMEM);
683 
684 	migf->filp = anon_inode_getfile("virtiovf_mig", &virtiovf_save_fops, migf,
685 					O_RDONLY);
686 	if (IS_ERR(migf->filp)) {
687 		ret = PTR_ERR(migf->filp);
688 		kfree(migf);
689 		return ERR_PTR(ret);
690 	}
691 
692 	stream_open(migf->filp->f_inode, migf->filp);
693 	mutex_init(&migf->lock);
694 	INIT_LIST_HEAD(&migf->buf_list);
695 	INIT_LIST_HEAD(&migf->avail_list);
696 	spin_lock_init(&migf->list_lock);
697 	migf->virtvdev = virtvdev;
698 
699 	lockdep_assert_held(&virtvdev->state_mutex);
700 	ret = virtiovf_pci_alloc_obj_id(virtvdev, VIRTIO_RESOURCE_OBJ_DEV_PARTS_TYPE_GET,
701 					&obj_id);
702 	if (ret)
703 		goto out;
704 
705 	migf->obj_id = obj_id;
706 	/* Mark as having a valid obj id which can be even 0 */
707 	migf->has_obj_id = true;
708 	ret = virtio_pci_admin_dev_parts_metadata_get(virtvdev->core_device.pdev,
709 				VIRTIO_RESOURCE_OBJ_DEV_PARTS, obj_id,
710 				VIRTIO_ADMIN_CMD_DEV_PARTS_METADATA_TYPE_SIZE,
711 				&ctx_size);
712 	if (ret)
713 		goto out_clean;
714 
715 	if (!ctx_size) {
716 		ret = -EINVAL;
717 		goto out_clean;
718 	}
719 
720 	ret = virtiovf_read_device_context_chunk(migf, ctx_size);
721 	if (ret)
722 		goto out_clean;
723 
724 	if (pre_copy) {
725 		migf->pre_copy_initial_bytes = migf->max_pos;
726 		/* Arbitrarily set the pre-copy rate limit to 1-second intervals */
727 		ratelimit_state_init(&migf->pre_copy_rl_state, 1 * HZ, 1);
728 		/* Prevent any rate messages upon its usage */
729 		ratelimit_set_flags(&migf->pre_copy_rl_state,
730 				    RATELIMIT_MSG_ON_RELEASE);
731 		migf->state = VIRTIOVF_MIGF_STATE_PRECOPY;
732 	} else {
733 		migf->state = VIRTIOVF_MIGF_STATE_COMPLETE;
734 	}
735 
736 	return migf;
737 
738 out_clean:
739 	virtiovf_clean_migf_resources(migf);
740 out:
741 	fput(migf->filp);
742 	return ERR_PTR(ret);
743 }
744 
745 /*
746  * Set the required object header at the beginning of the buffer.
747  * The actual device parts data will be written post of the header offset.
748  */
749 static int virtiovf_set_obj_cmd_header(struct virtiovf_data_buffer *vhca_buf)
750 {
751 	struct virtio_admin_cmd_resource_obj_cmd_hdr obj_hdr = {};
752 	struct page *page;
753 	u8 *to_buff;
754 
755 	obj_hdr.type = cpu_to_le16(VIRTIO_RESOURCE_OBJ_DEV_PARTS);
756 	obj_hdr.id = cpu_to_le32(vhca_buf->migf->obj_id);
757 	page = virtiovf_get_migration_page(vhca_buf, 0);
758 	if (!page)
759 		return -EINVAL;
760 	to_buff = kmap_local_page(page);
761 	memcpy(to_buff, &obj_hdr, sizeof(obj_hdr));
762 	kunmap_local(to_buff);
763 
764 	/* Mark the buffer as including the header object data */
765 	vhca_buf->include_header_object = 1;
766 	return 0;
767 }
768 
769 static int
770 virtiovf_append_page_to_mig_buf(struct virtiovf_data_buffer *vhca_buf,
771 				const char __user **buf, size_t *len,
772 				loff_t *pos, ssize_t *done)
773 {
774 	unsigned long offset;
775 	size_t page_offset;
776 	struct page *page;
777 	size_t page_len;
778 	u8 *to_buff;
779 	int ret;
780 
781 	offset = *pos - vhca_buf->start_pos;
782 
783 	if (vhca_buf->include_header_object)
784 		/* The buffer holds the object header, update the offset accordingly */
785 		offset += sizeof(struct virtio_admin_cmd_resource_obj_cmd_hdr);
786 
787 	page_offset = offset % PAGE_SIZE;
788 
789 	page = virtiovf_get_migration_page(vhca_buf, offset - page_offset);
790 	if (!page)
791 		return -EINVAL;
792 
793 	page_len = min_t(size_t, *len, PAGE_SIZE - page_offset);
794 	to_buff = kmap_local_page(page);
795 	ret = copy_from_user(to_buff + page_offset, *buf, page_len);
796 	kunmap_local(to_buff);
797 	if (ret)
798 		return -EFAULT;
799 
800 	*pos += page_len;
801 	*done += page_len;
802 	*buf += page_len;
803 	*len -= page_len;
804 	vhca_buf->length += page_len;
805 	return 0;
806 }
807 
808 static ssize_t
809 virtiovf_resume_read_chunk(struct virtiovf_migration_file *migf,
810 			   struct virtiovf_data_buffer *vhca_buf,
811 			   size_t chunk_size, const char __user **buf,
812 			   size_t *len, loff_t *pos, ssize_t *done,
813 			   bool *has_work)
814 {
815 	size_t copy_len, to_copy;
816 	int ret;
817 
818 	to_copy = min_t(size_t, *len, chunk_size - vhca_buf->length);
819 	copy_len = to_copy;
820 	while (to_copy) {
821 		ret = virtiovf_append_page_to_mig_buf(vhca_buf, buf, &to_copy,
822 						      pos, done);
823 		if (ret)
824 			return ret;
825 	}
826 
827 	*len -= copy_len;
828 	if (vhca_buf->length == chunk_size) {
829 		migf->load_state = VIRTIOVF_LOAD_STATE_LOAD_CHUNK;
830 		migf->max_pos += chunk_size;
831 		*has_work = true;
832 	}
833 
834 	return 0;
835 }
836 
837 static int
838 virtiovf_resume_read_header_data(struct virtiovf_migration_file *migf,
839 				 struct virtiovf_data_buffer *vhca_buf,
840 				 const char __user **buf, size_t *len,
841 				 loff_t *pos, ssize_t *done)
842 {
843 	size_t copy_len, to_copy;
844 	size_t required_data;
845 	int ret;
846 
847 	required_data = migf->record_size - vhca_buf->length;
848 	to_copy = min_t(size_t, *len, required_data);
849 	copy_len = to_copy;
850 	while (to_copy) {
851 		ret = virtiovf_append_page_to_mig_buf(vhca_buf, buf, &to_copy,
852 						      pos, done);
853 		if (ret)
854 			return ret;
855 	}
856 
857 	*len -= copy_len;
858 	if (vhca_buf->length == migf->record_size) {
859 		switch (migf->record_tag) {
860 		default:
861 			/* Optional tag */
862 			break;
863 		}
864 
865 		migf->load_state = VIRTIOVF_LOAD_STATE_READ_HEADER;
866 		migf->max_pos += migf->record_size;
867 		vhca_buf->length = 0;
868 	}
869 
870 	return 0;
871 }
872 
873 static int
874 virtiovf_resume_read_header(struct virtiovf_migration_file *migf,
875 			    struct virtiovf_data_buffer *vhca_buf,
876 			    const char __user **buf,
877 			    size_t *len, loff_t *pos,
878 			    ssize_t *done, bool *has_work)
879 {
880 	struct page *page;
881 	size_t copy_len;
882 	u8 *to_buff;
883 	int ret;
884 
885 	copy_len = min_t(size_t, *len,
886 		sizeof(struct virtiovf_migration_header) - vhca_buf->length);
887 	page = virtiovf_get_migration_page(vhca_buf, 0);
888 	if (!page)
889 		return -EINVAL;
890 	to_buff = kmap_local_page(page);
891 	ret = copy_from_user(to_buff + vhca_buf->length, *buf, copy_len);
892 	if (ret) {
893 		ret = -EFAULT;
894 		goto end;
895 	}
896 
897 	*buf += copy_len;
898 	*pos += copy_len;
899 	*done += copy_len;
900 	*len -= copy_len;
901 	vhca_buf->length += copy_len;
902 	if (vhca_buf->length == sizeof(struct virtiovf_migration_header)) {
903 		u64 record_size;
904 		u32 flags;
905 
906 		record_size = le64_to_cpup((__le64 *)to_buff);
907 		if (record_size > MAX_LOAD_SIZE) {
908 			ret = -ENOMEM;
909 			goto end;
910 		}
911 
912 		migf->record_size = record_size;
913 		flags = le32_to_cpup((__le32 *)(to_buff +
914 			    offsetof(struct virtiovf_migration_header, flags)));
915 		migf->record_tag = le32_to_cpup((__le32 *)(to_buff +
916 			    offsetof(struct virtiovf_migration_header, tag)));
917 		switch (migf->record_tag) {
918 		case VIRTIOVF_MIGF_HEADER_TAG_DEVICE_DATA:
919 			migf->load_state = VIRTIOVF_LOAD_STATE_PREP_CHUNK;
920 			break;
921 		default:
922 			if (!(flags & VIRTIOVF_MIGF_HEADER_FLAGS_TAG_OPTIONAL)) {
923 				ret = -EOPNOTSUPP;
924 				goto end;
925 			}
926 			/* We may read and skip this optional record data */
927 			migf->load_state = VIRTIOVF_LOAD_STATE_PREP_HEADER_DATA;
928 		}
929 
930 		migf->max_pos += vhca_buf->length;
931 		vhca_buf->length = 0;
932 		*has_work = true;
933 	}
934 end:
935 	kunmap_local(to_buff);
936 	return ret;
937 }
938 
939 static ssize_t virtiovf_resume_write(struct file *filp, const char __user *buf,
940 				     size_t len, loff_t *pos)
941 {
942 	struct virtiovf_migration_file *migf = filp->private_data;
943 	struct virtiovf_data_buffer *vhca_buf = migf->buf;
944 	struct virtiovf_data_buffer *vhca_buf_header = migf->buf_header;
945 	unsigned int orig_length;
946 	bool has_work = false;
947 	ssize_t done = 0;
948 	int ret = 0;
949 
950 	if (pos)
951 		return -ESPIPE;
952 
953 	pos = &filp->f_pos;
954 	if (*pos < vhca_buf->start_pos)
955 		return -EINVAL;
956 
957 	mutex_lock(&migf->virtvdev->state_mutex);
958 	mutex_lock(&migf->lock);
959 	if (migf->state == VIRTIOVF_MIGF_STATE_ERROR) {
960 		done = -ENODEV;
961 		goto out_unlock;
962 	}
963 
964 	while (len || has_work) {
965 		has_work = false;
966 		switch (migf->load_state) {
967 		case VIRTIOVF_LOAD_STATE_READ_HEADER:
968 			ret = virtiovf_resume_read_header(migf, vhca_buf_header, &buf,
969 							  &len, pos, &done, &has_work);
970 			if (ret)
971 				goto out_unlock;
972 			break;
973 		case VIRTIOVF_LOAD_STATE_PREP_HEADER_DATA:
974 			if (vhca_buf_header->allocated_length < migf->record_size) {
975 				virtiovf_free_data_buffer(vhca_buf_header);
976 
977 				migf->buf_header = virtiovf_alloc_data_buffer(migf,
978 						migf->record_size);
979 				if (IS_ERR(migf->buf_header)) {
980 					ret = PTR_ERR(migf->buf_header);
981 					migf->buf_header = NULL;
982 					goto out_unlock;
983 				}
984 
985 				vhca_buf_header = migf->buf_header;
986 			}
987 
988 			vhca_buf_header->start_pos = migf->max_pos;
989 			migf->load_state = VIRTIOVF_LOAD_STATE_READ_HEADER_DATA;
990 			break;
991 		case VIRTIOVF_LOAD_STATE_READ_HEADER_DATA:
992 			ret = virtiovf_resume_read_header_data(migf, vhca_buf_header,
993 							       &buf, &len, pos, &done);
994 			if (ret)
995 				goto out_unlock;
996 			break;
997 		case VIRTIOVF_LOAD_STATE_PREP_CHUNK:
998 		{
999 			u32 cmd_size = migf->record_size +
1000 				sizeof(struct virtio_admin_cmd_resource_obj_cmd_hdr);
1001 
1002 			/*
1003 			 * The DMA map/unmap is managed in virtio layer, we just need to extend
1004 			 * the SG pages to hold the extra required chunk data.
1005 			 */
1006 			if (vhca_buf->allocated_length < cmd_size) {
1007 				ret = virtiovf_add_migration_pages(vhca_buf,
1008 					DIV_ROUND_UP_ULL(cmd_size - vhca_buf->allocated_length,
1009 							 PAGE_SIZE));
1010 				if (ret)
1011 					goto out_unlock;
1012 			}
1013 
1014 			vhca_buf->start_pos = migf->max_pos;
1015 			migf->load_state = VIRTIOVF_LOAD_STATE_READ_CHUNK;
1016 			break;
1017 		}
1018 		case VIRTIOVF_LOAD_STATE_READ_CHUNK:
1019 			ret = virtiovf_resume_read_chunk(migf, vhca_buf, migf->record_size,
1020 							 &buf, &len, pos, &done, &has_work);
1021 			if (ret)
1022 				goto out_unlock;
1023 			break;
1024 		case VIRTIOVF_LOAD_STATE_LOAD_CHUNK:
1025 			/* Mark the last SG entry and set its length */
1026 			sg_mark_end(vhca_buf->last_offset_sg);
1027 			orig_length = vhca_buf->last_offset_sg->length;
1028 			/* Length should include the resource object command header */
1029 			vhca_buf->last_offset_sg->length = vhca_buf->length +
1030 					sizeof(struct virtio_admin_cmd_resource_obj_cmd_hdr) -
1031 					vhca_buf->last_offset;
1032 			ret = virtio_pci_admin_dev_parts_set(migf->virtvdev->core_device.pdev,
1033 							     vhca_buf->table.sgt.sgl);
1034 			/* Restore the original SG data */
1035 			vhca_buf->last_offset_sg->length = orig_length;
1036 			sg_unmark_end(vhca_buf->last_offset_sg);
1037 			if (ret)
1038 				goto out_unlock;
1039 			migf->load_state = VIRTIOVF_LOAD_STATE_READ_HEADER;
1040 			/* be ready for reading the next chunk */
1041 			vhca_buf->length = 0;
1042 			break;
1043 		default:
1044 			break;
1045 		}
1046 	}
1047 
1048 out_unlock:
1049 	if (ret)
1050 		migf->state = VIRTIOVF_MIGF_STATE_ERROR;
1051 	mutex_unlock(&migf->lock);
1052 	virtiovf_state_mutex_unlock(migf->virtvdev);
1053 	return ret ? ret : done;
1054 }
1055 
1056 static const struct file_operations virtiovf_resume_fops = {
1057 	.owner = THIS_MODULE,
1058 	.write = virtiovf_resume_write,
1059 	.release = virtiovf_release_file,
1060 };
1061 
1062 static struct virtiovf_migration_file *
1063 virtiovf_pci_resume_device_data(struct virtiovf_pci_core_device *virtvdev)
1064 {
1065 	struct virtiovf_migration_file *migf;
1066 	struct virtiovf_data_buffer *buf;
1067 	u32 obj_id;
1068 	int ret;
1069 
1070 	migf = kzalloc(sizeof(*migf), GFP_KERNEL_ACCOUNT);
1071 	if (!migf)
1072 		return ERR_PTR(-ENOMEM);
1073 
1074 	migf->filp = anon_inode_getfile("virtiovf_mig", &virtiovf_resume_fops, migf,
1075 					O_WRONLY);
1076 	if (IS_ERR(migf->filp)) {
1077 		ret = PTR_ERR(migf->filp);
1078 		kfree(migf);
1079 		return ERR_PTR(ret);
1080 	}
1081 
1082 	stream_open(migf->filp->f_inode, migf->filp);
1083 	mutex_init(&migf->lock);
1084 	INIT_LIST_HEAD(&migf->buf_list);
1085 	INIT_LIST_HEAD(&migf->avail_list);
1086 	spin_lock_init(&migf->list_lock);
1087 
1088 	buf = virtiovf_alloc_data_buffer(migf, VIRTIOVF_TARGET_INITIAL_BUF_SIZE);
1089 	if (IS_ERR(buf)) {
1090 		ret = PTR_ERR(buf);
1091 		goto out;
1092 	}
1093 
1094 	migf->buf = buf;
1095 
1096 	buf = virtiovf_alloc_data_buffer(migf,
1097 		sizeof(struct virtiovf_migration_header));
1098 	if (IS_ERR(buf)) {
1099 		ret = PTR_ERR(buf);
1100 		goto out_clean;
1101 	}
1102 
1103 	migf->buf_header = buf;
1104 	migf->load_state = VIRTIOVF_LOAD_STATE_READ_HEADER;
1105 
1106 	migf->virtvdev = virtvdev;
1107 	ret = virtiovf_pci_alloc_obj_id(virtvdev, VIRTIO_RESOURCE_OBJ_DEV_PARTS_TYPE_SET,
1108 					&obj_id);
1109 	if (ret)
1110 		goto out_clean;
1111 
1112 	migf->obj_id = obj_id;
1113 	/* Mark as having a valid obj id which can be even 0 */
1114 	migf->has_obj_id = true;
1115 	ret = virtiovf_set_obj_cmd_header(migf->buf);
1116 	if (ret)
1117 		goto out_clean;
1118 
1119 	return migf;
1120 
1121 out_clean:
1122 	virtiovf_clean_migf_resources(migf);
1123 out:
1124 	fput(migf->filp);
1125 	return ERR_PTR(ret);
1126 }
1127 
1128 static struct file *
1129 virtiovf_pci_step_device_state_locked(struct virtiovf_pci_core_device *virtvdev,
1130 				      u32 new)
1131 {
1132 	u32 cur = virtvdev->mig_state;
1133 	int ret;
1134 
1135 	if (cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_STOP) {
1136 		/* NOP */
1137 		return NULL;
1138 	}
1139 
1140 	if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_RUNNING_P2P) {
1141 		/* NOP */
1142 		return NULL;
1143 	}
1144 
1145 	if ((cur == VFIO_DEVICE_STATE_RUNNING && new == VFIO_DEVICE_STATE_RUNNING_P2P) ||
1146 	    (cur == VFIO_DEVICE_STATE_PRE_COPY && new == VFIO_DEVICE_STATE_PRE_COPY_P2P)) {
1147 		ret = virtio_pci_admin_mode_set(virtvdev->core_device.pdev,
1148 						BIT(VIRTIO_ADMIN_CMD_DEV_MODE_F_STOPPED));
1149 		if (ret)
1150 			return ERR_PTR(ret);
1151 		return NULL;
1152 	}
1153 
1154 	if ((cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_RUNNING) ||
1155 	    (cur == VFIO_DEVICE_STATE_PRE_COPY_P2P && new == VFIO_DEVICE_STATE_PRE_COPY)) {
1156 		ret = virtio_pci_admin_mode_set(virtvdev->core_device.pdev, 0);
1157 		if (ret)
1158 			return ERR_PTR(ret);
1159 		return NULL;
1160 	}
1161 
1162 	if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_STOP_COPY) {
1163 		struct virtiovf_migration_file *migf;
1164 
1165 		migf = virtiovf_pci_save_device_data(virtvdev, false);
1166 		if (IS_ERR(migf))
1167 			return ERR_CAST(migf);
1168 		get_file(migf->filp);
1169 		virtvdev->saving_migf = migf;
1170 		return migf->filp;
1171 	}
1172 
1173 	if ((cur == VFIO_DEVICE_STATE_STOP_COPY && new == VFIO_DEVICE_STATE_STOP) ||
1174 	    (cur == VFIO_DEVICE_STATE_PRE_COPY && new == VFIO_DEVICE_STATE_RUNNING) ||
1175 	    (cur == VFIO_DEVICE_STATE_PRE_COPY_P2P && new == VFIO_DEVICE_STATE_RUNNING_P2P)) {
1176 		virtiovf_disable_fds(virtvdev);
1177 		return NULL;
1178 	}
1179 
1180 	if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_RESUMING) {
1181 		struct virtiovf_migration_file *migf;
1182 
1183 		migf = virtiovf_pci_resume_device_data(virtvdev);
1184 		if (IS_ERR(migf))
1185 			return ERR_CAST(migf);
1186 		get_file(migf->filp);
1187 		virtvdev->resuming_migf = migf;
1188 		return migf->filp;
1189 	}
1190 
1191 	if (cur == VFIO_DEVICE_STATE_RESUMING && new == VFIO_DEVICE_STATE_STOP) {
1192 		virtiovf_disable_fds(virtvdev);
1193 		return NULL;
1194 	}
1195 
1196 	if ((cur == VFIO_DEVICE_STATE_RUNNING && new == VFIO_DEVICE_STATE_PRE_COPY) ||
1197 	    (cur == VFIO_DEVICE_STATE_RUNNING_P2P &&
1198 	     new == VFIO_DEVICE_STATE_PRE_COPY_P2P)) {
1199 		struct virtiovf_migration_file *migf;
1200 
1201 		migf = virtiovf_pci_save_device_data(virtvdev, true);
1202 		if (IS_ERR(migf))
1203 			return ERR_CAST(migf);
1204 		get_file(migf->filp);
1205 		virtvdev->saving_migf = migf;
1206 		return migf->filp;
1207 	}
1208 
1209 	if (cur == VFIO_DEVICE_STATE_PRE_COPY_P2P && new == VFIO_DEVICE_STATE_STOP_COPY) {
1210 		ret = virtiovf_pci_save_device_final_data(virtvdev);
1211 		return ret ? ERR_PTR(ret) : NULL;
1212 	}
1213 
1214 	/*
1215 	 * vfio_mig_get_next_state() does not use arcs other than the above
1216 	 */
1217 	WARN_ON(true);
1218 	return ERR_PTR(-EINVAL);
1219 }
1220 
1221 static struct file *
1222 virtiovf_pci_set_device_state(struct vfio_device *vdev,
1223 			      enum vfio_device_mig_state new_state)
1224 {
1225 	struct virtiovf_pci_core_device *virtvdev = container_of(
1226 		vdev, struct virtiovf_pci_core_device, core_device.vdev);
1227 	enum vfio_device_mig_state next_state;
1228 	struct file *res = NULL;
1229 	int ret;
1230 
1231 	mutex_lock(&virtvdev->state_mutex);
1232 	while (new_state != virtvdev->mig_state) {
1233 		ret = vfio_mig_get_next_state(vdev, virtvdev->mig_state,
1234 					      new_state, &next_state);
1235 		if (ret) {
1236 			res = ERR_PTR(ret);
1237 			break;
1238 		}
1239 		res = virtiovf_pci_step_device_state_locked(virtvdev, next_state);
1240 		if (IS_ERR(res))
1241 			break;
1242 		virtvdev->mig_state = next_state;
1243 		if (WARN_ON(res && new_state != virtvdev->mig_state)) {
1244 			fput(res);
1245 			res = ERR_PTR(-EINVAL);
1246 			break;
1247 		}
1248 	}
1249 	virtiovf_state_mutex_unlock(virtvdev);
1250 	return res;
1251 }
1252 
1253 static int virtiovf_pci_get_device_state(struct vfio_device *vdev,
1254 				       enum vfio_device_mig_state *curr_state)
1255 {
1256 	struct virtiovf_pci_core_device *virtvdev = container_of(
1257 		vdev, struct virtiovf_pci_core_device, core_device.vdev);
1258 
1259 	mutex_lock(&virtvdev->state_mutex);
1260 	*curr_state = virtvdev->mig_state;
1261 	virtiovf_state_mutex_unlock(virtvdev);
1262 	return 0;
1263 }
1264 
1265 static int virtiovf_pci_get_data_size(struct vfio_device *vdev,
1266 				      unsigned long *stop_copy_length)
1267 {
1268 	struct virtiovf_pci_core_device *virtvdev = container_of(
1269 		vdev, struct virtiovf_pci_core_device, core_device.vdev);
1270 	bool obj_id_exists;
1271 	u32 res_size;
1272 	u32 obj_id;
1273 	int ret;
1274 
1275 	mutex_lock(&virtvdev->state_mutex);
1276 	obj_id_exists = virtvdev->saving_migf && virtvdev->saving_migf->has_obj_id;
1277 	if (!obj_id_exists) {
1278 		ret = virtiovf_pci_alloc_obj_id(virtvdev,
1279 						VIRTIO_RESOURCE_OBJ_DEV_PARTS_TYPE_GET,
1280 						&obj_id);
1281 		if (ret)
1282 			goto end;
1283 	} else {
1284 		obj_id = virtvdev->saving_migf->obj_id;
1285 	}
1286 
1287 	ret = virtio_pci_admin_dev_parts_metadata_get(virtvdev->core_device.pdev,
1288 				VIRTIO_RESOURCE_OBJ_DEV_PARTS, obj_id,
1289 				VIRTIO_ADMIN_CMD_DEV_PARTS_METADATA_TYPE_SIZE,
1290 				&res_size);
1291 	if (!ret)
1292 		*stop_copy_length = res_size;
1293 
1294 	/*
1295 	 * We can't leave this obj_id alive if didn't exist before, otherwise, it might
1296 	 * stay alive, even without an active migration flow (e.g. migration was cancelled)
1297 	 */
1298 	if (!obj_id_exists)
1299 		virtiovf_pci_free_obj_id(virtvdev, obj_id);
1300 end:
1301 	virtiovf_state_mutex_unlock(virtvdev);
1302 	return ret;
1303 }
1304 
1305 static const struct vfio_migration_ops virtvdev_pci_mig_ops = {
1306 	.migration_set_state = virtiovf_pci_set_device_state,
1307 	.migration_get_state = virtiovf_pci_get_device_state,
1308 	.migration_get_data_size = virtiovf_pci_get_data_size,
1309 };
1310 
1311 void virtiovf_set_migratable(struct virtiovf_pci_core_device *virtvdev)
1312 {
1313 	virtvdev->migrate_cap = 1;
1314 	mutex_init(&virtvdev->state_mutex);
1315 	spin_lock_init(&virtvdev->reset_lock);
1316 	virtvdev->core_device.vdev.migration_flags =
1317 		VFIO_MIGRATION_STOP_COPY |
1318 		VFIO_MIGRATION_P2P |
1319 		VFIO_MIGRATION_PRE_COPY;
1320 	virtvdev->core_device.vdev.mig_ops = &virtvdev_pci_mig_ops;
1321 }
1322 
1323 void virtiovf_open_migration(struct virtiovf_pci_core_device *virtvdev)
1324 {
1325 	if (!virtvdev->migrate_cap)
1326 		return;
1327 
1328 	virtvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
1329 }
1330 
1331 void virtiovf_close_migration(struct virtiovf_pci_core_device *virtvdev)
1332 {
1333 	if (!virtvdev->migrate_cap)
1334 		return;
1335 
1336 	virtiovf_disable_fds(virtvdev);
1337 }
1338