1 // SPDX-License-Identifier: GPL-2.0-only 2 /* Copyright (C) 2009 Red Hat, Inc. 3 * Author: Michael S. Tsirkin <mst@redhat.com> 4 * 5 * virtio-net server in host kernel. 6 */ 7 8 #include <linux/compat.h> 9 #include <linux/eventfd.h> 10 #include <linux/vhost.h> 11 #include <linux/virtio_net.h> 12 #include <linux/miscdevice.h> 13 #include <linux/module.h> 14 #include <linux/moduleparam.h> 15 #include <linux/mutex.h> 16 #include <linux/workqueue.h> 17 #include <linux/file.h> 18 #include <linux/slab.h> 19 #include <linux/sched/clock.h> 20 #include <linux/sched/signal.h> 21 #include <linux/vmalloc.h> 22 23 #include <linux/net.h> 24 #include <linux/if_packet.h> 25 #include <linux/if_arp.h> 26 #include <linux/if_tun.h> 27 #include <linux/if_macvlan.h> 28 #include <linux/if_tap.h> 29 #include <linux/if_vlan.h> 30 #include <linux/skb_array.h> 31 #include <linux/skbuff.h> 32 33 #include <net/sock.h> 34 #include <net/xdp.h> 35 36 #include "vhost.h" 37 38 static int experimental_zcopytx = 0; 39 module_param(experimental_zcopytx, int, 0444); 40 MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;" 41 " 1 -Enable; 0 - Disable"); 42 43 /* Max number of bytes transferred before requeueing the job. 44 * Using this limit prevents one virtqueue from starving others. */ 45 #define VHOST_NET_WEIGHT 0x80000 46 47 /* Max number of packets transferred before requeueing the job. 48 * Using this limit prevents one virtqueue from starving others with small 49 * pkts. 50 */ 51 #define VHOST_NET_PKT_WEIGHT 256 52 53 /* MAX number of TX used buffers for outstanding zerocopy */ 54 #define VHOST_MAX_PEND 128 55 #define VHOST_GOODCOPY_LEN 256 56 57 /* 58 * For transmit, used buffer len is unused; we override it to track buffer 59 * status internally; used for zerocopy tx only. 60 */ 61 /* Lower device DMA failed */ 62 #define VHOST_DMA_FAILED_LEN ((__force __virtio32)3) 63 /* Lower device DMA done */ 64 #define VHOST_DMA_DONE_LEN ((__force __virtio32)2) 65 /* Lower device DMA in progress */ 66 #define VHOST_DMA_IN_PROGRESS ((__force __virtio32)1) 67 /* Buffer unused */ 68 #define VHOST_DMA_CLEAR_LEN ((__force __virtio32)0) 69 70 #define VHOST_DMA_IS_DONE(len) ((__force u32)(len) >= (__force u32)VHOST_DMA_DONE_LEN) 71 72 static const u64 vhost_net_features[VIRTIO_FEATURES_DWORDS] = { 73 VHOST_FEATURES | 74 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | 75 (1ULL << VIRTIO_NET_F_MRG_RXBUF) | 76 (1ULL << VIRTIO_F_ACCESS_PLATFORM) | 77 (1ULL << VIRTIO_F_RING_RESET) | 78 (1ULL << VIRTIO_F_IN_ORDER), 79 VIRTIO_BIT(VIRTIO_NET_F_GUEST_UDP_TUNNEL_GSO) | 80 VIRTIO_BIT(VIRTIO_NET_F_HOST_UDP_TUNNEL_GSO), 81 }; 82 83 enum { 84 VHOST_NET_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) 85 }; 86 87 enum { 88 VHOST_NET_VQ_RX = 0, 89 VHOST_NET_VQ_TX = 1, 90 VHOST_NET_VQ_MAX = 2, 91 }; 92 93 struct vhost_net_ubuf_ref { 94 /* refcount follows semantics similar to kref: 95 * 0: object is released 96 * 1: no outstanding ubufs 97 * >1: outstanding ubufs 98 */ 99 atomic_t refcount; 100 wait_queue_head_t wait; 101 struct vhost_virtqueue *vq; 102 }; 103 104 #define VHOST_NET_BATCH 64 105 struct vhost_net_buf { 106 void **queue; 107 int tail; 108 int head; 109 }; 110 111 struct vhost_net_virtqueue { 112 struct vhost_virtqueue vq; 113 size_t vhost_hlen; 114 size_t sock_hlen; 115 /* vhost zerocopy support fields below: */ 116 /* last used idx for outstanding DMA zerocopy buffers */ 117 int upend_idx; 118 /* For TX, first used idx for DMA done zerocopy buffers 119 * For RX, number of batched heads 120 */ 121 int done_idx; 122 /* Number of XDP frames batched */ 123 int batched_xdp; 124 /* an array of userspace buffers info */ 125 struct ubuf_info_msgzc *ubuf_info; 126 /* Reference counting for outstanding ubufs. 127 * Protected by vq mutex. Writers must also take device mutex. */ 128 struct vhost_net_ubuf_ref *ubufs; 129 struct ptr_ring *rx_ring; 130 struct vhost_net_buf rxq; 131 /* Batched XDP buffs */ 132 struct xdp_buff *xdp; 133 }; 134 135 struct vhost_net { 136 struct vhost_dev dev; 137 struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX]; 138 struct vhost_poll poll[VHOST_NET_VQ_MAX]; 139 /* Number of TX recently submitted. 140 * Protected by tx vq lock. */ 141 unsigned tx_packets; 142 /* Number of times zerocopy TX recently failed. 143 * Protected by tx vq lock. */ 144 unsigned tx_zcopy_err; 145 /* Flush in progress. Protected by tx vq lock. */ 146 bool tx_flush; 147 /* Private page frag cache */ 148 struct page_frag_cache pf_cache; 149 }; 150 151 static unsigned vhost_net_zcopy_mask __read_mostly; 152 153 static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq) 154 { 155 if (rxq->tail != rxq->head) 156 return rxq->queue[rxq->head]; 157 else 158 return NULL; 159 } 160 161 static int vhost_net_buf_get_size(struct vhost_net_buf *rxq) 162 { 163 return rxq->tail - rxq->head; 164 } 165 166 static int vhost_net_buf_is_empty(struct vhost_net_buf *rxq) 167 { 168 return rxq->tail == rxq->head; 169 } 170 171 static void *vhost_net_buf_consume(struct vhost_net_buf *rxq) 172 { 173 void *ret = vhost_net_buf_get_ptr(rxq); 174 ++rxq->head; 175 return ret; 176 } 177 178 static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq) 179 { 180 struct vhost_net_buf *rxq = &nvq->rxq; 181 182 rxq->head = 0; 183 rxq->tail = ptr_ring_consume_batched(nvq->rx_ring, rxq->queue, 184 VHOST_NET_BATCH); 185 return rxq->tail; 186 } 187 188 static void vhost_net_buf_unproduce(struct vhost_net_virtqueue *nvq) 189 { 190 struct vhost_net_buf *rxq = &nvq->rxq; 191 192 if (nvq->rx_ring && !vhost_net_buf_is_empty(rxq)) { 193 ptr_ring_unconsume(nvq->rx_ring, rxq->queue + rxq->head, 194 vhost_net_buf_get_size(rxq), 195 tun_ptr_free); 196 rxq->head = rxq->tail = 0; 197 } 198 } 199 200 static int vhost_net_buf_peek_len(void *ptr) 201 { 202 if (tun_is_xdp_frame(ptr)) { 203 struct xdp_frame *xdpf = tun_ptr_to_xdp(ptr); 204 205 return xdpf->len; 206 } 207 208 return __skb_array_len_with_tag(ptr); 209 } 210 211 static int vhost_net_buf_peek(struct vhost_net_virtqueue *nvq) 212 { 213 struct vhost_net_buf *rxq = &nvq->rxq; 214 215 if (!vhost_net_buf_is_empty(rxq)) 216 goto out; 217 218 if (!vhost_net_buf_produce(nvq)) 219 return 0; 220 221 out: 222 return vhost_net_buf_peek_len(vhost_net_buf_get_ptr(rxq)); 223 } 224 225 static void vhost_net_buf_init(struct vhost_net_buf *rxq) 226 { 227 rxq->head = rxq->tail = 0; 228 } 229 230 static void vhost_net_enable_zcopy(int vq) 231 { 232 vhost_net_zcopy_mask |= 0x1 << vq; 233 } 234 235 static struct vhost_net_ubuf_ref * 236 vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy) 237 { 238 struct vhost_net_ubuf_ref *ubufs; 239 /* No zero copy backend? Nothing to count. */ 240 if (!zcopy) 241 return NULL; 242 ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL); 243 if (!ubufs) 244 return ERR_PTR(-ENOMEM); 245 atomic_set(&ubufs->refcount, 1); 246 init_waitqueue_head(&ubufs->wait); 247 ubufs->vq = vq; 248 return ubufs; 249 } 250 251 static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs) 252 { 253 int r = atomic_sub_return(1, &ubufs->refcount); 254 if (unlikely(!r)) 255 wake_up(&ubufs->wait); 256 return r; 257 } 258 259 static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs) 260 { 261 vhost_net_ubuf_put(ubufs); 262 wait_event(ubufs->wait, !atomic_read(&ubufs->refcount)); 263 } 264 265 static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs) 266 { 267 vhost_net_ubuf_put_and_wait(ubufs); 268 kfree(ubufs); 269 } 270 271 static void vhost_net_clear_ubuf_info(struct vhost_net *n) 272 { 273 int i; 274 275 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { 276 kfree(n->vqs[i].ubuf_info); 277 n->vqs[i].ubuf_info = NULL; 278 } 279 } 280 281 static int vhost_net_set_ubuf_info(struct vhost_net *n) 282 { 283 bool zcopy; 284 int i; 285 286 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { 287 zcopy = vhost_net_zcopy_mask & (0x1 << i); 288 if (!zcopy) 289 continue; 290 n->vqs[i].ubuf_info = 291 kmalloc_array(UIO_MAXIOV, 292 sizeof(*n->vqs[i].ubuf_info), 293 GFP_KERNEL); 294 if (!n->vqs[i].ubuf_info) 295 goto err; 296 } 297 return 0; 298 299 err: 300 vhost_net_clear_ubuf_info(n); 301 return -ENOMEM; 302 } 303 304 static void vhost_net_vq_reset(struct vhost_net *n) 305 { 306 int i; 307 308 vhost_net_clear_ubuf_info(n); 309 310 for (i = 0; i < VHOST_NET_VQ_MAX; i++) { 311 n->vqs[i].done_idx = 0; 312 n->vqs[i].upend_idx = 0; 313 n->vqs[i].ubufs = NULL; 314 n->vqs[i].vhost_hlen = 0; 315 n->vqs[i].sock_hlen = 0; 316 vhost_net_buf_init(&n->vqs[i].rxq); 317 } 318 319 } 320 321 static void vhost_net_tx_packet(struct vhost_net *net) 322 { 323 ++net->tx_packets; 324 if (net->tx_packets < 1024) 325 return; 326 net->tx_packets = 0; 327 net->tx_zcopy_err = 0; 328 } 329 330 static void vhost_net_tx_err(struct vhost_net *net) 331 { 332 ++net->tx_zcopy_err; 333 } 334 335 static bool vhost_net_tx_select_zcopy(struct vhost_net *net) 336 { 337 /* TX flush waits for outstanding DMAs to be done. 338 * Don't start new DMAs. 339 */ 340 return !net->tx_flush && 341 net->tx_packets / 64 >= net->tx_zcopy_err; 342 } 343 344 static bool vhost_sock_zcopy(struct socket *sock) 345 { 346 return unlikely(experimental_zcopytx) && 347 sock_flag(sock->sk, SOCK_ZEROCOPY); 348 } 349 350 static bool vhost_sock_xdp(struct socket *sock) 351 { 352 return sock_flag(sock->sk, SOCK_XDP); 353 } 354 355 /* In case of DMA done not in order in lower device driver for some reason. 356 * upend_idx is used to track end of used idx, done_idx is used to track head 357 * of used idx. Once lower device DMA done contiguously, we will signal KVM 358 * guest used idx. 359 */ 360 static void vhost_zerocopy_signal_used(struct vhost_net *net, 361 struct vhost_virtqueue *vq) 362 { 363 struct vhost_net_virtqueue *nvq = 364 container_of(vq, struct vhost_net_virtqueue, vq); 365 int i, add; 366 int j = 0; 367 368 for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) { 369 if (vq->heads[i].len == VHOST_DMA_FAILED_LEN) 370 vhost_net_tx_err(net); 371 if (VHOST_DMA_IS_DONE(vq->heads[i].len)) { 372 vq->heads[i].len = VHOST_DMA_CLEAR_LEN; 373 ++j; 374 } else 375 break; 376 } 377 while (j) { 378 add = min(UIO_MAXIOV - nvq->done_idx, j); 379 vhost_add_used_and_signal_n(vq->dev, vq, 380 &vq->heads[nvq->done_idx], 381 NULL, add); 382 nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV; 383 j -= add; 384 } 385 } 386 387 static void vhost_zerocopy_complete(struct sk_buff *skb, 388 struct ubuf_info *ubuf_base, bool success) 389 { 390 struct ubuf_info_msgzc *ubuf = uarg_to_msgzc(ubuf_base); 391 struct vhost_net_ubuf_ref *ubufs = ubuf->ctx; 392 struct vhost_virtqueue *vq = ubufs->vq; 393 int cnt; 394 395 rcu_read_lock_bh(); 396 397 /* set len to mark this desc buffers done DMA */ 398 vq->heads[ubuf->desc].len = success ? 399 VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN; 400 cnt = vhost_net_ubuf_put(ubufs); 401 402 /* 403 * Trigger polling thread if guest stopped submitting new buffers: 404 * in this case, the refcount after decrement will eventually reach 1. 405 * We also trigger polling periodically after each 16 packets 406 * (the value 16 here is more or less arbitrary, it's tuned to trigger 407 * less than 10% of times). 408 */ 409 if (cnt <= 1 || !(cnt % 16)) 410 vhost_poll_queue(&vq->poll); 411 412 rcu_read_unlock_bh(); 413 } 414 415 static const struct ubuf_info_ops vhost_ubuf_ops = { 416 .complete = vhost_zerocopy_complete, 417 }; 418 419 static inline unsigned long busy_clock(void) 420 { 421 return local_clock() >> 10; 422 } 423 424 static bool vhost_can_busy_poll(unsigned long endtime) 425 { 426 return likely(!need_resched() && !time_after(busy_clock(), endtime) && 427 !signal_pending(current)); 428 } 429 430 static void vhost_net_disable_vq(struct vhost_net *n, 431 struct vhost_virtqueue *vq) 432 { 433 struct vhost_net_virtqueue *nvq = 434 container_of(vq, struct vhost_net_virtqueue, vq); 435 struct vhost_poll *poll = n->poll + (nvq - n->vqs); 436 if (!vhost_vq_get_backend(vq)) 437 return; 438 vhost_poll_stop(poll); 439 } 440 441 static int vhost_net_enable_vq(struct vhost_net *n, 442 struct vhost_virtqueue *vq) 443 { 444 struct vhost_net_virtqueue *nvq = 445 container_of(vq, struct vhost_net_virtqueue, vq); 446 struct vhost_poll *poll = n->poll + (nvq - n->vqs); 447 struct socket *sock; 448 449 sock = vhost_vq_get_backend(vq); 450 if (!sock) 451 return 0; 452 453 return vhost_poll_start(poll, sock->file); 454 } 455 456 static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq, 457 unsigned int count) 458 { 459 struct vhost_virtqueue *vq = &nvq->vq; 460 struct vhost_dev *dev = vq->dev; 461 462 if (!nvq->done_idx) 463 return; 464 465 vhost_add_used_and_signal_n(dev, vq, vq->heads, 466 vq->nheads, count); 467 nvq->done_idx = 0; 468 } 469 470 static void vhost_tx_batch(struct vhost_net *net, 471 struct vhost_net_virtqueue *nvq, 472 struct socket *sock, 473 struct msghdr *msghdr) 474 { 475 struct vhost_virtqueue *vq = &nvq->vq; 476 bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); 477 struct tun_msg_ctl ctl = { 478 .type = TUN_MSG_PTR, 479 .num = nvq->batched_xdp, 480 .ptr = nvq->xdp, 481 }; 482 int i, err; 483 484 if (in_order) { 485 vq->heads[0].len = 0; 486 vq->nheads[0] = nvq->done_idx; 487 } 488 489 if (nvq->batched_xdp == 0) 490 goto signal_used; 491 492 msghdr->msg_control = &ctl; 493 msghdr->msg_controllen = sizeof(ctl); 494 err = sock->ops->sendmsg(sock, msghdr, 0); 495 if (unlikely(err < 0)) { 496 vq_err(&nvq->vq, "Fail to batch sending packets\n"); 497 498 /* free pages owned by XDP; since this is an unlikely error path, 499 * keep it simple and avoid more complex bulk update for the 500 * used pages 501 */ 502 for (i = 0; i < nvq->batched_xdp; ++i) 503 put_page(virt_to_head_page(nvq->xdp[i].data)); 504 nvq->batched_xdp = 0; 505 nvq->done_idx = 0; 506 return; 507 } 508 509 signal_used: 510 vhost_net_signal_used(nvq, in_order ? 1 : nvq->done_idx); 511 nvq->batched_xdp = 0; 512 } 513 514 static int sock_has_rx_data(struct socket *sock) 515 { 516 if (unlikely(!sock)) 517 return 0; 518 519 if (sock->ops->peek_len) 520 return sock->ops->peek_len(sock); 521 522 return skb_queue_empty(&sock->sk->sk_receive_queue); 523 } 524 525 static void vhost_net_busy_poll_try_queue(struct vhost_net *net, 526 struct vhost_virtqueue *vq) 527 { 528 if (!vhost_vq_avail_empty(&net->dev, vq)) { 529 vhost_poll_queue(&vq->poll); 530 } else if (unlikely(vhost_enable_notify(&net->dev, vq))) { 531 vhost_disable_notify(&net->dev, vq); 532 vhost_poll_queue(&vq->poll); 533 } 534 } 535 536 static void vhost_net_busy_poll(struct vhost_net *net, 537 struct vhost_virtqueue *rvq, 538 struct vhost_virtqueue *tvq, 539 bool *busyloop_intr, 540 bool poll_rx) 541 { 542 unsigned long busyloop_timeout; 543 unsigned long endtime; 544 struct socket *sock; 545 struct vhost_virtqueue *vq = poll_rx ? tvq : rvq; 546 547 /* Try to hold the vq mutex of the paired virtqueue. We can't 548 * use mutex_lock() here since we could not guarantee a 549 * consistenet lock ordering. 550 */ 551 if (!mutex_trylock(&vq->mutex)) 552 return; 553 554 vhost_disable_notify(&net->dev, vq); 555 sock = vhost_vq_get_backend(rvq); 556 557 busyloop_timeout = poll_rx ? rvq->busyloop_timeout: 558 tvq->busyloop_timeout; 559 560 preempt_disable(); 561 endtime = busy_clock() + busyloop_timeout; 562 563 while (vhost_can_busy_poll(endtime)) { 564 if (vhost_vq_has_work(vq)) { 565 *busyloop_intr = true; 566 break; 567 } 568 569 if ((sock_has_rx_data(sock) && 570 !vhost_vq_avail_empty(&net->dev, rvq)) || 571 !vhost_vq_avail_empty(&net->dev, tvq)) 572 break; 573 574 cpu_relax(); 575 } 576 577 preempt_enable(); 578 579 if (poll_rx || sock_has_rx_data(sock)) 580 vhost_net_busy_poll_try_queue(net, vq); 581 else if (!poll_rx) /* On tx here, sock has no rx data. */ 582 vhost_enable_notify(&net->dev, rvq); 583 584 mutex_unlock(&vq->mutex); 585 } 586 587 static int vhost_net_tx_get_vq_desc(struct vhost_net *net, 588 struct vhost_net_virtqueue *tnvq, 589 unsigned int *out_num, unsigned int *in_num, 590 struct msghdr *msghdr, bool *busyloop_intr) 591 { 592 struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX]; 593 struct vhost_virtqueue *rvq = &rnvq->vq; 594 struct vhost_virtqueue *tvq = &tnvq->vq; 595 596 int r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov), 597 out_num, in_num, NULL, NULL); 598 599 if (r == tvq->num && tvq->busyloop_timeout) { 600 /* Flush batched packets first */ 601 if (!vhost_sock_zcopy(vhost_vq_get_backend(tvq))) 602 vhost_tx_batch(net, tnvq, 603 vhost_vq_get_backend(tvq), 604 msghdr); 605 606 vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, false); 607 608 r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov), 609 out_num, in_num, NULL, NULL); 610 } 611 612 return r; 613 } 614 615 static bool vhost_exceeds_maxpend(struct vhost_net *net) 616 { 617 struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; 618 struct vhost_virtqueue *vq = &nvq->vq; 619 620 return (nvq->upend_idx + UIO_MAXIOV - nvq->done_idx) % UIO_MAXIOV > 621 min_t(unsigned int, VHOST_MAX_PEND, vq->num >> 2); 622 } 623 624 static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter, 625 size_t hdr_size, int out) 626 { 627 /* Skip header. TODO: support TSO. */ 628 size_t len = iov_length(vq->iov, out); 629 630 iov_iter_init(iter, ITER_SOURCE, vq->iov, out, len); 631 iov_iter_advance(iter, hdr_size); 632 633 return iov_iter_count(iter); 634 } 635 636 static int get_tx_bufs(struct vhost_net *net, 637 struct vhost_net_virtqueue *nvq, 638 struct msghdr *msg, 639 unsigned int *out, unsigned int *in, 640 size_t *len, bool *busyloop_intr) 641 { 642 struct vhost_virtqueue *vq = &nvq->vq; 643 int ret; 644 645 ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg, busyloop_intr); 646 647 if (ret < 0 || ret == vq->num) 648 return ret; 649 650 if (*in) { 651 vq_err(vq, "Unexpected descriptor format for TX: out %d, int %d\n", 652 *out, *in); 653 return -EFAULT; 654 } 655 656 /* Sanity check */ 657 *len = init_iov_iter(vq, &msg->msg_iter, nvq->vhost_hlen, *out); 658 if (*len == 0) { 659 vq_err(vq, "Unexpected header len for TX: %zd expected %zd\n", 660 *len, nvq->vhost_hlen); 661 return -EFAULT; 662 } 663 664 return ret; 665 } 666 667 static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len) 668 { 669 return total_len < VHOST_NET_WEIGHT && 670 !vhost_vq_avail_empty(vq->dev, vq); 671 } 672 673 #define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD) 674 675 static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, 676 struct iov_iter *from) 677 { 678 struct vhost_virtqueue *vq = &nvq->vq; 679 struct vhost_net *net = container_of(vq->dev, struct vhost_net, 680 dev); 681 struct socket *sock = vhost_vq_get_backend(vq); 682 struct virtio_net_hdr *gso; 683 struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp]; 684 size_t len = iov_iter_count(from); 685 int headroom = vhost_sock_xdp(sock) ? XDP_PACKET_HEADROOM : 0; 686 int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); 687 int pad = SKB_DATA_ALIGN(VHOST_NET_RX_PAD + headroom + nvq->sock_hlen); 688 int sock_hlen = nvq->sock_hlen; 689 void *buf; 690 int copied; 691 int ret; 692 693 if (unlikely(len < nvq->sock_hlen)) 694 return -EFAULT; 695 696 if (SKB_DATA_ALIGN(len + pad) + 697 SKB_DATA_ALIGN(sizeof(struct skb_shared_info)) > PAGE_SIZE) 698 return -ENOSPC; 699 700 buflen += SKB_DATA_ALIGN(len + pad); 701 buf = page_frag_alloc_align(&net->pf_cache, buflen, GFP_KERNEL, 702 SMP_CACHE_BYTES); 703 if (unlikely(!buf)) 704 return -ENOMEM; 705 706 copied = copy_from_iter(buf + pad - sock_hlen, len, from); 707 if (copied != len) { 708 ret = -EFAULT; 709 goto err; 710 } 711 712 gso = buf + pad - sock_hlen; 713 714 if (!sock_hlen) 715 memset(buf, 0, pad); 716 717 if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) && 718 vhost16_to_cpu(vq, gso->csum_start) + 719 vhost16_to_cpu(vq, gso->csum_offset) + 2 > 720 vhost16_to_cpu(vq, gso->hdr_len)) { 721 gso->hdr_len = cpu_to_vhost16(vq, 722 vhost16_to_cpu(vq, gso->csum_start) + 723 vhost16_to_cpu(vq, gso->csum_offset) + 2); 724 725 if (vhost16_to_cpu(vq, gso->hdr_len) > len) { 726 ret = -EINVAL; 727 goto err; 728 } 729 } 730 731 /* pad contains sock_hlen */ 732 memcpy(buf, buf + pad - sock_hlen, sock_hlen); 733 734 xdp_init_buff(xdp, buflen, NULL); 735 xdp_prepare_buff(xdp, buf, pad, len - sock_hlen, true); 736 737 ++nvq->batched_xdp; 738 739 return 0; 740 741 err: 742 page_frag_free(buf); 743 return ret; 744 } 745 746 static void handle_tx_copy(struct vhost_net *net, struct socket *sock) 747 { 748 struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; 749 struct vhost_virtqueue *vq = &nvq->vq; 750 unsigned out, in; 751 int head; 752 struct msghdr msg = { 753 .msg_name = NULL, 754 .msg_namelen = 0, 755 .msg_control = NULL, 756 .msg_controllen = 0, 757 .msg_flags = MSG_DONTWAIT, 758 }; 759 size_t len, total_len = 0; 760 int err; 761 int sent_pkts = 0; 762 bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX); 763 bool busyloop_intr; 764 bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); 765 766 do { 767 busyloop_intr = false; 768 if (nvq->done_idx == VHOST_NET_BATCH) 769 vhost_tx_batch(net, nvq, sock, &msg); 770 771 head = get_tx_bufs(net, nvq, &msg, &out, &in, &len, 772 &busyloop_intr); 773 /* On error, stop handling until the next kick. */ 774 if (unlikely(head < 0)) 775 break; 776 /* Nothing new? Wait for eventfd to tell us they refilled. */ 777 if (head == vq->num) { 778 /* Kicks are disabled at this point, break loop and 779 * process any remaining batched packets. Queue will 780 * be re-enabled afterwards. 781 */ 782 break; 783 } 784 785 total_len += len; 786 787 /* For simplicity, TX batching is only enabled if 788 * sndbuf is unlimited. 789 */ 790 if (sock_can_batch) { 791 err = vhost_net_build_xdp(nvq, &msg.msg_iter); 792 if (!err) { 793 goto done; 794 } else if (unlikely(err != -ENOSPC)) { 795 vhost_tx_batch(net, nvq, sock, &msg); 796 vhost_discard_vq_desc(vq, 1); 797 vhost_net_enable_vq(net, vq); 798 break; 799 } 800 801 if (nvq->batched_xdp) { 802 /* We can't build XDP buff, go for single 803 * packet path but let's flush batched 804 * packets. 805 */ 806 vhost_tx_batch(net, nvq, sock, &msg); 807 } 808 msg.msg_control = NULL; 809 } else { 810 if (tx_can_batch(vq, total_len)) 811 msg.msg_flags |= MSG_MORE; 812 else 813 msg.msg_flags &= ~MSG_MORE; 814 } 815 816 err = sock->ops->sendmsg(sock, &msg, len); 817 if (unlikely(err < 0)) { 818 if (err == -EAGAIN || err == -ENOMEM || err == -ENOBUFS) { 819 vhost_discard_vq_desc(vq, 1); 820 vhost_net_enable_vq(net, vq); 821 break; 822 } 823 pr_debug("Fail to send packet: err %d", err); 824 } else if (unlikely(err != len)) 825 pr_debug("Truncated TX packet: len %d != %zd\n", 826 err, len); 827 done: 828 if (in_order) { 829 vq->heads[0].id = cpu_to_vhost32(vq, head); 830 } else { 831 vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); 832 vq->heads[nvq->done_idx].len = 0; 833 } 834 ++nvq->done_idx; 835 } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len))); 836 837 /* Kicks are still disabled, dispatch any remaining batched msgs. */ 838 vhost_tx_batch(net, nvq, sock, &msg); 839 840 if (unlikely(busyloop_intr)) 841 /* If interrupted while doing busy polling, requeue the 842 * handler to be fair handle_rx as well as other tasks 843 * waiting on cpu. 844 */ 845 vhost_poll_queue(&vq->poll); 846 else 847 /* All of our work has been completed; however, before 848 * leaving the TX handler, do one last check for work, 849 * and requeue handler if necessary. If there is no work, 850 * queue will be reenabled. 851 */ 852 vhost_net_busy_poll_try_queue(net, vq); 853 } 854 855 static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) 856 { 857 struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; 858 struct vhost_virtqueue *vq = &nvq->vq; 859 unsigned out, in; 860 int head; 861 struct msghdr msg = { 862 .msg_name = NULL, 863 .msg_namelen = 0, 864 .msg_control = NULL, 865 .msg_controllen = 0, 866 .msg_flags = MSG_DONTWAIT, 867 }; 868 struct tun_msg_ctl ctl; 869 size_t len, total_len = 0; 870 int err; 871 struct vhost_net_ubuf_ref *ubufs; 872 struct ubuf_info_msgzc *ubuf; 873 bool zcopy_used; 874 int sent_pkts = 0; 875 876 do { 877 bool busyloop_intr; 878 879 /* Release DMAs done buffers first */ 880 vhost_zerocopy_signal_used(net, vq); 881 882 busyloop_intr = false; 883 head = get_tx_bufs(net, nvq, &msg, &out, &in, &len, 884 &busyloop_intr); 885 /* On error, stop handling until the next kick. */ 886 if (unlikely(head < 0)) 887 break; 888 /* Nothing new? Wait for eventfd to tell us they refilled. */ 889 if (head == vq->num) { 890 if (unlikely(busyloop_intr)) { 891 vhost_poll_queue(&vq->poll); 892 } else if (unlikely(vhost_enable_notify(&net->dev, vq))) { 893 vhost_disable_notify(&net->dev, vq); 894 continue; 895 } 896 break; 897 } 898 899 zcopy_used = len >= VHOST_GOODCOPY_LEN 900 && !vhost_exceeds_maxpend(net) 901 && vhost_net_tx_select_zcopy(net); 902 903 /* use msg_control to pass vhost zerocopy ubuf info to skb */ 904 if (zcopy_used) { 905 ubuf = nvq->ubuf_info + nvq->upend_idx; 906 vq->heads[nvq->upend_idx].id = cpu_to_vhost32(vq, head); 907 vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS; 908 ubuf->ctx = nvq->ubufs; 909 ubuf->desc = nvq->upend_idx; 910 ubuf->ubuf.ops = &vhost_ubuf_ops; 911 ubuf->ubuf.flags = SKBFL_ZEROCOPY_FRAG; 912 refcount_set(&ubuf->ubuf.refcnt, 1); 913 msg.msg_control = &ctl; 914 ctl.type = TUN_MSG_UBUF; 915 ctl.ptr = &ubuf->ubuf; 916 msg.msg_controllen = sizeof(ctl); 917 ubufs = nvq->ubufs; 918 atomic_inc(&ubufs->refcount); 919 nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV; 920 } else { 921 msg.msg_control = NULL; 922 ubufs = NULL; 923 } 924 total_len += len; 925 if (tx_can_batch(vq, total_len) && 926 likely(!vhost_exceeds_maxpend(net))) { 927 msg.msg_flags |= MSG_MORE; 928 } else { 929 msg.msg_flags &= ~MSG_MORE; 930 } 931 932 err = sock->ops->sendmsg(sock, &msg, len); 933 if (unlikely(err < 0)) { 934 bool retry = err == -EAGAIN || err == -ENOMEM || err == -ENOBUFS; 935 936 if (zcopy_used) { 937 if (vq->heads[ubuf->desc].len == VHOST_DMA_IN_PROGRESS) 938 vhost_net_ubuf_put(ubufs); 939 if (retry) 940 nvq->upend_idx = ((unsigned)nvq->upend_idx - 1) 941 % UIO_MAXIOV; 942 else 943 vq->heads[ubuf->desc].len = VHOST_DMA_DONE_LEN; 944 } 945 if (retry) { 946 vhost_discard_vq_desc(vq, 1); 947 vhost_net_enable_vq(net, vq); 948 break; 949 } 950 pr_debug("Fail to send packet: err %d", err); 951 } else if (unlikely(err != len)) 952 pr_debug("Truncated TX packet: " 953 " len %d != %zd\n", err, len); 954 if (!zcopy_used) 955 vhost_add_used_and_signal(&net->dev, vq, head, 0); 956 else 957 vhost_zerocopy_signal_used(net, vq); 958 vhost_net_tx_packet(net); 959 } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len))); 960 } 961 962 /* Expects to be always run from workqueue - which acts as 963 * read-size critical section for our kind of RCU. */ 964 static void handle_tx(struct vhost_net *net) 965 { 966 struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; 967 struct vhost_virtqueue *vq = &nvq->vq; 968 struct socket *sock; 969 970 mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_TX); 971 sock = vhost_vq_get_backend(vq); 972 if (!sock) 973 goto out; 974 975 if (!vq_meta_prefetch(vq)) 976 goto out; 977 978 vhost_disable_notify(&net->dev, vq); 979 vhost_net_disable_vq(net, vq); 980 981 if (vhost_sock_zcopy(sock)) 982 handle_tx_zerocopy(net, sock); 983 else 984 handle_tx_copy(net, sock); 985 986 out: 987 mutex_unlock(&vq->mutex); 988 } 989 990 static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk) 991 { 992 struct sk_buff *head; 993 int len = 0; 994 unsigned long flags; 995 996 if (rvq->rx_ring) 997 return vhost_net_buf_peek(rvq); 998 999 spin_lock_irqsave(&sk->sk_receive_queue.lock, flags); 1000 head = skb_peek(&sk->sk_receive_queue); 1001 if (likely(head)) { 1002 len = head->len; 1003 if (skb_vlan_tag_present(head)) 1004 len += VLAN_HLEN; 1005 } 1006 1007 spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags); 1008 return len; 1009 } 1010 1011 static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, 1012 bool *busyloop_intr, unsigned int count) 1013 { 1014 struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX]; 1015 struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX]; 1016 struct vhost_virtqueue *rvq = &rnvq->vq; 1017 struct vhost_virtqueue *tvq = &tnvq->vq; 1018 int len = peek_head_len(rnvq, sk); 1019 1020 if (!len && rvq->busyloop_timeout) { 1021 /* Flush batched heads first */ 1022 vhost_net_signal_used(rnvq, count); 1023 /* Both tx vq and rx socket were polled here */ 1024 vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, true); 1025 1026 len = peek_head_len(rnvq, sk); 1027 } 1028 1029 return len; 1030 } 1031 1032 /* This is a multi-buffer version of vhost_get_desc, that works if 1033 * vq has read descriptors only. 1034 * @nvq - the relevant vhost_net virtqueue 1035 * @datalen - data length we'll be reading 1036 * @iovcount - returned count of io vectors we fill 1037 * @log - vhost log 1038 * @log_num - log offset 1039 * @quota - headcount quota, 1 for big buffer 1040 * returns number of buffer heads allocated, negative on error 1041 */ 1042 static int get_rx_bufs(struct vhost_net_virtqueue *nvq, 1043 struct vring_used_elem *heads, 1044 u16 *nheads, 1045 int datalen, 1046 unsigned *iovcount, 1047 struct vhost_log *log, 1048 unsigned *log_num, 1049 unsigned int quota) 1050 { 1051 struct vhost_virtqueue *vq = &nvq->vq; 1052 bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); 1053 unsigned int out, in; 1054 int seg = 0; 1055 int headcount = 0; 1056 unsigned d; 1057 int r, nlogs = 0; 1058 /* len is always initialized before use since we are always called with 1059 * datalen > 0. 1060 */ 1061 u32 len; 1062 1063 while (datalen > 0 && headcount < quota) { 1064 if (unlikely(seg >= UIO_MAXIOV)) { 1065 r = -ENOBUFS; 1066 goto err; 1067 } 1068 r = vhost_get_vq_desc(vq, vq->iov + seg, 1069 ARRAY_SIZE(vq->iov) - seg, &out, 1070 &in, log, log_num); 1071 if (unlikely(r < 0)) 1072 goto err; 1073 1074 d = r; 1075 if (d == vq->num) { 1076 r = 0; 1077 goto err; 1078 } 1079 if (unlikely(out || in <= 0)) { 1080 vq_err(vq, "unexpected descriptor format for RX: " 1081 "out %d, in %d\n", out, in); 1082 r = -EINVAL; 1083 goto err; 1084 } 1085 if (unlikely(log)) { 1086 nlogs += *log_num; 1087 log += *log_num; 1088 } 1089 len = iov_length(vq->iov + seg, in); 1090 if (!in_order) { 1091 heads[headcount].id = cpu_to_vhost32(vq, d); 1092 heads[headcount].len = cpu_to_vhost32(vq, len); 1093 } 1094 ++headcount; 1095 datalen -= len; 1096 seg += in; 1097 } 1098 1099 *iovcount = seg; 1100 if (unlikely(log)) 1101 *log_num = nlogs; 1102 1103 /* Detect overrun */ 1104 if (unlikely(datalen > 0)) { 1105 r = UIO_MAXIOV + 1; 1106 goto err; 1107 } 1108 1109 if (!in_order) 1110 heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen); 1111 else { 1112 heads[0].len = cpu_to_vhost32(vq, len + datalen); 1113 heads[0].id = cpu_to_vhost32(vq, d); 1114 nheads[0] = headcount; 1115 } 1116 1117 return headcount; 1118 err: 1119 vhost_discard_vq_desc(vq, headcount); 1120 return r; 1121 } 1122 1123 /* Expects to be always run from workqueue - which acts as 1124 * read-size critical section for our kind of RCU. */ 1125 static void handle_rx(struct vhost_net *net) 1126 { 1127 struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX]; 1128 struct vhost_virtqueue *vq = &nvq->vq; 1129 bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); 1130 unsigned int count = 0; 1131 unsigned in, log; 1132 struct vhost_log *vq_log; 1133 struct msghdr msg = { 1134 .msg_name = NULL, 1135 .msg_namelen = 0, 1136 .msg_control = NULL, /* FIXME: get and handle RX aux data. */ 1137 .msg_controllen = 0, 1138 .msg_flags = MSG_DONTWAIT, 1139 }; 1140 struct virtio_net_hdr hdr = { 1141 .flags = 0, 1142 .gso_type = VIRTIO_NET_HDR_GSO_NONE 1143 }; 1144 size_t total_len = 0; 1145 int err, mergeable; 1146 s16 headcount; 1147 size_t vhost_hlen, sock_hlen; 1148 size_t vhost_len, sock_len; 1149 bool busyloop_intr = false; 1150 bool set_num_buffers; 1151 struct socket *sock; 1152 struct iov_iter fixup; 1153 __virtio16 num_buffers; 1154 int recv_pkts = 0; 1155 1156 mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_RX); 1157 sock = vhost_vq_get_backend(vq); 1158 if (!sock) 1159 goto out; 1160 1161 if (!vq_meta_prefetch(vq)) 1162 goto out; 1163 1164 vhost_disable_notify(&net->dev, vq); 1165 vhost_net_disable_vq(net, vq); 1166 1167 vhost_hlen = nvq->vhost_hlen; 1168 sock_hlen = nvq->sock_hlen; 1169 1170 vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ? 1171 vq->log : NULL; 1172 mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF); 1173 set_num_buffers = mergeable || 1174 vhost_has_feature(vq, VIRTIO_F_VERSION_1); 1175 1176 do { 1177 sock_len = vhost_net_rx_peek_head_len(net, sock->sk, 1178 &busyloop_intr, count); 1179 if (!sock_len) 1180 break; 1181 sock_len += sock_hlen; 1182 vhost_len = sock_len + vhost_hlen; 1183 headcount = get_rx_bufs(nvq, vq->heads + count, 1184 vq->nheads + count, 1185 vhost_len, &in, vq_log, &log, 1186 likely(mergeable) ? UIO_MAXIOV : 1); 1187 /* On error, stop handling until the next kick. */ 1188 if (unlikely(headcount < 0)) 1189 goto out; 1190 /* OK, now we need to know about added descriptors. */ 1191 if (!headcount) { 1192 if (unlikely(busyloop_intr)) { 1193 vhost_poll_queue(&vq->poll); 1194 } else if (unlikely(vhost_enable_notify(&net->dev, vq))) { 1195 /* They have slipped one in as we were 1196 * doing that: check again. */ 1197 vhost_disable_notify(&net->dev, vq); 1198 continue; 1199 } 1200 /* Nothing new? Wait for eventfd to tell us 1201 * they refilled. */ 1202 goto out; 1203 } 1204 busyloop_intr = false; 1205 if (nvq->rx_ring) 1206 msg.msg_control = vhost_net_buf_consume(&nvq->rxq); 1207 /* On overrun, truncate and discard */ 1208 if (unlikely(headcount > UIO_MAXIOV)) { 1209 iov_iter_init(&msg.msg_iter, ITER_DEST, vq->iov, 1, 1); 1210 err = sock->ops->recvmsg(sock, &msg, 1211 1, MSG_DONTWAIT | MSG_TRUNC); 1212 pr_debug("Discarded rx packet: len %zd\n", sock_len); 1213 continue; 1214 } 1215 /* We don't need to be notified again. */ 1216 iov_iter_init(&msg.msg_iter, ITER_DEST, vq->iov, in, vhost_len); 1217 fixup = msg.msg_iter; 1218 if (unlikely((vhost_hlen))) { 1219 /* We will supply the header ourselves 1220 * TODO: support TSO. 1221 */ 1222 iov_iter_advance(&msg.msg_iter, vhost_hlen); 1223 } 1224 err = sock->ops->recvmsg(sock, &msg, 1225 sock_len, MSG_DONTWAIT | MSG_TRUNC); 1226 /* Userspace might have consumed the packet meanwhile: 1227 * it's not supposed to do this usually, but might be hard 1228 * to prevent. Discard data we got (if any) and keep going. */ 1229 if (unlikely(err != sock_len)) { 1230 pr_debug("Discarded rx packet: " 1231 " len %d, expected %zd\n", err, sock_len); 1232 vhost_discard_vq_desc(vq, headcount); 1233 continue; 1234 } 1235 /* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */ 1236 if (unlikely(vhost_hlen)) { 1237 if (copy_to_iter(&hdr, sizeof(hdr), 1238 &fixup) != sizeof(hdr)) { 1239 vq_err(vq, "Unable to write vnet_hdr " 1240 "at addr %p\n", vq->iov->iov_base); 1241 goto out; 1242 } 1243 } else { 1244 /* Header came from socket; we'll need to patch 1245 * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF 1246 */ 1247 iov_iter_advance(&fixup, sizeof(hdr)); 1248 } 1249 /* TODO: Should check and handle checksum. */ 1250 1251 num_buffers = cpu_to_vhost16(vq, headcount); 1252 if (likely(set_num_buffers) && 1253 copy_to_iter(&num_buffers, sizeof num_buffers, 1254 &fixup) != sizeof num_buffers) { 1255 vq_err(vq, "Failed num_buffers write"); 1256 vhost_discard_vq_desc(vq, headcount); 1257 goto out; 1258 } 1259 nvq->done_idx += headcount; 1260 count += in_order ? 1 : headcount; 1261 if (nvq->done_idx > VHOST_NET_BATCH) { 1262 vhost_net_signal_used(nvq, count); 1263 count = 0; 1264 } 1265 if (unlikely(vq_log)) 1266 vhost_log_write(vq, vq_log, log, vhost_len, 1267 vq->iov, in); 1268 total_len += vhost_len; 1269 } while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len))); 1270 1271 if (unlikely(busyloop_intr)) 1272 vhost_poll_queue(&vq->poll); 1273 else if (!sock_len) 1274 vhost_net_enable_vq(net, vq); 1275 out: 1276 vhost_net_signal_used(nvq, count); 1277 mutex_unlock(&vq->mutex); 1278 } 1279 1280 static void handle_tx_kick(struct vhost_work *work) 1281 { 1282 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, 1283 poll.work); 1284 struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev); 1285 1286 handle_tx(net); 1287 } 1288 1289 static void handle_rx_kick(struct vhost_work *work) 1290 { 1291 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, 1292 poll.work); 1293 struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev); 1294 1295 handle_rx(net); 1296 } 1297 1298 static void handle_tx_net(struct vhost_work *work) 1299 { 1300 struct vhost_net *net = container_of(work, struct vhost_net, 1301 poll[VHOST_NET_VQ_TX].work); 1302 handle_tx(net); 1303 } 1304 1305 static void handle_rx_net(struct vhost_work *work) 1306 { 1307 struct vhost_net *net = container_of(work, struct vhost_net, 1308 poll[VHOST_NET_VQ_RX].work); 1309 handle_rx(net); 1310 } 1311 1312 static int vhost_net_open(struct inode *inode, struct file *f) 1313 { 1314 struct vhost_net *n; 1315 struct vhost_dev *dev; 1316 struct vhost_virtqueue **vqs; 1317 void **queue; 1318 struct xdp_buff *xdp; 1319 int i; 1320 1321 n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL); 1322 if (!n) 1323 return -ENOMEM; 1324 vqs = kmalloc_array(VHOST_NET_VQ_MAX, sizeof(*vqs), GFP_KERNEL); 1325 if (!vqs) { 1326 kvfree(n); 1327 return -ENOMEM; 1328 } 1329 1330 queue = kmalloc_array(VHOST_NET_BATCH, sizeof(void *), 1331 GFP_KERNEL); 1332 if (!queue) { 1333 kfree(vqs); 1334 kvfree(n); 1335 return -ENOMEM; 1336 } 1337 n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue; 1338 1339 xdp = kmalloc_array(VHOST_NET_BATCH, sizeof(*xdp), GFP_KERNEL); 1340 if (!xdp) { 1341 kfree(vqs); 1342 kvfree(n); 1343 kfree(queue); 1344 return -ENOMEM; 1345 } 1346 n->vqs[VHOST_NET_VQ_TX].xdp = xdp; 1347 1348 dev = &n->dev; 1349 vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq; 1350 vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq; 1351 n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick; 1352 n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick; 1353 for (i = 0; i < VHOST_NET_VQ_MAX; i++) { 1354 n->vqs[i].ubufs = NULL; 1355 n->vqs[i].ubuf_info = NULL; 1356 n->vqs[i].upend_idx = 0; 1357 n->vqs[i].done_idx = 0; 1358 n->vqs[i].batched_xdp = 0; 1359 n->vqs[i].vhost_hlen = 0; 1360 n->vqs[i].sock_hlen = 0; 1361 n->vqs[i].rx_ring = NULL; 1362 vhost_net_buf_init(&n->vqs[i].rxq); 1363 } 1364 vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX, 1365 UIO_MAXIOV + VHOST_NET_BATCH, 1366 VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT, true, 1367 NULL); 1368 1369 vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev, 1370 vqs[VHOST_NET_VQ_TX]); 1371 vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev, 1372 vqs[VHOST_NET_VQ_RX]); 1373 1374 f->private_data = n; 1375 page_frag_cache_init(&n->pf_cache); 1376 1377 return 0; 1378 } 1379 1380 static struct socket *vhost_net_stop_vq(struct vhost_net *n, 1381 struct vhost_virtqueue *vq) 1382 { 1383 struct socket *sock; 1384 struct vhost_net_virtqueue *nvq = 1385 container_of(vq, struct vhost_net_virtqueue, vq); 1386 1387 mutex_lock(&vq->mutex); 1388 sock = vhost_vq_get_backend(vq); 1389 vhost_net_disable_vq(n, vq); 1390 vhost_vq_set_backend(vq, NULL); 1391 vhost_net_buf_unproduce(nvq); 1392 nvq->rx_ring = NULL; 1393 mutex_unlock(&vq->mutex); 1394 return sock; 1395 } 1396 1397 static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock, 1398 struct socket **rx_sock) 1399 { 1400 *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq); 1401 *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq); 1402 } 1403 1404 static void vhost_net_flush(struct vhost_net *n) 1405 { 1406 vhost_dev_flush(&n->dev); 1407 if (n->vqs[VHOST_NET_VQ_TX].ubufs) { 1408 mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); 1409 n->tx_flush = true; 1410 mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); 1411 /* Wait for all lower device DMAs done. */ 1412 vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs); 1413 mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); 1414 n->tx_flush = false; 1415 atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1); 1416 mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); 1417 } 1418 } 1419 1420 static int vhost_net_release(struct inode *inode, struct file *f) 1421 { 1422 struct vhost_net *n = f->private_data; 1423 struct socket *tx_sock; 1424 struct socket *rx_sock; 1425 1426 vhost_net_stop(n, &tx_sock, &rx_sock); 1427 vhost_net_flush(n); 1428 vhost_dev_stop(&n->dev); 1429 vhost_dev_cleanup(&n->dev); 1430 vhost_net_vq_reset(n); 1431 if (tx_sock) 1432 sockfd_put(tx_sock); 1433 if (rx_sock) 1434 sockfd_put(rx_sock); 1435 /* Make sure no callbacks are outstanding */ 1436 synchronize_rcu(); 1437 /* We do an extra flush before freeing memory, 1438 * since jobs can re-queue themselves. */ 1439 vhost_net_flush(n); 1440 kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue); 1441 kfree(n->vqs[VHOST_NET_VQ_TX].xdp); 1442 kfree(n->dev.vqs); 1443 page_frag_cache_drain(&n->pf_cache); 1444 kvfree(n); 1445 return 0; 1446 } 1447 1448 static struct socket *get_raw_socket(int fd) 1449 { 1450 int r; 1451 struct socket *sock = sockfd_lookup(fd, &r); 1452 1453 if (!sock) 1454 return ERR_PTR(-ENOTSOCK); 1455 1456 /* Parameter checking */ 1457 if (sock->sk->sk_type != SOCK_RAW) { 1458 r = -ESOCKTNOSUPPORT; 1459 goto err; 1460 } 1461 1462 if (sock->sk->sk_family != AF_PACKET) { 1463 r = -EPFNOSUPPORT; 1464 goto err; 1465 } 1466 return sock; 1467 err: 1468 sockfd_put(sock); 1469 return ERR_PTR(r); 1470 } 1471 1472 static struct ptr_ring *get_tap_ptr_ring(struct file *file) 1473 { 1474 struct ptr_ring *ring; 1475 ring = tun_get_tx_ring(file); 1476 if (!IS_ERR(ring)) 1477 goto out; 1478 ring = tap_get_ptr_ring(file); 1479 if (!IS_ERR(ring)) 1480 goto out; 1481 ring = NULL; 1482 out: 1483 return ring; 1484 } 1485 1486 static struct socket *get_tap_socket(int fd) 1487 { 1488 struct file *file = fget(fd); 1489 struct socket *sock; 1490 1491 if (!file) 1492 return ERR_PTR(-EBADF); 1493 sock = tun_get_socket(file); 1494 if (!IS_ERR(sock)) 1495 return sock; 1496 sock = tap_get_socket(file); 1497 if (IS_ERR(sock)) 1498 fput(file); 1499 return sock; 1500 } 1501 1502 static struct socket *get_socket(int fd) 1503 { 1504 struct socket *sock; 1505 1506 /* special case to disable backend */ 1507 if (fd == -1) 1508 return NULL; 1509 sock = get_raw_socket(fd); 1510 if (!IS_ERR(sock)) 1511 return sock; 1512 sock = get_tap_socket(fd); 1513 if (!IS_ERR(sock)) 1514 return sock; 1515 return ERR_PTR(-ENOTSOCK); 1516 } 1517 1518 static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) 1519 { 1520 struct socket *sock, *oldsock; 1521 struct vhost_virtqueue *vq; 1522 struct vhost_net_virtqueue *nvq; 1523 struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL; 1524 int r; 1525 1526 mutex_lock(&n->dev.mutex); 1527 r = vhost_dev_check_owner(&n->dev); 1528 if (r) 1529 goto err; 1530 1531 if (index >= VHOST_NET_VQ_MAX) { 1532 r = -ENOBUFS; 1533 goto err; 1534 } 1535 vq = &n->vqs[index].vq; 1536 nvq = &n->vqs[index]; 1537 mutex_lock(&vq->mutex); 1538 1539 if (fd == -1) 1540 vhost_clear_msg(&n->dev); 1541 1542 /* Verify that ring has been setup correctly. */ 1543 if (!vhost_vq_access_ok(vq)) { 1544 r = -EFAULT; 1545 goto err_vq; 1546 } 1547 sock = get_socket(fd); 1548 if (IS_ERR(sock)) { 1549 r = PTR_ERR(sock); 1550 goto err_vq; 1551 } 1552 1553 /* start polling new socket */ 1554 oldsock = vhost_vq_get_backend(vq); 1555 if (sock != oldsock) { 1556 ubufs = vhost_net_ubuf_alloc(vq, 1557 sock && vhost_sock_zcopy(sock)); 1558 if (IS_ERR(ubufs)) { 1559 r = PTR_ERR(ubufs); 1560 goto err_ubufs; 1561 } 1562 1563 vhost_net_disable_vq(n, vq); 1564 vhost_vq_set_backend(vq, sock); 1565 vhost_net_buf_unproduce(nvq); 1566 r = vhost_vq_init_access(vq); 1567 if (r) 1568 goto err_used; 1569 r = vhost_net_enable_vq(n, vq); 1570 if (r) 1571 goto err_used; 1572 if (index == VHOST_NET_VQ_RX) { 1573 if (sock) 1574 nvq->rx_ring = get_tap_ptr_ring(sock->file); 1575 else 1576 nvq->rx_ring = NULL; 1577 } 1578 1579 oldubufs = nvq->ubufs; 1580 nvq->ubufs = ubufs; 1581 1582 n->tx_packets = 0; 1583 n->tx_zcopy_err = 0; 1584 n->tx_flush = false; 1585 } 1586 1587 mutex_unlock(&vq->mutex); 1588 1589 if (oldubufs) { 1590 vhost_net_ubuf_put_wait_and_free(oldubufs); 1591 mutex_lock(&vq->mutex); 1592 vhost_zerocopy_signal_used(n, vq); 1593 mutex_unlock(&vq->mutex); 1594 } 1595 1596 if (oldsock) { 1597 vhost_dev_flush(&n->dev); 1598 sockfd_put(oldsock); 1599 } 1600 1601 mutex_unlock(&n->dev.mutex); 1602 return 0; 1603 1604 err_used: 1605 vhost_vq_set_backend(vq, oldsock); 1606 vhost_net_enable_vq(n, vq); 1607 if (ubufs) 1608 vhost_net_ubuf_put_wait_and_free(ubufs); 1609 err_ubufs: 1610 if (sock) 1611 sockfd_put(sock); 1612 err_vq: 1613 mutex_unlock(&vq->mutex); 1614 err: 1615 mutex_unlock(&n->dev.mutex); 1616 return r; 1617 } 1618 1619 static long vhost_net_reset_owner(struct vhost_net *n) 1620 { 1621 struct socket *tx_sock = NULL; 1622 struct socket *rx_sock = NULL; 1623 long err; 1624 struct vhost_iotlb *umem; 1625 1626 mutex_lock(&n->dev.mutex); 1627 err = vhost_dev_check_owner(&n->dev); 1628 if (err) 1629 goto done; 1630 umem = vhost_dev_reset_owner_prepare(); 1631 if (!umem) { 1632 err = -ENOMEM; 1633 goto done; 1634 } 1635 vhost_net_stop(n, &tx_sock, &rx_sock); 1636 vhost_net_flush(n); 1637 vhost_dev_stop(&n->dev); 1638 vhost_dev_reset_owner(&n->dev, umem); 1639 vhost_net_vq_reset(n); 1640 done: 1641 mutex_unlock(&n->dev.mutex); 1642 if (tx_sock) 1643 sockfd_put(tx_sock); 1644 if (rx_sock) 1645 sockfd_put(rx_sock); 1646 return err; 1647 } 1648 1649 static int vhost_net_set_features(struct vhost_net *n, const u64 *features) 1650 { 1651 size_t vhost_hlen, sock_hlen, hdr_len; 1652 int i; 1653 1654 hdr_len = virtio_features_test_bit(features, VIRTIO_NET_F_MRG_RXBUF) || 1655 virtio_features_test_bit(features, VIRTIO_F_VERSION_1) ? 1656 sizeof(struct virtio_net_hdr_mrg_rxbuf) : 1657 sizeof(struct virtio_net_hdr); 1658 1659 if (virtio_features_test_bit(features, 1660 VIRTIO_NET_F_HOST_UDP_TUNNEL_GSO) || 1661 virtio_features_test_bit(features, 1662 VIRTIO_NET_F_GUEST_UDP_TUNNEL_GSO)) 1663 hdr_len = sizeof(struct virtio_net_hdr_v1_hash_tunnel); 1664 1665 if (virtio_features_test_bit(features, VHOST_NET_F_VIRTIO_NET_HDR)) { 1666 /* vhost provides vnet_hdr */ 1667 vhost_hlen = hdr_len; 1668 sock_hlen = 0; 1669 } else { 1670 /* socket provides vnet_hdr */ 1671 vhost_hlen = 0; 1672 sock_hlen = hdr_len; 1673 } 1674 mutex_lock(&n->dev.mutex); 1675 if (virtio_features_test_bit(features, VHOST_F_LOG_ALL) && 1676 !vhost_log_access_ok(&n->dev)) 1677 goto out_unlock; 1678 1679 if (virtio_features_test_bit(features, VIRTIO_F_ACCESS_PLATFORM)) { 1680 if (vhost_init_device_iotlb(&n->dev)) 1681 goto out_unlock; 1682 } 1683 1684 for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { 1685 mutex_lock(&n->vqs[i].vq.mutex); 1686 virtio_features_copy(n->vqs[i].vq.acked_features_array, 1687 features); 1688 n->vqs[i].vhost_hlen = vhost_hlen; 1689 n->vqs[i].sock_hlen = sock_hlen; 1690 mutex_unlock(&n->vqs[i].vq.mutex); 1691 } 1692 mutex_unlock(&n->dev.mutex); 1693 return 0; 1694 1695 out_unlock: 1696 mutex_unlock(&n->dev.mutex); 1697 return -EFAULT; 1698 } 1699 1700 static long vhost_net_set_owner(struct vhost_net *n) 1701 { 1702 int r; 1703 1704 mutex_lock(&n->dev.mutex); 1705 if (vhost_dev_has_owner(&n->dev)) { 1706 r = -EBUSY; 1707 goto out; 1708 } 1709 r = vhost_net_set_ubuf_info(n); 1710 if (r) 1711 goto out; 1712 r = vhost_dev_set_owner(&n->dev); 1713 if (r) 1714 vhost_net_clear_ubuf_info(n); 1715 vhost_net_flush(n); 1716 out: 1717 mutex_unlock(&n->dev.mutex); 1718 return r; 1719 } 1720 1721 static long vhost_net_ioctl(struct file *f, unsigned int ioctl, 1722 unsigned long arg) 1723 { 1724 u64 all_features[VIRTIO_FEATURES_DWORDS]; 1725 struct vhost_net *n = f->private_data; 1726 void __user *argp = (void __user *)arg; 1727 u64 __user *featurep = argp; 1728 struct vhost_vring_file backend; 1729 u64 features, count, copied; 1730 int r, i; 1731 1732 switch (ioctl) { 1733 case VHOST_NET_SET_BACKEND: 1734 if (copy_from_user(&backend, argp, sizeof backend)) 1735 return -EFAULT; 1736 return vhost_net_set_backend(n, backend.index, backend.fd); 1737 case VHOST_GET_FEATURES: 1738 features = vhost_net_features[0]; 1739 if (copy_to_user(featurep, &features, sizeof features)) 1740 return -EFAULT; 1741 return 0; 1742 case VHOST_SET_FEATURES: 1743 if (copy_from_user(&features, featurep, sizeof features)) 1744 return -EFAULT; 1745 if (features & ~vhost_net_features[0]) 1746 return -EOPNOTSUPP; 1747 1748 virtio_features_from_u64(all_features, features); 1749 return vhost_net_set_features(n, all_features); 1750 case VHOST_GET_FEATURES_ARRAY: 1751 if (copy_from_user(&count, featurep, sizeof(count))) 1752 return -EFAULT; 1753 1754 /* Copy the net features, up to the user-provided buffer size */ 1755 argp += sizeof(u64); 1756 copied = min(count, VIRTIO_FEATURES_DWORDS); 1757 if (copy_to_user(argp, vhost_net_features, 1758 copied * sizeof(u64))) 1759 return -EFAULT; 1760 1761 /* Zero the trailing space provided by user-space, if any */ 1762 if (clear_user(argp, size_mul(count - copied, sizeof(u64)))) 1763 return -EFAULT; 1764 return 0; 1765 case VHOST_SET_FEATURES_ARRAY: 1766 if (copy_from_user(&count, featurep, sizeof(count))) 1767 return -EFAULT; 1768 1769 virtio_features_zero(all_features); 1770 argp += sizeof(u64); 1771 copied = min(count, VIRTIO_FEATURES_DWORDS); 1772 if (copy_from_user(all_features, argp, copied * sizeof(u64))) 1773 return -EFAULT; 1774 1775 /* 1776 * Any feature specified by user-space above 1777 * VIRTIO_FEATURES_MAX is not supported by definition. 1778 */ 1779 for (i = copied; i < count; ++i) { 1780 if (copy_from_user(&features, featurep + 1 + i, 1781 sizeof(features))) 1782 return -EFAULT; 1783 if (features) 1784 return -EOPNOTSUPP; 1785 } 1786 1787 for (i = 0; i < VIRTIO_FEATURES_DWORDS; i++) 1788 if (all_features[i] & ~vhost_net_features[i]) 1789 return -EOPNOTSUPP; 1790 1791 return vhost_net_set_features(n, all_features); 1792 case VHOST_GET_BACKEND_FEATURES: 1793 features = VHOST_NET_BACKEND_FEATURES; 1794 if (copy_to_user(featurep, &features, sizeof(features))) 1795 return -EFAULT; 1796 return 0; 1797 case VHOST_SET_BACKEND_FEATURES: 1798 if (copy_from_user(&features, featurep, sizeof(features))) 1799 return -EFAULT; 1800 if (features & ~VHOST_NET_BACKEND_FEATURES) 1801 return -EOPNOTSUPP; 1802 vhost_set_backend_features(&n->dev, features); 1803 return 0; 1804 case VHOST_RESET_OWNER: 1805 return vhost_net_reset_owner(n); 1806 case VHOST_SET_OWNER: 1807 return vhost_net_set_owner(n); 1808 default: 1809 mutex_lock(&n->dev.mutex); 1810 r = vhost_dev_ioctl(&n->dev, ioctl, argp); 1811 if (r == -ENOIOCTLCMD) 1812 r = vhost_vring_ioctl(&n->dev, ioctl, argp); 1813 else 1814 vhost_net_flush(n); 1815 mutex_unlock(&n->dev.mutex); 1816 return r; 1817 } 1818 } 1819 1820 static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to) 1821 { 1822 struct file *file = iocb->ki_filp; 1823 struct vhost_net *n = file->private_data; 1824 struct vhost_dev *dev = &n->dev; 1825 int noblock = file->f_flags & O_NONBLOCK; 1826 1827 return vhost_chr_read_iter(dev, to, noblock); 1828 } 1829 1830 static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb, 1831 struct iov_iter *from) 1832 { 1833 struct file *file = iocb->ki_filp; 1834 struct vhost_net *n = file->private_data; 1835 struct vhost_dev *dev = &n->dev; 1836 1837 return vhost_chr_write_iter(dev, from); 1838 } 1839 1840 static __poll_t vhost_net_chr_poll(struct file *file, poll_table *wait) 1841 { 1842 struct vhost_net *n = file->private_data; 1843 struct vhost_dev *dev = &n->dev; 1844 1845 return vhost_chr_poll(file, dev, wait); 1846 } 1847 1848 static const struct file_operations vhost_net_fops = { 1849 .owner = THIS_MODULE, 1850 .release = vhost_net_release, 1851 .read_iter = vhost_net_chr_read_iter, 1852 .write_iter = vhost_net_chr_write_iter, 1853 .poll = vhost_net_chr_poll, 1854 .unlocked_ioctl = vhost_net_ioctl, 1855 .compat_ioctl = compat_ptr_ioctl, 1856 .open = vhost_net_open, 1857 .llseek = noop_llseek, 1858 }; 1859 1860 static struct miscdevice vhost_net_misc = { 1861 .minor = VHOST_NET_MINOR, 1862 .name = "vhost-net", 1863 .fops = &vhost_net_fops, 1864 }; 1865 1866 static int __init vhost_net_init(void) 1867 { 1868 if (experimental_zcopytx) 1869 vhost_net_enable_zcopy(VHOST_NET_VQ_TX); 1870 return misc_register(&vhost_net_misc); 1871 } 1872 module_init(vhost_net_init); 1873 1874 static void __exit vhost_net_exit(void) 1875 { 1876 misc_deregister(&vhost_net_misc); 1877 } 1878 module_exit(vhost_net_exit); 1879 1880 MODULE_VERSION("0.0.1"); 1881 MODULE_LICENSE("GPL v2"); 1882 MODULE_AUTHOR("Michael S. Tsirkin"); 1883 MODULE_DESCRIPTION("Host kernel accelerator for virtio net"); 1884 MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR); 1885 MODULE_ALIAS("devname:vhost-net"); 1886