1 /* 2 * common code for virtio vsock 3 * 4 * Copyright (C) 2013-2015 Red Hat, Inc. 5 * Author: Asias He <asias@redhat.com> 6 * Stefan Hajnoczi <stefanha@redhat.com> 7 * 8 * This work is licensed under the terms of the GNU GPL, version 2. 9 */ 10 #include <linux/spinlock.h> 11 #include <linux/module.h> 12 #include <linux/sched/signal.h> 13 #include <linux/ctype.h> 14 #include <linux/list.h> 15 #include <linux/virtio.h> 16 #include <linux/virtio_ids.h> 17 #include <linux/virtio_config.h> 18 #include <linux/virtio_vsock.h> 19 #include <uapi/linux/vsockmon.h> 20 21 #include <net/sock.h> 22 #include <net/af_vsock.h> 23 24 #define CREATE_TRACE_POINTS 25 #include <trace/events/vsock_virtio_transport_common.h> 26 27 /* How long to wait for graceful shutdown of a connection */ 28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 29 30 static const struct virtio_transport *virtio_transport_get_ops(void) 31 { 32 const struct vsock_transport *t = vsock_core_get_transport(); 33 34 return container_of(t, struct virtio_transport, transport); 35 } 36 37 static struct virtio_vsock_pkt * 38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, 39 size_t len, 40 u32 src_cid, 41 u32 src_port, 42 u32 dst_cid, 43 u32 dst_port) 44 { 45 struct virtio_vsock_pkt *pkt; 46 int err; 47 48 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 49 if (!pkt) 50 return NULL; 51 52 pkt->hdr.type = cpu_to_le16(info->type); 53 pkt->hdr.op = cpu_to_le16(info->op); 54 pkt->hdr.src_cid = cpu_to_le64(src_cid); 55 pkt->hdr.dst_cid = cpu_to_le64(dst_cid); 56 pkt->hdr.src_port = cpu_to_le32(src_port); 57 pkt->hdr.dst_port = cpu_to_le32(dst_port); 58 pkt->hdr.flags = cpu_to_le32(info->flags); 59 pkt->len = len; 60 pkt->hdr.len = cpu_to_le32(len); 61 pkt->reply = info->reply; 62 pkt->vsk = info->vsk; 63 64 if (info->msg && len > 0) { 65 pkt->buf = kmalloc(len, GFP_KERNEL); 66 if (!pkt->buf) 67 goto out_pkt; 68 err = memcpy_from_msg(pkt->buf, info->msg, len); 69 if (err) 70 goto out; 71 } 72 73 trace_virtio_transport_alloc_pkt(src_cid, src_port, 74 dst_cid, dst_port, 75 len, 76 info->type, 77 info->op, 78 info->flags); 79 80 return pkt; 81 82 out: 83 kfree(pkt->buf); 84 out_pkt: 85 kfree(pkt); 86 return NULL; 87 } 88 89 /* Packet capture */ 90 static struct sk_buff *virtio_transport_build_skb(void *opaque) 91 { 92 struct virtio_vsock_pkt *pkt = opaque; 93 struct af_vsockmon_hdr *hdr; 94 struct sk_buff *skb; 95 96 skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len, 97 GFP_ATOMIC); 98 if (!skb) 99 return NULL; 100 101 hdr = skb_put(skb, sizeof(*hdr)); 102 103 /* pkt->hdr is little-endian so no need to byteswap here */ 104 hdr->src_cid = pkt->hdr.src_cid; 105 hdr->src_port = pkt->hdr.src_port; 106 hdr->dst_cid = pkt->hdr.dst_cid; 107 hdr->dst_port = pkt->hdr.dst_port; 108 109 hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO); 110 hdr->len = cpu_to_le16(sizeof(pkt->hdr)); 111 memset(hdr->reserved, 0, sizeof(hdr->reserved)); 112 113 switch (le16_to_cpu(pkt->hdr.op)) { 114 case VIRTIO_VSOCK_OP_REQUEST: 115 case VIRTIO_VSOCK_OP_RESPONSE: 116 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); 117 break; 118 case VIRTIO_VSOCK_OP_RST: 119 case VIRTIO_VSOCK_OP_SHUTDOWN: 120 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT); 121 break; 122 case VIRTIO_VSOCK_OP_RW: 123 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD); 124 break; 125 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 126 case VIRTIO_VSOCK_OP_CREDIT_REQUEST: 127 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); 128 break; 129 default: 130 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN); 131 break; 132 } 133 134 skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr)); 135 136 if (pkt->len) { 137 skb_put_data(skb, pkt->buf, pkt->len); 138 } 139 140 return skb; 141 } 142 143 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt) 144 { 145 vsock_deliver_tap(virtio_transport_build_skb, pkt); 146 } 147 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); 148 149 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 150 struct virtio_vsock_pkt_info *info) 151 { 152 u32 src_cid, src_port, dst_cid, dst_port; 153 struct virtio_vsock_sock *vvs; 154 struct virtio_vsock_pkt *pkt; 155 u32 pkt_len = info->pkt_len; 156 157 src_cid = vm_sockets_get_local_cid(); 158 src_port = vsk->local_addr.svm_port; 159 if (!info->remote_cid) { 160 dst_cid = vsk->remote_addr.svm_cid; 161 dst_port = vsk->remote_addr.svm_port; 162 } else { 163 dst_cid = info->remote_cid; 164 dst_port = info->remote_port; 165 } 166 167 vvs = vsk->trans; 168 169 /* we can send less than pkt_len bytes */ 170 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) 171 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; 172 173 /* virtio_transport_get_credit might return less than pkt_len credit */ 174 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 175 176 /* Do not send zero length OP_RW pkt */ 177 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 178 return pkt_len; 179 180 pkt = virtio_transport_alloc_pkt(info, pkt_len, 181 src_cid, src_port, 182 dst_cid, dst_port); 183 if (!pkt) { 184 virtio_transport_put_credit(vvs, pkt_len); 185 return -ENOMEM; 186 } 187 188 virtio_transport_inc_tx_pkt(vvs, pkt); 189 190 return virtio_transport_get_ops()->send_pkt(pkt); 191 } 192 193 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 194 struct virtio_vsock_pkt *pkt) 195 { 196 vvs->rx_bytes += pkt->len; 197 } 198 199 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 200 struct virtio_vsock_pkt *pkt) 201 { 202 vvs->rx_bytes -= pkt->len; 203 vvs->fwd_cnt += pkt->len; 204 } 205 206 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) 207 { 208 spin_lock_bh(&vvs->tx_lock); 209 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 210 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); 211 spin_unlock_bh(&vvs->tx_lock); 212 } 213 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 214 215 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 216 { 217 u32 ret; 218 219 spin_lock_bh(&vvs->tx_lock); 220 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 221 if (ret > credit) 222 ret = credit; 223 vvs->tx_cnt += ret; 224 spin_unlock_bh(&vvs->tx_lock); 225 226 return ret; 227 } 228 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 229 230 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 231 { 232 spin_lock_bh(&vvs->tx_lock); 233 vvs->tx_cnt -= credit; 234 spin_unlock_bh(&vvs->tx_lock); 235 } 236 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 237 238 static int virtio_transport_send_credit_update(struct vsock_sock *vsk, 239 int type, 240 struct virtio_vsock_hdr *hdr) 241 { 242 struct virtio_vsock_pkt_info info = { 243 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 244 .type = type, 245 .vsk = vsk, 246 }; 247 248 return virtio_transport_send_pkt_info(vsk, &info); 249 } 250 251 static ssize_t 252 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 253 struct msghdr *msg, 254 size_t len) 255 { 256 struct virtio_vsock_sock *vvs = vsk->trans; 257 struct virtio_vsock_pkt *pkt; 258 size_t bytes, total = 0; 259 int err = -EFAULT; 260 261 spin_lock_bh(&vvs->rx_lock); 262 while (total < len && !list_empty(&vvs->rx_queue)) { 263 pkt = list_first_entry(&vvs->rx_queue, 264 struct virtio_vsock_pkt, list); 265 266 bytes = len - total; 267 if (bytes > pkt->len - pkt->off) 268 bytes = pkt->len - pkt->off; 269 270 /* sk_lock is held by caller so no one else can dequeue. 271 * Unlock rx_lock since memcpy_to_msg() may sleep. 272 */ 273 spin_unlock_bh(&vvs->rx_lock); 274 275 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); 276 if (err) 277 goto out; 278 279 spin_lock_bh(&vvs->rx_lock); 280 281 total += bytes; 282 pkt->off += bytes; 283 if (pkt->off == pkt->len) { 284 virtio_transport_dec_rx_pkt(vvs, pkt); 285 list_del(&pkt->list); 286 virtio_transport_free_pkt(pkt); 287 } 288 } 289 spin_unlock_bh(&vvs->rx_lock); 290 291 /* Send a credit pkt to peer */ 292 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, 293 NULL); 294 295 return total; 296 297 out: 298 if (total) 299 err = total; 300 return err; 301 } 302 303 ssize_t 304 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 305 struct msghdr *msg, 306 size_t len, int flags) 307 { 308 if (flags & MSG_PEEK) 309 return -EOPNOTSUPP; 310 311 return virtio_transport_stream_do_dequeue(vsk, msg, len); 312 } 313 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 314 315 int 316 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 317 struct msghdr *msg, 318 size_t len, int flags) 319 { 320 return -EOPNOTSUPP; 321 } 322 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 323 324 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 325 { 326 struct virtio_vsock_sock *vvs = vsk->trans; 327 s64 bytes; 328 329 spin_lock_bh(&vvs->rx_lock); 330 bytes = vvs->rx_bytes; 331 spin_unlock_bh(&vvs->rx_lock); 332 333 return bytes; 334 } 335 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 336 337 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 338 { 339 struct virtio_vsock_sock *vvs = vsk->trans; 340 s64 bytes; 341 342 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 343 if (bytes < 0) 344 bytes = 0; 345 346 return bytes; 347 } 348 349 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 350 { 351 struct virtio_vsock_sock *vvs = vsk->trans; 352 s64 bytes; 353 354 spin_lock_bh(&vvs->tx_lock); 355 bytes = virtio_transport_has_space(vsk); 356 spin_unlock_bh(&vvs->tx_lock); 357 358 return bytes; 359 } 360 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 361 362 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 363 struct vsock_sock *psk) 364 { 365 struct virtio_vsock_sock *vvs; 366 367 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 368 if (!vvs) 369 return -ENOMEM; 370 371 vsk->trans = vvs; 372 vvs->vsk = vsk; 373 if (psk) { 374 struct virtio_vsock_sock *ptrans = psk->trans; 375 376 vvs->buf_size = ptrans->buf_size; 377 vvs->buf_size_min = ptrans->buf_size_min; 378 vvs->buf_size_max = ptrans->buf_size_max; 379 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 380 } else { 381 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; 382 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; 383 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; 384 } 385 386 vvs->buf_alloc = vvs->buf_size; 387 388 spin_lock_init(&vvs->rx_lock); 389 spin_lock_init(&vvs->tx_lock); 390 INIT_LIST_HEAD(&vvs->rx_queue); 391 392 return 0; 393 } 394 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 395 396 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) 397 { 398 struct virtio_vsock_sock *vvs = vsk->trans; 399 400 return vvs->buf_size; 401 } 402 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); 403 404 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) 405 { 406 struct virtio_vsock_sock *vvs = vsk->trans; 407 408 return vvs->buf_size_min; 409 } 410 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); 411 412 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) 413 { 414 struct virtio_vsock_sock *vvs = vsk->trans; 415 416 return vvs->buf_size_max; 417 } 418 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); 419 420 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) 421 { 422 struct virtio_vsock_sock *vvs = vsk->trans; 423 424 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 425 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 426 if (val < vvs->buf_size_min) 427 vvs->buf_size_min = val; 428 if (val > vvs->buf_size_max) 429 vvs->buf_size_max = val; 430 vvs->buf_size = val; 431 vvs->buf_alloc = val; 432 } 433 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); 434 435 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) 436 { 437 struct virtio_vsock_sock *vvs = vsk->trans; 438 439 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 440 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 441 if (val > vvs->buf_size) 442 vvs->buf_size = val; 443 vvs->buf_size_min = val; 444 } 445 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); 446 447 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) 448 { 449 struct virtio_vsock_sock *vvs = vsk->trans; 450 451 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 452 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 453 if (val < vvs->buf_size) 454 vvs->buf_size = val; 455 vvs->buf_size_max = val; 456 } 457 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); 458 459 int 460 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 461 size_t target, 462 bool *data_ready_now) 463 { 464 if (vsock_stream_has_data(vsk)) 465 *data_ready_now = true; 466 else 467 *data_ready_now = false; 468 469 return 0; 470 } 471 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 472 473 int 474 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 475 size_t target, 476 bool *space_avail_now) 477 { 478 s64 free_space; 479 480 free_space = vsock_stream_has_space(vsk); 481 if (free_space > 0) 482 *space_avail_now = true; 483 else if (free_space == 0) 484 *space_avail_now = false; 485 486 return 0; 487 } 488 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 489 490 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 491 size_t target, struct vsock_transport_recv_notify_data *data) 492 { 493 return 0; 494 } 495 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 496 497 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 498 size_t target, struct vsock_transport_recv_notify_data *data) 499 { 500 return 0; 501 } 502 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 503 504 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 505 size_t target, struct vsock_transport_recv_notify_data *data) 506 { 507 return 0; 508 } 509 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 510 511 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 512 size_t target, ssize_t copied, bool data_read, 513 struct vsock_transport_recv_notify_data *data) 514 { 515 return 0; 516 } 517 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 518 519 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 520 struct vsock_transport_send_notify_data *data) 521 { 522 return 0; 523 } 524 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 525 526 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 527 struct vsock_transport_send_notify_data *data) 528 { 529 return 0; 530 } 531 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 532 533 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 534 struct vsock_transport_send_notify_data *data) 535 { 536 return 0; 537 } 538 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 539 540 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 541 ssize_t written, struct vsock_transport_send_notify_data *data) 542 { 543 return 0; 544 } 545 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 546 547 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 548 { 549 struct virtio_vsock_sock *vvs = vsk->trans; 550 551 return vvs->buf_size; 552 } 553 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 554 555 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 556 { 557 return true; 558 } 559 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 560 561 bool virtio_transport_stream_allow(u32 cid, u32 port) 562 { 563 return true; 564 } 565 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 566 567 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 568 struct sockaddr_vm *addr) 569 { 570 return -EOPNOTSUPP; 571 } 572 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 573 574 bool virtio_transport_dgram_allow(u32 cid, u32 port) 575 { 576 return false; 577 } 578 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 579 580 int virtio_transport_connect(struct vsock_sock *vsk) 581 { 582 struct virtio_vsock_pkt_info info = { 583 .op = VIRTIO_VSOCK_OP_REQUEST, 584 .type = VIRTIO_VSOCK_TYPE_STREAM, 585 .vsk = vsk, 586 }; 587 588 return virtio_transport_send_pkt_info(vsk, &info); 589 } 590 EXPORT_SYMBOL_GPL(virtio_transport_connect); 591 592 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 593 { 594 struct virtio_vsock_pkt_info info = { 595 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 596 .type = VIRTIO_VSOCK_TYPE_STREAM, 597 .flags = (mode & RCV_SHUTDOWN ? 598 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 599 (mode & SEND_SHUTDOWN ? 600 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 601 .vsk = vsk, 602 }; 603 604 return virtio_transport_send_pkt_info(vsk, &info); 605 } 606 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 607 608 int 609 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 610 struct sockaddr_vm *remote_addr, 611 struct msghdr *msg, 612 size_t dgram_len) 613 { 614 return -EOPNOTSUPP; 615 } 616 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 617 618 ssize_t 619 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 620 struct msghdr *msg, 621 size_t len) 622 { 623 struct virtio_vsock_pkt_info info = { 624 .op = VIRTIO_VSOCK_OP_RW, 625 .type = VIRTIO_VSOCK_TYPE_STREAM, 626 .msg = msg, 627 .pkt_len = len, 628 .vsk = vsk, 629 }; 630 631 return virtio_transport_send_pkt_info(vsk, &info); 632 } 633 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 634 635 void virtio_transport_destruct(struct vsock_sock *vsk) 636 { 637 struct virtio_vsock_sock *vvs = vsk->trans; 638 639 kfree(vvs); 640 } 641 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 642 643 static int virtio_transport_reset(struct vsock_sock *vsk, 644 struct virtio_vsock_pkt *pkt) 645 { 646 struct virtio_vsock_pkt_info info = { 647 .op = VIRTIO_VSOCK_OP_RST, 648 .type = VIRTIO_VSOCK_TYPE_STREAM, 649 .reply = !!pkt, 650 .vsk = vsk, 651 }; 652 653 /* Send RST only if the original pkt is not a RST pkt */ 654 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 655 return 0; 656 657 return virtio_transport_send_pkt_info(vsk, &info); 658 } 659 660 /* Normally packets are associated with a socket. There may be no socket if an 661 * attempt was made to connect to a socket that does not exist. 662 */ 663 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) 664 { 665 const struct virtio_transport *t; 666 struct virtio_vsock_pkt *reply; 667 struct virtio_vsock_pkt_info info = { 668 .op = VIRTIO_VSOCK_OP_RST, 669 .type = le16_to_cpu(pkt->hdr.type), 670 .reply = true, 671 }; 672 673 /* Send RST only if the original pkt is not a RST pkt */ 674 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 675 return 0; 676 677 reply = virtio_transport_alloc_pkt(&info, 0, 678 le64_to_cpu(pkt->hdr.dst_cid), 679 le32_to_cpu(pkt->hdr.dst_port), 680 le64_to_cpu(pkt->hdr.src_cid), 681 le32_to_cpu(pkt->hdr.src_port)); 682 if (!reply) 683 return -ENOMEM; 684 685 t = virtio_transport_get_ops(); 686 if (!t) { 687 virtio_transport_free_pkt(reply); 688 return -ENOTCONN; 689 } 690 691 return t->send_pkt(reply); 692 } 693 694 static void virtio_transport_wait_close(struct sock *sk, long timeout) 695 { 696 if (timeout) { 697 DEFINE_WAIT_FUNC(wait, woken_wake_function); 698 699 add_wait_queue(sk_sleep(sk), &wait); 700 701 do { 702 if (sk_wait_event(sk, &timeout, 703 sock_flag(sk, SOCK_DONE), &wait)) 704 break; 705 } while (!signal_pending(current) && timeout); 706 707 remove_wait_queue(sk_sleep(sk), &wait); 708 } 709 } 710 711 static void virtio_transport_do_close(struct vsock_sock *vsk, 712 bool cancel_timeout) 713 { 714 struct sock *sk = sk_vsock(vsk); 715 716 sock_set_flag(sk, SOCK_DONE); 717 vsk->peer_shutdown = SHUTDOWN_MASK; 718 if (vsock_stream_has_data(vsk) <= 0) 719 sk->sk_state = TCP_CLOSING; 720 sk->sk_state_change(sk); 721 722 if (vsk->close_work_scheduled && 723 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 724 vsk->close_work_scheduled = false; 725 726 vsock_remove_sock(vsk); 727 728 /* Release refcnt obtained when we scheduled the timeout */ 729 sock_put(sk); 730 } 731 } 732 733 static void virtio_transport_close_timeout(struct work_struct *work) 734 { 735 struct vsock_sock *vsk = 736 container_of(work, struct vsock_sock, close_work.work); 737 struct sock *sk = sk_vsock(vsk); 738 739 sock_hold(sk); 740 lock_sock(sk); 741 742 if (!sock_flag(sk, SOCK_DONE)) { 743 (void)virtio_transport_reset(vsk, NULL); 744 745 virtio_transport_do_close(vsk, false); 746 } 747 748 vsk->close_work_scheduled = false; 749 750 release_sock(sk); 751 sock_put(sk); 752 } 753 754 /* User context, vsk->sk is locked */ 755 static bool virtio_transport_close(struct vsock_sock *vsk) 756 { 757 struct sock *sk = &vsk->sk; 758 759 if (!(sk->sk_state == TCP_ESTABLISHED || 760 sk->sk_state == TCP_CLOSING)) 761 return true; 762 763 /* Already received SHUTDOWN from peer, reply with RST */ 764 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 765 (void)virtio_transport_reset(vsk, NULL); 766 return true; 767 } 768 769 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 770 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 771 772 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 773 virtio_transport_wait_close(sk, sk->sk_lingertime); 774 775 if (sock_flag(sk, SOCK_DONE)) { 776 return true; 777 } 778 779 sock_hold(sk); 780 INIT_DELAYED_WORK(&vsk->close_work, 781 virtio_transport_close_timeout); 782 vsk->close_work_scheduled = true; 783 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 784 return false; 785 } 786 787 void virtio_transport_release(struct vsock_sock *vsk) 788 { 789 struct virtio_vsock_sock *vvs = vsk->trans; 790 struct virtio_vsock_pkt *pkt, *tmp; 791 struct sock *sk = &vsk->sk; 792 bool remove_sock = true; 793 794 lock_sock(sk); 795 if (sk->sk_type == SOCK_STREAM) 796 remove_sock = virtio_transport_close(vsk); 797 798 list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { 799 list_del(&pkt->list); 800 virtio_transport_free_pkt(pkt); 801 } 802 release_sock(sk); 803 804 if (remove_sock) 805 vsock_remove_sock(vsk); 806 } 807 EXPORT_SYMBOL_GPL(virtio_transport_release); 808 809 static int 810 virtio_transport_recv_connecting(struct sock *sk, 811 struct virtio_vsock_pkt *pkt) 812 { 813 struct vsock_sock *vsk = vsock_sk(sk); 814 int err; 815 int skerr; 816 817 switch (le16_to_cpu(pkt->hdr.op)) { 818 case VIRTIO_VSOCK_OP_RESPONSE: 819 sk->sk_state = TCP_ESTABLISHED; 820 sk->sk_socket->state = SS_CONNECTED; 821 vsock_insert_connected(vsk); 822 sk->sk_state_change(sk); 823 break; 824 case VIRTIO_VSOCK_OP_INVALID: 825 break; 826 case VIRTIO_VSOCK_OP_RST: 827 skerr = ECONNRESET; 828 err = 0; 829 goto destroy; 830 default: 831 skerr = EPROTO; 832 err = -EINVAL; 833 goto destroy; 834 } 835 return 0; 836 837 destroy: 838 virtio_transport_reset(vsk, pkt); 839 sk->sk_state = TCP_CLOSE; 840 sk->sk_err = skerr; 841 sk->sk_error_report(sk); 842 return err; 843 } 844 845 static int 846 virtio_transport_recv_connected(struct sock *sk, 847 struct virtio_vsock_pkt *pkt) 848 { 849 struct vsock_sock *vsk = vsock_sk(sk); 850 struct virtio_vsock_sock *vvs = vsk->trans; 851 int err = 0; 852 853 switch (le16_to_cpu(pkt->hdr.op)) { 854 case VIRTIO_VSOCK_OP_RW: 855 pkt->len = le32_to_cpu(pkt->hdr.len); 856 pkt->off = 0; 857 858 spin_lock_bh(&vvs->rx_lock); 859 virtio_transport_inc_rx_pkt(vvs, pkt); 860 list_add_tail(&pkt->list, &vvs->rx_queue); 861 spin_unlock_bh(&vvs->rx_lock); 862 863 sk->sk_data_ready(sk); 864 return err; 865 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 866 sk->sk_write_space(sk); 867 break; 868 case VIRTIO_VSOCK_OP_SHUTDOWN: 869 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 870 vsk->peer_shutdown |= RCV_SHUTDOWN; 871 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 872 vsk->peer_shutdown |= SEND_SHUTDOWN; 873 if (vsk->peer_shutdown == SHUTDOWN_MASK && 874 vsock_stream_has_data(vsk) <= 0) 875 sk->sk_state = TCP_CLOSING; 876 if (le32_to_cpu(pkt->hdr.flags)) 877 sk->sk_state_change(sk); 878 break; 879 case VIRTIO_VSOCK_OP_RST: 880 virtio_transport_do_close(vsk, true); 881 break; 882 default: 883 err = -EINVAL; 884 break; 885 } 886 887 virtio_transport_free_pkt(pkt); 888 return err; 889 } 890 891 static void 892 virtio_transport_recv_disconnecting(struct sock *sk, 893 struct virtio_vsock_pkt *pkt) 894 { 895 struct vsock_sock *vsk = vsock_sk(sk); 896 897 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 898 virtio_transport_do_close(vsk, true); 899 } 900 901 static int 902 virtio_transport_send_response(struct vsock_sock *vsk, 903 struct virtio_vsock_pkt *pkt) 904 { 905 struct virtio_vsock_pkt_info info = { 906 .op = VIRTIO_VSOCK_OP_RESPONSE, 907 .type = VIRTIO_VSOCK_TYPE_STREAM, 908 .remote_cid = le64_to_cpu(pkt->hdr.src_cid), 909 .remote_port = le32_to_cpu(pkt->hdr.src_port), 910 .reply = true, 911 .vsk = vsk, 912 }; 913 914 return virtio_transport_send_pkt_info(vsk, &info); 915 } 916 917 /* Handle server socket */ 918 static int 919 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) 920 { 921 struct vsock_sock *vsk = vsock_sk(sk); 922 struct vsock_sock *vchild; 923 struct sock *child; 924 925 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { 926 virtio_transport_reset(vsk, pkt); 927 return -EINVAL; 928 } 929 930 if (sk_acceptq_is_full(sk)) { 931 virtio_transport_reset(vsk, pkt); 932 return -ENOMEM; 933 } 934 935 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, 936 sk->sk_type, 0); 937 if (!child) { 938 virtio_transport_reset(vsk, pkt); 939 return -ENOMEM; 940 } 941 942 sk->sk_ack_backlog++; 943 944 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 945 946 child->sk_state = TCP_ESTABLISHED; 947 948 vchild = vsock_sk(child); 949 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), 950 le32_to_cpu(pkt->hdr.dst_port)); 951 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), 952 le32_to_cpu(pkt->hdr.src_port)); 953 954 vsock_insert_connected(vchild); 955 vsock_enqueue_accept(sk, child); 956 virtio_transport_send_response(vchild, pkt); 957 958 release_sock(child); 959 960 sk->sk_data_ready(sk); 961 return 0; 962 } 963 964 static bool virtio_transport_space_update(struct sock *sk, 965 struct virtio_vsock_pkt *pkt) 966 { 967 struct vsock_sock *vsk = vsock_sk(sk); 968 struct virtio_vsock_sock *vvs = vsk->trans; 969 bool space_available; 970 971 /* buf_alloc and fwd_cnt is always included in the hdr */ 972 spin_lock_bh(&vvs->tx_lock); 973 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); 974 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); 975 space_available = virtio_transport_has_space(vsk); 976 spin_unlock_bh(&vvs->tx_lock); 977 return space_available; 978 } 979 980 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 981 * lock. 982 */ 983 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) 984 { 985 struct sockaddr_vm src, dst; 986 struct vsock_sock *vsk; 987 struct sock *sk; 988 bool space_available; 989 990 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), 991 le32_to_cpu(pkt->hdr.src_port)); 992 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), 993 le32_to_cpu(pkt->hdr.dst_port)); 994 995 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 996 dst.svm_cid, dst.svm_port, 997 le32_to_cpu(pkt->hdr.len), 998 le16_to_cpu(pkt->hdr.type), 999 le16_to_cpu(pkt->hdr.op), 1000 le32_to_cpu(pkt->hdr.flags), 1001 le32_to_cpu(pkt->hdr.buf_alloc), 1002 le32_to_cpu(pkt->hdr.fwd_cnt)); 1003 1004 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { 1005 (void)virtio_transport_reset_no_sock(pkt); 1006 goto free_pkt; 1007 } 1008 1009 /* The socket must be in connected or bound table 1010 * otherwise send reset back 1011 */ 1012 sk = vsock_find_connected_socket(&src, &dst); 1013 if (!sk) { 1014 sk = vsock_find_bound_socket(&dst); 1015 if (!sk) { 1016 (void)virtio_transport_reset_no_sock(pkt); 1017 goto free_pkt; 1018 } 1019 } 1020 1021 vsk = vsock_sk(sk); 1022 1023 space_available = virtio_transport_space_update(sk, pkt); 1024 1025 lock_sock(sk); 1026 1027 /* Update CID in case it has changed after a transport reset event */ 1028 vsk->local_addr.svm_cid = dst.svm_cid; 1029 1030 if (space_available) 1031 sk->sk_write_space(sk); 1032 1033 switch (sk->sk_state) { 1034 case TCP_LISTEN: 1035 virtio_transport_recv_listen(sk, pkt); 1036 virtio_transport_free_pkt(pkt); 1037 break; 1038 case TCP_SYN_SENT: 1039 virtio_transport_recv_connecting(sk, pkt); 1040 virtio_transport_free_pkt(pkt); 1041 break; 1042 case TCP_ESTABLISHED: 1043 virtio_transport_recv_connected(sk, pkt); 1044 break; 1045 case TCP_CLOSING: 1046 virtio_transport_recv_disconnecting(sk, pkt); 1047 virtio_transport_free_pkt(pkt); 1048 break; 1049 default: 1050 virtio_transport_free_pkt(pkt); 1051 break; 1052 } 1053 release_sock(sk); 1054 1055 /* Release refcnt obtained when we fetched this socket out of the 1056 * bound or connected list. 1057 */ 1058 sock_put(sk); 1059 return; 1060 1061 free_pkt: 1062 virtio_transport_free_pkt(pkt); 1063 } 1064 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 1065 1066 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) 1067 { 1068 kfree(pkt->buf); 1069 kfree(pkt); 1070 } 1071 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); 1072 1073 MODULE_LICENSE("GPL v2"); 1074 MODULE_AUTHOR("Asias He"); 1075 MODULE_DESCRIPTION("common code for virtio vsock"); 1076