1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * common code for virtio vsock 4 * 5 * Copyright (C) 2013-2015 Red Hat, Inc. 6 * Author: Asias He <asias@redhat.com> 7 * Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 #include <linux/spinlock.h> 10 #include <linux/module.h> 11 #include <linux/sched/signal.h> 12 #include <linux/ctype.h> 13 #include <linux/list.h> 14 #include <linux/virtio_vsock.h> 15 #include <uapi/linux/vsockmon.h> 16 17 #include <net/sock.h> 18 #include <net/af_vsock.h> 19 20 #define CREATE_TRACE_POINTS 21 #include <trace/events/vsock_virtio_transport_common.h> 22 23 /* How long to wait for graceful shutdown of a connection */ 24 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 25 26 /* Threshold for detecting small packets to copy */ 27 #define GOOD_COPY_LEN 128 28 29 static const struct virtio_transport * 30 virtio_transport_get_ops(struct vsock_sock *vsk) 31 { 32 const struct vsock_transport *t = vsock_core_get_transport(vsk); 33 34 if (WARN_ON(!t)) 35 return NULL; 36 37 return container_of(t, struct virtio_transport, transport); 38 } 39 40 static bool virtio_transport_can_zcopy(const struct virtio_transport *t_ops, 41 struct virtio_vsock_pkt_info *info, 42 size_t pkt_len) 43 { 44 struct iov_iter *iov_iter; 45 46 if (!info->msg) 47 return false; 48 49 iov_iter = &info->msg->msg_iter; 50 51 if (iov_iter->iov_offset) 52 return false; 53 54 /* We can't send whole iov. */ 55 if (iov_iter->count > pkt_len) 56 return false; 57 58 /* Check that transport can send data in zerocopy mode. */ 59 t_ops = virtio_transport_get_ops(info->vsk); 60 61 if (t_ops->can_msgzerocopy) { 62 int pages_in_iov = iov_iter_npages(iov_iter, MAX_SKB_FRAGS); 63 int pages_to_send = min(pages_in_iov, MAX_SKB_FRAGS); 64 65 /* +1 is for packet header. */ 66 return t_ops->can_msgzerocopy(pages_to_send + 1); 67 } 68 69 return true; 70 } 71 72 static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk, 73 struct sk_buff *skb, 74 struct msghdr *msg, 75 bool zerocopy) 76 { 77 struct ubuf_info *uarg; 78 79 if (msg->msg_ubuf) { 80 uarg = msg->msg_ubuf; 81 net_zcopy_get(uarg); 82 } else { 83 struct iov_iter *iter = &msg->msg_iter; 84 struct ubuf_info_msgzc *uarg_zc; 85 86 uarg = msg_zerocopy_realloc(sk_vsock(vsk), 87 iter->count, 88 NULL); 89 if (!uarg) 90 return -1; 91 92 uarg_zc = uarg_to_msgzc(uarg); 93 uarg_zc->zerocopy = zerocopy ? 1 : 0; 94 } 95 96 skb_zcopy_init(skb, uarg); 97 98 return 0; 99 } 100 101 static int virtio_transport_fill_skb(struct sk_buff *skb, 102 struct virtio_vsock_pkt_info *info, 103 size_t len, 104 bool zcopy) 105 { 106 if (zcopy) 107 return __zerocopy_sg_from_iter(info->msg, NULL, skb, 108 &info->msg->msg_iter, 109 len); 110 111 return memcpy_from_msg(skb_put(skb, len), info->msg, len); 112 } 113 114 static void virtio_transport_init_hdr(struct sk_buff *skb, 115 struct virtio_vsock_pkt_info *info, 116 size_t payload_len, 117 u32 src_cid, 118 u32 src_port, 119 u32 dst_cid, 120 u32 dst_port) 121 { 122 struct virtio_vsock_hdr *hdr; 123 124 hdr = virtio_vsock_hdr(skb); 125 hdr->type = cpu_to_le16(info->type); 126 hdr->op = cpu_to_le16(info->op); 127 hdr->src_cid = cpu_to_le64(src_cid); 128 hdr->dst_cid = cpu_to_le64(dst_cid); 129 hdr->src_port = cpu_to_le32(src_port); 130 hdr->dst_port = cpu_to_le32(dst_port); 131 hdr->flags = cpu_to_le32(info->flags); 132 hdr->len = cpu_to_le32(payload_len); 133 } 134 135 static void virtio_transport_copy_nonlinear_skb(const struct sk_buff *skb, 136 void *dst, 137 size_t len) 138 { 139 struct iov_iter iov_iter = { 0 }; 140 struct kvec kvec; 141 size_t to_copy; 142 143 kvec.iov_base = dst; 144 kvec.iov_len = len; 145 146 iov_iter.iter_type = ITER_KVEC; 147 iov_iter.kvec = &kvec; 148 iov_iter.nr_segs = 1; 149 150 to_copy = min_t(size_t, len, skb->len); 151 152 skb_copy_datagram_iter(skb, VIRTIO_VSOCK_SKB_CB(skb)->offset, 153 &iov_iter, to_copy); 154 } 155 156 /* Packet capture */ 157 static struct sk_buff *virtio_transport_build_skb(void *opaque) 158 { 159 struct virtio_vsock_hdr *pkt_hdr; 160 struct sk_buff *pkt = opaque; 161 struct af_vsockmon_hdr *hdr; 162 struct sk_buff *skb; 163 size_t payload_len; 164 165 /* A packet could be split to fit the RX buffer, so we can retrieve 166 * the payload length from the header and the buffer pointer taking 167 * care of the offset in the original packet. 168 */ 169 pkt_hdr = virtio_vsock_hdr(pkt); 170 payload_len = pkt->len; 171 172 skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len, 173 GFP_ATOMIC); 174 if (!skb) 175 return NULL; 176 177 hdr = skb_put(skb, sizeof(*hdr)); 178 179 /* pkt->hdr is little-endian so no need to byteswap here */ 180 hdr->src_cid = pkt_hdr->src_cid; 181 hdr->src_port = pkt_hdr->src_port; 182 hdr->dst_cid = pkt_hdr->dst_cid; 183 hdr->dst_port = pkt_hdr->dst_port; 184 185 hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO); 186 hdr->len = cpu_to_le16(sizeof(*pkt_hdr)); 187 memset(hdr->reserved, 0, sizeof(hdr->reserved)); 188 189 switch (le16_to_cpu(pkt_hdr->op)) { 190 case VIRTIO_VSOCK_OP_REQUEST: 191 case VIRTIO_VSOCK_OP_RESPONSE: 192 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); 193 break; 194 case VIRTIO_VSOCK_OP_RST: 195 case VIRTIO_VSOCK_OP_SHUTDOWN: 196 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT); 197 break; 198 case VIRTIO_VSOCK_OP_RW: 199 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD); 200 break; 201 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 202 case VIRTIO_VSOCK_OP_CREDIT_REQUEST: 203 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); 204 break; 205 default: 206 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN); 207 break; 208 } 209 210 skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr)); 211 212 if (payload_len) { 213 if (skb_is_nonlinear(pkt)) { 214 void *data = skb_put(skb, payload_len); 215 216 virtio_transport_copy_nonlinear_skb(pkt, data, payload_len); 217 } else { 218 skb_put_data(skb, pkt->data, payload_len); 219 } 220 } 221 222 return skb; 223 } 224 225 void virtio_transport_deliver_tap_pkt(struct sk_buff *skb) 226 { 227 if (virtio_vsock_skb_tap_delivered(skb)) 228 return; 229 230 vsock_deliver_tap(virtio_transport_build_skb, skb); 231 virtio_vsock_skb_set_tap_delivered(skb); 232 } 233 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); 234 235 static u16 virtio_transport_get_type(struct sock *sk) 236 { 237 if (sk->sk_type == SOCK_STREAM) 238 return VIRTIO_VSOCK_TYPE_STREAM; 239 else 240 return VIRTIO_VSOCK_TYPE_SEQPACKET; 241 } 242 243 /* Returns new sk_buff on success, otherwise returns NULL. */ 244 static struct sk_buff *virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info, 245 size_t payload_len, 246 bool zcopy, 247 u32 src_cid, 248 u32 src_port, 249 u32 dst_cid, 250 u32 dst_port) 251 { 252 struct vsock_sock *vsk; 253 struct sk_buff *skb; 254 size_t skb_len; 255 256 skb_len = VIRTIO_VSOCK_SKB_HEADROOM; 257 258 if (!zcopy) 259 skb_len += payload_len; 260 261 skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL); 262 if (!skb) 263 return NULL; 264 265 virtio_transport_init_hdr(skb, info, payload_len, src_cid, src_port, 266 dst_cid, dst_port); 267 268 vsk = info->vsk; 269 270 /* If 'vsk' != NULL then payload is always present, so we 271 * will never call '__zerocopy_sg_from_iter()' below without 272 * setting skb owner in 'skb_set_owner_w()'. The only case 273 * when 'vsk' == NULL is VIRTIO_VSOCK_OP_RST control message 274 * without payload. 275 */ 276 WARN_ON_ONCE(!(vsk && (info->msg && payload_len)) && zcopy); 277 278 /* Set owner here, because '__zerocopy_sg_from_iter()' uses 279 * owner of skb without check to update 'sk_wmem_alloc'. 280 */ 281 if (vsk) 282 skb_set_owner_w(skb, sk_vsock(vsk)); 283 284 if (info->msg && payload_len > 0) { 285 int err; 286 287 err = virtio_transport_fill_skb(skb, info, payload_len, zcopy); 288 if (err) 289 goto out; 290 291 if (msg_data_left(info->msg) == 0 && 292 info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) { 293 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 294 295 hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM); 296 297 if (info->msg->msg_flags & MSG_EOR) 298 hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR); 299 } 300 } 301 302 if (info->reply) 303 virtio_vsock_skb_set_reply(skb); 304 305 trace_virtio_transport_alloc_pkt(src_cid, src_port, 306 dst_cid, dst_port, 307 payload_len, 308 info->type, 309 info->op, 310 info->flags, 311 zcopy); 312 313 return skb; 314 out: 315 kfree_skb(skb); 316 return NULL; 317 } 318 319 /* This function can only be used on connecting/connected sockets, 320 * since a socket assigned to a transport is required. 321 * 322 * Do not use on listener sockets! 323 */ 324 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 325 struct virtio_vsock_pkt_info *info) 326 { 327 u32 max_skb_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE; 328 u32 src_cid, src_port, dst_cid, dst_port; 329 const struct virtio_transport *t_ops; 330 struct virtio_vsock_sock *vvs; 331 u32 pkt_len = info->pkt_len; 332 bool can_zcopy = false; 333 u32 rest_len; 334 int ret; 335 336 info->type = virtio_transport_get_type(sk_vsock(vsk)); 337 338 t_ops = virtio_transport_get_ops(vsk); 339 if (unlikely(!t_ops)) 340 return -EFAULT; 341 342 src_cid = t_ops->transport.get_local_cid(); 343 src_port = vsk->local_addr.svm_port; 344 if (!info->remote_cid) { 345 dst_cid = vsk->remote_addr.svm_cid; 346 dst_port = vsk->remote_addr.svm_port; 347 } else { 348 dst_cid = info->remote_cid; 349 dst_port = info->remote_port; 350 } 351 352 vvs = vsk->trans; 353 354 /* virtio_transport_get_credit might return less than pkt_len credit */ 355 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 356 357 /* Do not send zero length OP_RW pkt */ 358 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 359 return pkt_len; 360 361 if (info->msg) { 362 /* If zerocopy is not enabled by 'setsockopt()', we behave as 363 * there is no MSG_ZEROCOPY flag set. 364 */ 365 if (!sock_flag(sk_vsock(vsk), SOCK_ZEROCOPY)) 366 info->msg->msg_flags &= ~MSG_ZEROCOPY; 367 368 if (info->msg->msg_flags & MSG_ZEROCOPY) 369 can_zcopy = virtio_transport_can_zcopy(t_ops, info, pkt_len); 370 371 if (can_zcopy) 372 max_skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, 373 (MAX_SKB_FRAGS * PAGE_SIZE)); 374 } 375 376 rest_len = pkt_len; 377 378 do { 379 struct sk_buff *skb; 380 size_t skb_len; 381 382 skb_len = min(max_skb_len, rest_len); 383 384 skb = virtio_transport_alloc_skb(info, skb_len, can_zcopy, 385 src_cid, src_port, 386 dst_cid, dst_port); 387 if (!skb) { 388 ret = -ENOMEM; 389 break; 390 } 391 392 /* We process buffer part by part, allocating skb on 393 * each iteration. If this is last skb for this buffer 394 * and MSG_ZEROCOPY mode is in use - we must allocate 395 * completion for the current syscall. 396 */ 397 if (info->msg && info->msg->msg_flags & MSG_ZEROCOPY && 398 skb_len == rest_len && info->op == VIRTIO_VSOCK_OP_RW) { 399 if (virtio_transport_init_zcopy_skb(vsk, skb, 400 info->msg, 401 can_zcopy)) { 402 ret = -ENOMEM; 403 break; 404 } 405 } 406 407 virtio_transport_inc_tx_pkt(vvs, skb); 408 409 ret = t_ops->send_pkt(skb); 410 if (ret < 0) 411 break; 412 413 /* Both virtio and vhost 'send_pkt()' returns 'skb_len', 414 * but for reliability use 'ret' instead of 'skb_len'. 415 * Also if partial send happens (e.g. 'ret' != 'skb_len') 416 * somehow, we break this loop, but account such returned 417 * value in 'virtio_transport_put_credit()'. 418 */ 419 rest_len -= ret; 420 421 if (WARN_ONCE(ret != skb_len, 422 "'send_pkt()' returns %i, but %zu expected\n", 423 ret, skb_len)) 424 break; 425 } while (rest_len); 426 427 virtio_transport_put_credit(vvs, rest_len); 428 429 /* Return number of bytes, if any data has been sent. */ 430 if (rest_len != pkt_len) 431 ret = pkt_len - rest_len; 432 433 return ret; 434 } 435 436 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 437 u32 len) 438 { 439 if (vvs->rx_bytes + len > vvs->buf_alloc) 440 return false; 441 442 vvs->rx_bytes += len; 443 return true; 444 } 445 446 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 447 u32 len) 448 { 449 vvs->rx_bytes -= len; 450 vvs->fwd_cnt += len; 451 } 452 453 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb) 454 { 455 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 456 457 spin_lock_bh(&vvs->rx_lock); 458 vvs->last_fwd_cnt = vvs->fwd_cnt; 459 hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 460 hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc); 461 spin_unlock_bh(&vvs->rx_lock); 462 } 463 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 464 465 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 466 { 467 u32 ret; 468 469 if (!credit) 470 return 0; 471 472 spin_lock_bh(&vvs->tx_lock); 473 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 474 if (ret > credit) 475 ret = credit; 476 vvs->tx_cnt += ret; 477 spin_unlock_bh(&vvs->tx_lock); 478 479 return ret; 480 } 481 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 482 483 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 484 { 485 if (!credit) 486 return; 487 488 spin_lock_bh(&vvs->tx_lock); 489 vvs->tx_cnt -= credit; 490 spin_unlock_bh(&vvs->tx_lock); 491 } 492 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 493 494 static int virtio_transport_send_credit_update(struct vsock_sock *vsk) 495 { 496 struct virtio_vsock_pkt_info info = { 497 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 498 .vsk = vsk, 499 }; 500 501 return virtio_transport_send_pkt_info(vsk, &info); 502 } 503 504 static ssize_t 505 virtio_transport_stream_do_peek(struct vsock_sock *vsk, 506 struct msghdr *msg, 507 size_t len) 508 { 509 struct virtio_vsock_sock *vvs = vsk->trans; 510 struct sk_buff *skb; 511 size_t total = 0; 512 int err; 513 514 spin_lock_bh(&vvs->rx_lock); 515 516 skb_queue_walk(&vvs->rx_queue, skb) { 517 size_t bytes; 518 519 bytes = len - total; 520 if (bytes > skb->len) 521 bytes = skb->len; 522 523 spin_unlock_bh(&vvs->rx_lock); 524 525 /* sk_lock is held by caller so no one else can dequeue. 526 * Unlock rx_lock since skb_copy_datagram_iter() may sleep. 527 */ 528 err = skb_copy_datagram_iter(skb, VIRTIO_VSOCK_SKB_CB(skb)->offset, 529 &msg->msg_iter, bytes); 530 if (err) 531 goto out; 532 533 total += bytes; 534 535 spin_lock_bh(&vvs->rx_lock); 536 537 if (total == len) 538 break; 539 } 540 541 spin_unlock_bh(&vvs->rx_lock); 542 543 return total; 544 545 out: 546 if (total) 547 err = total; 548 return err; 549 } 550 551 static ssize_t 552 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 553 struct msghdr *msg, 554 size_t len) 555 { 556 struct virtio_vsock_sock *vvs = vsk->trans; 557 size_t bytes, total = 0; 558 struct sk_buff *skb; 559 int err = -EFAULT; 560 u32 free_space; 561 562 spin_lock_bh(&vvs->rx_lock); 563 564 if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes, 565 "rx_queue is empty, but rx_bytes is non-zero\n")) { 566 spin_unlock_bh(&vvs->rx_lock); 567 return err; 568 } 569 570 while (total < len && !skb_queue_empty(&vvs->rx_queue)) { 571 skb = skb_peek(&vvs->rx_queue); 572 573 bytes = min_t(size_t, len - total, 574 skb->len - VIRTIO_VSOCK_SKB_CB(skb)->offset); 575 576 /* sk_lock is held by caller so no one else can dequeue. 577 * Unlock rx_lock since skb_copy_datagram_iter() may sleep. 578 */ 579 spin_unlock_bh(&vvs->rx_lock); 580 581 err = skb_copy_datagram_iter(skb, 582 VIRTIO_VSOCK_SKB_CB(skb)->offset, 583 &msg->msg_iter, bytes); 584 if (err) 585 goto out; 586 587 spin_lock_bh(&vvs->rx_lock); 588 589 total += bytes; 590 591 VIRTIO_VSOCK_SKB_CB(skb)->offset += bytes; 592 593 if (skb->len == VIRTIO_VSOCK_SKB_CB(skb)->offset) { 594 u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len); 595 596 virtio_transport_dec_rx_pkt(vvs, pkt_len); 597 __skb_unlink(skb, &vvs->rx_queue); 598 consume_skb(skb); 599 } 600 } 601 602 free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt); 603 604 spin_unlock_bh(&vvs->rx_lock); 605 606 /* To reduce the number of credit update messages, 607 * don't update credits as long as lots of space is available. 608 * Note: the limit chosen here is arbitrary. Setting the limit 609 * too high causes extra messages. Too low causes transmitter 610 * stalls. As stalls are in theory more expensive than extra 611 * messages, we set the limit to a high value. TODO: experiment 612 * with different values. 613 */ 614 if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) 615 virtio_transport_send_credit_update(vsk); 616 617 return total; 618 619 out: 620 if (total) 621 err = total; 622 return err; 623 } 624 625 static ssize_t 626 virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk, 627 struct msghdr *msg) 628 { 629 struct virtio_vsock_sock *vvs = vsk->trans; 630 struct sk_buff *skb; 631 size_t total, len; 632 633 spin_lock_bh(&vvs->rx_lock); 634 635 if (!vvs->msg_count) { 636 spin_unlock_bh(&vvs->rx_lock); 637 return 0; 638 } 639 640 total = 0; 641 len = msg_data_left(msg); 642 643 skb_queue_walk(&vvs->rx_queue, skb) { 644 struct virtio_vsock_hdr *hdr; 645 646 if (total < len) { 647 size_t bytes; 648 int err; 649 650 bytes = len - total; 651 if (bytes > skb->len) 652 bytes = skb->len; 653 654 spin_unlock_bh(&vvs->rx_lock); 655 656 /* sk_lock is held by caller so no one else can dequeue. 657 * Unlock rx_lock since skb_copy_datagram_iter() may sleep. 658 */ 659 err = skb_copy_datagram_iter(skb, VIRTIO_VSOCK_SKB_CB(skb)->offset, 660 &msg->msg_iter, bytes); 661 if (err) 662 return err; 663 664 spin_lock_bh(&vvs->rx_lock); 665 } 666 667 total += skb->len; 668 hdr = virtio_vsock_hdr(skb); 669 670 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) { 671 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) 672 msg->msg_flags |= MSG_EOR; 673 674 break; 675 } 676 } 677 678 spin_unlock_bh(&vvs->rx_lock); 679 680 return total; 681 } 682 683 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk, 684 struct msghdr *msg, 685 int flags) 686 { 687 struct virtio_vsock_sock *vvs = vsk->trans; 688 int dequeued_len = 0; 689 size_t user_buf_len = msg_data_left(msg); 690 bool msg_ready = false; 691 struct sk_buff *skb; 692 693 spin_lock_bh(&vvs->rx_lock); 694 695 if (vvs->msg_count == 0) { 696 spin_unlock_bh(&vvs->rx_lock); 697 return 0; 698 } 699 700 while (!msg_ready) { 701 struct virtio_vsock_hdr *hdr; 702 size_t pkt_len; 703 704 skb = __skb_dequeue(&vvs->rx_queue); 705 if (!skb) 706 break; 707 hdr = virtio_vsock_hdr(skb); 708 pkt_len = (size_t)le32_to_cpu(hdr->len); 709 710 if (dequeued_len >= 0) { 711 size_t bytes_to_copy; 712 713 bytes_to_copy = min(user_buf_len, pkt_len); 714 715 if (bytes_to_copy) { 716 int err; 717 718 /* sk_lock is held by caller so no one else can dequeue. 719 * Unlock rx_lock since skb_copy_datagram_iter() may sleep. 720 */ 721 spin_unlock_bh(&vvs->rx_lock); 722 723 err = skb_copy_datagram_iter(skb, 0, 724 &msg->msg_iter, 725 bytes_to_copy); 726 if (err) { 727 /* Copy of message failed. Rest of 728 * fragments will be freed without copy. 729 */ 730 dequeued_len = err; 731 } else { 732 user_buf_len -= bytes_to_copy; 733 } 734 735 spin_lock_bh(&vvs->rx_lock); 736 } 737 738 if (dequeued_len >= 0) 739 dequeued_len += pkt_len; 740 } 741 742 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) { 743 msg_ready = true; 744 vvs->msg_count--; 745 746 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) 747 msg->msg_flags |= MSG_EOR; 748 } 749 750 virtio_transport_dec_rx_pkt(vvs, pkt_len); 751 kfree_skb(skb); 752 } 753 754 spin_unlock_bh(&vvs->rx_lock); 755 756 virtio_transport_send_credit_update(vsk); 757 758 return dequeued_len; 759 } 760 761 ssize_t 762 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 763 struct msghdr *msg, 764 size_t len, int flags) 765 { 766 if (flags & MSG_PEEK) 767 return virtio_transport_stream_do_peek(vsk, msg, len); 768 else 769 return virtio_transport_stream_do_dequeue(vsk, msg, len); 770 } 771 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 772 773 ssize_t 774 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, 775 struct msghdr *msg, 776 int flags) 777 { 778 if (flags & MSG_PEEK) 779 return virtio_transport_seqpacket_do_peek(vsk, msg); 780 else 781 return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags); 782 } 783 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue); 784 785 int 786 virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, 787 struct msghdr *msg, 788 size_t len) 789 { 790 struct virtio_vsock_sock *vvs = vsk->trans; 791 792 spin_lock_bh(&vvs->tx_lock); 793 794 if (len > vvs->peer_buf_alloc) { 795 spin_unlock_bh(&vvs->tx_lock); 796 return -EMSGSIZE; 797 } 798 799 spin_unlock_bh(&vvs->tx_lock); 800 801 return virtio_transport_stream_enqueue(vsk, msg, len); 802 } 803 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue); 804 805 int 806 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 807 struct msghdr *msg, 808 size_t len, int flags) 809 { 810 return -EOPNOTSUPP; 811 } 812 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 813 814 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 815 { 816 struct virtio_vsock_sock *vvs = vsk->trans; 817 s64 bytes; 818 819 spin_lock_bh(&vvs->rx_lock); 820 bytes = vvs->rx_bytes; 821 spin_unlock_bh(&vvs->rx_lock); 822 823 return bytes; 824 } 825 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 826 827 u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk) 828 { 829 struct virtio_vsock_sock *vvs = vsk->trans; 830 u32 msg_count; 831 832 spin_lock_bh(&vvs->rx_lock); 833 msg_count = vvs->msg_count; 834 spin_unlock_bh(&vvs->rx_lock); 835 836 return msg_count; 837 } 838 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data); 839 840 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 841 { 842 struct virtio_vsock_sock *vvs = vsk->trans; 843 s64 bytes; 844 845 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 846 if (bytes < 0) 847 bytes = 0; 848 849 return bytes; 850 } 851 852 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 853 { 854 struct virtio_vsock_sock *vvs = vsk->trans; 855 s64 bytes; 856 857 spin_lock_bh(&vvs->tx_lock); 858 bytes = virtio_transport_has_space(vsk); 859 spin_unlock_bh(&vvs->tx_lock); 860 861 return bytes; 862 } 863 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 864 865 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 866 struct vsock_sock *psk) 867 { 868 struct virtio_vsock_sock *vvs; 869 870 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 871 if (!vvs) 872 return -ENOMEM; 873 874 vsk->trans = vvs; 875 vvs->vsk = vsk; 876 if (psk && psk->trans) { 877 struct virtio_vsock_sock *ptrans = psk->trans; 878 879 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 880 } 881 882 if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE) 883 vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE; 884 885 vvs->buf_alloc = vsk->buffer_size; 886 887 spin_lock_init(&vvs->rx_lock); 888 spin_lock_init(&vvs->tx_lock); 889 skb_queue_head_init(&vvs->rx_queue); 890 891 return 0; 892 } 893 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 894 895 /* sk_lock held by the caller */ 896 void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) 897 { 898 struct virtio_vsock_sock *vvs = vsk->trans; 899 900 if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE) 901 *val = VIRTIO_VSOCK_MAX_BUF_SIZE; 902 903 vvs->buf_alloc = *val; 904 905 virtio_transport_send_credit_update(vsk); 906 } 907 EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); 908 909 int 910 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 911 size_t target, 912 bool *data_ready_now) 913 { 914 *data_ready_now = vsock_stream_has_data(vsk) >= target; 915 916 return 0; 917 } 918 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 919 920 int 921 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 922 size_t target, 923 bool *space_avail_now) 924 { 925 s64 free_space; 926 927 free_space = vsock_stream_has_space(vsk); 928 if (free_space > 0) 929 *space_avail_now = true; 930 else if (free_space == 0) 931 *space_avail_now = false; 932 933 return 0; 934 } 935 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 936 937 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 938 size_t target, struct vsock_transport_recv_notify_data *data) 939 { 940 return 0; 941 } 942 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 943 944 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 945 size_t target, struct vsock_transport_recv_notify_data *data) 946 { 947 return 0; 948 } 949 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 950 951 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 952 size_t target, struct vsock_transport_recv_notify_data *data) 953 { 954 return 0; 955 } 956 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 957 958 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 959 size_t target, ssize_t copied, bool data_read, 960 struct vsock_transport_recv_notify_data *data) 961 { 962 return 0; 963 } 964 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 965 966 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 967 struct vsock_transport_send_notify_data *data) 968 { 969 return 0; 970 } 971 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 972 973 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 974 struct vsock_transport_send_notify_data *data) 975 { 976 return 0; 977 } 978 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 979 980 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 981 struct vsock_transport_send_notify_data *data) 982 { 983 return 0; 984 } 985 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 986 987 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 988 ssize_t written, struct vsock_transport_send_notify_data *data) 989 { 990 return 0; 991 } 992 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 993 994 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 995 { 996 return vsk->buffer_size; 997 } 998 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 999 1000 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 1001 { 1002 return true; 1003 } 1004 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 1005 1006 bool virtio_transport_stream_allow(u32 cid, u32 port) 1007 { 1008 return true; 1009 } 1010 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 1011 1012 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 1013 struct sockaddr_vm *addr) 1014 { 1015 return -EOPNOTSUPP; 1016 } 1017 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 1018 1019 bool virtio_transport_dgram_allow(u32 cid, u32 port) 1020 { 1021 return false; 1022 } 1023 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 1024 1025 int virtio_transport_connect(struct vsock_sock *vsk) 1026 { 1027 struct virtio_vsock_pkt_info info = { 1028 .op = VIRTIO_VSOCK_OP_REQUEST, 1029 .vsk = vsk, 1030 }; 1031 1032 return virtio_transport_send_pkt_info(vsk, &info); 1033 } 1034 EXPORT_SYMBOL_GPL(virtio_transport_connect); 1035 1036 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 1037 { 1038 struct virtio_vsock_pkt_info info = { 1039 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 1040 .flags = (mode & RCV_SHUTDOWN ? 1041 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 1042 (mode & SEND_SHUTDOWN ? 1043 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 1044 .vsk = vsk, 1045 }; 1046 1047 return virtio_transport_send_pkt_info(vsk, &info); 1048 } 1049 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 1050 1051 int 1052 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 1053 struct sockaddr_vm *remote_addr, 1054 struct msghdr *msg, 1055 size_t dgram_len) 1056 { 1057 return -EOPNOTSUPP; 1058 } 1059 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 1060 1061 ssize_t 1062 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 1063 struct msghdr *msg, 1064 size_t len) 1065 { 1066 struct virtio_vsock_pkt_info info = { 1067 .op = VIRTIO_VSOCK_OP_RW, 1068 .msg = msg, 1069 .pkt_len = len, 1070 .vsk = vsk, 1071 }; 1072 1073 return virtio_transport_send_pkt_info(vsk, &info); 1074 } 1075 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 1076 1077 void virtio_transport_destruct(struct vsock_sock *vsk) 1078 { 1079 struct virtio_vsock_sock *vvs = vsk->trans; 1080 1081 kfree(vvs); 1082 } 1083 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 1084 1085 static int virtio_transport_reset(struct vsock_sock *vsk, 1086 struct sk_buff *skb) 1087 { 1088 struct virtio_vsock_pkt_info info = { 1089 .op = VIRTIO_VSOCK_OP_RST, 1090 .reply = !!skb, 1091 .vsk = vsk, 1092 }; 1093 1094 /* Send RST only if the original pkt is not a RST pkt */ 1095 if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST) 1096 return 0; 1097 1098 return virtio_transport_send_pkt_info(vsk, &info); 1099 } 1100 1101 /* Normally packets are associated with a socket. There may be no socket if an 1102 * attempt was made to connect to a socket that does not exist. 1103 */ 1104 static int virtio_transport_reset_no_sock(const struct virtio_transport *t, 1105 struct sk_buff *skb) 1106 { 1107 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1108 struct virtio_vsock_pkt_info info = { 1109 .op = VIRTIO_VSOCK_OP_RST, 1110 .type = le16_to_cpu(hdr->type), 1111 .reply = true, 1112 }; 1113 struct sk_buff *reply; 1114 1115 /* Send RST only if the original pkt is not a RST pkt */ 1116 if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST) 1117 return 0; 1118 1119 if (!t) 1120 return -ENOTCONN; 1121 1122 reply = virtio_transport_alloc_skb(&info, 0, false, 1123 le64_to_cpu(hdr->dst_cid), 1124 le32_to_cpu(hdr->dst_port), 1125 le64_to_cpu(hdr->src_cid), 1126 le32_to_cpu(hdr->src_port)); 1127 if (!reply) 1128 return -ENOMEM; 1129 1130 return t->send_pkt(reply); 1131 } 1132 1133 /* This function should be called with sk_lock held and SOCK_DONE set */ 1134 static void virtio_transport_remove_sock(struct vsock_sock *vsk) 1135 { 1136 struct virtio_vsock_sock *vvs = vsk->trans; 1137 1138 /* We don't need to take rx_lock, as the socket is closing and we are 1139 * removing it. 1140 */ 1141 __skb_queue_purge(&vvs->rx_queue); 1142 vsock_remove_sock(vsk); 1143 } 1144 1145 static void virtio_transport_wait_close(struct sock *sk, long timeout) 1146 { 1147 if (timeout) { 1148 DEFINE_WAIT_FUNC(wait, woken_wake_function); 1149 1150 add_wait_queue(sk_sleep(sk), &wait); 1151 1152 do { 1153 if (sk_wait_event(sk, &timeout, 1154 sock_flag(sk, SOCK_DONE), &wait)) 1155 break; 1156 } while (!signal_pending(current) && timeout); 1157 1158 remove_wait_queue(sk_sleep(sk), &wait); 1159 } 1160 } 1161 1162 static void virtio_transport_do_close(struct vsock_sock *vsk, 1163 bool cancel_timeout) 1164 { 1165 struct sock *sk = sk_vsock(vsk); 1166 1167 sock_set_flag(sk, SOCK_DONE); 1168 vsk->peer_shutdown = SHUTDOWN_MASK; 1169 if (vsock_stream_has_data(vsk) <= 0) 1170 sk->sk_state = TCP_CLOSING; 1171 sk->sk_state_change(sk); 1172 1173 if (vsk->close_work_scheduled && 1174 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 1175 vsk->close_work_scheduled = false; 1176 1177 virtio_transport_remove_sock(vsk); 1178 1179 /* Release refcnt obtained when we scheduled the timeout */ 1180 sock_put(sk); 1181 } 1182 } 1183 1184 static void virtio_transport_close_timeout(struct work_struct *work) 1185 { 1186 struct vsock_sock *vsk = 1187 container_of(work, struct vsock_sock, close_work.work); 1188 struct sock *sk = sk_vsock(vsk); 1189 1190 sock_hold(sk); 1191 lock_sock(sk); 1192 1193 if (!sock_flag(sk, SOCK_DONE)) { 1194 (void)virtio_transport_reset(vsk, NULL); 1195 1196 virtio_transport_do_close(vsk, false); 1197 } 1198 1199 vsk->close_work_scheduled = false; 1200 1201 release_sock(sk); 1202 sock_put(sk); 1203 } 1204 1205 /* User context, vsk->sk is locked */ 1206 static bool virtio_transport_close(struct vsock_sock *vsk) 1207 { 1208 struct sock *sk = &vsk->sk; 1209 1210 if (!(sk->sk_state == TCP_ESTABLISHED || 1211 sk->sk_state == TCP_CLOSING)) 1212 return true; 1213 1214 /* Already received SHUTDOWN from peer, reply with RST */ 1215 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 1216 (void)virtio_transport_reset(vsk, NULL); 1217 return true; 1218 } 1219 1220 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 1221 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 1222 1223 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 1224 virtio_transport_wait_close(sk, sk->sk_lingertime); 1225 1226 if (sock_flag(sk, SOCK_DONE)) { 1227 return true; 1228 } 1229 1230 sock_hold(sk); 1231 INIT_DELAYED_WORK(&vsk->close_work, 1232 virtio_transport_close_timeout); 1233 vsk->close_work_scheduled = true; 1234 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 1235 return false; 1236 } 1237 1238 void virtio_transport_release(struct vsock_sock *vsk) 1239 { 1240 struct sock *sk = &vsk->sk; 1241 bool remove_sock = true; 1242 1243 if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) 1244 remove_sock = virtio_transport_close(vsk); 1245 1246 if (remove_sock) { 1247 sock_set_flag(sk, SOCK_DONE); 1248 virtio_transport_remove_sock(vsk); 1249 } 1250 } 1251 EXPORT_SYMBOL_GPL(virtio_transport_release); 1252 1253 static int 1254 virtio_transport_recv_connecting(struct sock *sk, 1255 struct sk_buff *skb) 1256 { 1257 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1258 struct vsock_sock *vsk = vsock_sk(sk); 1259 int skerr; 1260 int err; 1261 1262 switch (le16_to_cpu(hdr->op)) { 1263 case VIRTIO_VSOCK_OP_RESPONSE: 1264 sk->sk_state = TCP_ESTABLISHED; 1265 sk->sk_socket->state = SS_CONNECTED; 1266 vsock_insert_connected(vsk); 1267 sk->sk_state_change(sk); 1268 break; 1269 case VIRTIO_VSOCK_OP_INVALID: 1270 break; 1271 case VIRTIO_VSOCK_OP_RST: 1272 skerr = ECONNRESET; 1273 err = 0; 1274 goto destroy; 1275 default: 1276 skerr = EPROTO; 1277 err = -EINVAL; 1278 goto destroy; 1279 } 1280 return 0; 1281 1282 destroy: 1283 virtio_transport_reset(vsk, skb); 1284 sk->sk_state = TCP_CLOSE; 1285 sk->sk_err = skerr; 1286 sk_error_report(sk); 1287 return err; 1288 } 1289 1290 static void 1291 virtio_transport_recv_enqueue(struct vsock_sock *vsk, 1292 struct sk_buff *skb) 1293 { 1294 struct virtio_vsock_sock *vvs = vsk->trans; 1295 bool can_enqueue, free_pkt = false; 1296 struct virtio_vsock_hdr *hdr; 1297 u32 len; 1298 1299 hdr = virtio_vsock_hdr(skb); 1300 len = le32_to_cpu(hdr->len); 1301 1302 spin_lock_bh(&vvs->rx_lock); 1303 1304 can_enqueue = virtio_transport_inc_rx_pkt(vvs, len); 1305 if (!can_enqueue) { 1306 free_pkt = true; 1307 goto out; 1308 } 1309 1310 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) 1311 vvs->msg_count++; 1312 1313 /* Try to copy small packets into the buffer of last packet queued, 1314 * to avoid wasting memory queueing the entire buffer with a small 1315 * payload. 1316 */ 1317 if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) { 1318 struct virtio_vsock_hdr *last_hdr; 1319 struct sk_buff *last_skb; 1320 1321 last_skb = skb_peek_tail(&vvs->rx_queue); 1322 last_hdr = virtio_vsock_hdr(last_skb); 1323 1324 /* If there is space in the last packet queued, we copy the 1325 * new packet in its buffer. We avoid this if the last packet 1326 * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is 1327 * delimiter of SEQPACKET message, so 'pkt' is the first packet 1328 * of a new message. 1329 */ 1330 if (skb->len < skb_tailroom(last_skb) && 1331 !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) { 1332 memcpy(skb_put(last_skb, skb->len), skb->data, skb->len); 1333 free_pkt = true; 1334 last_hdr->flags |= hdr->flags; 1335 le32_add_cpu(&last_hdr->len, len); 1336 goto out; 1337 } 1338 } 1339 1340 __skb_queue_tail(&vvs->rx_queue, skb); 1341 1342 out: 1343 spin_unlock_bh(&vvs->rx_lock); 1344 if (free_pkt) 1345 kfree_skb(skb); 1346 } 1347 1348 static int 1349 virtio_transport_recv_connected(struct sock *sk, 1350 struct sk_buff *skb) 1351 { 1352 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1353 struct vsock_sock *vsk = vsock_sk(sk); 1354 int err = 0; 1355 1356 switch (le16_to_cpu(hdr->op)) { 1357 case VIRTIO_VSOCK_OP_RW: 1358 virtio_transport_recv_enqueue(vsk, skb); 1359 vsock_data_ready(sk); 1360 return err; 1361 case VIRTIO_VSOCK_OP_CREDIT_REQUEST: 1362 virtio_transport_send_credit_update(vsk); 1363 break; 1364 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 1365 sk->sk_write_space(sk); 1366 break; 1367 case VIRTIO_VSOCK_OP_SHUTDOWN: 1368 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 1369 vsk->peer_shutdown |= RCV_SHUTDOWN; 1370 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 1371 vsk->peer_shutdown |= SEND_SHUTDOWN; 1372 if (vsk->peer_shutdown == SHUTDOWN_MASK && 1373 vsock_stream_has_data(vsk) <= 0 && 1374 !sock_flag(sk, SOCK_DONE)) { 1375 (void)virtio_transport_reset(vsk, NULL); 1376 virtio_transport_do_close(vsk, true); 1377 } 1378 if (le32_to_cpu(virtio_vsock_hdr(skb)->flags)) 1379 sk->sk_state_change(sk); 1380 break; 1381 case VIRTIO_VSOCK_OP_RST: 1382 virtio_transport_do_close(vsk, true); 1383 break; 1384 default: 1385 err = -EINVAL; 1386 break; 1387 } 1388 1389 kfree_skb(skb); 1390 return err; 1391 } 1392 1393 static void 1394 virtio_transport_recv_disconnecting(struct sock *sk, 1395 struct sk_buff *skb) 1396 { 1397 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1398 struct vsock_sock *vsk = vsock_sk(sk); 1399 1400 if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST) 1401 virtio_transport_do_close(vsk, true); 1402 } 1403 1404 static int 1405 virtio_transport_send_response(struct vsock_sock *vsk, 1406 struct sk_buff *skb) 1407 { 1408 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1409 struct virtio_vsock_pkt_info info = { 1410 .op = VIRTIO_VSOCK_OP_RESPONSE, 1411 .remote_cid = le64_to_cpu(hdr->src_cid), 1412 .remote_port = le32_to_cpu(hdr->src_port), 1413 .reply = true, 1414 .vsk = vsk, 1415 }; 1416 1417 return virtio_transport_send_pkt_info(vsk, &info); 1418 } 1419 1420 static bool virtio_transport_space_update(struct sock *sk, 1421 struct sk_buff *skb) 1422 { 1423 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1424 struct vsock_sock *vsk = vsock_sk(sk); 1425 struct virtio_vsock_sock *vvs = vsk->trans; 1426 bool space_available; 1427 1428 /* Listener sockets are not associated with any transport, so we are 1429 * not able to take the state to see if there is space available in the 1430 * remote peer, but since they are only used to receive requests, we 1431 * can assume that there is always space available in the other peer. 1432 */ 1433 if (!vvs) 1434 return true; 1435 1436 /* buf_alloc and fwd_cnt is always included in the hdr */ 1437 spin_lock_bh(&vvs->tx_lock); 1438 vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc); 1439 vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt); 1440 space_available = virtio_transport_has_space(vsk); 1441 spin_unlock_bh(&vvs->tx_lock); 1442 return space_available; 1443 } 1444 1445 /* Handle server socket */ 1446 static int 1447 virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, 1448 struct virtio_transport *t) 1449 { 1450 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1451 struct vsock_sock *vsk = vsock_sk(sk); 1452 struct vsock_sock *vchild; 1453 struct sock *child; 1454 int ret; 1455 1456 if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) { 1457 virtio_transport_reset_no_sock(t, skb); 1458 return -EINVAL; 1459 } 1460 1461 if (sk_acceptq_is_full(sk)) { 1462 virtio_transport_reset_no_sock(t, skb); 1463 return -ENOMEM; 1464 } 1465 1466 child = vsock_create_connected(sk); 1467 if (!child) { 1468 virtio_transport_reset_no_sock(t, skb); 1469 return -ENOMEM; 1470 } 1471 1472 sk_acceptq_added(sk); 1473 1474 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 1475 1476 child->sk_state = TCP_ESTABLISHED; 1477 1478 vchild = vsock_sk(child); 1479 vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid), 1480 le32_to_cpu(hdr->dst_port)); 1481 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid), 1482 le32_to_cpu(hdr->src_port)); 1483 1484 ret = vsock_assign_transport(vchild, vsk); 1485 /* Transport assigned (looking at remote_addr) must be the same 1486 * where we received the request. 1487 */ 1488 if (ret || vchild->transport != &t->transport) { 1489 release_sock(child); 1490 virtio_transport_reset_no_sock(t, skb); 1491 sock_put(child); 1492 return ret; 1493 } 1494 1495 if (virtio_transport_space_update(child, skb)) 1496 child->sk_write_space(child); 1497 1498 vsock_insert_connected(vchild); 1499 vsock_enqueue_accept(sk, child); 1500 virtio_transport_send_response(vchild, skb); 1501 1502 release_sock(child); 1503 1504 sk->sk_data_ready(sk); 1505 return 0; 1506 } 1507 1508 static bool virtio_transport_valid_type(u16 type) 1509 { 1510 return (type == VIRTIO_VSOCK_TYPE_STREAM) || 1511 (type == VIRTIO_VSOCK_TYPE_SEQPACKET); 1512 } 1513 1514 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 1515 * lock. 1516 */ 1517 void virtio_transport_recv_pkt(struct virtio_transport *t, 1518 struct sk_buff *skb) 1519 { 1520 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); 1521 struct sockaddr_vm src, dst; 1522 struct vsock_sock *vsk; 1523 struct sock *sk; 1524 bool space_available; 1525 1526 vsock_addr_init(&src, le64_to_cpu(hdr->src_cid), 1527 le32_to_cpu(hdr->src_port)); 1528 vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid), 1529 le32_to_cpu(hdr->dst_port)); 1530 1531 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 1532 dst.svm_cid, dst.svm_port, 1533 le32_to_cpu(hdr->len), 1534 le16_to_cpu(hdr->type), 1535 le16_to_cpu(hdr->op), 1536 le32_to_cpu(hdr->flags), 1537 le32_to_cpu(hdr->buf_alloc), 1538 le32_to_cpu(hdr->fwd_cnt)); 1539 1540 if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) { 1541 (void)virtio_transport_reset_no_sock(t, skb); 1542 goto free_pkt; 1543 } 1544 1545 /* The socket must be in connected or bound table 1546 * otherwise send reset back 1547 */ 1548 sk = vsock_find_connected_socket(&src, &dst); 1549 if (!sk) { 1550 sk = vsock_find_bound_socket(&dst); 1551 if (!sk) { 1552 (void)virtio_transport_reset_no_sock(t, skb); 1553 goto free_pkt; 1554 } 1555 } 1556 1557 if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) { 1558 (void)virtio_transport_reset_no_sock(t, skb); 1559 sock_put(sk); 1560 goto free_pkt; 1561 } 1562 1563 if (!skb_set_owner_sk_safe(skb, sk)) { 1564 WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n"); 1565 goto free_pkt; 1566 } 1567 1568 vsk = vsock_sk(sk); 1569 1570 lock_sock(sk); 1571 1572 /* Check if sk has been closed before lock_sock */ 1573 if (sock_flag(sk, SOCK_DONE)) { 1574 (void)virtio_transport_reset_no_sock(t, skb); 1575 release_sock(sk); 1576 sock_put(sk); 1577 goto free_pkt; 1578 } 1579 1580 space_available = virtio_transport_space_update(sk, skb); 1581 1582 /* Update CID in case it has changed after a transport reset event */ 1583 if (vsk->local_addr.svm_cid != VMADDR_CID_ANY) 1584 vsk->local_addr.svm_cid = dst.svm_cid; 1585 1586 if (space_available) 1587 sk->sk_write_space(sk); 1588 1589 switch (sk->sk_state) { 1590 case TCP_LISTEN: 1591 virtio_transport_recv_listen(sk, skb, t); 1592 kfree_skb(skb); 1593 break; 1594 case TCP_SYN_SENT: 1595 virtio_transport_recv_connecting(sk, skb); 1596 kfree_skb(skb); 1597 break; 1598 case TCP_ESTABLISHED: 1599 virtio_transport_recv_connected(sk, skb); 1600 break; 1601 case TCP_CLOSING: 1602 virtio_transport_recv_disconnecting(sk, skb); 1603 kfree_skb(skb); 1604 break; 1605 default: 1606 (void)virtio_transport_reset_no_sock(t, skb); 1607 kfree_skb(skb); 1608 break; 1609 } 1610 1611 release_sock(sk); 1612 1613 /* Release refcnt obtained when we fetched this socket out of the 1614 * bound or connected list. 1615 */ 1616 sock_put(sk); 1617 return; 1618 1619 free_pkt: 1620 kfree_skb(skb); 1621 } 1622 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 1623 1624 /* Remove skbs found in a queue that have a vsk that matches. 1625 * 1626 * Each skb is freed. 1627 * 1628 * Returns the count of skbs that were reply packets. 1629 */ 1630 int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue) 1631 { 1632 struct sk_buff_head freeme; 1633 struct sk_buff *skb, *tmp; 1634 int cnt = 0; 1635 1636 skb_queue_head_init(&freeme); 1637 1638 spin_lock_bh(&queue->lock); 1639 skb_queue_walk_safe(queue, skb, tmp) { 1640 if (vsock_sk(skb->sk) != vsk) 1641 continue; 1642 1643 __skb_unlink(skb, queue); 1644 __skb_queue_tail(&freeme, skb); 1645 1646 if (virtio_vsock_skb_reply(skb)) 1647 cnt++; 1648 } 1649 spin_unlock_bh(&queue->lock); 1650 1651 __skb_queue_purge(&freeme); 1652 1653 return cnt; 1654 } 1655 EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs); 1656 1657 int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor) 1658 { 1659 struct virtio_vsock_sock *vvs = vsk->trans; 1660 struct sock *sk = sk_vsock(vsk); 1661 struct sk_buff *skb; 1662 int off = 0; 1663 int err; 1664 1665 spin_lock_bh(&vvs->rx_lock); 1666 /* Use __skb_recv_datagram() for race-free handling of the receive. It 1667 * works for types other than dgrams. 1668 */ 1669 skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err); 1670 spin_unlock_bh(&vvs->rx_lock); 1671 1672 if (!skb) 1673 return err; 1674 1675 return recv_actor(sk, skb); 1676 } 1677 EXPORT_SYMBOL_GPL(virtio_transport_read_skb); 1678 1679 MODULE_LICENSE("GPL v2"); 1680 MODULE_AUTHOR("Asias He"); 1681 MODULE_DESCRIPTION("common code for virtio vsock"); 1682