1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * vhost transport for vsock 4 * 5 * Copyright (C) 2013-2015 Red Hat, Inc. 6 * Author: Asias He <asias@redhat.com> 7 * Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 #include <linux/miscdevice.h> 10 #include <linux/atomic.h> 11 #include <linux/module.h> 12 #include <linux/mutex.h> 13 #include <linux/vmalloc.h> 14 #include <net/sock.h> 15 #include <linux/virtio_vsock.h> 16 #include <linux/vhost.h> 17 #include <linux/hashtable.h> 18 19 #include <net/af_vsock.h> 20 #include "vhost.h" 21 22 #define VHOST_VSOCK_DEFAULT_HOST_CID 2 23 /* Max number of bytes transferred before requeueing the job. 24 * Using this limit prevents one virtqueue from starving others. */ 25 #define VHOST_VSOCK_WEIGHT 0x80000 26 /* Max number of packets transferred before requeueing the job. 27 * Using this limit prevents one virtqueue from starving others with 28 * small pkts. 29 */ 30 #define VHOST_VSOCK_PKT_WEIGHT 256 31 32 enum { 33 VHOST_VSOCK_FEATURES = VHOST_FEATURES | 34 (1ULL << VIRTIO_F_ACCESS_PLATFORM) 35 }; 36 37 enum { 38 VHOST_VSOCK_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) 39 }; 40 41 /* Used to track all the vhost_vsock instances on the system. */ 42 static DEFINE_MUTEX(vhost_vsock_mutex); 43 static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8); 44 45 struct vhost_vsock { 46 struct vhost_dev dev; 47 struct vhost_virtqueue vqs[2]; 48 49 /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */ 50 struct hlist_node hash; 51 52 struct vhost_work send_pkt_work; 53 spinlock_t send_pkt_list_lock; 54 struct list_head send_pkt_list; /* host->guest pending packets */ 55 56 atomic_t queued_replies; 57 58 u32 guest_cid; 59 }; 60 61 static u32 vhost_transport_get_local_cid(void) 62 { 63 return VHOST_VSOCK_DEFAULT_HOST_CID; 64 } 65 66 /* Callers that dereference the return value must hold vhost_vsock_mutex or the 67 * RCU read lock. 68 */ 69 static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) 70 { 71 struct vhost_vsock *vsock; 72 73 hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) { 74 u32 other_cid = vsock->guest_cid; 75 76 /* Skip instances that have no CID yet */ 77 if (other_cid == 0) 78 continue; 79 80 if (other_cid == guest_cid) 81 return vsock; 82 83 } 84 85 return NULL; 86 } 87 88 static void 89 vhost_transport_do_send_pkt(struct vhost_vsock *vsock, 90 struct vhost_virtqueue *vq) 91 { 92 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX]; 93 int pkts = 0, total_len = 0; 94 bool added = false; 95 bool restart_tx = false; 96 97 mutex_lock(&vq->mutex); 98 99 if (!vhost_vq_get_backend(vq)) 100 goto out; 101 102 if (!vq_meta_prefetch(vq)) 103 goto out; 104 105 /* Avoid further vmexits, we're already processing the virtqueue */ 106 vhost_disable_notify(&vsock->dev, vq); 107 108 do { 109 struct virtio_vsock_pkt *pkt; 110 struct iov_iter iov_iter; 111 unsigned out, in; 112 size_t nbytes; 113 size_t iov_len, payload_len; 114 int head; 115 116 spin_lock_bh(&vsock->send_pkt_list_lock); 117 if (list_empty(&vsock->send_pkt_list)) { 118 spin_unlock_bh(&vsock->send_pkt_list_lock); 119 vhost_enable_notify(&vsock->dev, vq); 120 break; 121 } 122 123 pkt = list_first_entry(&vsock->send_pkt_list, 124 struct virtio_vsock_pkt, list); 125 list_del_init(&pkt->list); 126 spin_unlock_bh(&vsock->send_pkt_list_lock); 127 128 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 129 &out, &in, NULL, NULL); 130 if (head < 0) { 131 spin_lock_bh(&vsock->send_pkt_list_lock); 132 list_add(&pkt->list, &vsock->send_pkt_list); 133 spin_unlock_bh(&vsock->send_pkt_list_lock); 134 break; 135 } 136 137 if (head == vq->num) { 138 spin_lock_bh(&vsock->send_pkt_list_lock); 139 list_add(&pkt->list, &vsock->send_pkt_list); 140 spin_unlock_bh(&vsock->send_pkt_list_lock); 141 142 /* We cannot finish yet if more buffers snuck in while 143 * re-enabling notify. 144 */ 145 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { 146 vhost_disable_notify(&vsock->dev, vq); 147 continue; 148 } 149 break; 150 } 151 152 if (out) { 153 virtio_transport_free_pkt(pkt); 154 vq_err(vq, "Expected 0 output buffers, got %u\n", out); 155 break; 156 } 157 158 iov_len = iov_length(&vq->iov[out], in); 159 if (iov_len < sizeof(pkt->hdr)) { 160 virtio_transport_free_pkt(pkt); 161 vq_err(vq, "Buffer len [%zu] too small\n", iov_len); 162 break; 163 } 164 165 iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len); 166 payload_len = pkt->len - pkt->off; 167 168 /* If the packet is greater than the space available in the 169 * buffer, we split it using multiple buffers. 170 */ 171 if (payload_len > iov_len - sizeof(pkt->hdr)) 172 payload_len = iov_len - sizeof(pkt->hdr); 173 174 /* Set the correct length in the header */ 175 pkt->hdr.len = cpu_to_le32(payload_len); 176 177 nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); 178 if (nbytes != sizeof(pkt->hdr)) { 179 virtio_transport_free_pkt(pkt); 180 vq_err(vq, "Faulted on copying pkt hdr\n"); 181 break; 182 } 183 184 nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len, 185 &iov_iter); 186 if (nbytes != payload_len) { 187 virtio_transport_free_pkt(pkt); 188 vq_err(vq, "Faulted on copying pkt buf\n"); 189 break; 190 } 191 192 /* Deliver to monitoring devices all packets that we 193 * will transmit. 194 */ 195 virtio_transport_deliver_tap_pkt(pkt); 196 197 vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len); 198 added = true; 199 200 pkt->off += payload_len; 201 total_len += payload_len; 202 203 /* If we didn't send all the payload we can requeue the packet 204 * to send it with the next available buffer. 205 */ 206 if (pkt->off < pkt->len) { 207 /* We are queueing the same virtio_vsock_pkt to handle 208 * the remaining bytes, and we want to deliver it 209 * to monitoring devices in the next iteration. 210 */ 211 pkt->tap_delivered = false; 212 213 spin_lock_bh(&vsock->send_pkt_list_lock); 214 list_add(&pkt->list, &vsock->send_pkt_list); 215 spin_unlock_bh(&vsock->send_pkt_list_lock); 216 } else { 217 if (pkt->reply) { 218 int val; 219 220 val = atomic_dec_return(&vsock->queued_replies); 221 222 /* Do we have resources to resume tx 223 * processing? 224 */ 225 if (val + 1 == tx_vq->num) 226 restart_tx = true; 227 } 228 229 virtio_transport_free_pkt(pkt); 230 } 231 } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len))); 232 if (added) 233 vhost_signal(&vsock->dev, vq); 234 235 out: 236 mutex_unlock(&vq->mutex); 237 238 if (restart_tx) 239 vhost_poll_queue(&tx_vq->poll); 240 } 241 242 static void vhost_transport_send_pkt_work(struct vhost_work *work) 243 { 244 struct vhost_virtqueue *vq; 245 struct vhost_vsock *vsock; 246 247 vsock = container_of(work, struct vhost_vsock, send_pkt_work); 248 vq = &vsock->vqs[VSOCK_VQ_RX]; 249 250 vhost_transport_do_send_pkt(vsock, vq); 251 } 252 253 static int 254 vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) 255 { 256 struct vhost_vsock *vsock; 257 int len = pkt->len; 258 259 rcu_read_lock(); 260 261 /* Find the vhost_vsock according to guest context id */ 262 vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid)); 263 if (!vsock) { 264 rcu_read_unlock(); 265 virtio_transport_free_pkt(pkt); 266 return -ENODEV; 267 } 268 269 if (pkt->reply) 270 atomic_inc(&vsock->queued_replies); 271 272 spin_lock_bh(&vsock->send_pkt_list_lock); 273 list_add_tail(&pkt->list, &vsock->send_pkt_list); 274 spin_unlock_bh(&vsock->send_pkt_list_lock); 275 276 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); 277 278 rcu_read_unlock(); 279 return len; 280 } 281 282 static int 283 vhost_transport_cancel_pkt(struct vsock_sock *vsk) 284 { 285 struct vhost_vsock *vsock; 286 struct virtio_vsock_pkt *pkt, *n; 287 int cnt = 0; 288 int ret = -ENODEV; 289 LIST_HEAD(freeme); 290 291 rcu_read_lock(); 292 293 /* Find the vhost_vsock according to guest context id */ 294 vsock = vhost_vsock_get(vsk->remote_addr.svm_cid); 295 if (!vsock) 296 goto out; 297 298 spin_lock_bh(&vsock->send_pkt_list_lock); 299 list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) { 300 if (pkt->vsk != vsk) 301 continue; 302 list_move(&pkt->list, &freeme); 303 } 304 spin_unlock_bh(&vsock->send_pkt_list_lock); 305 306 list_for_each_entry_safe(pkt, n, &freeme, list) { 307 if (pkt->reply) 308 cnt++; 309 list_del(&pkt->list); 310 virtio_transport_free_pkt(pkt); 311 } 312 313 if (cnt) { 314 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX]; 315 int new_cnt; 316 317 new_cnt = atomic_sub_return(cnt, &vsock->queued_replies); 318 if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num) 319 vhost_poll_queue(&tx_vq->poll); 320 } 321 322 ret = 0; 323 out: 324 rcu_read_unlock(); 325 return ret; 326 } 327 328 static struct virtio_vsock_pkt * 329 vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq, 330 unsigned int out, unsigned int in) 331 { 332 struct virtio_vsock_pkt *pkt; 333 struct iov_iter iov_iter; 334 size_t nbytes; 335 size_t len; 336 337 if (in != 0) { 338 vq_err(vq, "Expected 0 input buffers, got %u\n", in); 339 return NULL; 340 } 341 342 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 343 if (!pkt) 344 return NULL; 345 346 len = iov_length(vq->iov, out); 347 iov_iter_init(&iov_iter, WRITE, vq->iov, out, len); 348 349 nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); 350 if (nbytes != sizeof(pkt->hdr)) { 351 vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n", 352 sizeof(pkt->hdr), nbytes); 353 kfree(pkt); 354 return NULL; 355 } 356 357 if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) 358 pkt->len = le32_to_cpu(pkt->hdr.len); 359 360 /* No payload */ 361 if (!pkt->len) 362 return pkt; 363 364 /* The pkt is too big */ 365 if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { 366 kfree(pkt); 367 return NULL; 368 } 369 370 pkt->buf = kmalloc(pkt->len, GFP_KERNEL); 371 if (!pkt->buf) { 372 kfree(pkt); 373 return NULL; 374 } 375 376 pkt->buf_len = pkt->len; 377 378 nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter); 379 if (nbytes != pkt->len) { 380 vq_err(vq, "Expected %u byte payload, got %zu bytes\n", 381 pkt->len, nbytes); 382 virtio_transport_free_pkt(pkt); 383 return NULL; 384 } 385 386 return pkt; 387 } 388 389 /* Is there space left for replies to rx packets? */ 390 static bool vhost_vsock_more_replies(struct vhost_vsock *vsock) 391 { 392 struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX]; 393 int val; 394 395 smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ 396 val = atomic_read(&vsock->queued_replies); 397 398 return val < vq->num; 399 } 400 401 static struct virtio_transport vhost_transport = { 402 .transport = { 403 .module = THIS_MODULE, 404 405 .get_local_cid = vhost_transport_get_local_cid, 406 407 .init = virtio_transport_do_socket_init, 408 .destruct = virtio_transport_destruct, 409 .release = virtio_transport_release, 410 .connect = virtio_transport_connect, 411 .shutdown = virtio_transport_shutdown, 412 .cancel_pkt = vhost_transport_cancel_pkt, 413 414 .dgram_enqueue = virtio_transport_dgram_enqueue, 415 .dgram_dequeue = virtio_transport_dgram_dequeue, 416 .dgram_bind = virtio_transport_dgram_bind, 417 .dgram_allow = virtio_transport_dgram_allow, 418 419 .stream_enqueue = virtio_transport_stream_enqueue, 420 .stream_dequeue = virtio_transport_stream_dequeue, 421 .stream_has_data = virtio_transport_stream_has_data, 422 .stream_has_space = virtio_transport_stream_has_space, 423 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, 424 .stream_is_active = virtio_transport_stream_is_active, 425 .stream_allow = virtio_transport_stream_allow, 426 427 .notify_poll_in = virtio_transport_notify_poll_in, 428 .notify_poll_out = virtio_transport_notify_poll_out, 429 .notify_recv_init = virtio_transport_notify_recv_init, 430 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, 431 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, 432 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, 433 .notify_send_init = virtio_transport_notify_send_init, 434 .notify_send_pre_block = virtio_transport_notify_send_pre_block, 435 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, 436 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, 437 .notify_buffer_size = virtio_transport_notify_buffer_size, 438 439 }, 440 441 .send_pkt = vhost_transport_send_pkt, 442 }; 443 444 static void vhost_vsock_handle_tx_kick(struct vhost_work *work) 445 { 446 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, 447 poll.work); 448 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, 449 dev); 450 struct virtio_vsock_pkt *pkt; 451 int head, pkts = 0, total_len = 0; 452 unsigned int out, in; 453 bool added = false; 454 455 mutex_lock(&vq->mutex); 456 457 if (!vhost_vq_get_backend(vq)) 458 goto out; 459 460 if (!vq_meta_prefetch(vq)) 461 goto out; 462 463 vhost_disable_notify(&vsock->dev, vq); 464 do { 465 u32 len; 466 467 if (!vhost_vsock_more_replies(vsock)) { 468 /* Stop tx until the device processes already 469 * pending replies. Leave tx virtqueue 470 * callbacks disabled. 471 */ 472 goto no_more_replies; 473 } 474 475 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 476 &out, &in, NULL, NULL); 477 if (head < 0) 478 break; 479 480 if (head == vq->num) { 481 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { 482 vhost_disable_notify(&vsock->dev, vq); 483 continue; 484 } 485 break; 486 } 487 488 pkt = vhost_vsock_alloc_pkt(vq, out, in); 489 if (!pkt) { 490 vq_err(vq, "Faulted on pkt\n"); 491 continue; 492 } 493 494 len = pkt->len; 495 496 /* Deliver to monitoring devices all received packets */ 497 virtio_transport_deliver_tap_pkt(pkt); 498 499 /* Only accept correctly addressed packets */ 500 if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid && 501 le64_to_cpu(pkt->hdr.dst_cid) == 502 vhost_transport_get_local_cid()) 503 virtio_transport_recv_pkt(&vhost_transport, pkt); 504 else 505 virtio_transport_free_pkt(pkt); 506 507 len += sizeof(pkt->hdr); 508 vhost_add_used(vq, head, len); 509 total_len += len; 510 added = true; 511 } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len))); 512 513 no_more_replies: 514 if (added) 515 vhost_signal(&vsock->dev, vq); 516 517 out: 518 mutex_unlock(&vq->mutex); 519 } 520 521 static void vhost_vsock_handle_rx_kick(struct vhost_work *work) 522 { 523 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, 524 poll.work); 525 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, 526 dev); 527 528 vhost_transport_do_send_pkt(vsock, vq); 529 } 530 531 static int vhost_vsock_start(struct vhost_vsock *vsock) 532 { 533 struct vhost_virtqueue *vq; 534 size_t i; 535 int ret; 536 537 mutex_lock(&vsock->dev.mutex); 538 539 ret = vhost_dev_check_owner(&vsock->dev); 540 if (ret) 541 goto err; 542 543 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 544 vq = &vsock->vqs[i]; 545 546 mutex_lock(&vq->mutex); 547 548 if (!vhost_vq_access_ok(vq)) { 549 ret = -EFAULT; 550 goto err_vq; 551 } 552 553 if (!vhost_vq_get_backend(vq)) { 554 vhost_vq_set_backend(vq, vsock); 555 ret = vhost_vq_init_access(vq); 556 if (ret) 557 goto err_vq; 558 } 559 560 mutex_unlock(&vq->mutex); 561 } 562 563 /* Some packets may have been queued before the device was started, 564 * let's kick the send worker to send them. 565 */ 566 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); 567 568 mutex_unlock(&vsock->dev.mutex); 569 return 0; 570 571 err_vq: 572 vhost_vq_set_backend(vq, NULL); 573 mutex_unlock(&vq->mutex); 574 575 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 576 vq = &vsock->vqs[i]; 577 578 mutex_lock(&vq->mutex); 579 vhost_vq_set_backend(vq, NULL); 580 mutex_unlock(&vq->mutex); 581 } 582 err: 583 mutex_unlock(&vsock->dev.mutex); 584 return ret; 585 } 586 587 static int vhost_vsock_stop(struct vhost_vsock *vsock) 588 { 589 size_t i; 590 int ret; 591 592 mutex_lock(&vsock->dev.mutex); 593 594 ret = vhost_dev_check_owner(&vsock->dev); 595 if (ret) 596 goto err; 597 598 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 599 struct vhost_virtqueue *vq = &vsock->vqs[i]; 600 601 mutex_lock(&vq->mutex); 602 vhost_vq_set_backend(vq, NULL); 603 mutex_unlock(&vq->mutex); 604 } 605 606 err: 607 mutex_unlock(&vsock->dev.mutex); 608 return ret; 609 } 610 611 static void vhost_vsock_free(struct vhost_vsock *vsock) 612 { 613 kvfree(vsock); 614 } 615 616 static int vhost_vsock_dev_open(struct inode *inode, struct file *file) 617 { 618 struct vhost_virtqueue **vqs; 619 struct vhost_vsock *vsock; 620 int ret; 621 622 /* This struct is large and allocation could fail, fall back to vmalloc 623 * if there is no other way. 624 */ 625 vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL); 626 if (!vsock) 627 return -ENOMEM; 628 629 vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL); 630 if (!vqs) { 631 ret = -ENOMEM; 632 goto out; 633 } 634 635 vsock->guest_cid = 0; /* no CID assigned yet */ 636 637 atomic_set(&vsock->queued_replies, 0); 638 639 vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX]; 640 vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX]; 641 vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick; 642 vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick; 643 644 vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs), 645 UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT, 646 VHOST_VSOCK_WEIGHT, true, NULL); 647 648 file->private_data = vsock; 649 spin_lock_init(&vsock->send_pkt_list_lock); 650 INIT_LIST_HEAD(&vsock->send_pkt_list); 651 vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work); 652 return 0; 653 654 out: 655 vhost_vsock_free(vsock); 656 return ret; 657 } 658 659 static void vhost_vsock_flush(struct vhost_vsock *vsock) 660 { 661 int i; 662 663 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) 664 if (vsock->vqs[i].handle_kick) 665 vhost_poll_flush(&vsock->vqs[i].poll); 666 vhost_work_flush(&vsock->dev, &vsock->send_pkt_work); 667 } 668 669 static void vhost_vsock_reset_orphans(struct sock *sk) 670 { 671 struct vsock_sock *vsk = vsock_sk(sk); 672 673 /* vmci_transport.c doesn't take sk_lock here either. At least we're 674 * under vsock_table_lock so the sock cannot disappear while we're 675 * executing. 676 */ 677 678 /* If the peer is still valid, no need to reset connection */ 679 if (vhost_vsock_get(vsk->remote_addr.svm_cid)) 680 return; 681 682 /* If the close timeout is pending, let it expire. This avoids races 683 * with the timeout callback. 684 */ 685 if (vsk->close_work_scheduled) 686 return; 687 688 sock_set_flag(sk, SOCK_DONE); 689 vsk->peer_shutdown = SHUTDOWN_MASK; 690 sk->sk_state = SS_UNCONNECTED; 691 sk->sk_err = ECONNRESET; 692 sk->sk_error_report(sk); 693 } 694 695 static int vhost_vsock_dev_release(struct inode *inode, struct file *file) 696 { 697 struct vhost_vsock *vsock = file->private_data; 698 699 mutex_lock(&vhost_vsock_mutex); 700 if (vsock->guest_cid) 701 hash_del_rcu(&vsock->hash); 702 mutex_unlock(&vhost_vsock_mutex); 703 704 /* Wait for other CPUs to finish using vsock */ 705 synchronize_rcu(); 706 707 /* Iterating over all connections for all CIDs to find orphans is 708 * inefficient. Room for improvement here. */ 709 vsock_for_each_connected_socket(vhost_vsock_reset_orphans); 710 711 vhost_vsock_stop(vsock); 712 vhost_vsock_flush(vsock); 713 vhost_dev_stop(&vsock->dev); 714 715 spin_lock_bh(&vsock->send_pkt_list_lock); 716 while (!list_empty(&vsock->send_pkt_list)) { 717 struct virtio_vsock_pkt *pkt; 718 719 pkt = list_first_entry(&vsock->send_pkt_list, 720 struct virtio_vsock_pkt, list); 721 list_del_init(&pkt->list); 722 virtio_transport_free_pkt(pkt); 723 } 724 spin_unlock_bh(&vsock->send_pkt_list_lock); 725 726 vhost_dev_cleanup(&vsock->dev); 727 kfree(vsock->dev.vqs); 728 vhost_vsock_free(vsock); 729 return 0; 730 } 731 732 static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) 733 { 734 struct vhost_vsock *other; 735 736 /* Refuse reserved CIDs */ 737 if (guest_cid <= VMADDR_CID_HOST || 738 guest_cid == U32_MAX) 739 return -EINVAL; 740 741 /* 64-bit CIDs are not yet supported */ 742 if (guest_cid > U32_MAX) 743 return -EINVAL; 744 745 /* Refuse if CID is assigned to the guest->host transport (i.e. nested 746 * VM), to make the loopback work. 747 */ 748 if (vsock_find_cid(guest_cid)) 749 return -EADDRINUSE; 750 751 /* Refuse if CID is already in use */ 752 mutex_lock(&vhost_vsock_mutex); 753 other = vhost_vsock_get(guest_cid); 754 if (other && other != vsock) { 755 mutex_unlock(&vhost_vsock_mutex); 756 return -EADDRINUSE; 757 } 758 759 if (vsock->guest_cid) 760 hash_del_rcu(&vsock->hash); 761 762 vsock->guest_cid = guest_cid; 763 hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid); 764 mutex_unlock(&vhost_vsock_mutex); 765 766 return 0; 767 } 768 769 static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features) 770 { 771 struct vhost_virtqueue *vq; 772 int i; 773 774 if (features & ~VHOST_VSOCK_FEATURES) 775 return -EOPNOTSUPP; 776 777 mutex_lock(&vsock->dev.mutex); 778 if ((features & (1 << VHOST_F_LOG_ALL)) && 779 !vhost_log_access_ok(&vsock->dev)) { 780 goto err; 781 } 782 783 if ((features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) { 784 if (vhost_init_device_iotlb(&vsock->dev, true)) 785 goto err; 786 } 787 788 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 789 vq = &vsock->vqs[i]; 790 mutex_lock(&vq->mutex); 791 vq->acked_features = features; 792 mutex_unlock(&vq->mutex); 793 } 794 mutex_unlock(&vsock->dev.mutex); 795 return 0; 796 797 err: 798 mutex_unlock(&vsock->dev.mutex); 799 return -EFAULT; 800 } 801 802 static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl, 803 unsigned long arg) 804 { 805 struct vhost_vsock *vsock = f->private_data; 806 void __user *argp = (void __user *)arg; 807 u64 guest_cid; 808 u64 features; 809 int start; 810 int r; 811 812 switch (ioctl) { 813 case VHOST_VSOCK_SET_GUEST_CID: 814 if (copy_from_user(&guest_cid, argp, sizeof(guest_cid))) 815 return -EFAULT; 816 return vhost_vsock_set_cid(vsock, guest_cid); 817 case VHOST_VSOCK_SET_RUNNING: 818 if (copy_from_user(&start, argp, sizeof(start))) 819 return -EFAULT; 820 if (start) 821 return vhost_vsock_start(vsock); 822 else 823 return vhost_vsock_stop(vsock); 824 case VHOST_GET_FEATURES: 825 features = VHOST_VSOCK_FEATURES; 826 if (copy_to_user(argp, &features, sizeof(features))) 827 return -EFAULT; 828 return 0; 829 case VHOST_SET_FEATURES: 830 if (copy_from_user(&features, argp, sizeof(features))) 831 return -EFAULT; 832 return vhost_vsock_set_features(vsock, features); 833 case VHOST_GET_BACKEND_FEATURES: 834 features = VHOST_VSOCK_BACKEND_FEATURES; 835 if (copy_to_user(argp, &features, sizeof(features))) 836 return -EFAULT; 837 return 0; 838 case VHOST_SET_BACKEND_FEATURES: 839 if (copy_from_user(&features, argp, sizeof(features))) 840 return -EFAULT; 841 if (features & ~VHOST_VSOCK_BACKEND_FEATURES) 842 return -EOPNOTSUPP; 843 vhost_set_backend_features(&vsock->dev, features); 844 return 0; 845 default: 846 mutex_lock(&vsock->dev.mutex); 847 r = vhost_dev_ioctl(&vsock->dev, ioctl, argp); 848 if (r == -ENOIOCTLCMD) 849 r = vhost_vring_ioctl(&vsock->dev, ioctl, argp); 850 else 851 vhost_vsock_flush(vsock); 852 mutex_unlock(&vsock->dev.mutex); 853 return r; 854 } 855 } 856 857 static ssize_t vhost_vsock_chr_read_iter(struct kiocb *iocb, struct iov_iter *to) 858 { 859 struct file *file = iocb->ki_filp; 860 struct vhost_vsock *vsock = file->private_data; 861 struct vhost_dev *dev = &vsock->dev; 862 int noblock = file->f_flags & O_NONBLOCK; 863 864 return vhost_chr_read_iter(dev, to, noblock); 865 } 866 867 static ssize_t vhost_vsock_chr_write_iter(struct kiocb *iocb, 868 struct iov_iter *from) 869 { 870 struct file *file = iocb->ki_filp; 871 struct vhost_vsock *vsock = file->private_data; 872 struct vhost_dev *dev = &vsock->dev; 873 874 return vhost_chr_write_iter(dev, from); 875 } 876 877 static __poll_t vhost_vsock_chr_poll(struct file *file, poll_table *wait) 878 { 879 struct vhost_vsock *vsock = file->private_data; 880 struct vhost_dev *dev = &vsock->dev; 881 882 return vhost_chr_poll(file, dev, wait); 883 } 884 885 static const struct file_operations vhost_vsock_fops = { 886 .owner = THIS_MODULE, 887 .open = vhost_vsock_dev_open, 888 .release = vhost_vsock_dev_release, 889 .llseek = noop_llseek, 890 .unlocked_ioctl = vhost_vsock_dev_ioctl, 891 .compat_ioctl = compat_ptr_ioctl, 892 .read_iter = vhost_vsock_chr_read_iter, 893 .write_iter = vhost_vsock_chr_write_iter, 894 .poll = vhost_vsock_chr_poll, 895 }; 896 897 static struct miscdevice vhost_vsock_misc = { 898 .minor = VHOST_VSOCK_MINOR, 899 .name = "vhost-vsock", 900 .fops = &vhost_vsock_fops, 901 }; 902 903 static int __init vhost_vsock_init(void) 904 { 905 int ret; 906 907 ret = vsock_core_register(&vhost_transport.transport, 908 VSOCK_TRANSPORT_F_H2G); 909 if (ret < 0) 910 return ret; 911 return misc_register(&vhost_vsock_misc); 912 }; 913 914 static void __exit vhost_vsock_exit(void) 915 { 916 misc_deregister(&vhost_vsock_misc); 917 vsock_core_unregister(&vhost_transport.transport); 918 }; 919 920 module_init(vhost_vsock_init); 921 module_exit(vhost_vsock_exit); 922 MODULE_LICENSE("GPL v2"); 923 MODULE_AUTHOR("Asias He"); 924 MODULE_DESCRIPTION("vhost transport for vsock "); 925 MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR); 926 MODULE_ALIAS("devname:vhost-vsock"); 927