xref: /linux/drivers/vdpa/vdpa_user/vduse_dev.c (revision 455a2a1af92651764e9eb42cec0d95ac142afc28)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VDUSE: vDPA Device in Userspace
4  *
5  * Copyright (C) 2020-2021 Bytedance Inc. and/or its affiliates. All rights reserved.
6  *
7  * Author: Xie Yongji <xieyongji@bytedance.com>
8  *
9  */
10 
11 #include "linux/virtio_net.h"
12 #include <linux/cleanup.h>
13 #include <linux/init.h>
14 #include <linux/module.h>
15 #include <linux/cdev.h>
16 #include <linux/device.h>
17 #include <linux/eventfd.h>
18 #include <linux/slab.h>
19 #include <linux/wait.h>
20 #include <linux/dma-map-ops.h>
21 #include <linux/poll.h>
22 #include <linux/file.h>
23 #include <linux/uio.h>
24 #include <linux/vdpa.h>
25 #include <linux/nospec.h>
26 #include <linux/virtio.h>
27 #include <linux/vmalloc.h>
28 #include <linux/sched/mm.h>
29 #include <uapi/linux/vduse.h>
30 #include <uapi/linux/vdpa.h>
31 #include <uapi/linux/virtio_config.h>
32 #include <uapi/linux/virtio_ids.h>
33 #include <uapi/linux/virtio_blk.h>
34 #include <uapi/linux/virtio_ring.h>
35 #include <linux/mod_devicetable.h>
36 
37 #include "iova_domain.h"
38 
39 #define DRV_AUTHOR   "Yongji Xie <xieyongji@bytedance.com>"
40 #define DRV_DESC     "vDPA Device in Userspace"
41 #define DRV_LICENSE  "GPL v2"
42 
43 #define VDUSE_DEV_MAX (1U << MINORBITS)
44 #define VDUSE_DEV_MAX_GROUPS 0xffff
45 #define VDUSE_DEV_MAX_AS 0xffff
46 #define VDUSE_MAX_BOUNCE_SIZE (1024 * 1024 * 1024)
47 #define VDUSE_MIN_BOUNCE_SIZE (1024 * 1024)
48 #define VDUSE_BOUNCE_SIZE (64 * 1024 * 1024)
49 /* 128 MB reserved for virtqueue creation */
50 #define VDUSE_IOVA_SIZE (VDUSE_MAX_BOUNCE_SIZE + 128 * 1024 * 1024)
51 #define VDUSE_MSG_DEFAULT_TIMEOUT 30
52 
53 #define IRQ_UNBOUND -1
54 
55 /*
56  * VDUSE instance have not asked the vduse API version, so assume 0.
57  *
58  * Old devices may not ask for the device version and assume it is 0.  Keep
59  * this value for these.  From the moment the VDUSE instance ask for the
60  * version, convert to the latests supported one and continue regular flow
61  */
62 #define VDUSE_API_VERSION_NOT_ASKED U64_MAX
63 
64 struct vduse_virtqueue {
65 	u16 index;
66 	u16 num_max;
67 	u32 num;
68 	u64 desc_addr;
69 	u64 driver_addr;
70 	u64 device_addr;
71 	struct vdpa_vq_state state;
72 	bool ready;
73 	bool kicked;
74 	u32 group;
75 	spinlock_t kick_lock;
76 	spinlock_t irq_lock;
77 	struct eventfd_ctx *kickfd;
78 	struct vdpa_callback cb;
79 	struct work_struct inject;
80 	struct work_struct kick;
81 	int irq_effective_cpu;
82 	struct cpumask irq_affinity;
83 	struct kobject kobj;
84 };
85 
86 struct vduse_dev;
87 
88 struct vduse_vdpa {
89 	struct vdpa_device vdpa;
90 	struct vduse_dev *dev;
91 };
92 
93 struct vduse_umem {
94 	unsigned long iova;
95 	unsigned long npages;
96 	struct page **pages;
97 	struct mm_struct *mm;
98 };
99 
100 struct vduse_as {
101 	struct vduse_iova_domain *domain;
102 	struct vduse_umem *umem;
103 	struct mutex mem_lock;
104 };
105 
106 struct vduse_vq_group {
107 	rwlock_t as_lock;
108 	struct vduse_as *as; /* Protected by as_lock */
109 	struct vduse_dev *dev;
110 };
111 
112 struct vduse_dev {
113 	struct vduse_vdpa *vdev;
114 	struct device *dev;
115 	struct vduse_virtqueue **vqs;
116 	struct vduse_as *as;
117 	char *name;
118 	struct mutex lock;
119 	spinlock_t msg_lock;
120 	u64 msg_unique;
121 	u32 msg_timeout;
122 	wait_queue_head_t waitq;
123 	struct list_head send_list;
124 	struct list_head recv_list;
125 	struct vdpa_callback config_cb;
126 	struct work_struct inject;
127 	spinlock_t irq_lock;
128 	struct rw_semaphore rwsem;
129 	int minor;
130 	bool broken;
131 	bool connected;
132 	u64 api_version;
133 	u64 device_features;
134 	u64 driver_features;
135 	u32 device_id;
136 	u32 vendor_id;
137 	u32 generation;
138 	u32 config_size;
139 	void *config;
140 	u8 status;
141 	u32 vq_num;
142 	u32 vq_align;
143 	u32 ngroups;
144 	u32 nas;
145 	struct vduse_vq_group *groups;
146 	unsigned int bounce_size;
147 	struct mutex domain_lock;
148 };
149 
150 struct vduse_dev_msg {
151 	struct vduse_dev_request req;
152 	struct vduse_dev_response resp;
153 	struct list_head list;
154 	wait_queue_head_t waitq;
155 	bool completed;
156 };
157 
158 struct vduse_control {
159 	u64 api_version;
160 };
161 
162 static DEFINE_MUTEX(vduse_lock);
163 static DEFINE_IDR(vduse_idr);
164 
165 static dev_t vduse_major;
166 static struct cdev vduse_ctrl_cdev;
167 static struct cdev vduse_cdev;
168 static struct workqueue_struct *vduse_irq_wq;
169 static struct workqueue_struct *vduse_irq_bound_wq;
170 
171 static u32 allowed_device_id[] = {
172 	VIRTIO_ID_BLOCK,
173 	VIRTIO_ID_NET,
174 	VIRTIO_ID_FS,
175 };
176 
177 static inline struct vduse_dev *vdpa_to_vduse(struct vdpa_device *vdpa)
178 {
179 	struct vduse_vdpa *vdev = container_of(vdpa, struct vduse_vdpa, vdpa);
180 
181 	return vdev->dev;
182 }
183 
184 static inline struct vduse_dev *dev_to_vduse(struct device *dev)
185 {
186 	struct vdpa_device *vdpa = dev_to_vdpa(dev);
187 
188 	return vdpa_to_vduse(vdpa);
189 }
190 
191 static struct vduse_dev_msg *vduse_find_msg(struct list_head *head,
192 					    uint32_t request_id)
193 {
194 	struct vduse_dev_msg *msg;
195 
196 	list_for_each_entry(msg, head, list) {
197 		if (msg->req.request_id == request_id) {
198 			list_del(&msg->list);
199 			return msg;
200 		}
201 	}
202 
203 	return NULL;
204 }
205 
206 static struct vduse_dev_msg *vduse_dequeue_msg(struct list_head *head)
207 {
208 	struct vduse_dev_msg *msg = NULL;
209 
210 	if (!list_empty(head)) {
211 		msg = list_first_entry(head, struct vduse_dev_msg, list);
212 		list_del(&msg->list);
213 	}
214 
215 	return msg;
216 }
217 
218 static void vduse_enqueue_msg(struct list_head *head,
219 			      struct vduse_dev_msg *msg)
220 {
221 	list_add_tail(&msg->list, head);
222 }
223 
224 static void vduse_enqueue_msg_head(struct list_head *head,
225 				   struct vduse_dev_msg *msg)
226 {
227 	list_add(&msg->list, head);
228 }
229 
230 static void vduse_dev_broken(struct vduse_dev *dev)
231 {
232 	struct vduse_dev_msg *msg, *tmp;
233 
234 	if (unlikely(dev->broken))
235 		return;
236 
237 	list_splice_init(&dev->recv_list, &dev->send_list);
238 	list_for_each_entry_safe(msg, tmp, &dev->send_list, list) {
239 		list_del(&msg->list);
240 		msg->completed = 1;
241 		msg->resp.result = VDUSE_REQ_RESULT_FAILED;
242 		wake_up(&msg->waitq);
243 	}
244 	dev->broken = true;
245 	wake_up(&dev->waitq);
246 }
247 
248 static int vduse_dev_msg_sync(struct vduse_dev *dev,
249 			      struct vduse_dev_msg *msg)
250 {
251 	int ret;
252 
253 	if (unlikely(dev->broken))
254 		return -EIO;
255 
256 	init_waitqueue_head(&msg->waitq);
257 	spin_lock(&dev->msg_lock);
258 	if (unlikely(dev->broken)) {
259 		spin_unlock(&dev->msg_lock);
260 		return -EIO;
261 	}
262 	msg->req.request_id = dev->msg_unique++;
263 	vduse_enqueue_msg(&dev->send_list, msg);
264 	wake_up(&dev->waitq);
265 	spin_unlock(&dev->msg_lock);
266 	if (dev->msg_timeout)
267 		ret = wait_event_killable_timeout(msg->waitq, msg->completed,
268 						  (long)dev->msg_timeout * HZ);
269 	else
270 		ret = wait_event_killable(msg->waitq, msg->completed);
271 
272 	spin_lock(&dev->msg_lock);
273 	if (!msg->completed) {
274 		list_del(&msg->list);
275 		msg->resp.result = VDUSE_REQ_RESULT_FAILED;
276 		/* Mark the device as malfunction when there is a timeout */
277 		if (!ret)
278 			vduse_dev_broken(dev);
279 	}
280 	ret = (msg->resp.result == VDUSE_REQ_RESULT_OK) ? 0 : -EIO;
281 	spin_unlock(&dev->msg_lock);
282 
283 	return ret;
284 }
285 
286 static int vduse_dev_get_vq_state_packed(struct vduse_dev *dev,
287 					 struct vduse_virtqueue *vq,
288 					 struct vdpa_vq_state_packed *packed)
289 {
290 	struct vduse_dev_msg msg = { 0 };
291 	int ret;
292 
293 	msg.req.type = VDUSE_GET_VQ_STATE;
294 	msg.req.vq_state.index = vq->index;
295 
296 	ret = vduse_dev_msg_sync(dev, &msg);
297 	if (ret)
298 		return ret;
299 
300 	packed->last_avail_counter =
301 			msg.resp.vq_state.packed.last_avail_counter & 0x0001;
302 	packed->last_avail_idx =
303 			msg.resp.vq_state.packed.last_avail_idx & 0x7FFF;
304 	packed->last_used_counter =
305 			msg.resp.vq_state.packed.last_used_counter & 0x0001;
306 	packed->last_used_idx =
307 			msg.resp.vq_state.packed.last_used_idx & 0x7FFF;
308 
309 	return 0;
310 }
311 
312 static int vduse_dev_get_vq_state_split(struct vduse_dev *dev,
313 					struct vduse_virtqueue *vq,
314 					struct vdpa_vq_state_split *split)
315 {
316 	struct vduse_dev_msg msg = { 0 };
317 	int ret;
318 
319 	msg.req.type = VDUSE_GET_VQ_STATE;
320 	msg.req.vq_state.index = vq->index;
321 
322 	ret = vduse_dev_msg_sync(dev, &msg);
323 	if (ret)
324 		return ret;
325 
326 	split->avail_index = msg.resp.vq_state.split.avail_index;
327 
328 	return 0;
329 }
330 
331 static int vduse_dev_set_status(struct vduse_dev *dev, u8 status)
332 {
333 	struct vduse_dev_msg msg = { 0 };
334 
335 	msg.req.type = VDUSE_SET_STATUS;
336 	msg.req.s.status = status;
337 
338 	return vduse_dev_msg_sync(dev, &msg);
339 }
340 
341 static int vduse_dev_update_iotlb(struct vduse_dev *dev, u32 asid,
342 				  u64 start, u64 last)
343 {
344 	struct vduse_dev_msg msg = { 0 };
345 
346 	if (last < start)
347 		return -EINVAL;
348 
349 	msg.req.type = VDUSE_UPDATE_IOTLB;
350 	if (dev->api_version < VDUSE_API_VERSION_1) {
351 		msg.req.iova.start = start;
352 		msg.req.iova.last = last;
353 	} else {
354 		msg.req.iova_v2.start = start;
355 		msg.req.iova_v2.last = last;
356 		msg.req.iova_v2.asid = asid;
357 	}
358 
359 	return vduse_dev_msg_sync(dev, &msg);
360 }
361 
362 static ssize_t vduse_dev_read_iter(struct kiocb *iocb, struct iov_iter *to)
363 {
364 	struct file *file = iocb->ki_filp;
365 	struct vduse_dev *dev = file->private_data;
366 	struct vduse_dev_msg *msg;
367 	struct vduse_dev_request req;
368 	int size = sizeof(struct vduse_dev_request);
369 	ssize_t ret;
370 
371 	if (iov_iter_count(to) < size)
372 		return -EINVAL;
373 
374 	spin_lock(&dev->msg_lock);
375 	while (1) {
376 		msg = vduse_dequeue_msg(&dev->send_list);
377 		if (msg)
378 			break;
379 		spin_unlock(&dev->msg_lock);
380 
381 		if (file->f_flags & O_NONBLOCK)
382 			return -EAGAIN;
383 
384 		ret = wait_event_interruptible_exclusive(dev->waitq,
385 					!list_empty(&dev->send_list));
386 		if (ret)
387 			return ret;
388 
389 		spin_lock(&dev->msg_lock);
390 	}
391 
392 	memcpy(&req, &msg->req, sizeof(req));
393 	/*
394 	 * We must ensure vduse_msg is on send_list or recv_list before unlock
395 	 * dev->msg_lock. Because vduse_dev_msg_sync() may be timeout when we
396 	 * copy data to userspace, and will call list_del() for this msg.
397 	 */
398 	vduse_enqueue_msg(&dev->recv_list, msg);
399 	spin_unlock(&dev->msg_lock);
400 
401 	ret = copy_to_iter(&req, size, to);
402 	if (ret != size) {
403 		/*
404 		 * Roll back: move msg back to send_list if still pending.
405 		 *
406 		 * NOTE:
407 		 * vduse_find_msg() must use req.request_id instead of `msg`.
408 		 * A malicious userspace may reply to this request, and wake up
409 		 * the caller, after which `msg` will have already been freed.
410 		 * And here vduse_find_msg() will return NULL then do nothing.
411 		 */
412 		spin_lock(&dev->msg_lock);
413 		msg = vduse_find_msg(&dev->recv_list, req.request_id);
414 		if (msg)
415 			vduse_enqueue_msg_head(&dev->send_list, msg);
416 		spin_unlock(&dev->msg_lock);
417 		ret = -EFAULT;
418 	}
419 
420 	return ret;
421 }
422 
423 static bool is_mem_zero(const char *ptr, int size)
424 {
425 	int i;
426 
427 	for (i = 0; i < size; i++) {
428 		if (ptr[i])
429 			return false;
430 	}
431 	return true;
432 }
433 
434 static ssize_t vduse_dev_write_iter(struct kiocb *iocb, struct iov_iter *from)
435 {
436 	struct file *file = iocb->ki_filp;
437 	struct vduse_dev *dev = file->private_data;
438 	struct vduse_dev_response resp;
439 	struct vduse_dev_msg *msg;
440 	size_t ret;
441 
442 	ret = copy_from_iter(&resp, sizeof(resp), from);
443 	if (ret != sizeof(resp))
444 		return -EINVAL;
445 
446 	if (!is_mem_zero((const char *)resp.reserved, sizeof(resp.reserved)))
447 		return -EINVAL;
448 
449 	spin_lock(&dev->msg_lock);
450 	msg = vduse_find_msg(&dev->recv_list, resp.request_id);
451 	if (!msg) {
452 		ret = -ENOENT;
453 		goto unlock;
454 	}
455 
456 	memcpy(&msg->resp, &resp, sizeof(resp));
457 	msg->completed = 1;
458 	wake_up(&msg->waitq);
459 unlock:
460 	spin_unlock(&dev->msg_lock);
461 
462 	return ret;
463 }
464 
465 static __poll_t vduse_dev_poll(struct file *file, poll_table *wait)
466 {
467 	struct vduse_dev *dev = file->private_data;
468 	__poll_t mask = 0;
469 
470 	poll_wait(file, &dev->waitq, wait);
471 
472 	spin_lock(&dev->msg_lock);
473 
474 	if (unlikely(dev->broken))
475 		mask |= EPOLLERR;
476 	if (!list_empty(&dev->send_list))
477 		mask |= EPOLLIN | EPOLLRDNORM;
478 	if (!list_empty(&dev->recv_list))
479 		mask |= EPOLLOUT | EPOLLWRNORM;
480 
481 	spin_unlock(&dev->msg_lock);
482 
483 	return mask;
484 }
485 
486 static void vduse_dev_reset(struct vduse_dev *dev)
487 {
488 	int i;
489 
490 	/* The coherent mappings are handled in vduse_dev_free_coherent() */
491 	for (i = 0; i < dev->nas; i++) {
492 		struct vduse_iova_domain *domain = dev->as[i].domain;
493 
494 		if (domain && domain->bounce_map)
495 			vduse_domain_reset_bounce_map(domain);
496 	}
497 
498 	down_write(&dev->rwsem);
499 
500 	dev->status = 0;
501 	dev->driver_features = 0;
502 	dev->generation++;
503 	spin_lock(&dev->irq_lock);
504 	dev->config_cb.callback = NULL;
505 	dev->config_cb.private = NULL;
506 	spin_unlock(&dev->irq_lock);
507 	flush_work(&dev->inject);
508 
509 	for (i = 0; i < dev->vq_num; i++) {
510 		struct vduse_virtqueue *vq = dev->vqs[i];
511 
512 		vq->ready = false;
513 		vq->desc_addr = 0;
514 		vq->driver_addr = 0;
515 		vq->device_addr = 0;
516 		vq->num = 0;
517 		memset(&vq->state, 0, sizeof(vq->state));
518 
519 		spin_lock(&vq->kick_lock);
520 		vq->kicked = false;
521 		if (vq->kickfd)
522 			eventfd_ctx_put(vq->kickfd);
523 		vq->kickfd = NULL;
524 		spin_unlock(&vq->kick_lock);
525 
526 		spin_lock(&vq->irq_lock);
527 		vq->cb.callback = NULL;
528 		vq->cb.private = NULL;
529 		vq->cb.trigger = NULL;
530 		spin_unlock(&vq->irq_lock);
531 		flush_work(&vq->inject);
532 		flush_work(&vq->kick);
533 	}
534 
535 	up_write(&dev->rwsem);
536 }
537 
538 static int vduse_vdpa_set_vq_address(struct vdpa_device *vdpa, u16 idx,
539 				u64 desc_area, u64 driver_area,
540 				u64 device_area)
541 {
542 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
543 	struct vduse_virtqueue *vq = dev->vqs[idx];
544 
545 	vq->desc_addr = desc_area;
546 	vq->driver_addr = driver_area;
547 	vq->device_addr = device_area;
548 
549 	return 0;
550 }
551 
552 static void vduse_vq_kick(struct vduse_virtqueue *vq)
553 {
554 	spin_lock(&vq->kick_lock);
555 	if (!vq->ready)
556 		goto unlock;
557 
558 	if (vq->kickfd)
559 		eventfd_signal(vq->kickfd);
560 	else
561 		vq->kicked = true;
562 unlock:
563 	spin_unlock(&vq->kick_lock);
564 }
565 
566 static void vduse_vq_kick_work(struct work_struct *work)
567 {
568 	struct vduse_virtqueue *vq = container_of(work,
569 					struct vduse_virtqueue, kick);
570 
571 	vduse_vq_kick(vq);
572 }
573 
574 static void vduse_vdpa_kick_vq(struct vdpa_device *vdpa, u16 idx)
575 {
576 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
577 	struct vduse_virtqueue *vq = dev->vqs[idx];
578 
579 	if (!eventfd_signal_allowed()) {
580 		schedule_work(&vq->kick);
581 		return;
582 	}
583 	vduse_vq_kick(vq);
584 }
585 
586 static void vduse_vdpa_set_vq_cb(struct vdpa_device *vdpa, u16 idx,
587 			      struct vdpa_callback *cb)
588 {
589 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
590 	struct vduse_virtqueue *vq = dev->vqs[idx];
591 
592 	spin_lock(&vq->irq_lock);
593 	vq->cb.callback = cb->callback;
594 	vq->cb.private = cb->private;
595 	vq->cb.trigger = cb->trigger;
596 	spin_unlock(&vq->irq_lock);
597 }
598 
599 static void vduse_vdpa_set_vq_num(struct vdpa_device *vdpa, u16 idx, u32 num)
600 {
601 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
602 	struct vduse_virtqueue *vq = dev->vqs[idx];
603 
604 	vq->num = num;
605 }
606 
607 static u16 vduse_vdpa_get_vq_size(struct vdpa_device *vdpa, u16 idx)
608 {
609 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
610 	struct vduse_virtqueue *vq = dev->vqs[idx];
611 
612 	if (vq->num)
613 		return vq->num;
614 	else
615 		return vq->num_max;
616 }
617 
618 static void vduse_vdpa_set_vq_ready(struct vdpa_device *vdpa,
619 					u16 idx, bool ready)
620 {
621 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
622 	struct vduse_virtqueue *vq = dev->vqs[idx];
623 
624 	vq->ready = ready;
625 }
626 
627 static bool vduse_vdpa_get_vq_ready(struct vdpa_device *vdpa, u16 idx)
628 {
629 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
630 	struct vduse_virtqueue *vq = dev->vqs[idx];
631 
632 	return vq->ready;
633 }
634 
635 static int vduse_vdpa_set_vq_state(struct vdpa_device *vdpa, u16 idx,
636 				const struct vdpa_vq_state *state)
637 {
638 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
639 	struct vduse_virtqueue *vq = dev->vqs[idx];
640 
641 	if (dev->driver_features & BIT_ULL(VIRTIO_F_RING_PACKED)) {
642 		vq->state.packed.last_avail_counter =
643 				state->packed.last_avail_counter;
644 		vq->state.packed.last_avail_idx = state->packed.last_avail_idx;
645 		vq->state.packed.last_used_counter =
646 				state->packed.last_used_counter;
647 		vq->state.packed.last_used_idx = state->packed.last_used_idx;
648 	} else
649 		vq->state.split.avail_index = state->split.avail_index;
650 
651 	return 0;
652 }
653 
654 static u32 vduse_get_vq_group(struct vdpa_device *vdpa, u16 idx)
655 {
656 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
657 
658 	if (dev->api_version < VDUSE_API_VERSION_1)
659 		return 0;
660 
661 	return dev->vqs[idx]->group;
662 }
663 
664 static union virtio_map vduse_get_vq_map(struct vdpa_device *vdpa, u16 idx)
665 {
666 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
667 	u32 vq_group = vduse_get_vq_group(vdpa, idx);
668 	union virtio_map ret = {
669 		.group = &dev->groups[vq_group],
670 	};
671 
672 	return ret;
673 }
674 
675 DEFINE_GUARD(vq_group_as_read_lock, struct vduse_vq_group *,
676 	if (_T->dev->nas > 1)
677 		read_lock(&_T->as_lock),
678 	if (_T->dev->nas > 1)
679 		read_unlock(&_T->as_lock))
680 
681 DEFINE_GUARD(vq_group_as_write_lock, struct vduse_vq_group *,
682 	if (_T->dev->nas > 1)
683 		write_lock(&_T->as_lock),
684 	if (_T->dev->nas > 1)
685 		write_unlock(&_T->as_lock))
686 
687 static int vduse_set_group_asid(struct vdpa_device *vdpa, unsigned int group,
688 				unsigned int asid)
689 {
690 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
691 	struct vduse_dev_msg msg = { 0 };
692 	int r;
693 
694 	if (dev->api_version < VDUSE_API_VERSION_1)
695 		return -EINVAL;
696 
697 	msg.req.type = VDUSE_SET_VQ_GROUP_ASID;
698 	msg.req.vq_group_asid.group = group;
699 	msg.req.vq_group_asid.asid = asid;
700 
701 	r = vduse_dev_msg_sync(dev, &msg);
702 	if (r < 0)
703 		return r;
704 
705 	guard(vq_group_as_write_lock)(&dev->groups[group]);
706 	dev->groups[group].as = &dev->as[asid];
707 
708 	return 0;
709 }
710 
711 static int vduse_vdpa_get_vq_state(struct vdpa_device *vdpa, u16 idx,
712 				struct vdpa_vq_state *state)
713 {
714 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
715 	struct vduse_virtqueue *vq = dev->vqs[idx];
716 
717 	if (dev->driver_features & BIT_ULL(VIRTIO_F_RING_PACKED))
718 		return vduse_dev_get_vq_state_packed(dev, vq, &state->packed);
719 
720 	return vduse_dev_get_vq_state_split(dev, vq, &state->split);
721 }
722 
723 static u32 vduse_vdpa_get_vq_align(struct vdpa_device *vdpa)
724 {
725 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
726 
727 	return dev->vq_align;
728 }
729 
730 static u64 vduse_vdpa_get_device_features(struct vdpa_device *vdpa)
731 {
732 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
733 
734 	return dev->device_features;
735 }
736 
737 static int vduse_vdpa_set_driver_features(struct vdpa_device *vdpa, u64 features)
738 {
739 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
740 
741 	dev->driver_features = features;
742 	return 0;
743 }
744 
745 static u64 vduse_vdpa_get_driver_features(struct vdpa_device *vdpa)
746 {
747 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
748 
749 	return dev->driver_features;
750 }
751 
752 static void vduse_vdpa_set_config_cb(struct vdpa_device *vdpa,
753 				  struct vdpa_callback *cb)
754 {
755 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
756 
757 	spin_lock(&dev->irq_lock);
758 	dev->config_cb.callback = cb->callback;
759 	dev->config_cb.private = cb->private;
760 	spin_unlock(&dev->irq_lock);
761 }
762 
763 static u16 vduse_vdpa_get_vq_num_max(struct vdpa_device *vdpa)
764 {
765 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
766 	u16 num_max = 0;
767 	int i;
768 
769 	for (i = 0; i < dev->vq_num; i++)
770 		if (num_max < dev->vqs[i]->num_max)
771 			num_max = dev->vqs[i]->num_max;
772 
773 	return num_max;
774 }
775 
776 static u32 vduse_vdpa_get_device_id(struct vdpa_device *vdpa)
777 {
778 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
779 
780 	return dev->device_id;
781 }
782 
783 static u32 vduse_vdpa_get_vendor_id(struct vdpa_device *vdpa)
784 {
785 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
786 
787 	return dev->vendor_id;
788 }
789 
790 static u8 vduse_vdpa_get_status(struct vdpa_device *vdpa)
791 {
792 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
793 
794 	return dev->status;
795 }
796 
797 static void vduse_vdpa_set_status(struct vdpa_device *vdpa, u8 status)
798 {
799 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
800 
801 	if (vduse_dev_set_status(dev, status))
802 		return;
803 
804 	dev->status = status;
805 }
806 
807 static size_t vduse_vdpa_get_config_size(struct vdpa_device *vdpa)
808 {
809 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
810 
811 	return dev->config_size;
812 }
813 
814 static void vduse_vdpa_get_config(struct vdpa_device *vdpa, unsigned int offset,
815 				  void *buf, unsigned int len)
816 {
817 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
818 
819 	/* Initialize the buffer in case of partial copy. */
820 	memset(buf, 0, len);
821 
822 	if (offset > dev->config_size)
823 		return;
824 
825 	if (len > dev->config_size - offset)
826 		len = dev->config_size - offset;
827 
828 	memcpy(buf, dev->config + offset, len);
829 }
830 
831 static void vduse_vdpa_set_config(struct vdpa_device *vdpa, unsigned int offset,
832 			const void *buf, unsigned int len)
833 {
834 	/* Now we only support read-only configuration space */
835 }
836 
837 static int vduse_vdpa_reset(struct vdpa_device *vdpa)
838 {
839 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
840 	int ret = vduse_dev_set_status(dev, 0);
841 
842 	vduse_dev_reset(dev);
843 
844 	return ret;
845 }
846 
847 static u32 vduse_vdpa_get_generation(struct vdpa_device *vdpa)
848 {
849 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
850 
851 	return dev->generation;
852 }
853 
854 static int vduse_vdpa_set_vq_affinity(struct vdpa_device *vdpa, u16 idx,
855 				      const struct cpumask *cpu_mask)
856 {
857 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
858 
859 	if (cpu_mask)
860 		cpumask_copy(&dev->vqs[idx]->irq_affinity, cpu_mask);
861 	else
862 		cpumask_setall(&dev->vqs[idx]->irq_affinity);
863 
864 	return 0;
865 }
866 
867 static const struct cpumask *
868 vduse_vdpa_get_vq_affinity(struct vdpa_device *vdpa, u16 idx)
869 {
870 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
871 
872 	return &dev->vqs[idx]->irq_affinity;
873 }
874 
875 static int vduse_vdpa_set_map(struct vdpa_device *vdpa,
876 				unsigned int asid,
877 				struct vhost_iotlb *iotlb)
878 {
879 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
880 	int ret;
881 
882 	ret = vduse_domain_set_map(dev->as[asid].domain, iotlb);
883 	if (ret)
884 		return ret;
885 
886 	ret = vduse_dev_update_iotlb(dev, asid, 0ULL, ULLONG_MAX);
887 	if (ret) {
888 		vduse_domain_clear_map(dev->as[asid].domain, iotlb);
889 		return ret;
890 	}
891 
892 	return 0;
893 }
894 
895 static void vduse_vdpa_free(struct vdpa_device *vdpa)
896 {
897 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
898 
899 	dev->vdev = NULL;
900 }
901 
902 static const struct vdpa_config_ops vduse_vdpa_config_ops = {
903 	.set_vq_address		= vduse_vdpa_set_vq_address,
904 	.kick_vq		= vduse_vdpa_kick_vq,
905 	.set_vq_cb		= vduse_vdpa_set_vq_cb,
906 	.set_vq_num             = vduse_vdpa_set_vq_num,
907 	.get_vq_size		= vduse_vdpa_get_vq_size,
908 	.get_vq_group		= vduse_get_vq_group,
909 	.set_vq_ready		= vduse_vdpa_set_vq_ready,
910 	.get_vq_ready		= vduse_vdpa_get_vq_ready,
911 	.set_vq_state		= vduse_vdpa_set_vq_state,
912 	.get_vq_state		= vduse_vdpa_get_vq_state,
913 	.get_vq_align		= vduse_vdpa_get_vq_align,
914 	.get_device_features	= vduse_vdpa_get_device_features,
915 	.set_driver_features	= vduse_vdpa_set_driver_features,
916 	.get_driver_features	= vduse_vdpa_get_driver_features,
917 	.set_config_cb		= vduse_vdpa_set_config_cb,
918 	.get_vq_num_max		= vduse_vdpa_get_vq_num_max,
919 	.get_device_id		= vduse_vdpa_get_device_id,
920 	.get_vendor_id		= vduse_vdpa_get_vendor_id,
921 	.get_status		= vduse_vdpa_get_status,
922 	.set_status		= vduse_vdpa_set_status,
923 	.get_config_size	= vduse_vdpa_get_config_size,
924 	.get_config		= vduse_vdpa_get_config,
925 	.set_config		= vduse_vdpa_set_config,
926 	.get_generation		= vduse_vdpa_get_generation,
927 	.set_vq_affinity	= vduse_vdpa_set_vq_affinity,
928 	.get_vq_affinity	= vduse_vdpa_get_vq_affinity,
929 	.reset			= vduse_vdpa_reset,
930 	.set_map		= vduse_vdpa_set_map,
931 	.set_group_asid		= vduse_set_group_asid,
932 	.get_vq_map		= vduse_get_vq_map,
933 	.free			= vduse_vdpa_free,
934 };
935 
936 static void vduse_dev_sync_single_for_device(union virtio_map token,
937 					     dma_addr_t dma_addr, size_t size,
938 					     enum dma_data_direction dir)
939 {
940 	struct vduse_iova_domain *domain;
941 
942 	if (!token.group)
943 		return;
944 
945 	guard(vq_group_as_read_lock)(token.group);
946 	domain = token.group->as->domain;
947 	vduse_domain_sync_single_for_device(domain, dma_addr, size, dir);
948 }
949 
950 static void vduse_dev_sync_single_for_cpu(union virtio_map token,
951 					     dma_addr_t dma_addr, size_t size,
952 					     enum dma_data_direction dir)
953 {
954 	struct vduse_iova_domain *domain;
955 
956 	if (!token.group)
957 		return;
958 
959 	guard(vq_group_as_read_lock)(token.group);
960 	domain = token.group->as->domain;
961 	vduse_domain_sync_single_for_cpu(domain, dma_addr, size, dir);
962 }
963 
964 static dma_addr_t vduse_dev_map_page(union virtio_map token, struct page *page,
965 				     unsigned long offset, size_t size,
966 				     enum dma_data_direction dir,
967 				     unsigned long attrs)
968 {
969 	struct vduse_iova_domain *domain;
970 
971 	if (!token.group)
972 		return DMA_MAPPING_ERROR;
973 
974 	guard(vq_group_as_read_lock)(token.group);
975 	domain = token.group->as->domain;
976 	return vduse_domain_map_page(domain, page, offset, size, dir, attrs);
977 }
978 
979 static void vduse_dev_unmap_page(union virtio_map token, dma_addr_t dma_addr,
980 				 size_t size, enum dma_data_direction dir,
981 				 unsigned long attrs)
982 {
983 	struct vduse_iova_domain *domain;
984 
985 	if (!token.group)
986 		return;
987 
988 	guard(vq_group_as_read_lock)(token.group);
989 	domain = token.group->as->domain;
990 	vduse_domain_unmap_page(domain, dma_addr, size, dir, attrs);
991 }
992 
993 static void *vduse_dev_alloc_coherent(union virtio_map token, size_t size,
994 				      dma_addr_t *dma_addr, gfp_t flag)
995 {
996 	void *addr;
997 
998 	*dma_addr = DMA_MAPPING_ERROR;
999 	if (!token.group)
1000 		return NULL;
1001 
1002 	addr = alloc_pages_exact(size, flag | __GFP_ZERO);
1003 	if (!addr)
1004 		return NULL;
1005 
1006 	{
1007 		struct vduse_iova_domain *domain;
1008 
1009 		guard(vq_group_as_read_lock)(token.group);
1010 		domain = token.group->as->domain;
1011 		*dma_addr = vduse_domain_alloc_coherent(domain, size, addr);
1012 		if (*dma_addr == DMA_MAPPING_ERROR)
1013 			goto err;
1014 	}
1015 
1016 	return addr;
1017 
1018 err:
1019 	free_pages_exact(addr, size);
1020 	return NULL;
1021 }
1022 
1023 static void vduse_dev_free_coherent(union virtio_map token, size_t size,
1024 				    void *vaddr, dma_addr_t dma_addr,
1025 				    unsigned long attrs)
1026 {
1027 	if (!token.group)
1028 		return;
1029 
1030 	{
1031 		struct vduse_iova_domain *domain;
1032 
1033 		guard(vq_group_as_read_lock)(token.group);
1034 		domain = token.group->as->domain;
1035 		vduse_domain_free_coherent(domain, size, dma_addr, attrs);
1036 	}
1037 
1038 	free_pages_exact(vaddr, size);
1039 }
1040 
1041 static bool vduse_dev_need_sync(union virtio_map token, dma_addr_t dma_addr)
1042 {
1043 	if (!token.group)
1044 		return false;
1045 
1046 	guard(vq_group_as_read_lock)(token.group);
1047 	return dma_addr < token.group->as->domain->bounce_size;
1048 }
1049 
1050 static int vduse_dev_mapping_error(union virtio_map token, dma_addr_t dma_addr)
1051 {
1052 	if (unlikely(dma_addr == DMA_MAPPING_ERROR))
1053 		return -ENOMEM;
1054 	return 0;
1055 }
1056 
1057 static size_t vduse_dev_max_mapping_size(union virtio_map token)
1058 {
1059 	if (!token.group)
1060 		return 0;
1061 
1062 	guard(vq_group_as_read_lock)(token.group);
1063 	return token.group->as->domain->bounce_size;
1064 }
1065 
1066 static const struct virtio_map_ops vduse_map_ops = {
1067 	.sync_single_for_device = vduse_dev_sync_single_for_device,
1068 	.sync_single_for_cpu = vduse_dev_sync_single_for_cpu,
1069 	.map_page = vduse_dev_map_page,
1070 	.unmap_page = vduse_dev_unmap_page,
1071 	.alloc = vduse_dev_alloc_coherent,
1072 	.free = vduse_dev_free_coherent,
1073 	.need_sync = vduse_dev_need_sync,
1074 	.mapping_error = vduse_dev_mapping_error,
1075 	.max_mapping_size = vduse_dev_max_mapping_size,
1076 };
1077 
1078 static unsigned int perm_to_file_flags(u8 perm)
1079 {
1080 	unsigned int flags = 0;
1081 
1082 	switch (perm) {
1083 	case VDUSE_ACCESS_WO:
1084 		flags |= O_WRONLY;
1085 		break;
1086 	case VDUSE_ACCESS_RO:
1087 		flags |= O_RDONLY;
1088 		break;
1089 	case VDUSE_ACCESS_RW:
1090 		flags |= O_RDWR;
1091 		break;
1092 	default:
1093 		WARN(1, "invalidate vhost IOTLB permission\n");
1094 		break;
1095 	}
1096 
1097 	return flags;
1098 }
1099 
1100 static int vduse_kickfd_setup(struct vduse_dev *dev,
1101 			struct vduse_vq_eventfd *eventfd)
1102 {
1103 	struct eventfd_ctx *ctx = NULL;
1104 	struct vduse_virtqueue *vq;
1105 	u32 index;
1106 
1107 	if (eventfd->index >= dev->vq_num)
1108 		return -EINVAL;
1109 
1110 	index = array_index_nospec(eventfd->index, dev->vq_num);
1111 	vq = dev->vqs[index];
1112 	if (eventfd->fd >= 0) {
1113 		ctx = eventfd_ctx_fdget(eventfd->fd);
1114 		if (IS_ERR(ctx))
1115 			return PTR_ERR(ctx);
1116 	} else if (eventfd->fd != VDUSE_EVENTFD_DEASSIGN)
1117 		return 0;
1118 
1119 	spin_lock(&vq->kick_lock);
1120 	if (vq->kickfd)
1121 		eventfd_ctx_put(vq->kickfd);
1122 	vq->kickfd = ctx;
1123 	if (vq->ready && vq->kicked && vq->kickfd) {
1124 		eventfd_signal(vq->kickfd);
1125 		vq->kicked = false;
1126 	}
1127 	spin_unlock(&vq->kick_lock);
1128 
1129 	return 0;
1130 }
1131 
1132 static bool vduse_dev_is_ready(struct vduse_dev *dev)
1133 {
1134 	int i;
1135 
1136 	for (i = 0; i < dev->vq_num; i++)
1137 		if (!dev->vqs[i]->num_max)
1138 			return false;
1139 
1140 	return true;
1141 }
1142 
1143 static void vduse_dev_irq_inject(struct work_struct *work)
1144 {
1145 	struct vduse_dev *dev = container_of(work, struct vduse_dev, inject);
1146 
1147 	spin_lock_bh(&dev->irq_lock);
1148 	if (dev->config_cb.callback)
1149 		dev->config_cb.callback(dev->config_cb.private);
1150 	spin_unlock_bh(&dev->irq_lock);
1151 }
1152 
1153 static void vduse_vq_irq_inject(struct work_struct *work)
1154 {
1155 	struct vduse_virtqueue *vq = container_of(work,
1156 					struct vduse_virtqueue, inject);
1157 
1158 	spin_lock_bh(&vq->irq_lock);
1159 	if (vq->ready && vq->cb.callback)
1160 		vq->cb.callback(vq->cb.private);
1161 	spin_unlock_bh(&vq->irq_lock);
1162 }
1163 
1164 static bool vduse_vq_signal_irqfd(struct vduse_virtqueue *vq)
1165 {
1166 	bool signal = false;
1167 
1168 	if (!vq->cb.trigger)
1169 		return false;
1170 
1171 	spin_lock_irq(&vq->irq_lock);
1172 	if (vq->ready && vq->cb.trigger) {
1173 		eventfd_signal(vq->cb.trigger);
1174 		signal = true;
1175 	}
1176 	spin_unlock_irq(&vq->irq_lock);
1177 
1178 	return signal;
1179 }
1180 
1181 static int vduse_dev_queue_irq_work(struct vduse_dev *dev,
1182 				    struct work_struct *irq_work,
1183 				    int irq_effective_cpu)
1184 {
1185 	int ret = -EINVAL;
1186 
1187 	down_read(&dev->rwsem);
1188 	if (!(dev->status & VIRTIO_CONFIG_S_DRIVER_OK))
1189 		goto unlock;
1190 
1191 	ret = 0;
1192 	if (irq_effective_cpu == IRQ_UNBOUND)
1193 		queue_work(vduse_irq_wq, irq_work);
1194 	else
1195 		queue_work_on(irq_effective_cpu,
1196 			      vduse_irq_bound_wq, irq_work);
1197 unlock:
1198 	up_read(&dev->rwsem);
1199 
1200 	return ret;
1201 }
1202 
1203 static int vduse_dev_dereg_umem(struct vduse_dev *dev, u32 asid,
1204 				u64 iova, u64 size)
1205 {
1206 	int ret;
1207 
1208 	mutex_lock(&dev->as[asid].mem_lock);
1209 	ret = -ENOENT;
1210 	if (!dev->as[asid].umem)
1211 		goto unlock;
1212 
1213 	ret = -EINVAL;
1214 	if (!dev->as[asid].domain)
1215 		goto unlock;
1216 
1217 	if (dev->as[asid].umem->iova != iova ||
1218 	    size != dev->as[asid].domain->bounce_size)
1219 		goto unlock;
1220 
1221 	vduse_domain_remove_user_bounce_pages(dev->as[asid].domain);
1222 	unpin_user_pages_dirty_lock(dev->as[asid].umem->pages,
1223 				    dev->as[asid].umem->npages, true);
1224 	atomic64_sub(dev->as[asid].umem->npages, &dev->as[asid].umem->mm->pinned_vm);
1225 	mmdrop(dev->as[asid].umem->mm);
1226 	vfree(dev->as[asid].umem->pages);
1227 	kfree(dev->as[asid].umem);
1228 	dev->as[asid].umem = NULL;
1229 	ret = 0;
1230 unlock:
1231 	mutex_unlock(&dev->as[asid].mem_lock);
1232 	return ret;
1233 }
1234 
1235 static int vduse_dev_reg_umem(struct vduse_dev *dev,
1236 			      u32 asid, u64 iova, u64 uaddr, u64 size)
1237 {
1238 	struct page **page_list = NULL;
1239 	struct vduse_umem *umem = NULL;
1240 	long pinned = 0;
1241 	unsigned long npages, lock_limit;
1242 	int ret;
1243 
1244 	if (!dev->as[asid].domain || !dev->as[asid].domain->bounce_map ||
1245 	    size != dev->as[asid].domain->bounce_size ||
1246 	    iova != 0 || uaddr & ~PAGE_MASK)
1247 		return -EINVAL;
1248 
1249 	mutex_lock(&dev->as[asid].mem_lock);
1250 	ret = -EEXIST;
1251 	if (dev->as[asid].umem)
1252 		goto unlock;
1253 
1254 	ret = -ENOMEM;
1255 	npages = size >> PAGE_SHIFT;
1256 	page_list = __vmalloc(array_size(npages, sizeof(struct page *)),
1257 			      GFP_KERNEL_ACCOUNT);
1258 	umem = kzalloc_obj(*umem);
1259 	if (!page_list || !umem)
1260 		goto unlock;
1261 
1262 	mmap_read_lock(current->mm);
1263 
1264 	lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
1265 	if (npages + atomic64_read(&current->mm->pinned_vm) > lock_limit)
1266 		goto out;
1267 
1268 	pinned = pin_user_pages(uaddr, npages, FOLL_LONGTERM | FOLL_WRITE,
1269 				page_list);
1270 	if (pinned != npages) {
1271 		ret = pinned < 0 ? pinned : -ENOMEM;
1272 		goto out;
1273 	}
1274 
1275 	ret = vduse_domain_add_user_bounce_pages(dev->as[asid].domain,
1276 						 page_list, pinned);
1277 	if (ret)
1278 		goto out;
1279 
1280 	atomic64_add(npages, &current->mm->pinned_vm);
1281 
1282 	umem->pages = page_list;
1283 	umem->npages = pinned;
1284 	umem->iova = iova;
1285 	umem->mm = current->mm;
1286 	mmgrab(current->mm);
1287 
1288 	dev->as[asid].umem = umem;
1289 out:
1290 	if (ret && pinned > 0)
1291 		unpin_user_pages(page_list, pinned);
1292 
1293 	mmap_read_unlock(current->mm);
1294 unlock:
1295 	if (ret) {
1296 		vfree(page_list);
1297 		kfree(umem);
1298 	}
1299 	mutex_unlock(&dev->as[asid].mem_lock);
1300 	return ret;
1301 }
1302 
1303 static void vduse_vq_update_effective_cpu(struct vduse_virtqueue *vq)
1304 {
1305 	int curr_cpu = vq->irq_effective_cpu;
1306 
1307 	while (true) {
1308 		curr_cpu = cpumask_next(curr_cpu, &vq->irq_affinity);
1309 		if (cpu_online(curr_cpu))
1310 			break;
1311 
1312 		if (curr_cpu >= nr_cpu_ids)
1313 			curr_cpu = IRQ_UNBOUND;
1314 	}
1315 
1316 	vq->irq_effective_cpu = curr_cpu;
1317 }
1318 
1319 static int vduse_dev_iotlb_entry(struct vduse_dev *dev,
1320 				 struct vduse_iotlb_entry_v2 *entry,
1321 				 struct file **f, uint64_t *capability)
1322 {
1323 	u32 asid;
1324 	int r = -EINVAL;
1325 	struct vhost_iotlb_map *map;
1326 
1327 	if (entry->start > entry->last || entry->asid >= dev->nas)
1328 		return -EINVAL;
1329 
1330 	asid = array_index_nospec(entry->asid, dev->nas);
1331 	mutex_lock(&dev->domain_lock);
1332 
1333 	if (!dev->as[asid].domain)
1334 		goto out;
1335 
1336 	spin_lock(&dev->as[asid].domain->iotlb_lock);
1337 	map = vhost_iotlb_itree_first(dev->as[asid].domain->iotlb,
1338 				      entry->start, entry->last);
1339 	if (map) {
1340 		if (f) {
1341 			const struct vdpa_map_file *map_file;
1342 
1343 			map_file = (struct vdpa_map_file *)map->opaque;
1344 			entry->offset = map_file->offset;
1345 			*f = get_file(map_file->file);
1346 		}
1347 		entry->start = map->start;
1348 		entry->last = map->last;
1349 		entry->perm = map->perm;
1350 		if (capability) {
1351 			*capability = 0;
1352 
1353 			if (dev->as[asid].domain->bounce_map && map->start == 0 &&
1354 			    map->last == dev->as[asid].domain->bounce_size - 1)
1355 				*capability |= VDUSE_IOVA_CAP_UMEM;
1356 		}
1357 
1358 		r = 0;
1359 	}
1360 	spin_unlock(&dev->as[asid].domain->iotlb_lock);
1361 
1362 out:
1363 	mutex_unlock(&dev->domain_lock);
1364 	return r;
1365 }
1366 
1367 static long vduse_dev_ioctl(struct file *file, unsigned int cmd,
1368 			    unsigned long arg)
1369 {
1370 	struct vduse_dev *dev = file->private_data;
1371 	void __user *argp = (void __user *)arg;
1372 	int ret;
1373 
1374 	if (unlikely(dev->broken))
1375 		return -EPERM;
1376 
1377 	switch (cmd) {
1378 	case VDUSE_IOTLB_GET_FD:
1379 	case VDUSE_IOTLB_GET_FD2: {
1380 		struct vduse_iotlb_entry_v2 entry = {0};
1381 		struct file *f = NULL;
1382 
1383 		ret = -ENOIOCTLCMD;
1384 		if (dev->api_version < VDUSE_API_VERSION_1 &&
1385 		    cmd == VDUSE_IOTLB_GET_FD2)
1386 			break;
1387 
1388 		ret = -EFAULT;
1389 		if (copy_from_user(&entry, argp, _IOC_SIZE(cmd)))
1390 			break;
1391 
1392 		ret = -EINVAL;
1393 		if (!is_mem_zero((const char *)entry.reserved,
1394 				 sizeof(entry.reserved)))
1395 			break;
1396 
1397 		ret = vduse_dev_iotlb_entry(dev, &entry, &f, NULL);
1398 		if (ret)
1399 			break;
1400 
1401 		ret = -EINVAL;
1402 		if (!f)
1403 			break;
1404 
1405 		ret = copy_to_user(argp, &entry, _IOC_SIZE(cmd));
1406 		if (ret) {
1407 			ret = -EFAULT;
1408 			fput(f);
1409 			break;
1410 		}
1411 		ret = receive_fd(f, NULL, perm_to_file_flags(entry.perm));
1412 		fput(f);
1413 		break;
1414 	}
1415 	case VDUSE_DEV_GET_FEATURES:
1416 		/*
1417 		 * Just mirror what driver wrote here.
1418 		 * The driver is expected to check FEATURE_OK later.
1419 		 */
1420 		ret = put_user(dev->driver_features, (u64 __user *)argp);
1421 		break;
1422 	case VDUSE_DEV_SET_CONFIG: {
1423 		struct vduse_config_data config;
1424 		unsigned long size = offsetof(struct vduse_config_data,
1425 					      buffer);
1426 
1427 		ret = -EFAULT;
1428 		if (copy_from_user(&config, argp, size))
1429 			break;
1430 
1431 		ret = -EINVAL;
1432 		if (config.offset > dev->config_size ||
1433 		    config.length == 0 ||
1434 		    config.length > dev->config_size - config.offset)
1435 			break;
1436 
1437 		ret = -EFAULT;
1438 		if (copy_from_user(dev->config + config.offset, argp + size,
1439 				   config.length))
1440 			break;
1441 
1442 		ret = 0;
1443 		break;
1444 	}
1445 	case VDUSE_DEV_INJECT_CONFIG_IRQ:
1446 		ret = vduse_dev_queue_irq_work(dev, &dev->inject, IRQ_UNBOUND);
1447 		break;
1448 	case VDUSE_VQ_SETUP: {
1449 		struct vduse_vq_config config;
1450 		u32 index;
1451 
1452 		ret = -EFAULT;
1453 		if (copy_from_user(&config, argp, sizeof(config)))
1454 			break;
1455 
1456 		ret = -EINVAL;
1457 		if (config.index >= dev->vq_num)
1458 			break;
1459 
1460 		if (dev->api_version < VDUSE_API_VERSION_1) {
1461 			if (config.group)
1462 				break;
1463 		} else {
1464 			if (config.group >= dev->ngroups)
1465 				break;
1466 			if (dev->status & VIRTIO_CONFIG_S_DRIVER_OK)
1467 				break;
1468 		}
1469 
1470 		if (config.reserved1 ||
1471 		    !is_mem_zero((const char *)config.reserved2,
1472 				 sizeof(config.reserved2)))
1473 			break;
1474 
1475 		index = array_index_nospec(config.index, dev->vq_num);
1476 		dev->vqs[index]->num_max = config.max_size;
1477 		dev->vqs[index]->group = config.group;
1478 		ret = 0;
1479 		break;
1480 	}
1481 	case VDUSE_VQ_GET_INFO: {
1482 		struct vduse_vq_info vq_info;
1483 		struct vduse_virtqueue *vq;
1484 		u32 index;
1485 
1486 		ret = -EFAULT;
1487 		if (copy_from_user(&vq_info, argp, sizeof(vq_info)))
1488 			break;
1489 
1490 		ret = -EINVAL;
1491 		if (vq_info.index >= dev->vq_num)
1492 			break;
1493 
1494 		index = array_index_nospec(vq_info.index, dev->vq_num);
1495 		vq = dev->vqs[index];
1496 		vq_info.desc_addr = vq->desc_addr;
1497 		vq_info.driver_addr = vq->driver_addr;
1498 		vq_info.device_addr = vq->device_addr;
1499 		vq_info.num = vq->num;
1500 
1501 		if (dev->driver_features & BIT_ULL(VIRTIO_F_RING_PACKED)) {
1502 			vq_info.packed.last_avail_counter =
1503 				vq->state.packed.last_avail_counter;
1504 			vq_info.packed.last_avail_idx =
1505 				vq->state.packed.last_avail_idx;
1506 			vq_info.packed.last_used_counter =
1507 				vq->state.packed.last_used_counter;
1508 			vq_info.packed.last_used_idx =
1509 				vq->state.packed.last_used_idx;
1510 		} else
1511 			vq_info.split.avail_index =
1512 				vq->state.split.avail_index;
1513 
1514 		vq_info.ready = vq->ready;
1515 
1516 		ret = -EFAULT;
1517 		if (copy_to_user(argp, &vq_info, sizeof(vq_info)))
1518 			break;
1519 
1520 		ret = 0;
1521 		break;
1522 	}
1523 	case VDUSE_VQ_SETUP_KICKFD: {
1524 		struct vduse_vq_eventfd eventfd;
1525 
1526 		ret = -EFAULT;
1527 		if (copy_from_user(&eventfd, argp, sizeof(eventfd)))
1528 			break;
1529 
1530 		ret = vduse_kickfd_setup(dev, &eventfd);
1531 		break;
1532 	}
1533 	case VDUSE_VQ_INJECT_IRQ: {
1534 		u32 index;
1535 
1536 		ret = -EFAULT;
1537 		if (get_user(index, (u32 __user *)argp))
1538 			break;
1539 
1540 		ret = -EINVAL;
1541 		if (index >= dev->vq_num)
1542 			break;
1543 
1544 		ret = 0;
1545 		index = array_index_nospec(index, dev->vq_num);
1546 		if (!vduse_vq_signal_irqfd(dev->vqs[index])) {
1547 			vduse_vq_update_effective_cpu(dev->vqs[index]);
1548 			ret = vduse_dev_queue_irq_work(dev,
1549 						&dev->vqs[index]->inject,
1550 						dev->vqs[index]->irq_effective_cpu);
1551 		}
1552 		break;
1553 	}
1554 	case VDUSE_IOTLB_REG_UMEM: {
1555 		struct vduse_iova_umem umem;
1556 		u32 asid;
1557 
1558 		ret = -EFAULT;
1559 		if (copy_from_user(&umem, argp, sizeof(umem)))
1560 			break;
1561 
1562 		ret = -EINVAL;
1563 		if (!is_mem_zero((const char *)umem.reserved,
1564 				 sizeof(umem.reserved)) ||
1565 		    (dev->api_version < VDUSE_API_VERSION_1 &&
1566 		     umem.asid != 0) || umem.asid >= dev->nas)
1567 			break;
1568 
1569 		mutex_lock(&dev->domain_lock);
1570 		asid = array_index_nospec(umem.asid, dev->nas);
1571 		ret = vduse_dev_reg_umem(dev, asid, umem.iova,
1572 					 umem.uaddr, umem.size);
1573 		mutex_unlock(&dev->domain_lock);
1574 		break;
1575 	}
1576 	case VDUSE_IOTLB_DEREG_UMEM: {
1577 		struct vduse_iova_umem umem;
1578 		u32 asid;
1579 
1580 		ret = -EFAULT;
1581 		if (copy_from_user(&umem, argp, sizeof(umem)))
1582 			break;
1583 
1584 		ret = -EINVAL;
1585 		if (!is_mem_zero((const char *)umem.reserved,
1586 				 sizeof(umem.reserved)) ||
1587 		    (dev->api_version < VDUSE_API_VERSION_1 &&
1588 		     umem.asid != 0) ||
1589 		     umem.asid >= dev->nas)
1590 			break;
1591 
1592 		mutex_lock(&dev->domain_lock);
1593 		asid = array_index_nospec(umem.asid, dev->nas);
1594 		ret = vduse_dev_dereg_umem(dev, asid, umem.iova,
1595 					   umem.size);
1596 		mutex_unlock(&dev->domain_lock);
1597 		break;
1598 	}
1599 	case VDUSE_IOTLB_GET_INFO: {
1600 		struct vduse_iova_info info;
1601 		struct vduse_iotlb_entry_v2 entry;
1602 
1603 		ret = -EFAULT;
1604 		if (copy_from_user(&info, argp, sizeof(info)))
1605 			break;
1606 
1607 		if (!is_mem_zero((const char *)info.reserved,
1608 				 sizeof(info.reserved)))
1609 			break;
1610 
1611 		if (dev->api_version < VDUSE_API_VERSION_1) {
1612 			if (info.asid)
1613 				break;
1614 		} else if (info.asid >= dev->nas)
1615 			break;
1616 
1617 		entry.start = info.start;
1618 		entry.last = info.last;
1619 		entry.asid = info.asid;
1620 		ret = vduse_dev_iotlb_entry(dev, &entry, NULL,
1621 					    &info.capability);
1622 		if (ret < 0)
1623 			break;
1624 
1625 		info.start = entry.start;
1626 		info.last = entry.last;
1627 		info.asid = entry.asid;
1628 
1629 		ret = -EFAULT;
1630 		if (copy_to_user(argp, &info, sizeof(info)))
1631 			break;
1632 
1633 		ret = 0;
1634 		break;
1635 	}
1636 	default:
1637 		ret = -ENOIOCTLCMD;
1638 		break;
1639 	}
1640 
1641 	return ret;
1642 }
1643 
1644 #ifdef CONFIG_COMPAT_FOR_U64_ALIGNMENT
1645 /*
1646  * i386 has different alignment constraints than x86_64,
1647  * so there are only 3 bytes of padding instead of 7.
1648  */
1649 struct compat_vduse_iotlb_entry {
1650 	compat_u64 offset;
1651 	compat_u64 start;
1652 	compat_u64 last;
1653 	__u8 perm;
1654 	__u8 padding[3];
1655 };
1656 #define COMPAT_VDUSE_IOTLB_GET_FD	_IOWR(VDUSE_BASE, 0x10, struct compat_vduse_iotlb_entry)
1657 
1658 struct compat_vduse_vq_info {
1659 	__u32 index;
1660 	__u32 num;
1661 	compat_u64 desc_addr;
1662 	compat_u64 driver_addr;
1663 	compat_u64 device_addr;
1664 	union {
1665 		struct vduse_vq_state_split split;
1666 		struct vduse_vq_state_packed packed;
1667 	};
1668 	__u8 ready;
1669 	__u8 padding[3];
1670 };
1671 #define COMPAT_VDUSE_VQ_GET_INFO	_IOWR(VDUSE_BASE, 0x15, struct compat_vduse_vq_info)
1672 
1673 static long vduse_dev_compat_ioctl(struct file *file, unsigned int cmd,
1674 				   unsigned long arg)
1675 {
1676 	struct vduse_dev *dev = file->private_data;
1677 	void __user *argp = (void __user *)arg;
1678 	int ret;
1679 
1680 	if (unlikely(dev->broken))
1681 		return -EPERM;
1682 
1683 	switch (cmd) {
1684 	case COMPAT_VDUSE_IOTLB_GET_FD: {
1685 		struct vduse_iotlb_entry_v2 entry = {0};
1686 		struct file *f = NULL;
1687 
1688 		ret = -EFAULT;
1689 		if (copy_from_user(&entry, argp, _IOC_SIZE(cmd)))
1690 			break;
1691 
1692 		ret = vduse_dev_iotlb_entry(dev, &entry, &f, NULL);
1693 		if (ret)
1694 			break;
1695 
1696 		ret = -EINVAL;
1697 		if (!f)
1698 			break;
1699 
1700 		ret = copy_to_user(argp, &entry, _IOC_SIZE(cmd));
1701 		if (ret) {
1702 			ret = -EFAULT;
1703 			fput(f);
1704 			break;
1705 		}
1706 		ret = receive_fd(f, NULL, perm_to_file_flags(entry.perm));
1707 		fput(f);
1708 		break;
1709 	}
1710 	case COMPAT_VDUSE_VQ_GET_INFO: {
1711 		struct vduse_vq_info vq_info = {};
1712 		struct vduse_virtqueue *vq;
1713 		u32 index;
1714 
1715 		ret = -EFAULT;
1716 		if (copy_from_user(&vq_info, argp,
1717 				   sizeof(struct compat_vduse_vq_info)))
1718 			break;
1719 
1720 		ret = -EINVAL;
1721 		if (vq_info.index >= dev->vq_num)
1722 			break;
1723 
1724 		index = array_index_nospec(vq_info.index, dev->vq_num);
1725 		vq = dev->vqs[index];
1726 		vq_info.desc_addr = vq->desc_addr;
1727 		vq_info.driver_addr = vq->driver_addr;
1728 		vq_info.device_addr = vq->device_addr;
1729 		vq_info.num = vq->num;
1730 
1731 		if (dev->driver_features & BIT_ULL(VIRTIO_F_RING_PACKED)) {
1732 			vq_info.packed.last_avail_counter =
1733 				vq->state.packed.last_avail_counter;
1734 			vq_info.packed.last_avail_idx =
1735 				vq->state.packed.last_avail_idx;
1736 			vq_info.packed.last_used_counter =
1737 				vq->state.packed.last_used_counter;
1738 			vq_info.packed.last_used_idx =
1739 				vq->state.packed.last_used_idx;
1740 		} else
1741 			vq_info.split.avail_index =
1742 				vq->state.split.avail_index;
1743 
1744 		vq_info.ready = vq->ready;
1745 
1746 		ret = -EFAULT;
1747 		if (copy_to_user(argp, &vq_info,
1748 		    sizeof(struct compat_vduse_vq_info)))
1749 			break;
1750 
1751 		ret = 0;
1752 		break;
1753 	}
1754 	default:
1755 		ret = -ENOIOCTLCMD;
1756 		break;
1757 	}
1758 
1759 	return vduse_dev_ioctl(file, cmd, (unsigned long)compat_ptr(arg));
1760 }
1761 #else
1762 #define vduse_dev_compat_ioctl compat_ptr_ioctl
1763 #endif
1764 
1765 static int vduse_dev_release(struct inode *inode, struct file *file)
1766 {
1767 	struct vduse_dev *dev = file->private_data;
1768 
1769 	mutex_lock(&dev->domain_lock);
1770 	for (int i = 0; i < dev->nas; i++)
1771 		if (dev->as[i].domain)
1772 			vduse_dev_dereg_umem(dev, i, 0,
1773 					     dev->as[i].domain->bounce_size);
1774 	mutex_unlock(&dev->domain_lock);
1775 	spin_lock(&dev->msg_lock);
1776 	/* Make sure the inflight messages can processed after reconncection */
1777 	list_splice_init(&dev->recv_list, &dev->send_list);
1778 	spin_unlock(&dev->msg_lock);
1779 	dev->connected = false;
1780 
1781 	return 0;
1782 }
1783 
1784 static int vduse_dev_open(struct inode *inode, struct file *file)
1785 {
1786 	int ret = -EBUSY;
1787 	struct vduse_dev *dev;
1788 
1789 	mutex_lock(&vduse_lock);
1790 	dev = idr_find(&vduse_idr, iminor(inode));
1791 	if (!dev) {
1792 		mutex_unlock(&vduse_lock);
1793 		return -ENODEV;
1794 	}
1795 
1796 	mutex_lock(&dev->lock);
1797 	if (dev->connected)
1798 		goto unlock;
1799 
1800 	ret = 0;
1801 	dev->connected = true;
1802 	file->private_data = dev;
1803 unlock:
1804 	mutex_unlock(&dev->lock);
1805 	mutex_unlock(&vduse_lock);
1806 
1807 	return ret;
1808 }
1809 
1810 static const struct file_operations vduse_dev_fops = {
1811 	.owner		= THIS_MODULE,
1812 	.open		= vduse_dev_open,
1813 	.release	= vduse_dev_release,
1814 	.read_iter	= vduse_dev_read_iter,
1815 	.write_iter	= vduse_dev_write_iter,
1816 	.poll		= vduse_dev_poll,
1817 	.unlocked_ioctl	= vduse_dev_ioctl,
1818 	.compat_ioctl	= vduse_dev_compat_ioctl,
1819 	.llseek		= noop_llseek,
1820 };
1821 
1822 static ssize_t irq_cb_affinity_show(struct vduse_virtqueue *vq, char *buf)
1823 {
1824 	return sprintf(buf, "%*pb\n", cpumask_pr_args(&vq->irq_affinity));
1825 }
1826 
1827 static ssize_t irq_cb_affinity_store(struct vduse_virtqueue *vq,
1828 				     const char *buf, size_t count)
1829 {
1830 	cpumask_var_t new_value;
1831 	int ret;
1832 
1833 	if (!zalloc_cpumask_var(&new_value, GFP_KERNEL))
1834 		return -ENOMEM;
1835 
1836 	ret = cpumask_parse(buf, new_value);
1837 	if (ret)
1838 		goto free_mask;
1839 
1840 	ret = -EINVAL;
1841 	if (!cpumask_intersects(new_value, cpu_online_mask))
1842 		goto free_mask;
1843 
1844 	cpumask_copy(&vq->irq_affinity, new_value);
1845 	ret = count;
1846 free_mask:
1847 	free_cpumask_var(new_value);
1848 	return ret;
1849 }
1850 
1851 struct vq_sysfs_entry {
1852 	struct attribute attr;
1853 	ssize_t (*show)(struct vduse_virtqueue *vq, char *buf);
1854 	ssize_t (*store)(struct vduse_virtqueue *vq, const char *buf,
1855 			 size_t count);
1856 };
1857 
1858 static struct vq_sysfs_entry irq_cb_affinity_attr = __ATTR_RW(irq_cb_affinity);
1859 
1860 static struct attribute *vq_attrs[] = {
1861 	&irq_cb_affinity_attr.attr,
1862 	NULL,
1863 };
1864 ATTRIBUTE_GROUPS(vq);
1865 
1866 static ssize_t vq_attr_show(struct kobject *kobj, struct attribute *attr,
1867 			    char *buf)
1868 {
1869 	struct vduse_virtqueue *vq = container_of(kobj,
1870 					struct vduse_virtqueue, kobj);
1871 	struct vq_sysfs_entry *entry = container_of(attr,
1872 					struct vq_sysfs_entry, attr);
1873 
1874 	if (!entry->show)
1875 		return -EIO;
1876 
1877 	return entry->show(vq, buf);
1878 }
1879 
1880 static ssize_t vq_attr_store(struct kobject *kobj, struct attribute *attr,
1881 			     const char *buf, size_t count)
1882 {
1883 	struct vduse_virtqueue *vq = container_of(kobj,
1884 					struct vduse_virtqueue, kobj);
1885 	struct vq_sysfs_entry *entry = container_of(attr,
1886 					struct vq_sysfs_entry, attr);
1887 
1888 	if (!entry->store)
1889 		return -EIO;
1890 
1891 	return entry->store(vq, buf, count);
1892 }
1893 
1894 static const struct sysfs_ops vq_sysfs_ops = {
1895 	.show = vq_attr_show,
1896 	.store = vq_attr_store,
1897 };
1898 
1899 static void vq_release(struct kobject *kobj)
1900 {
1901 	struct vduse_virtqueue *vq = container_of(kobj,
1902 					struct vduse_virtqueue, kobj);
1903 	kfree(vq);
1904 }
1905 
1906 static const struct kobj_type vq_type = {
1907 	.release	= vq_release,
1908 	.sysfs_ops	= &vq_sysfs_ops,
1909 	.default_groups	= vq_groups,
1910 };
1911 
1912 static char *vduse_devnode(const struct device *dev, umode_t *mode)
1913 {
1914 	return kasprintf(GFP_KERNEL, "vduse/%s", dev_name(dev));
1915 }
1916 
1917 static const struct class vduse_class = {
1918 	.name = "vduse",
1919 	.devnode = vduse_devnode,
1920 };
1921 
1922 static void vduse_dev_deinit_vqs(struct vduse_dev *dev)
1923 {
1924 	int i;
1925 
1926 	if (!dev->vqs)
1927 		return;
1928 
1929 	for (i = 0; i < dev->vq_num; i++)
1930 		kobject_put(&dev->vqs[i]->kobj);
1931 	kfree(dev->vqs);
1932 }
1933 
1934 static int vduse_dev_init_vqs(struct vduse_dev *dev, u32 vq_align, u32 vq_num)
1935 {
1936 	int ret, i;
1937 
1938 	dev->vq_align = vq_align;
1939 	dev->vq_num = vq_num;
1940 	dev->vqs = kzalloc_objs(*dev->vqs, dev->vq_num);
1941 	if (!dev->vqs)
1942 		return -ENOMEM;
1943 
1944 	for (i = 0; i < vq_num; i++) {
1945 		dev->vqs[i] = kzalloc_obj(*dev->vqs[i]);
1946 		if (!dev->vqs[i]) {
1947 			ret = -ENOMEM;
1948 			goto err;
1949 		}
1950 
1951 		dev->vqs[i]->index = i;
1952 		dev->vqs[i]->irq_effective_cpu = IRQ_UNBOUND;
1953 		INIT_WORK(&dev->vqs[i]->inject, vduse_vq_irq_inject);
1954 		INIT_WORK(&dev->vqs[i]->kick, vduse_vq_kick_work);
1955 		spin_lock_init(&dev->vqs[i]->kick_lock);
1956 		spin_lock_init(&dev->vqs[i]->irq_lock);
1957 		cpumask_setall(&dev->vqs[i]->irq_affinity);
1958 
1959 		kobject_init(&dev->vqs[i]->kobj, &vq_type);
1960 		ret = kobject_add(&dev->vqs[i]->kobj,
1961 				  &dev->dev->kobj, "vq%d", i);
1962 		if (ret) {
1963 			kfree(dev->vqs[i]);
1964 			goto err;
1965 		}
1966 	}
1967 
1968 	return 0;
1969 err:
1970 	while (i--)
1971 		kobject_put(&dev->vqs[i]->kobj);
1972 	kfree(dev->vqs);
1973 	dev->vqs = NULL;
1974 	return ret;
1975 }
1976 
1977 static struct vduse_dev *vduse_dev_create(void)
1978 {
1979 	struct vduse_dev *dev = kzalloc_obj(*dev);
1980 
1981 	if (!dev)
1982 		return NULL;
1983 
1984 	mutex_init(&dev->lock);
1985 	mutex_init(&dev->domain_lock);
1986 	spin_lock_init(&dev->msg_lock);
1987 	INIT_LIST_HEAD(&dev->send_list);
1988 	INIT_LIST_HEAD(&dev->recv_list);
1989 	spin_lock_init(&dev->irq_lock);
1990 	init_rwsem(&dev->rwsem);
1991 
1992 	INIT_WORK(&dev->inject, vduse_dev_irq_inject);
1993 	init_waitqueue_head(&dev->waitq);
1994 
1995 	return dev;
1996 }
1997 
1998 static void vduse_dev_destroy(struct vduse_dev *dev)
1999 {
2000 	kfree(dev);
2001 }
2002 
2003 static struct vduse_dev *vduse_find_dev(const char *name)
2004 {
2005 	struct vduse_dev *dev;
2006 	int id;
2007 
2008 	idr_for_each_entry(&vduse_idr, dev, id)
2009 		if (!strcmp(dev->name, name))
2010 			return dev;
2011 
2012 	return NULL;
2013 }
2014 
2015 static int vduse_destroy_dev(char *name)
2016 {
2017 	struct vduse_dev *dev = vduse_find_dev(name);
2018 
2019 	if (!dev)
2020 		return -EINVAL;
2021 
2022 	mutex_lock(&dev->lock);
2023 	if (dev->vdev || dev->connected) {
2024 		mutex_unlock(&dev->lock);
2025 		return -EBUSY;
2026 	}
2027 	dev->connected = true;
2028 	mutex_unlock(&dev->lock);
2029 
2030 	vduse_dev_reset(dev);
2031 	device_destroy(&vduse_class, MKDEV(MAJOR(vduse_major), dev->minor));
2032 	idr_remove(&vduse_idr, dev->minor);
2033 	kvfree(dev->config);
2034 	vduse_dev_deinit_vqs(dev);
2035 	for (int i = 0; i < dev->nas; i++) {
2036 		if (dev->as[i].domain)
2037 			vduse_domain_destroy(dev->as[i].domain);
2038 	}
2039 	kfree(dev->as);
2040 	kfree(dev->name);
2041 	kfree(dev->groups);
2042 	vduse_dev_destroy(dev);
2043 	module_put(THIS_MODULE);
2044 
2045 	return 0;
2046 }
2047 
2048 static bool device_is_allowed(u32 device_id)
2049 {
2050 	int i;
2051 
2052 	for (i = 0; i < ARRAY_SIZE(allowed_device_id); i++)
2053 		if (allowed_device_id[i] == device_id)
2054 			return true;
2055 
2056 	return false;
2057 }
2058 
2059 static bool features_is_valid(struct vduse_dev_config *config)
2060 {
2061 	if (!(config->features & BIT_ULL(VIRTIO_F_ACCESS_PLATFORM)))
2062 		return false;
2063 
2064 	/* Now we only support read-only configuration space */
2065 	if ((config->device_id == VIRTIO_ID_BLOCK) &&
2066 			(config->features & BIT_ULL(VIRTIO_BLK_F_CONFIG_WCE)))
2067 		return false;
2068 	else if ((config->device_id == VIRTIO_ID_NET) &&
2069 			(config->features & BIT_ULL(VIRTIO_NET_F_CTRL_VQ)))
2070 		return false;
2071 
2072 	if ((config->device_id == VIRTIO_ID_NET) &&
2073 			!(config->features & BIT_ULL(VIRTIO_F_VERSION_1)))
2074 		return false;
2075 
2076 	return true;
2077 }
2078 
2079 static bool vduse_validate_config(struct vduse_dev_config *config,
2080 				  u64 api_version)
2081 {
2082 	if (!is_mem_zero((const char *)config->reserved,
2083 			 sizeof(config->reserved)))
2084 		return false;
2085 
2086 	if (api_version < VDUSE_API_VERSION_1 &&
2087 	    (config->ngroups || config->nas))
2088 		return false;
2089 
2090 	if (api_version >= VDUSE_API_VERSION_1) {
2091 		if (!config->ngroups || config->ngroups > VDUSE_DEV_MAX_GROUPS)
2092 			return false;
2093 
2094 		if (!config->nas || config->nas > VDUSE_DEV_MAX_AS)
2095 			return false;
2096 	}
2097 
2098 	if (config->vq_align > PAGE_SIZE)
2099 		return false;
2100 
2101 	if (config->config_size > PAGE_SIZE)
2102 		return false;
2103 
2104 	if (config->vq_num > 0xffff)
2105 		return false;
2106 
2107 	if (!config->name[0])
2108 		return false;
2109 
2110 	if (!device_is_allowed(config->device_id))
2111 		return false;
2112 
2113 	if (!features_is_valid(config))
2114 		return false;
2115 
2116 	return true;
2117 }
2118 
2119 static ssize_t msg_timeout_show(struct device *device,
2120 				struct device_attribute *attr, char *buf)
2121 {
2122 	struct vduse_dev *dev = dev_get_drvdata(device);
2123 
2124 	return sysfs_emit(buf, "%u\n", dev->msg_timeout);
2125 }
2126 
2127 static ssize_t msg_timeout_store(struct device *device,
2128 				 struct device_attribute *attr,
2129 				 const char *buf, size_t count)
2130 {
2131 	struct vduse_dev *dev = dev_get_drvdata(device);
2132 	int ret;
2133 
2134 	ret = kstrtouint(buf, 10, &dev->msg_timeout);
2135 	if (ret < 0)
2136 		return ret;
2137 
2138 	return count;
2139 }
2140 
2141 static DEVICE_ATTR_RW(msg_timeout);
2142 
2143 static ssize_t bounce_size_show(struct device *device,
2144 				struct device_attribute *attr, char *buf)
2145 {
2146 	struct vduse_dev *dev = dev_get_drvdata(device);
2147 
2148 	return sysfs_emit(buf, "%u\n", dev->bounce_size);
2149 }
2150 
2151 static ssize_t bounce_size_store(struct device *device,
2152 				 struct device_attribute *attr,
2153 				 const char *buf, size_t count)
2154 {
2155 	struct vduse_dev *dev = dev_get_drvdata(device);
2156 	unsigned int bounce_size;
2157 	int ret;
2158 
2159 	ret = -EPERM;
2160 	mutex_lock(&dev->domain_lock);
2161 	/* Assuming that if the first domain is allocated, all are allocated */
2162 	if (dev->as[0].domain)
2163 		goto unlock;
2164 
2165 	ret = kstrtouint(buf, 10, &bounce_size);
2166 	if (ret < 0)
2167 		goto unlock;
2168 
2169 	ret = -EINVAL;
2170 	if (bounce_size > VDUSE_MAX_BOUNCE_SIZE ||
2171 	    bounce_size < VDUSE_MIN_BOUNCE_SIZE)
2172 		goto unlock;
2173 
2174 	dev->bounce_size = bounce_size & PAGE_MASK;
2175 	ret = count;
2176 unlock:
2177 	mutex_unlock(&dev->domain_lock);
2178 	return ret;
2179 }
2180 
2181 static DEVICE_ATTR_RW(bounce_size);
2182 
2183 static struct attribute *vduse_dev_attrs[] = {
2184 	&dev_attr_msg_timeout.attr,
2185 	&dev_attr_bounce_size.attr,
2186 	NULL
2187 };
2188 
2189 ATTRIBUTE_GROUPS(vduse_dev);
2190 
2191 static int vduse_create_dev(struct vduse_dev_config *config,
2192 			    void *config_buf, u64 api_version)
2193 {
2194 	int ret;
2195 	struct vduse_dev *dev;
2196 
2197 	ret = -EPERM;
2198 	if ((config->device_id == VIRTIO_ID_NET) && !capable(CAP_NET_ADMIN))
2199 		goto err;
2200 
2201 	ret = -EEXIST;
2202 	if (vduse_find_dev(config->name))
2203 		goto err;
2204 
2205 	ret = -ENOMEM;
2206 	dev = vduse_dev_create();
2207 	if (!dev)
2208 		goto err;
2209 
2210 	dev->api_version = api_version;
2211 	dev->device_features = config->features;
2212 	dev->device_id = config->device_id;
2213 	dev->vendor_id = config->vendor_id;
2214 
2215 	dev->nas = (dev->api_version < VDUSE_API_VERSION_1) ? 1 : config->nas;
2216 	dev->as = kzalloc_objs(dev->as[0], dev->nas);
2217 	if (!dev->as)
2218 		goto err_as;
2219 	for (int i = 0; i < dev->nas; i++)
2220 		mutex_init(&dev->as[i].mem_lock);
2221 
2222 	dev->ngroups = (dev->api_version < VDUSE_API_VERSION_1)
2223 		       ? 1
2224 		       : config->ngroups;
2225 	dev->groups = kzalloc_objs(dev->groups[0], dev->ngroups);
2226 	if (!dev->groups)
2227 		goto err_vq_groups;
2228 	for (u32 i = 0; i < dev->ngroups; ++i) {
2229 		dev->groups[i].dev = dev;
2230 		rwlock_init(&dev->groups[i].as_lock);
2231 		dev->groups[i].as = &dev->as[0];
2232 	}
2233 
2234 	dev->name = kstrdup(config->name, GFP_KERNEL);
2235 	if (!dev->name)
2236 		goto err_str;
2237 
2238 	dev->bounce_size = VDUSE_BOUNCE_SIZE;
2239 	dev->config = config_buf;
2240 	dev->config_size = config->config_size;
2241 
2242 	ret = idr_alloc(&vduse_idr, dev, 1, VDUSE_DEV_MAX, GFP_KERNEL);
2243 	if (ret < 0)
2244 		goto err_idr;
2245 
2246 	dev->minor = ret;
2247 	dev->msg_timeout = VDUSE_MSG_DEFAULT_TIMEOUT;
2248 	dev->dev = device_create_with_groups(&vduse_class, NULL,
2249 				MKDEV(MAJOR(vduse_major), dev->minor),
2250 				dev, vduse_dev_groups, "%s", config->name);
2251 	if (IS_ERR(dev->dev)) {
2252 		ret = PTR_ERR(dev->dev);
2253 		goto err_dev;
2254 	}
2255 
2256 	ret = vduse_dev_init_vqs(dev, config->vq_align, config->vq_num);
2257 	if (ret)
2258 		goto err_vqs;
2259 
2260 	__module_get(THIS_MODULE);
2261 
2262 	return 0;
2263 err_vqs:
2264 	device_destroy(&vduse_class, MKDEV(MAJOR(vduse_major), dev->minor));
2265 err_dev:
2266 	idr_remove(&vduse_idr, dev->minor);
2267 err_idr:
2268 	kfree(dev->name);
2269 err_str:
2270 	kfree(dev->groups);
2271 err_vq_groups:
2272 	kfree(dev->as);
2273 err_as:
2274 	vduse_dev_destroy(dev);
2275 err:
2276 	return ret;
2277 }
2278 
2279 static long vduse_ioctl(struct file *file, unsigned int cmd,
2280 			unsigned long arg)
2281 {
2282 	int ret;
2283 	void __user *argp = (void __user *)arg;
2284 	struct vduse_control *control = file->private_data;
2285 
2286 	mutex_lock(&vduse_lock);
2287 	switch (cmd) {
2288 	case VDUSE_GET_API_VERSION:
2289 		if (control->api_version == VDUSE_API_VERSION_NOT_ASKED)
2290 			control->api_version = VDUSE_API_VERSION_1;
2291 		ret = put_user(control->api_version, (u64 __user *)argp);
2292 		break;
2293 	case VDUSE_SET_API_VERSION: {
2294 		u64 api_version;
2295 
2296 		ret = -EFAULT;
2297 		if (get_user(api_version, (u64 __user *)argp))
2298 			break;
2299 
2300 		ret = -EINVAL;
2301 		if (api_version > VDUSE_API_VERSION_1)
2302 			break;
2303 
2304 		ret = 0;
2305 		control->api_version = api_version;
2306 		break;
2307 	}
2308 	case VDUSE_CREATE_DEV: {
2309 		struct vduse_dev_config config;
2310 		unsigned long size = offsetof(struct vduse_dev_config, config);
2311 		void *buf;
2312 
2313 		ret = -EFAULT;
2314 		if (copy_from_user(&config, argp, size))
2315 			break;
2316 
2317 		ret = -EINVAL;
2318 		if (control->api_version == VDUSE_API_VERSION_NOT_ASKED)
2319 			control->api_version = VDUSE_API_VERSION;
2320 		if (!vduse_validate_config(&config, control->api_version))
2321 			break;
2322 
2323 		buf = vmemdup_user(argp + size, config.config_size);
2324 		if (IS_ERR(buf)) {
2325 			ret = PTR_ERR(buf);
2326 			break;
2327 		}
2328 		config.name[VDUSE_NAME_MAX - 1] = '\0';
2329 		ret = vduse_create_dev(&config, buf, control->api_version);
2330 		if (ret)
2331 			kvfree(buf);
2332 		break;
2333 	}
2334 	case VDUSE_DESTROY_DEV: {
2335 		char name[VDUSE_NAME_MAX];
2336 
2337 		ret = -EFAULT;
2338 		if (copy_from_user(name, argp, VDUSE_NAME_MAX))
2339 			break;
2340 
2341 		name[VDUSE_NAME_MAX - 1] = '\0';
2342 		ret = vduse_destroy_dev(name);
2343 		break;
2344 	}
2345 	default:
2346 		ret = -EINVAL;
2347 		break;
2348 	}
2349 	mutex_unlock(&vduse_lock);
2350 
2351 	return ret;
2352 }
2353 
2354 static int vduse_release(struct inode *inode, struct file *file)
2355 {
2356 	struct vduse_control *control = file->private_data;
2357 
2358 	kfree(control);
2359 	return 0;
2360 }
2361 
2362 static int vduse_open(struct inode *inode, struct file *file)
2363 {
2364 	struct vduse_control *control;
2365 
2366 	control = kmalloc_obj(struct vduse_control);
2367 	if (!control)
2368 		return -ENOMEM;
2369 
2370 	control->api_version = VDUSE_API_VERSION_NOT_ASKED;
2371 	file->private_data = control;
2372 
2373 	return 0;
2374 }
2375 
2376 static const struct file_operations vduse_ctrl_fops = {
2377 	.owner		= THIS_MODULE,
2378 	.open		= vduse_open,
2379 	.release	= vduse_release,
2380 	.unlocked_ioctl	= vduse_ioctl,
2381 	.compat_ioctl	= compat_ptr_ioctl,
2382 	.llseek		= noop_llseek,
2383 };
2384 
2385 struct vduse_mgmt_dev {
2386 	struct vdpa_mgmt_dev mgmt_dev;
2387 	struct device dev;
2388 };
2389 
2390 static struct vduse_mgmt_dev *vduse_mgmt;
2391 
2392 static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name)
2393 {
2394 	struct vduse_vdpa *vdev;
2395 
2396 	if (dev->vdev)
2397 		return -EEXIST;
2398 
2399 	vdev = vdpa_alloc_device(struct vduse_vdpa, vdpa, dev->dev,
2400 				 &vduse_vdpa_config_ops, &vduse_map_ops,
2401 				 dev->ngroups, dev->nas, name, true);
2402 	if (IS_ERR(vdev))
2403 		return PTR_ERR(vdev);
2404 
2405 	dev->vdev = vdev;
2406 	vdev->dev = dev;
2407 	vdev->vdpa.mdev = &vduse_mgmt->mgmt_dev;
2408 
2409 	return 0;
2410 }
2411 
2412 static int vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
2413 			const struct vdpa_dev_set_config *config)
2414 {
2415 	struct vduse_dev *dev;
2416 	size_t domain_bounce_size;
2417 	int ret, i;
2418 
2419 	mutex_lock(&vduse_lock);
2420 	dev = vduse_find_dev(name);
2421 	if (!dev || !vduse_dev_is_ready(dev)) {
2422 		mutex_unlock(&vduse_lock);
2423 		return -EINVAL;
2424 	}
2425 	ret = vduse_dev_init_vdpa(dev, name);
2426 	mutex_unlock(&vduse_lock);
2427 	if (ret)
2428 		return ret;
2429 
2430 	mutex_lock(&dev->domain_lock);
2431 	ret = 0;
2432 
2433 	domain_bounce_size = dev->bounce_size / dev->nas;
2434 	for (i = 0; i < dev->nas; ++i) {
2435 		dev->as[i].domain = vduse_domain_create(VDUSE_IOVA_SIZE - 1,
2436 							domain_bounce_size);
2437 		if (!dev->as[i].domain) {
2438 			ret = -ENOMEM;
2439 			goto err;
2440 		}
2441 	}
2442 
2443 	mutex_unlock(&dev->domain_lock);
2444 
2445 	ret = _vdpa_register_device(&dev->vdev->vdpa, dev->vq_num);
2446 	if (ret)
2447 		goto err_register;
2448 
2449 	return 0;
2450 
2451 err_register:
2452 	mutex_lock(&dev->domain_lock);
2453 
2454 err:
2455 	for (int j = 0; j < i; j++) {
2456 		if (dev->as[j].domain) {
2457 			vduse_domain_destroy(dev->as[j].domain);
2458 			dev->as[j].domain = NULL;
2459 		}
2460 	}
2461 	mutex_unlock(&dev->domain_lock);
2462 
2463 	put_device(&dev->vdev->vdpa.dev);
2464 
2465 	return ret;
2466 }
2467 
2468 static void vdpa_dev_del(struct vdpa_mgmt_dev *mdev, struct vdpa_device *dev)
2469 {
2470 	_vdpa_unregister_device(dev);
2471 }
2472 
2473 static const struct vdpa_mgmtdev_ops vdpa_dev_mgmtdev_ops = {
2474 	.dev_add = vdpa_dev_add,
2475 	.dev_del = vdpa_dev_del,
2476 };
2477 
2478 static struct virtio_device_id id_table[] = {
2479 	{ VIRTIO_ID_BLOCK, VIRTIO_DEV_ANY_ID },
2480 	{ VIRTIO_ID_NET, VIRTIO_DEV_ANY_ID },
2481 	{ 0 },
2482 };
2483 
2484 static void vduse_mgmtdev_release(struct device *dev)
2485 {
2486 	struct vduse_mgmt_dev *mgmt_dev;
2487 
2488 	mgmt_dev = container_of(dev, struct vduse_mgmt_dev, dev);
2489 	kfree(mgmt_dev);
2490 }
2491 
2492 static int vduse_mgmtdev_init(void)
2493 {
2494 	int ret;
2495 
2496 	vduse_mgmt = kzalloc_obj(*vduse_mgmt);
2497 	if (!vduse_mgmt)
2498 		return -ENOMEM;
2499 
2500 	ret = dev_set_name(&vduse_mgmt->dev, "vduse");
2501 	if (ret) {
2502 		kfree(vduse_mgmt);
2503 		return ret;
2504 	}
2505 
2506 	vduse_mgmt->dev.release = vduse_mgmtdev_release;
2507 
2508 	ret = device_register(&vduse_mgmt->dev);
2509 	if (ret)
2510 		goto dev_reg_err;
2511 
2512 	vduse_mgmt->mgmt_dev.id_table = id_table;
2513 	vduse_mgmt->mgmt_dev.ops = &vdpa_dev_mgmtdev_ops;
2514 	vduse_mgmt->mgmt_dev.device = &vduse_mgmt->dev;
2515 	ret = vdpa_mgmtdev_register(&vduse_mgmt->mgmt_dev);
2516 	if (ret)
2517 		device_unregister(&vduse_mgmt->dev);
2518 
2519 	return ret;
2520 
2521 dev_reg_err:
2522 	put_device(&vduse_mgmt->dev);
2523 	return ret;
2524 }
2525 
2526 static void vduse_mgmtdev_exit(void)
2527 {
2528 	vdpa_mgmtdev_unregister(&vduse_mgmt->mgmt_dev);
2529 	device_unregister(&vduse_mgmt->dev);
2530 }
2531 
2532 static int vduse_init(void)
2533 {
2534 	int ret;
2535 	struct device *dev;
2536 
2537 	ret = class_register(&vduse_class);
2538 	if (ret)
2539 		return ret;
2540 
2541 	ret = alloc_chrdev_region(&vduse_major, 0, VDUSE_DEV_MAX, "vduse");
2542 	if (ret)
2543 		goto err_chardev_region;
2544 
2545 	/* /dev/vduse/control */
2546 	cdev_init(&vduse_ctrl_cdev, &vduse_ctrl_fops);
2547 	vduse_ctrl_cdev.owner = THIS_MODULE;
2548 	ret = cdev_add(&vduse_ctrl_cdev, vduse_major, 1);
2549 	if (ret)
2550 		goto err_ctrl_cdev;
2551 
2552 	dev = device_create(&vduse_class, NULL, vduse_major, NULL, "control");
2553 	if (IS_ERR(dev)) {
2554 		ret = PTR_ERR(dev);
2555 		goto err_device;
2556 	}
2557 
2558 	/* /dev/vduse/$DEVICE */
2559 	cdev_init(&vduse_cdev, &vduse_dev_fops);
2560 	vduse_cdev.owner = THIS_MODULE;
2561 	ret = cdev_add(&vduse_cdev, MKDEV(MAJOR(vduse_major), 1),
2562 		       VDUSE_DEV_MAX - 1);
2563 	if (ret)
2564 		goto err_cdev;
2565 
2566 	ret = -ENOMEM;
2567 	vduse_irq_wq = alloc_workqueue("vduse-irq",
2568 				WQ_HIGHPRI | WQ_SYSFS | WQ_UNBOUND, 0);
2569 	if (!vduse_irq_wq)
2570 		goto err_wq;
2571 
2572 	vduse_irq_bound_wq = alloc_workqueue("vduse-irq-bound",
2573 					     WQ_HIGHPRI | WQ_PERCPU, 0);
2574 	if (!vduse_irq_bound_wq)
2575 		goto err_bound_wq;
2576 
2577 	ret = vduse_domain_init();
2578 	if (ret)
2579 		goto err_domain;
2580 
2581 	ret = vduse_mgmtdev_init();
2582 	if (ret)
2583 		goto err_mgmtdev;
2584 
2585 	return 0;
2586 err_mgmtdev:
2587 	vduse_domain_exit();
2588 err_domain:
2589 	destroy_workqueue(vduse_irq_bound_wq);
2590 err_bound_wq:
2591 	destroy_workqueue(vduse_irq_wq);
2592 err_wq:
2593 	cdev_del(&vduse_cdev);
2594 err_cdev:
2595 	device_destroy(&vduse_class, vduse_major);
2596 err_device:
2597 	cdev_del(&vduse_ctrl_cdev);
2598 err_ctrl_cdev:
2599 	unregister_chrdev_region(vduse_major, VDUSE_DEV_MAX);
2600 err_chardev_region:
2601 	class_unregister(&vduse_class);
2602 	return ret;
2603 }
2604 module_init(vduse_init);
2605 
2606 static void vduse_exit(void)
2607 {
2608 	vduse_mgmtdev_exit();
2609 	vduse_domain_exit();
2610 	destroy_workqueue(vduse_irq_bound_wq);
2611 	destroy_workqueue(vduse_irq_wq);
2612 	cdev_del(&vduse_cdev);
2613 	device_destroy(&vduse_class, vduse_major);
2614 	cdev_del(&vduse_ctrl_cdev);
2615 	unregister_chrdev_region(vduse_major, VDUSE_DEV_MAX);
2616 	class_unregister(&vduse_class);
2617 	idr_destroy(&vduse_idr);
2618 }
2619 module_exit(vduse_exit);
2620 
2621 MODULE_LICENSE(DRV_LICENSE);
2622 MODULE_AUTHOR(DRV_AUTHOR);
2623 MODULE_DESCRIPTION(DRV_DESC);
2624