xref: /linux/drivers/vhost/vdpa.c (revision 97733180fafbeb7cc3fd1c8be60d05980615f5d6)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13 
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/slab.h>
20 #include <linux/iommu.h>
21 #include <linux/uuid.h>
22 #include <linux/vdpa.h>
23 #include <linux/nospec.h>
24 #include <linux/vhost.h>
25 
26 #include "vhost.h"
27 
28 enum {
29 	VHOST_VDPA_BACKEND_FEATURES =
30 	(1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31 	(1ULL << VHOST_BACKEND_F_IOTLB_BATCH),
32 };
33 
34 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
35 
36 struct vhost_vdpa {
37 	struct vhost_dev vdev;
38 	struct iommu_domain *domain;
39 	struct vhost_virtqueue *vqs;
40 	struct completion completion;
41 	struct vdpa_device *vdpa;
42 	struct device dev;
43 	struct cdev cdev;
44 	atomic_t opened;
45 	int nvqs;
46 	int virtio_id;
47 	int minor;
48 	struct eventfd_ctx *config_ctx;
49 	int in_batch;
50 	struct vdpa_iova_range range;
51 };
52 
53 static DEFINE_IDA(vhost_vdpa_ida);
54 
55 static dev_t vhost_vdpa_major;
56 
57 static void handle_vq_kick(struct vhost_work *work)
58 {
59 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
60 						  poll.work);
61 	struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
62 	const struct vdpa_config_ops *ops = v->vdpa->config;
63 
64 	ops->kick_vq(v->vdpa, vq - v->vqs);
65 }
66 
67 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
68 {
69 	struct vhost_virtqueue *vq = private;
70 	struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
71 
72 	if (call_ctx)
73 		eventfd_signal(call_ctx, 1);
74 
75 	return IRQ_HANDLED;
76 }
77 
78 static irqreturn_t vhost_vdpa_config_cb(void *private)
79 {
80 	struct vhost_vdpa *v = private;
81 	struct eventfd_ctx *config_ctx = v->config_ctx;
82 
83 	if (config_ctx)
84 		eventfd_signal(config_ctx, 1);
85 
86 	return IRQ_HANDLED;
87 }
88 
89 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
90 {
91 	struct vhost_virtqueue *vq = &v->vqs[qid];
92 	const struct vdpa_config_ops *ops = v->vdpa->config;
93 	struct vdpa_device *vdpa = v->vdpa;
94 	int ret, irq;
95 
96 	if (!ops->get_vq_irq)
97 		return;
98 
99 	irq = ops->get_vq_irq(vdpa, qid);
100 	irq_bypass_unregister_producer(&vq->call_ctx.producer);
101 	if (!vq->call_ctx.ctx || irq < 0)
102 		return;
103 
104 	vq->call_ctx.producer.token = vq->call_ctx.ctx;
105 	vq->call_ctx.producer.irq = irq;
106 	ret = irq_bypass_register_producer(&vq->call_ctx.producer);
107 	if (unlikely(ret))
108 		dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
109 			 qid, vq->call_ctx.producer.token, ret);
110 }
111 
112 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
113 {
114 	struct vhost_virtqueue *vq = &v->vqs[qid];
115 
116 	irq_bypass_unregister_producer(&vq->call_ctx.producer);
117 }
118 
119 static int vhost_vdpa_reset(struct vhost_vdpa *v)
120 {
121 	struct vdpa_device *vdpa = v->vdpa;
122 
123 	v->in_batch = 0;
124 
125 	return vdpa_reset(vdpa);
126 }
127 
128 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
129 {
130 	struct vdpa_device *vdpa = v->vdpa;
131 	const struct vdpa_config_ops *ops = vdpa->config;
132 	u32 device_id;
133 
134 	device_id = ops->get_device_id(vdpa);
135 
136 	if (copy_to_user(argp, &device_id, sizeof(device_id)))
137 		return -EFAULT;
138 
139 	return 0;
140 }
141 
142 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
143 {
144 	struct vdpa_device *vdpa = v->vdpa;
145 	const struct vdpa_config_ops *ops = vdpa->config;
146 	u8 status;
147 
148 	status = ops->get_status(vdpa);
149 
150 	if (copy_to_user(statusp, &status, sizeof(status)))
151 		return -EFAULT;
152 
153 	return 0;
154 }
155 
156 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
157 {
158 	struct vdpa_device *vdpa = v->vdpa;
159 	const struct vdpa_config_ops *ops = vdpa->config;
160 	u8 status, status_old;
161 	int ret, nvqs = v->nvqs;
162 	u16 i;
163 
164 	if (copy_from_user(&status, statusp, sizeof(status)))
165 		return -EFAULT;
166 
167 	status_old = ops->get_status(vdpa);
168 
169 	/*
170 	 * Userspace shouldn't remove status bits unless reset the
171 	 * status to 0.
172 	 */
173 	if (status != 0 && (status_old & ~status) != 0)
174 		return -EINVAL;
175 
176 	if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
177 		for (i = 0; i < nvqs; i++)
178 			vhost_vdpa_unsetup_vq_irq(v, i);
179 
180 	if (status == 0) {
181 		ret = vdpa_reset(vdpa);
182 		if (ret)
183 			return ret;
184 	} else
185 		vdpa_set_status(vdpa, status);
186 
187 	if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
188 		for (i = 0; i < nvqs; i++)
189 			vhost_vdpa_setup_vq_irq(v, i);
190 
191 	return 0;
192 }
193 
194 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
195 				      struct vhost_vdpa_config *c)
196 {
197 	struct vdpa_device *vdpa = v->vdpa;
198 	size_t size = vdpa->config->get_config_size(vdpa);
199 
200 	if (c->len == 0 || c->off > size)
201 		return -EINVAL;
202 
203 	if (c->len > size - c->off)
204 		return -E2BIG;
205 
206 	return 0;
207 }
208 
209 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
210 				  struct vhost_vdpa_config __user *c)
211 {
212 	struct vdpa_device *vdpa = v->vdpa;
213 	struct vhost_vdpa_config config;
214 	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
215 	u8 *buf;
216 
217 	if (copy_from_user(&config, c, size))
218 		return -EFAULT;
219 	if (vhost_vdpa_config_validate(v, &config))
220 		return -EINVAL;
221 	buf = kvzalloc(config.len, GFP_KERNEL);
222 	if (!buf)
223 		return -ENOMEM;
224 
225 	vdpa_get_config(vdpa, config.off, buf, config.len);
226 
227 	if (copy_to_user(c->buf, buf, config.len)) {
228 		kvfree(buf);
229 		return -EFAULT;
230 	}
231 
232 	kvfree(buf);
233 	return 0;
234 }
235 
236 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
237 				  struct vhost_vdpa_config __user *c)
238 {
239 	struct vdpa_device *vdpa = v->vdpa;
240 	struct vhost_vdpa_config config;
241 	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
242 	u8 *buf;
243 
244 	if (copy_from_user(&config, c, size))
245 		return -EFAULT;
246 	if (vhost_vdpa_config_validate(v, &config))
247 		return -EINVAL;
248 
249 	buf = vmemdup_user(c->buf, config.len);
250 	if (IS_ERR(buf))
251 		return PTR_ERR(buf);
252 
253 	vdpa_set_config(vdpa, config.off, buf, config.len);
254 
255 	kvfree(buf);
256 	return 0;
257 }
258 
259 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
260 {
261 	struct vdpa_device *vdpa = v->vdpa;
262 	const struct vdpa_config_ops *ops = vdpa->config;
263 	u64 features;
264 
265 	features = ops->get_device_features(vdpa);
266 
267 	if (copy_to_user(featurep, &features, sizeof(features)))
268 		return -EFAULT;
269 
270 	return 0;
271 }
272 
273 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
274 {
275 	struct vdpa_device *vdpa = v->vdpa;
276 	const struct vdpa_config_ops *ops = vdpa->config;
277 	u64 features;
278 
279 	/*
280 	 * It's not allowed to change the features after they have
281 	 * been negotiated.
282 	 */
283 	if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
284 		return -EBUSY;
285 
286 	if (copy_from_user(&features, featurep, sizeof(features)))
287 		return -EFAULT;
288 
289 	if (vdpa_set_features(vdpa, features, false))
290 		return -EINVAL;
291 
292 	return 0;
293 }
294 
295 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
296 {
297 	struct vdpa_device *vdpa = v->vdpa;
298 	const struct vdpa_config_ops *ops = vdpa->config;
299 	u16 num;
300 
301 	num = ops->get_vq_num_max(vdpa);
302 
303 	if (copy_to_user(argp, &num, sizeof(num)))
304 		return -EFAULT;
305 
306 	return 0;
307 }
308 
309 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
310 {
311 	if (v->config_ctx) {
312 		eventfd_ctx_put(v->config_ctx);
313 		v->config_ctx = NULL;
314 	}
315 }
316 
317 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
318 {
319 	struct vdpa_callback cb;
320 	int fd;
321 	struct eventfd_ctx *ctx;
322 
323 	cb.callback = vhost_vdpa_config_cb;
324 	cb.private = v;
325 	if (copy_from_user(&fd, argp, sizeof(fd)))
326 		return  -EFAULT;
327 
328 	ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
329 	swap(ctx, v->config_ctx);
330 
331 	if (!IS_ERR_OR_NULL(ctx))
332 		eventfd_ctx_put(ctx);
333 
334 	if (IS_ERR(v->config_ctx)) {
335 		long ret = PTR_ERR(v->config_ctx);
336 
337 		v->config_ctx = NULL;
338 		return ret;
339 	}
340 
341 	v->vdpa->config->set_config_cb(v->vdpa, &cb);
342 
343 	return 0;
344 }
345 
346 static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
347 {
348 	struct vhost_vdpa_iova_range range = {
349 		.first = v->range.first,
350 		.last = v->range.last,
351 	};
352 
353 	if (copy_to_user(argp, &range, sizeof(range)))
354 		return -EFAULT;
355 	return 0;
356 }
357 
358 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
359 				   void __user *argp)
360 {
361 	struct vdpa_device *vdpa = v->vdpa;
362 	const struct vdpa_config_ops *ops = vdpa->config;
363 	struct vdpa_vq_state vq_state;
364 	struct vdpa_callback cb;
365 	struct vhost_virtqueue *vq;
366 	struct vhost_vring_state s;
367 	u32 idx;
368 	long r;
369 
370 	r = get_user(idx, (u32 __user *)argp);
371 	if (r < 0)
372 		return r;
373 
374 	if (idx >= v->nvqs)
375 		return -ENOBUFS;
376 
377 	idx = array_index_nospec(idx, v->nvqs);
378 	vq = &v->vqs[idx];
379 
380 	switch (cmd) {
381 	case VHOST_VDPA_SET_VRING_ENABLE:
382 		if (copy_from_user(&s, argp, sizeof(s)))
383 			return -EFAULT;
384 		ops->set_vq_ready(vdpa, idx, s.num);
385 		return 0;
386 	case VHOST_GET_VRING_BASE:
387 		r = ops->get_vq_state(v->vdpa, idx, &vq_state);
388 		if (r)
389 			return r;
390 
391 		vq->last_avail_idx = vq_state.split.avail_index;
392 		break;
393 	}
394 
395 	r = vhost_vring_ioctl(&v->vdev, cmd, argp);
396 	if (r)
397 		return r;
398 
399 	switch (cmd) {
400 	case VHOST_SET_VRING_ADDR:
401 		if (ops->set_vq_address(vdpa, idx,
402 					(u64)(uintptr_t)vq->desc,
403 					(u64)(uintptr_t)vq->avail,
404 					(u64)(uintptr_t)vq->used))
405 			r = -EINVAL;
406 		break;
407 
408 	case VHOST_SET_VRING_BASE:
409 		vq_state.split.avail_index = vq->last_avail_idx;
410 		if (ops->set_vq_state(vdpa, idx, &vq_state))
411 			r = -EINVAL;
412 		break;
413 
414 	case VHOST_SET_VRING_CALL:
415 		if (vq->call_ctx.ctx) {
416 			cb.callback = vhost_vdpa_virtqueue_cb;
417 			cb.private = vq;
418 		} else {
419 			cb.callback = NULL;
420 			cb.private = NULL;
421 		}
422 		ops->set_vq_cb(vdpa, idx, &cb);
423 		vhost_vdpa_setup_vq_irq(v, idx);
424 		break;
425 
426 	case VHOST_SET_VRING_NUM:
427 		ops->set_vq_num(vdpa, idx, vq->num);
428 		break;
429 	}
430 
431 	return r;
432 }
433 
434 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
435 				      unsigned int cmd, unsigned long arg)
436 {
437 	struct vhost_vdpa *v = filep->private_data;
438 	struct vhost_dev *d = &v->vdev;
439 	void __user *argp = (void __user *)arg;
440 	u64 __user *featurep = argp;
441 	u64 features;
442 	long r = 0;
443 
444 	if (cmd == VHOST_SET_BACKEND_FEATURES) {
445 		if (copy_from_user(&features, featurep, sizeof(features)))
446 			return -EFAULT;
447 		if (features & ~VHOST_VDPA_BACKEND_FEATURES)
448 			return -EOPNOTSUPP;
449 		vhost_set_backend_features(&v->vdev, features);
450 		return 0;
451 	}
452 
453 	mutex_lock(&d->mutex);
454 
455 	switch (cmd) {
456 	case VHOST_VDPA_GET_DEVICE_ID:
457 		r = vhost_vdpa_get_device_id(v, argp);
458 		break;
459 	case VHOST_VDPA_GET_STATUS:
460 		r = vhost_vdpa_get_status(v, argp);
461 		break;
462 	case VHOST_VDPA_SET_STATUS:
463 		r = vhost_vdpa_set_status(v, argp);
464 		break;
465 	case VHOST_VDPA_GET_CONFIG:
466 		r = vhost_vdpa_get_config(v, argp);
467 		break;
468 	case VHOST_VDPA_SET_CONFIG:
469 		r = vhost_vdpa_set_config(v, argp);
470 		break;
471 	case VHOST_GET_FEATURES:
472 		r = vhost_vdpa_get_features(v, argp);
473 		break;
474 	case VHOST_SET_FEATURES:
475 		r = vhost_vdpa_set_features(v, argp);
476 		break;
477 	case VHOST_VDPA_GET_VRING_NUM:
478 		r = vhost_vdpa_get_vring_num(v, argp);
479 		break;
480 	case VHOST_SET_LOG_BASE:
481 	case VHOST_SET_LOG_FD:
482 		r = -ENOIOCTLCMD;
483 		break;
484 	case VHOST_VDPA_SET_CONFIG_CALL:
485 		r = vhost_vdpa_set_config_call(v, argp);
486 		break;
487 	case VHOST_GET_BACKEND_FEATURES:
488 		features = VHOST_VDPA_BACKEND_FEATURES;
489 		if (copy_to_user(featurep, &features, sizeof(features)))
490 			r = -EFAULT;
491 		break;
492 	case VHOST_VDPA_GET_IOVA_RANGE:
493 		r = vhost_vdpa_get_iova_range(v, argp);
494 		break;
495 	default:
496 		r = vhost_dev_ioctl(&v->vdev, cmd, argp);
497 		if (r == -ENOIOCTLCMD)
498 			r = vhost_vdpa_vring_ioctl(v, cmd, argp);
499 		break;
500 	}
501 
502 	mutex_unlock(&d->mutex);
503 	return r;
504 }
505 
506 static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, u64 start, u64 last)
507 {
508 	struct vhost_dev *dev = &v->vdev;
509 	struct vhost_iotlb *iotlb = dev->iotlb;
510 	struct vhost_iotlb_map *map;
511 	struct page *page;
512 	unsigned long pfn, pinned;
513 
514 	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
515 		pinned = PFN_DOWN(map->size);
516 		for (pfn = PFN_DOWN(map->addr);
517 		     pinned > 0; pfn++, pinned--) {
518 			page = pfn_to_page(pfn);
519 			if (map->perm & VHOST_ACCESS_WO)
520 				set_page_dirty_lock(page);
521 			unpin_user_page(page);
522 		}
523 		atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
524 		vhost_iotlb_map_free(iotlb, map);
525 	}
526 }
527 
528 static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, u64 start, u64 last)
529 {
530 	struct vhost_dev *dev = &v->vdev;
531 	struct vhost_iotlb *iotlb = dev->iotlb;
532 	struct vhost_iotlb_map *map;
533 	struct vdpa_map_file *map_file;
534 
535 	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
536 		map_file = (struct vdpa_map_file *)map->opaque;
537 		fput(map_file->file);
538 		kfree(map_file);
539 		vhost_iotlb_map_free(iotlb, map);
540 	}
541 }
542 
543 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
544 {
545 	struct vdpa_device *vdpa = v->vdpa;
546 
547 	if (vdpa->use_va)
548 		return vhost_vdpa_va_unmap(v, start, last);
549 
550 	return vhost_vdpa_pa_unmap(v, start, last);
551 }
552 
553 static void vhost_vdpa_iotlb_free(struct vhost_vdpa *v)
554 {
555 	struct vhost_dev *dev = &v->vdev;
556 
557 	vhost_vdpa_iotlb_unmap(v, 0ULL, 0ULL - 1);
558 	kfree(dev->iotlb);
559 	dev->iotlb = NULL;
560 }
561 
562 static int perm_to_iommu_flags(u32 perm)
563 {
564 	int flags = 0;
565 
566 	switch (perm) {
567 	case VHOST_ACCESS_WO:
568 		flags |= IOMMU_WRITE;
569 		break;
570 	case VHOST_ACCESS_RO:
571 		flags |= IOMMU_READ;
572 		break;
573 	case VHOST_ACCESS_RW:
574 		flags |= (IOMMU_WRITE | IOMMU_READ);
575 		break;
576 	default:
577 		WARN(1, "invalidate vhost IOTLB permission\n");
578 		break;
579 	}
580 
581 	return flags | IOMMU_CACHE;
582 }
583 
584 static int vhost_vdpa_map(struct vhost_vdpa *v, u64 iova,
585 			  u64 size, u64 pa, u32 perm, void *opaque)
586 {
587 	struct vhost_dev *dev = &v->vdev;
588 	struct vdpa_device *vdpa = v->vdpa;
589 	const struct vdpa_config_ops *ops = vdpa->config;
590 	int r = 0;
591 
592 	r = vhost_iotlb_add_range_ctx(dev->iotlb, iova, iova + size - 1,
593 				      pa, perm, opaque);
594 	if (r)
595 		return r;
596 
597 	if (ops->dma_map) {
598 		r = ops->dma_map(vdpa, iova, size, pa, perm, opaque);
599 	} else if (ops->set_map) {
600 		if (!v->in_batch)
601 			r = ops->set_map(vdpa, dev->iotlb);
602 	} else {
603 		r = iommu_map(v->domain, iova, pa, size,
604 			      perm_to_iommu_flags(perm));
605 	}
606 	if (r) {
607 		vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
608 		return r;
609 	}
610 
611 	if (!vdpa->use_va)
612 		atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
613 
614 	return 0;
615 }
616 
617 static void vhost_vdpa_unmap(struct vhost_vdpa *v, u64 iova, u64 size)
618 {
619 	struct vhost_dev *dev = &v->vdev;
620 	struct vdpa_device *vdpa = v->vdpa;
621 	const struct vdpa_config_ops *ops = vdpa->config;
622 
623 	vhost_vdpa_iotlb_unmap(v, iova, iova + size - 1);
624 
625 	if (ops->dma_map) {
626 		ops->dma_unmap(vdpa, iova, size);
627 	} else if (ops->set_map) {
628 		if (!v->in_batch)
629 			ops->set_map(vdpa, dev->iotlb);
630 	} else {
631 		iommu_unmap(v->domain, iova, size);
632 	}
633 }
634 
635 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
636 			     u64 iova, u64 size, u64 uaddr, u32 perm)
637 {
638 	struct vhost_dev *dev = &v->vdev;
639 	u64 offset, map_size, map_iova = iova;
640 	struct vdpa_map_file *map_file;
641 	struct vm_area_struct *vma;
642 	int ret = 0;
643 
644 	mmap_read_lock(dev->mm);
645 
646 	while (size) {
647 		vma = find_vma(dev->mm, uaddr);
648 		if (!vma) {
649 			ret = -EINVAL;
650 			break;
651 		}
652 		map_size = min(size, vma->vm_end - uaddr);
653 		if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
654 			!(vma->vm_flags & (VM_IO | VM_PFNMAP))))
655 			goto next;
656 
657 		map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
658 		if (!map_file) {
659 			ret = -ENOMEM;
660 			break;
661 		}
662 		offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
663 		map_file->offset = offset;
664 		map_file->file = get_file(vma->vm_file);
665 		ret = vhost_vdpa_map(v, map_iova, map_size, uaddr,
666 				     perm, map_file);
667 		if (ret) {
668 			fput(map_file->file);
669 			kfree(map_file);
670 			break;
671 		}
672 next:
673 		size -= map_size;
674 		uaddr += map_size;
675 		map_iova += map_size;
676 	}
677 	if (ret)
678 		vhost_vdpa_unmap(v, iova, map_iova - iova);
679 
680 	mmap_read_unlock(dev->mm);
681 
682 	return ret;
683 }
684 
685 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
686 			     u64 iova, u64 size, u64 uaddr, u32 perm)
687 {
688 	struct vhost_dev *dev = &v->vdev;
689 	struct page **page_list;
690 	unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
691 	unsigned int gup_flags = FOLL_LONGTERM;
692 	unsigned long npages, cur_base, map_pfn, last_pfn = 0;
693 	unsigned long lock_limit, sz2pin, nchunks, i;
694 	u64 start = iova;
695 	long pinned;
696 	int ret = 0;
697 
698 	/* Limit the use of memory for bookkeeping */
699 	page_list = (struct page **) __get_free_page(GFP_KERNEL);
700 	if (!page_list)
701 		return -ENOMEM;
702 
703 	if (perm & VHOST_ACCESS_WO)
704 		gup_flags |= FOLL_WRITE;
705 
706 	npages = PFN_UP(size + (iova & ~PAGE_MASK));
707 	if (!npages) {
708 		ret = -EINVAL;
709 		goto free;
710 	}
711 
712 	mmap_read_lock(dev->mm);
713 
714 	lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
715 	if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
716 		ret = -ENOMEM;
717 		goto unlock;
718 	}
719 
720 	cur_base = uaddr & PAGE_MASK;
721 	iova &= PAGE_MASK;
722 	nchunks = 0;
723 
724 	while (npages) {
725 		sz2pin = min_t(unsigned long, npages, list_size);
726 		pinned = pin_user_pages(cur_base, sz2pin,
727 					gup_flags, page_list, NULL);
728 		if (sz2pin != pinned) {
729 			if (pinned < 0) {
730 				ret = pinned;
731 			} else {
732 				unpin_user_pages(page_list, pinned);
733 				ret = -ENOMEM;
734 			}
735 			goto out;
736 		}
737 		nchunks++;
738 
739 		if (!last_pfn)
740 			map_pfn = page_to_pfn(page_list[0]);
741 
742 		for (i = 0; i < pinned; i++) {
743 			unsigned long this_pfn = page_to_pfn(page_list[i]);
744 			u64 csize;
745 
746 			if (last_pfn && (this_pfn != last_pfn + 1)) {
747 				/* Pin a contiguous chunk of memory */
748 				csize = PFN_PHYS(last_pfn - map_pfn + 1);
749 				ret = vhost_vdpa_map(v, iova, csize,
750 						     PFN_PHYS(map_pfn),
751 						     perm, NULL);
752 				if (ret) {
753 					/*
754 					 * Unpin the pages that are left unmapped
755 					 * from this point on in the current
756 					 * page_list. The remaining outstanding
757 					 * ones which may stride across several
758 					 * chunks will be covered in the common
759 					 * error path subsequently.
760 					 */
761 					unpin_user_pages(&page_list[i],
762 							 pinned - i);
763 					goto out;
764 				}
765 
766 				map_pfn = this_pfn;
767 				iova += csize;
768 				nchunks = 0;
769 			}
770 
771 			last_pfn = this_pfn;
772 		}
773 
774 		cur_base += PFN_PHYS(pinned);
775 		npages -= pinned;
776 	}
777 
778 	/* Pin the rest chunk */
779 	ret = vhost_vdpa_map(v, iova, PFN_PHYS(last_pfn - map_pfn + 1),
780 			     PFN_PHYS(map_pfn), perm, NULL);
781 out:
782 	if (ret) {
783 		if (nchunks) {
784 			unsigned long pfn;
785 
786 			/*
787 			 * Unpin the outstanding pages which are yet to be
788 			 * mapped but haven't due to vdpa_map() or
789 			 * pin_user_pages() failure.
790 			 *
791 			 * Mapped pages are accounted in vdpa_map(), hence
792 			 * the corresponding unpinning will be handled by
793 			 * vdpa_unmap().
794 			 */
795 			WARN_ON(!last_pfn);
796 			for (pfn = map_pfn; pfn <= last_pfn; pfn++)
797 				unpin_user_page(pfn_to_page(pfn));
798 		}
799 		vhost_vdpa_unmap(v, start, size);
800 	}
801 unlock:
802 	mmap_read_unlock(dev->mm);
803 free:
804 	free_page((unsigned long)page_list);
805 	return ret;
806 
807 }
808 
809 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
810 					   struct vhost_iotlb_msg *msg)
811 {
812 	struct vhost_dev *dev = &v->vdev;
813 	struct vdpa_device *vdpa = v->vdpa;
814 	struct vhost_iotlb *iotlb = dev->iotlb;
815 
816 	if (msg->iova < v->range.first || !msg->size ||
817 	    msg->iova > U64_MAX - msg->size + 1 ||
818 	    msg->iova + msg->size - 1 > v->range.last)
819 		return -EINVAL;
820 
821 	if (vhost_iotlb_itree_first(iotlb, msg->iova,
822 				    msg->iova + msg->size - 1))
823 		return -EEXIST;
824 
825 	if (vdpa->use_va)
826 		return vhost_vdpa_va_map(v, msg->iova, msg->size,
827 					 msg->uaddr, msg->perm);
828 
829 	return vhost_vdpa_pa_map(v, msg->iova, msg->size, msg->uaddr,
830 				 msg->perm);
831 }
832 
833 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev,
834 					struct vhost_iotlb_msg *msg)
835 {
836 	struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
837 	struct vdpa_device *vdpa = v->vdpa;
838 	const struct vdpa_config_ops *ops = vdpa->config;
839 	int r = 0;
840 
841 	mutex_lock(&dev->mutex);
842 
843 	r = vhost_dev_check_owner(dev);
844 	if (r)
845 		goto unlock;
846 
847 	switch (msg->type) {
848 	case VHOST_IOTLB_UPDATE:
849 		r = vhost_vdpa_process_iotlb_update(v, msg);
850 		break;
851 	case VHOST_IOTLB_INVALIDATE:
852 		vhost_vdpa_unmap(v, msg->iova, msg->size);
853 		break;
854 	case VHOST_IOTLB_BATCH_BEGIN:
855 		v->in_batch = true;
856 		break;
857 	case VHOST_IOTLB_BATCH_END:
858 		if (v->in_batch && ops->set_map)
859 			ops->set_map(vdpa, dev->iotlb);
860 		v->in_batch = false;
861 		break;
862 	default:
863 		r = -EINVAL;
864 		break;
865 	}
866 unlock:
867 	mutex_unlock(&dev->mutex);
868 
869 	return r;
870 }
871 
872 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
873 					 struct iov_iter *from)
874 {
875 	struct file *file = iocb->ki_filp;
876 	struct vhost_vdpa *v = file->private_data;
877 	struct vhost_dev *dev = &v->vdev;
878 
879 	return vhost_chr_write_iter(dev, from);
880 }
881 
882 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
883 {
884 	struct vdpa_device *vdpa = v->vdpa;
885 	const struct vdpa_config_ops *ops = vdpa->config;
886 	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
887 	struct bus_type *bus;
888 	int ret;
889 
890 	/* Device want to do DMA by itself */
891 	if (ops->set_map || ops->dma_map)
892 		return 0;
893 
894 	bus = dma_dev->bus;
895 	if (!bus)
896 		return -EFAULT;
897 
898 	if (!iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
899 		return -ENOTSUPP;
900 
901 	v->domain = iommu_domain_alloc(bus);
902 	if (!v->domain)
903 		return -EIO;
904 
905 	ret = iommu_attach_device(v->domain, dma_dev);
906 	if (ret)
907 		goto err_attach;
908 
909 	return 0;
910 
911 err_attach:
912 	iommu_domain_free(v->domain);
913 	return ret;
914 }
915 
916 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
917 {
918 	struct vdpa_device *vdpa = v->vdpa;
919 	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
920 
921 	if (v->domain) {
922 		iommu_detach_device(v->domain, dma_dev);
923 		iommu_domain_free(v->domain);
924 	}
925 
926 	v->domain = NULL;
927 }
928 
929 static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
930 {
931 	struct vdpa_iova_range *range = &v->range;
932 	struct vdpa_device *vdpa = v->vdpa;
933 	const struct vdpa_config_ops *ops = vdpa->config;
934 
935 	if (ops->get_iova_range) {
936 		*range = ops->get_iova_range(vdpa);
937 	} else if (v->domain && v->domain->geometry.force_aperture) {
938 		range->first = v->domain->geometry.aperture_start;
939 		range->last = v->domain->geometry.aperture_end;
940 	} else {
941 		range->first = 0;
942 		range->last = ULLONG_MAX;
943 	}
944 }
945 
946 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
947 {
948 	struct vhost_vdpa *v;
949 	struct vhost_dev *dev;
950 	struct vhost_virtqueue **vqs;
951 	int nvqs, i, r, opened;
952 
953 	v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
954 
955 	opened = atomic_cmpxchg(&v->opened, 0, 1);
956 	if (opened)
957 		return -EBUSY;
958 
959 	nvqs = v->nvqs;
960 	r = vhost_vdpa_reset(v);
961 	if (r)
962 		goto err;
963 
964 	vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
965 	if (!vqs) {
966 		r = -ENOMEM;
967 		goto err;
968 	}
969 
970 	dev = &v->vdev;
971 	for (i = 0; i < nvqs; i++) {
972 		vqs[i] = &v->vqs[i];
973 		vqs[i]->handle_kick = handle_vq_kick;
974 	}
975 	vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
976 		       vhost_vdpa_process_iotlb_msg);
977 
978 	dev->iotlb = vhost_iotlb_alloc(0, 0);
979 	if (!dev->iotlb) {
980 		r = -ENOMEM;
981 		goto err_init_iotlb;
982 	}
983 
984 	r = vhost_vdpa_alloc_domain(v);
985 	if (r)
986 		goto err_init_iotlb;
987 
988 	vhost_vdpa_set_iova_range(v);
989 
990 	filep->private_data = v;
991 
992 	return 0;
993 
994 err_init_iotlb:
995 	vhost_dev_cleanup(&v->vdev);
996 	kfree(vqs);
997 err:
998 	atomic_dec(&v->opened);
999 	return r;
1000 }
1001 
1002 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1003 {
1004 	int i;
1005 
1006 	for (i = 0; i < v->nvqs; i++)
1007 		vhost_vdpa_unsetup_vq_irq(v, i);
1008 }
1009 
1010 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1011 {
1012 	struct vhost_vdpa *v = filep->private_data;
1013 	struct vhost_dev *d = &v->vdev;
1014 
1015 	mutex_lock(&d->mutex);
1016 	filep->private_data = NULL;
1017 	vhost_vdpa_clean_irq(v);
1018 	vhost_vdpa_reset(v);
1019 	vhost_dev_stop(&v->vdev);
1020 	vhost_vdpa_iotlb_free(v);
1021 	vhost_vdpa_free_domain(v);
1022 	vhost_vdpa_config_put(v);
1023 	vhost_dev_cleanup(&v->vdev);
1024 	kfree(v->vdev.vqs);
1025 	mutex_unlock(&d->mutex);
1026 
1027 	atomic_dec(&v->opened);
1028 	complete(&v->completion);
1029 
1030 	return 0;
1031 }
1032 
1033 #ifdef CONFIG_MMU
1034 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1035 {
1036 	struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1037 	struct vdpa_device *vdpa = v->vdpa;
1038 	const struct vdpa_config_ops *ops = vdpa->config;
1039 	struct vdpa_notification_area notify;
1040 	struct vm_area_struct *vma = vmf->vma;
1041 	u16 index = vma->vm_pgoff;
1042 
1043 	notify = ops->get_vq_notification(vdpa, index);
1044 
1045 	vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
1046 	if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
1047 			    PFN_DOWN(notify.addr), PAGE_SIZE,
1048 			    vma->vm_page_prot))
1049 		return VM_FAULT_SIGBUS;
1050 
1051 	return VM_FAULT_NOPAGE;
1052 }
1053 
1054 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1055 	.fault = vhost_vdpa_fault,
1056 };
1057 
1058 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1059 {
1060 	struct vhost_vdpa *v = vma->vm_file->private_data;
1061 	struct vdpa_device *vdpa = v->vdpa;
1062 	const struct vdpa_config_ops *ops = vdpa->config;
1063 	struct vdpa_notification_area notify;
1064 	unsigned long index = vma->vm_pgoff;
1065 
1066 	if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1067 		return -EINVAL;
1068 	if ((vma->vm_flags & VM_SHARED) == 0)
1069 		return -EINVAL;
1070 	if (vma->vm_flags & VM_READ)
1071 		return -EINVAL;
1072 	if (index > 65535)
1073 		return -EINVAL;
1074 	if (!ops->get_vq_notification)
1075 		return -ENOTSUPP;
1076 
1077 	/* To be safe and easily modelled by userspace, We only
1078 	 * support the doorbell which sits on the page boundary and
1079 	 * does not share the page with other registers.
1080 	 */
1081 	notify = ops->get_vq_notification(vdpa, index);
1082 	if (notify.addr & (PAGE_SIZE - 1))
1083 		return -EINVAL;
1084 	if (vma->vm_end - vma->vm_start != notify.size)
1085 		return -ENOTSUPP;
1086 
1087 	vma->vm_flags |= VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP;
1088 	vma->vm_ops = &vhost_vdpa_vm_ops;
1089 	return 0;
1090 }
1091 #endif /* CONFIG_MMU */
1092 
1093 static const struct file_operations vhost_vdpa_fops = {
1094 	.owner		= THIS_MODULE,
1095 	.open		= vhost_vdpa_open,
1096 	.release	= vhost_vdpa_release,
1097 	.write_iter	= vhost_vdpa_chr_write_iter,
1098 	.unlocked_ioctl	= vhost_vdpa_unlocked_ioctl,
1099 #ifdef CONFIG_MMU
1100 	.mmap		= vhost_vdpa_mmap,
1101 #endif /* CONFIG_MMU */
1102 	.compat_ioctl	= compat_ptr_ioctl,
1103 };
1104 
1105 static void vhost_vdpa_release_dev(struct device *device)
1106 {
1107 	struct vhost_vdpa *v =
1108 	       container_of(device, struct vhost_vdpa, dev);
1109 
1110 	ida_simple_remove(&vhost_vdpa_ida, v->minor);
1111 	kfree(v->vqs);
1112 	kfree(v);
1113 }
1114 
1115 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1116 {
1117 	const struct vdpa_config_ops *ops = vdpa->config;
1118 	struct vhost_vdpa *v;
1119 	int minor;
1120 	int r;
1121 
1122 	v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1123 	if (!v)
1124 		return -ENOMEM;
1125 
1126 	minor = ida_simple_get(&vhost_vdpa_ida, 0,
1127 			       VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1128 	if (minor < 0) {
1129 		kfree(v);
1130 		return minor;
1131 	}
1132 
1133 	atomic_set(&v->opened, 0);
1134 	v->minor = minor;
1135 	v->vdpa = vdpa;
1136 	v->nvqs = vdpa->nvqs;
1137 	v->virtio_id = ops->get_device_id(vdpa);
1138 
1139 	device_initialize(&v->dev);
1140 	v->dev.release = vhost_vdpa_release_dev;
1141 	v->dev.parent = &vdpa->dev;
1142 	v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1143 	v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1144 			       GFP_KERNEL);
1145 	if (!v->vqs) {
1146 		r = -ENOMEM;
1147 		goto err;
1148 	}
1149 
1150 	r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1151 	if (r)
1152 		goto err;
1153 
1154 	cdev_init(&v->cdev, &vhost_vdpa_fops);
1155 	v->cdev.owner = THIS_MODULE;
1156 
1157 	r = cdev_device_add(&v->cdev, &v->dev);
1158 	if (r)
1159 		goto err;
1160 
1161 	init_completion(&v->completion);
1162 	vdpa_set_drvdata(vdpa, v);
1163 
1164 	return 0;
1165 
1166 err:
1167 	put_device(&v->dev);
1168 	return r;
1169 }
1170 
1171 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1172 {
1173 	struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1174 	int opened;
1175 
1176 	cdev_device_del(&v->cdev, &v->dev);
1177 
1178 	do {
1179 		opened = atomic_cmpxchg(&v->opened, 0, 1);
1180 		if (!opened)
1181 			break;
1182 		wait_for_completion(&v->completion);
1183 	} while (1);
1184 
1185 	put_device(&v->dev);
1186 }
1187 
1188 static struct vdpa_driver vhost_vdpa_driver = {
1189 	.driver = {
1190 		.name	= "vhost_vdpa",
1191 	},
1192 	.probe	= vhost_vdpa_probe,
1193 	.remove	= vhost_vdpa_remove,
1194 };
1195 
1196 static int __init vhost_vdpa_init(void)
1197 {
1198 	int r;
1199 
1200 	r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1201 				"vhost-vdpa");
1202 	if (r)
1203 		goto err_alloc_chrdev;
1204 
1205 	r = vdpa_register_driver(&vhost_vdpa_driver);
1206 	if (r)
1207 		goto err_vdpa_register_driver;
1208 
1209 	return 0;
1210 
1211 err_vdpa_register_driver:
1212 	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1213 err_alloc_chrdev:
1214 	return r;
1215 }
1216 module_init(vhost_vdpa_init);
1217 
1218 static void __exit vhost_vdpa_exit(void)
1219 {
1220 	vdpa_unregister_driver(&vhost_vdpa_driver);
1221 	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1222 }
1223 module_exit(vhost_vdpa_exit);
1224 
1225 MODULE_VERSION("0.0.1");
1226 MODULE_LICENSE("GPL v2");
1227 MODULE_AUTHOR("Intel Corporation");
1228 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");
1229