1 /* 2 * common code for virtio vsock 3 * 4 * Copyright (C) 2013-2015 Red Hat, Inc. 5 * Author: Asias He <asias@redhat.com> 6 * Stefan Hajnoczi <stefanha@redhat.com> 7 * 8 * This work is licensed under the terms of the GNU GPL, version 2. 9 */ 10 #include <linux/spinlock.h> 11 #include <linux/module.h> 12 #include <linux/sched/signal.h> 13 #include <linux/ctype.h> 14 #include <linux/list.h> 15 #include <linux/virtio.h> 16 #include <linux/virtio_ids.h> 17 #include <linux/virtio_config.h> 18 #include <linux/virtio_vsock.h> 19 #include <uapi/linux/vsockmon.h> 20 21 #include <net/sock.h> 22 #include <net/af_vsock.h> 23 24 #define CREATE_TRACE_POINTS 25 #include <trace/events/vsock_virtio_transport_common.h> 26 27 /* How long to wait for graceful shutdown of a connection */ 28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 29 30 static const struct virtio_transport *virtio_transport_get_ops(void) 31 { 32 const struct vsock_transport *t = vsock_core_get_transport(); 33 34 return container_of(t, struct virtio_transport, transport); 35 } 36 37 static struct virtio_vsock_pkt * 38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, 39 size_t len, 40 u32 src_cid, 41 u32 src_port, 42 u32 dst_cid, 43 u32 dst_port) 44 { 45 struct virtio_vsock_pkt *pkt; 46 int err; 47 48 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 49 if (!pkt) 50 return NULL; 51 52 pkt->hdr.type = cpu_to_le16(info->type); 53 pkt->hdr.op = cpu_to_le16(info->op); 54 pkt->hdr.src_cid = cpu_to_le64(src_cid); 55 pkt->hdr.dst_cid = cpu_to_le64(dst_cid); 56 pkt->hdr.src_port = cpu_to_le32(src_port); 57 pkt->hdr.dst_port = cpu_to_le32(dst_port); 58 pkt->hdr.flags = cpu_to_le32(info->flags); 59 pkt->len = len; 60 pkt->hdr.len = cpu_to_le32(len); 61 pkt->reply = info->reply; 62 pkt->vsk = info->vsk; 63 64 if (info->msg && len > 0) { 65 pkt->buf = kmalloc(len, GFP_KERNEL); 66 if (!pkt->buf) 67 goto out_pkt; 68 err = memcpy_from_msg(pkt->buf, info->msg, len); 69 if (err) 70 goto out; 71 } 72 73 trace_virtio_transport_alloc_pkt(src_cid, src_port, 74 dst_cid, dst_port, 75 len, 76 info->type, 77 info->op, 78 info->flags); 79 80 return pkt; 81 82 out: 83 kfree(pkt->buf); 84 out_pkt: 85 kfree(pkt); 86 return NULL; 87 } 88 89 /* Packet capture */ 90 static struct sk_buff *virtio_transport_build_skb(void *opaque) 91 { 92 struct virtio_vsock_pkt *pkt = opaque; 93 unsigned char *t_hdr, *payload; 94 struct af_vsockmon_hdr *hdr; 95 struct sk_buff *skb; 96 97 skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len, 98 GFP_ATOMIC); 99 if (!skb) 100 return NULL; 101 102 hdr = (struct af_vsockmon_hdr *)skb_put(skb, sizeof(*hdr)); 103 104 /* pkt->hdr is little-endian so no need to byteswap here */ 105 hdr->src_cid = pkt->hdr.src_cid; 106 hdr->src_port = pkt->hdr.src_port; 107 hdr->dst_cid = pkt->hdr.dst_cid; 108 hdr->dst_port = pkt->hdr.dst_port; 109 110 hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO); 111 hdr->len = cpu_to_le16(sizeof(pkt->hdr)); 112 memset(hdr->reserved, 0, sizeof(hdr->reserved)); 113 114 switch (le16_to_cpu(pkt->hdr.op)) { 115 case VIRTIO_VSOCK_OP_REQUEST: 116 case VIRTIO_VSOCK_OP_RESPONSE: 117 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); 118 break; 119 case VIRTIO_VSOCK_OP_RST: 120 case VIRTIO_VSOCK_OP_SHUTDOWN: 121 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT); 122 break; 123 case VIRTIO_VSOCK_OP_RW: 124 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD); 125 break; 126 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 127 case VIRTIO_VSOCK_OP_CREDIT_REQUEST: 128 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); 129 break; 130 default: 131 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN); 132 break; 133 } 134 135 t_hdr = skb_put(skb, sizeof(pkt->hdr)); 136 memcpy(t_hdr, &pkt->hdr, sizeof(pkt->hdr)); 137 138 if (pkt->len) { 139 payload = skb_put(skb, pkt->len); 140 memcpy(payload, pkt->buf, pkt->len); 141 } 142 143 return skb; 144 } 145 146 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt) 147 { 148 vsock_deliver_tap(virtio_transport_build_skb, pkt); 149 } 150 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); 151 152 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 153 struct virtio_vsock_pkt_info *info) 154 { 155 u32 src_cid, src_port, dst_cid, dst_port; 156 struct virtio_vsock_sock *vvs; 157 struct virtio_vsock_pkt *pkt; 158 u32 pkt_len = info->pkt_len; 159 160 src_cid = vm_sockets_get_local_cid(); 161 src_port = vsk->local_addr.svm_port; 162 if (!info->remote_cid) { 163 dst_cid = vsk->remote_addr.svm_cid; 164 dst_port = vsk->remote_addr.svm_port; 165 } else { 166 dst_cid = info->remote_cid; 167 dst_port = info->remote_port; 168 } 169 170 vvs = vsk->trans; 171 172 /* we can send less than pkt_len bytes */ 173 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) 174 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; 175 176 /* virtio_transport_get_credit might return less than pkt_len credit */ 177 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 178 179 /* Do not send zero length OP_RW pkt */ 180 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 181 return pkt_len; 182 183 pkt = virtio_transport_alloc_pkt(info, pkt_len, 184 src_cid, src_port, 185 dst_cid, dst_port); 186 if (!pkt) { 187 virtio_transport_put_credit(vvs, pkt_len); 188 return -ENOMEM; 189 } 190 191 virtio_transport_inc_tx_pkt(vvs, pkt); 192 193 return virtio_transport_get_ops()->send_pkt(pkt); 194 } 195 196 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 197 struct virtio_vsock_pkt *pkt) 198 { 199 vvs->rx_bytes += pkt->len; 200 } 201 202 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 203 struct virtio_vsock_pkt *pkt) 204 { 205 vvs->rx_bytes -= pkt->len; 206 vvs->fwd_cnt += pkt->len; 207 } 208 209 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) 210 { 211 spin_lock_bh(&vvs->tx_lock); 212 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 213 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); 214 spin_unlock_bh(&vvs->tx_lock); 215 } 216 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 217 218 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 219 { 220 u32 ret; 221 222 spin_lock_bh(&vvs->tx_lock); 223 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 224 if (ret > credit) 225 ret = credit; 226 vvs->tx_cnt += ret; 227 spin_unlock_bh(&vvs->tx_lock); 228 229 return ret; 230 } 231 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 232 233 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 234 { 235 spin_lock_bh(&vvs->tx_lock); 236 vvs->tx_cnt -= credit; 237 spin_unlock_bh(&vvs->tx_lock); 238 } 239 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 240 241 static int virtio_transport_send_credit_update(struct vsock_sock *vsk, 242 int type, 243 struct virtio_vsock_hdr *hdr) 244 { 245 struct virtio_vsock_pkt_info info = { 246 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 247 .type = type, 248 .vsk = vsk, 249 }; 250 251 return virtio_transport_send_pkt_info(vsk, &info); 252 } 253 254 static ssize_t 255 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 256 struct msghdr *msg, 257 size_t len) 258 { 259 struct virtio_vsock_sock *vvs = vsk->trans; 260 struct virtio_vsock_pkt *pkt; 261 size_t bytes, total = 0; 262 int err = -EFAULT; 263 264 spin_lock_bh(&vvs->rx_lock); 265 while (total < len && !list_empty(&vvs->rx_queue)) { 266 pkt = list_first_entry(&vvs->rx_queue, 267 struct virtio_vsock_pkt, list); 268 269 bytes = len - total; 270 if (bytes > pkt->len - pkt->off) 271 bytes = pkt->len - pkt->off; 272 273 /* sk_lock is held by caller so no one else can dequeue. 274 * Unlock rx_lock since memcpy_to_msg() may sleep. 275 */ 276 spin_unlock_bh(&vvs->rx_lock); 277 278 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); 279 if (err) 280 goto out; 281 282 spin_lock_bh(&vvs->rx_lock); 283 284 total += bytes; 285 pkt->off += bytes; 286 if (pkt->off == pkt->len) { 287 virtio_transport_dec_rx_pkt(vvs, pkt); 288 list_del(&pkt->list); 289 virtio_transport_free_pkt(pkt); 290 } 291 } 292 spin_unlock_bh(&vvs->rx_lock); 293 294 /* Send a credit pkt to peer */ 295 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, 296 NULL); 297 298 return total; 299 300 out: 301 if (total) 302 err = total; 303 return err; 304 } 305 306 ssize_t 307 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 308 struct msghdr *msg, 309 size_t len, int flags) 310 { 311 if (flags & MSG_PEEK) 312 return -EOPNOTSUPP; 313 314 return virtio_transport_stream_do_dequeue(vsk, msg, len); 315 } 316 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 317 318 int 319 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 320 struct msghdr *msg, 321 size_t len, int flags) 322 { 323 return -EOPNOTSUPP; 324 } 325 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 326 327 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 328 { 329 struct virtio_vsock_sock *vvs = vsk->trans; 330 s64 bytes; 331 332 spin_lock_bh(&vvs->rx_lock); 333 bytes = vvs->rx_bytes; 334 spin_unlock_bh(&vvs->rx_lock); 335 336 return bytes; 337 } 338 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 339 340 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 341 { 342 struct virtio_vsock_sock *vvs = vsk->trans; 343 s64 bytes; 344 345 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 346 if (bytes < 0) 347 bytes = 0; 348 349 return bytes; 350 } 351 352 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 353 { 354 struct virtio_vsock_sock *vvs = vsk->trans; 355 s64 bytes; 356 357 spin_lock_bh(&vvs->tx_lock); 358 bytes = virtio_transport_has_space(vsk); 359 spin_unlock_bh(&vvs->tx_lock); 360 361 return bytes; 362 } 363 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 364 365 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 366 struct vsock_sock *psk) 367 { 368 struct virtio_vsock_sock *vvs; 369 370 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 371 if (!vvs) 372 return -ENOMEM; 373 374 vsk->trans = vvs; 375 vvs->vsk = vsk; 376 if (psk) { 377 struct virtio_vsock_sock *ptrans = psk->trans; 378 379 vvs->buf_size = ptrans->buf_size; 380 vvs->buf_size_min = ptrans->buf_size_min; 381 vvs->buf_size_max = ptrans->buf_size_max; 382 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 383 } else { 384 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; 385 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; 386 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; 387 } 388 389 vvs->buf_alloc = vvs->buf_size; 390 391 spin_lock_init(&vvs->rx_lock); 392 spin_lock_init(&vvs->tx_lock); 393 INIT_LIST_HEAD(&vvs->rx_queue); 394 395 return 0; 396 } 397 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 398 399 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) 400 { 401 struct virtio_vsock_sock *vvs = vsk->trans; 402 403 return vvs->buf_size; 404 } 405 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); 406 407 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) 408 { 409 struct virtio_vsock_sock *vvs = vsk->trans; 410 411 return vvs->buf_size_min; 412 } 413 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); 414 415 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) 416 { 417 struct virtio_vsock_sock *vvs = vsk->trans; 418 419 return vvs->buf_size_max; 420 } 421 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); 422 423 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) 424 { 425 struct virtio_vsock_sock *vvs = vsk->trans; 426 427 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 428 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 429 if (val < vvs->buf_size_min) 430 vvs->buf_size_min = val; 431 if (val > vvs->buf_size_max) 432 vvs->buf_size_max = val; 433 vvs->buf_size = val; 434 vvs->buf_alloc = val; 435 } 436 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); 437 438 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) 439 { 440 struct virtio_vsock_sock *vvs = vsk->trans; 441 442 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 443 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 444 if (val > vvs->buf_size) 445 vvs->buf_size = val; 446 vvs->buf_size_min = val; 447 } 448 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); 449 450 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) 451 { 452 struct virtio_vsock_sock *vvs = vsk->trans; 453 454 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 455 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 456 if (val < vvs->buf_size) 457 vvs->buf_size = val; 458 vvs->buf_size_max = val; 459 } 460 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); 461 462 int 463 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 464 size_t target, 465 bool *data_ready_now) 466 { 467 if (vsock_stream_has_data(vsk)) 468 *data_ready_now = true; 469 else 470 *data_ready_now = false; 471 472 return 0; 473 } 474 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 475 476 int 477 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 478 size_t target, 479 bool *space_avail_now) 480 { 481 s64 free_space; 482 483 free_space = vsock_stream_has_space(vsk); 484 if (free_space > 0) 485 *space_avail_now = true; 486 else if (free_space == 0) 487 *space_avail_now = false; 488 489 return 0; 490 } 491 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 492 493 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 494 size_t target, struct vsock_transport_recv_notify_data *data) 495 { 496 return 0; 497 } 498 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 499 500 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 501 size_t target, struct vsock_transport_recv_notify_data *data) 502 { 503 return 0; 504 } 505 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 506 507 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 508 size_t target, struct vsock_transport_recv_notify_data *data) 509 { 510 return 0; 511 } 512 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 513 514 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 515 size_t target, ssize_t copied, bool data_read, 516 struct vsock_transport_recv_notify_data *data) 517 { 518 return 0; 519 } 520 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 521 522 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 523 struct vsock_transport_send_notify_data *data) 524 { 525 return 0; 526 } 527 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 528 529 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 530 struct vsock_transport_send_notify_data *data) 531 { 532 return 0; 533 } 534 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 535 536 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 537 struct vsock_transport_send_notify_data *data) 538 { 539 return 0; 540 } 541 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 542 543 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 544 ssize_t written, struct vsock_transport_send_notify_data *data) 545 { 546 return 0; 547 } 548 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 549 550 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 551 { 552 struct virtio_vsock_sock *vvs = vsk->trans; 553 554 return vvs->buf_size; 555 } 556 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 557 558 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 559 { 560 return true; 561 } 562 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 563 564 bool virtio_transport_stream_allow(u32 cid, u32 port) 565 { 566 return true; 567 } 568 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 569 570 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 571 struct sockaddr_vm *addr) 572 { 573 return -EOPNOTSUPP; 574 } 575 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 576 577 bool virtio_transport_dgram_allow(u32 cid, u32 port) 578 { 579 return false; 580 } 581 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 582 583 int virtio_transport_connect(struct vsock_sock *vsk) 584 { 585 struct virtio_vsock_pkt_info info = { 586 .op = VIRTIO_VSOCK_OP_REQUEST, 587 .type = VIRTIO_VSOCK_TYPE_STREAM, 588 .vsk = vsk, 589 }; 590 591 return virtio_transport_send_pkt_info(vsk, &info); 592 } 593 EXPORT_SYMBOL_GPL(virtio_transport_connect); 594 595 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 596 { 597 struct virtio_vsock_pkt_info info = { 598 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 599 .type = VIRTIO_VSOCK_TYPE_STREAM, 600 .flags = (mode & RCV_SHUTDOWN ? 601 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 602 (mode & SEND_SHUTDOWN ? 603 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 604 .vsk = vsk, 605 }; 606 607 return virtio_transport_send_pkt_info(vsk, &info); 608 } 609 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 610 611 int 612 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 613 struct sockaddr_vm *remote_addr, 614 struct msghdr *msg, 615 size_t dgram_len) 616 { 617 return -EOPNOTSUPP; 618 } 619 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 620 621 ssize_t 622 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 623 struct msghdr *msg, 624 size_t len) 625 { 626 struct virtio_vsock_pkt_info info = { 627 .op = VIRTIO_VSOCK_OP_RW, 628 .type = VIRTIO_VSOCK_TYPE_STREAM, 629 .msg = msg, 630 .pkt_len = len, 631 .vsk = vsk, 632 }; 633 634 return virtio_transport_send_pkt_info(vsk, &info); 635 } 636 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 637 638 void virtio_transport_destruct(struct vsock_sock *vsk) 639 { 640 struct virtio_vsock_sock *vvs = vsk->trans; 641 642 kfree(vvs); 643 } 644 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 645 646 static int virtio_transport_reset(struct vsock_sock *vsk, 647 struct virtio_vsock_pkt *pkt) 648 { 649 struct virtio_vsock_pkt_info info = { 650 .op = VIRTIO_VSOCK_OP_RST, 651 .type = VIRTIO_VSOCK_TYPE_STREAM, 652 .reply = !!pkt, 653 .vsk = vsk, 654 }; 655 656 /* Send RST only if the original pkt is not a RST pkt */ 657 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 658 return 0; 659 660 return virtio_transport_send_pkt_info(vsk, &info); 661 } 662 663 /* Normally packets are associated with a socket. There may be no socket if an 664 * attempt was made to connect to a socket that does not exist. 665 */ 666 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) 667 { 668 struct virtio_vsock_pkt_info info = { 669 .op = VIRTIO_VSOCK_OP_RST, 670 .type = le16_to_cpu(pkt->hdr.type), 671 .reply = true, 672 }; 673 674 /* Send RST only if the original pkt is not a RST pkt */ 675 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 676 return 0; 677 678 pkt = virtio_transport_alloc_pkt(&info, 0, 679 le64_to_cpu(pkt->hdr.dst_cid), 680 le32_to_cpu(pkt->hdr.dst_port), 681 le64_to_cpu(pkt->hdr.src_cid), 682 le32_to_cpu(pkt->hdr.src_port)); 683 if (!pkt) 684 return -ENOMEM; 685 686 return virtio_transport_get_ops()->send_pkt(pkt); 687 } 688 689 static void virtio_transport_wait_close(struct sock *sk, long timeout) 690 { 691 if (timeout) { 692 DEFINE_WAIT_FUNC(wait, woken_wake_function); 693 694 add_wait_queue(sk_sleep(sk), &wait); 695 696 do { 697 if (sk_wait_event(sk, &timeout, 698 sock_flag(sk, SOCK_DONE), &wait)) 699 break; 700 } while (!signal_pending(current) && timeout); 701 702 remove_wait_queue(sk_sleep(sk), &wait); 703 } 704 } 705 706 static void virtio_transport_do_close(struct vsock_sock *vsk, 707 bool cancel_timeout) 708 { 709 struct sock *sk = sk_vsock(vsk); 710 711 sock_set_flag(sk, SOCK_DONE); 712 vsk->peer_shutdown = SHUTDOWN_MASK; 713 if (vsock_stream_has_data(vsk) <= 0) 714 sk->sk_state = SS_DISCONNECTING; 715 sk->sk_state_change(sk); 716 717 if (vsk->close_work_scheduled && 718 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 719 vsk->close_work_scheduled = false; 720 721 vsock_remove_sock(vsk); 722 723 /* Release refcnt obtained when we scheduled the timeout */ 724 sock_put(sk); 725 } 726 } 727 728 static void virtio_transport_close_timeout(struct work_struct *work) 729 { 730 struct vsock_sock *vsk = 731 container_of(work, struct vsock_sock, close_work.work); 732 struct sock *sk = sk_vsock(vsk); 733 734 sock_hold(sk); 735 lock_sock(sk); 736 737 if (!sock_flag(sk, SOCK_DONE)) { 738 (void)virtio_transport_reset(vsk, NULL); 739 740 virtio_transport_do_close(vsk, false); 741 } 742 743 vsk->close_work_scheduled = false; 744 745 release_sock(sk); 746 sock_put(sk); 747 } 748 749 /* User context, vsk->sk is locked */ 750 static bool virtio_transport_close(struct vsock_sock *vsk) 751 { 752 struct sock *sk = &vsk->sk; 753 754 if (!(sk->sk_state == SS_CONNECTED || 755 sk->sk_state == SS_DISCONNECTING)) 756 return true; 757 758 /* Already received SHUTDOWN from peer, reply with RST */ 759 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 760 (void)virtio_transport_reset(vsk, NULL); 761 return true; 762 } 763 764 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 765 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 766 767 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 768 virtio_transport_wait_close(sk, sk->sk_lingertime); 769 770 if (sock_flag(sk, SOCK_DONE)) { 771 return true; 772 } 773 774 sock_hold(sk); 775 INIT_DELAYED_WORK(&vsk->close_work, 776 virtio_transport_close_timeout); 777 vsk->close_work_scheduled = true; 778 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 779 return false; 780 } 781 782 void virtio_transport_release(struct vsock_sock *vsk) 783 { 784 struct sock *sk = &vsk->sk; 785 bool remove_sock = true; 786 787 lock_sock(sk); 788 if (sk->sk_type == SOCK_STREAM) 789 remove_sock = virtio_transport_close(vsk); 790 release_sock(sk); 791 792 if (remove_sock) 793 vsock_remove_sock(vsk); 794 } 795 EXPORT_SYMBOL_GPL(virtio_transport_release); 796 797 static int 798 virtio_transport_recv_connecting(struct sock *sk, 799 struct virtio_vsock_pkt *pkt) 800 { 801 struct vsock_sock *vsk = vsock_sk(sk); 802 int err; 803 int skerr; 804 805 switch (le16_to_cpu(pkt->hdr.op)) { 806 case VIRTIO_VSOCK_OP_RESPONSE: 807 sk->sk_state = SS_CONNECTED; 808 sk->sk_socket->state = SS_CONNECTED; 809 vsock_insert_connected(vsk); 810 sk->sk_state_change(sk); 811 break; 812 case VIRTIO_VSOCK_OP_INVALID: 813 break; 814 case VIRTIO_VSOCK_OP_RST: 815 skerr = ECONNRESET; 816 err = 0; 817 goto destroy; 818 default: 819 skerr = EPROTO; 820 err = -EINVAL; 821 goto destroy; 822 } 823 return 0; 824 825 destroy: 826 virtio_transport_reset(vsk, pkt); 827 sk->sk_state = SS_UNCONNECTED; 828 sk->sk_err = skerr; 829 sk->sk_error_report(sk); 830 return err; 831 } 832 833 static int 834 virtio_transport_recv_connected(struct sock *sk, 835 struct virtio_vsock_pkt *pkt) 836 { 837 struct vsock_sock *vsk = vsock_sk(sk); 838 struct virtio_vsock_sock *vvs = vsk->trans; 839 int err = 0; 840 841 switch (le16_to_cpu(pkt->hdr.op)) { 842 case VIRTIO_VSOCK_OP_RW: 843 pkt->len = le32_to_cpu(pkt->hdr.len); 844 pkt->off = 0; 845 846 spin_lock_bh(&vvs->rx_lock); 847 virtio_transport_inc_rx_pkt(vvs, pkt); 848 list_add_tail(&pkt->list, &vvs->rx_queue); 849 spin_unlock_bh(&vvs->rx_lock); 850 851 sk->sk_data_ready(sk); 852 return err; 853 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 854 sk->sk_write_space(sk); 855 break; 856 case VIRTIO_VSOCK_OP_SHUTDOWN: 857 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 858 vsk->peer_shutdown |= RCV_SHUTDOWN; 859 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 860 vsk->peer_shutdown |= SEND_SHUTDOWN; 861 if (vsk->peer_shutdown == SHUTDOWN_MASK && 862 vsock_stream_has_data(vsk) <= 0) 863 sk->sk_state = SS_DISCONNECTING; 864 if (le32_to_cpu(pkt->hdr.flags)) 865 sk->sk_state_change(sk); 866 break; 867 case VIRTIO_VSOCK_OP_RST: 868 virtio_transport_do_close(vsk, true); 869 break; 870 default: 871 err = -EINVAL; 872 break; 873 } 874 875 virtio_transport_free_pkt(pkt); 876 return err; 877 } 878 879 static void 880 virtio_transport_recv_disconnecting(struct sock *sk, 881 struct virtio_vsock_pkt *pkt) 882 { 883 struct vsock_sock *vsk = vsock_sk(sk); 884 885 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 886 virtio_transport_do_close(vsk, true); 887 } 888 889 static int 890 virtio_transport_send_response(struct vsock_sock *vsk, 891 struct virtio_vsock_pkt *pkt) 892 { 893 struct virtio_vsock_pkt_info info = { 894 .op = VIRTIO_VSOCK_OP_RESPONSE, 895 .type = VIRTIO_VSOCK_TYPE_STREAM, 896 .remote_cid = le64_to_cpu(pkt->hdr.src_cid), 897 .remote_port = le32_to_cpu(pkt->hdr.src_port), 898 .reply = true, 899 .vsk = vsk, 900 }; 901 902 return virtio_transport_send_pkt_info(vsk, &info); 903 } 904 905 /* Handle server socket */ 906 static int 907 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) 908 { 909 struct vsock_sock *vsk = vsock_sk(sk); 910 struct vsock_sock *vchild; 911 struct sock *child; 912 913 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { 914 virtio_transport_reset(vsk, pkt); 915 return -EINVAL; 916 } 917 918 if (sk_acceptq_is_full(sk)) { 919 virtio_transport_reset(vsk, pkt); 920 return -ENOMEM; 921 } 922 923 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, 924 sk->sk_type, 0); 925 if (!child) { 926 virtio_transport_reset(vsk, pkt); 927 return -ENOMEM; 928 } 929 930 sk->sk_ack_backlog++; 931 932 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 933 934 child->sk_state = SS_CONNECTED; 935 936 vchild = vsock_sk(child); 937 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), 938 le32_to_cpu(pkt->hdr.dst_port)); 939 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), 940 le32_to_cpu(pkt->hdr.src_port)); 941 942 vsock_insert_connected(vchild); 943 vsock_enqueue_accept(sk, child); 944 virtio_transport_send_response(vchild, pkt); 945 946 release_sock(child); 947 948 sk->sk_data_ready(sk); 949 return 0; 950 } 951 952 static bool virtio_transport_space_update(struct sock *sk, 953 struct virtio_vsock_pkt *pkt) 954 { 955 struct vsock_sock *vsk = vsock_sk(sk); 956 struct virtio_vsock_sock *vvs = vsk->trans; 957 bool space_available; 958 959 /* buf_alloc and fwd_cnt is always included in the hdr */ 960 spin_lock_bh(&vvs->tx_lock); 961 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); 962 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); 963 space_available = virtio_transport_has_space(vsk); 964 spin_unlock_bh(&vvs->tx_lock); 965 return space_available; 966 } 967 968 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 969 * lock. 970 */ 971 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) 972 { 973 struct sockaddr_vm src, dst; 974 struct vsock_sock *vsk; 975 struct sock *sk; 976 bool space_available; 977 978 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), 979 le32_to_cpu(pkt->hdr.src_port)); 980 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), 981 le32_to_cpu(pkt->hdr.dst_port)); 982 983 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 984 dst.svm_cid, dst.svm_port, 985 le32_to_cpu(pkt->hdr.len), 986 le16_to_cpu(pkt->hdr.type), 987 le16_to_cpu(pkt->hdr.op), 988 le32_to_cpu(pkt->hdr.flags), 989 le32_to_cpu(pkt->hdr.buf_alloc), 990 le32_to_cpu(pkt->hdr.fwd_cnt)); 991 992 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { 993 (void)virtio_transport_reset_no_sock(pkt); 994 goto free_pkt; 995 } 996 997 /* The socket must be in connected or bound table 998 * otherwise send reset back 999 */ 1000 sk = vsock_find_connected_socket(&src, &dst); 1001 if (!sk) { 1002 sk = vsock_find_bound_socket(&dst); 1003 if (!sk) { 1004 (void)virtio_transport_reset_no_sock(pkt); 1005 goto free_pkt; 1006 } 1007 } 1008 1009 vsk = vsock_sk(sk); 1010 1011 space_available = virtio_transport_space_update(sk, pkt); 1012 1013 lock_sock(sk); 1014 1015 /* Update CID in case it has changed after a transport reset event */ 1016 vsk->local_addr.svm_cid = dst.svm_cid; 1017 1018 if (space_available) 1019 sk->sk_write_space(sk); 1020 1021 switch (sk->sk_state) { 1022 case VSOCK_SS_LISTEN: 1023 virtio_transport_recv_listen(sk, pkt); 1024 virtio_transport_free_pkt(pkt); 1025 break; 1026 case SS_CONNECTING: 1027 virtio_transport_recv_connecting(sk, pkt); 1028 virtio_transport_free_pkt(pkt); 1029 break; 1030 case SS_CONNECTED: 1031 virtio_transport_recv_connected(sk, pkt); 1032 break; 1033 case SS_DISCONNECTING: 1034 virtio_transport_recv_disconnecting(sk, pkt); 1035 virtio_transport_free_pkt(pkt); 1036 break; 1037 default: 1038 virtio_transport_free_pkt(pkt); 1039 break; 1040 } 1041 release_sock(sk); 1042 1043 /* Release refcnt obtained when we fetched this socket out of the 1044 * bound or connected list. 1045 */ 1046 sock_put(sk); 1047 return; 1048 1049 free_pkt: 1050 virtio_transport_free_pkt(pkt); 1051 } 1052 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 1053 1054 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) 1055 { 1056 kfree(pkt->buf); 1057 kfree(pkt); 1058 } 1059 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); 1060 1061 MODULE_LICENSE("GPL v2"); 1062 MODULE_AUTHOR("Asias He"); 1063 MODULE_DESCRIPTION("common code for virtio vsock"); 1064