1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * (c) 2017 Stefano Stabellini <stefano@aporeto.com> 4 */ 5 6 #include <linux/module.h> 7 #include <linux/net.h> 8 #include <linux/socket.h> 9 10 #include <net/sock.h> 11 12 #include <xen/events.h> 13 #include <xen/grant_table.h> 14 #include <xen/xen.h> 15 #include <xen/xenbus.h> 16 #include <xen/interface/io/pvcalls.h> 17 18 #include "pvcalls-front.h" 19 20 #define PVCALLS_INVALID_ID UINT_MAX 21 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER 22 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE) 23 #define PVCALLS_FRONT_MAX_SPIN 5000 24 25 static struct proto pvcalls_proto = { 26 .name = "PVCalls", 27 .owner = THIS_MODULE, 28 .obj_size = sizeof(struct sock), 29 }; 30 31 struct pvcalls_bedata { 32 struct xen_pvcalls_front_ring ring; 33 grant_ref_t ref; 34 int irq; 35 36 struct list_head socket_mappings; 37 spinlock_t socket_lock; 38 39 wait_queue_head_t inflight_req; 40 struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING]; 41 }; 42 /* Only one front/back connection supported. */ 43 static struct xenbus_device *pvcalls_front_dev; 44 static atomic_t pvcalls_refcount; 45 46 /* first increment refcount, then proceed */ 47 #define pvcalls_enter() { \ 48 atomic_inc(&pvcalls_refcount); \ 49 } 50 51 /* first complete other operations, then decrement refcount */ 52 #define pvcalls_exit() { \ 53 atomic_dec(&pvcalls_refcount); \ 54 } 55 56 struct sock_mapping { 57 bool active_socket; 58 struct list_head list; 59 struct socket *sock; 60 atomic_t refcount; 61 union { 62 struct { 63 int irq; 64 grant_ref_t ref; 65 struct pvcalls_data_intf *ring; 66 struct pvcalls_data data; 67 struct mutex in_mutex; 68 struct mutex out_mutex; 69 70 wait_queue_head_t inflight_conn_req; 71 } active; 72 struct { 73 /* 74 * Socket status, needs to be 64-bit aligned due to the 75 * test_and_* functions which have this requirement on arm64. 76 */ 77 #define PVCALLS_STATUS_UNINITALIZED 0 78 #define PVCALLS_STATUS_BIND 1 79 #define PVCALLS_STATUS_LISTEN 2 80 uint8_t status __attribute__((aligned(8))); 81 /* 82 * Internal state-machine flags. 83 * Only one accept operation can be inflight for a socket. 84 * Only one poll operation can be inflight for a given socket. 85 * flags needs to be 64-bit aligned due to the test_and_* 86 * functions which have this requirement on arm64. 87 */ 88 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0 89 #define PVCALLS_FLAG_POLL_INFLIGHT 1 90 #define PVCALLS_FLAG_POLL_RET 2 91 uint8_t flags __attribute__((aligned(8))); 92 uint32_t inflight_req_id; 93 struct sock_mapping *accept_map; 94 wait_queue_head_t inflight_accept_req; 95 } passive; 96 }; 97 }; 98 99 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock) 100 { 101 struct sock_mapping *map; 102 103 if (!pvcalls_front_dev || 104 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL) 105 return ERR_PTR(-ENOTCONN); 106 107 map = (struct sock_mapping *)sock->sk->sk_send_head; 108 if (map == NULL) 109 return ERR_PTR(-ENOTSOCK); 110 111 pvcalls_enter(); 112 atomic_inc(&map->refcount); 113 return map; 114 } 115 116 static inline void pvcalls_exit_sock(struct socket *sock) 117 { 118 struct sock_mapping *map; 119 120 map = (struct sock_mapping *)sock->sk->sk_send_head; 121 atomic_dec(&map->refcount); 122 pvcalls_exit(); 123 } 124 125 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 126 { 127 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 128 if (RING_FULL(&bedata->ring) || 129 bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID) 130 return -EAGAIN; 131 return 0; 132 } 133 134 static bool pvcalls_front_write_todo(struct sock_mapping *map) 135 { 136 struct pvcalls_data_intf *intf = map->active.ring; 137 RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 138 int32_t error; 139 140 error = intf->out_error; 141 if (error == -ENOTCONN) 142 return false; 143 if (error != 0) 144 return true; 145 146 cons = intf->out_cons; 147 prod = intf->out_prod; 148 return !!(size - pvcalls_queued(prod, cons, size)); 149 } 150 151 static bool pvcalls_front_read_todo(struct sock_mapping *map) 152 { 153 struct pvcalls_data_intf *intf = map->active.ring; 154 RING_IDX cons, prod; 155 int32_t error; 156 157 cons = intf->in_cons; 158 prod = intf->in_prod; 159 error = intf->in_error; 160 return (error != 0 || 161 pvcalls_queued(prod, cons, 162 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0); 163 } 164 165 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id) 166 { 167 struct xenbus_device *dev = dev_id; 168 struct pvcalls_bedata *bedata; 169 struct xen_pvcalls_response *rsp; 170 uint8_t *src, *dst; 171 int req_id = 0, more = 0, done = 0; 172 173 if (dev == NULL) 174 return IRQ_HANDLED; 175 176 pvcalls_enter(); 177 bedata = dev_get_drvdata(&dev->dev); 178 if (bedata == NULL) { 179 pvcalls_exit(); 180 return IRQ_HANDLED; 181 } 182 183 again: 184 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) { 185 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons); 186 187 req_id = rsp->req_id; 188 if (rsp->cmd == PVCALLS_POLL) { 189 struct sock_mapping *map = (struct sock_mapping *)(uintptr_t) 190 rsp->u.poll.id; 191 192 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT, 193 (void *)&map->passive.flags); 194 /* 195 * clear INFLIGHT, then set RET. It pairs with 196 * the checks at the beginning of 197 * pvcalls_front_poll_passive. 198 */ 199 smp_wmb(); 200 set_bit(PVCALLS_FLAG_POLL_RET, 201 (void *)&map->passive.flags); 202 } else { 203 dst = (uint8_t *)&bedata->rsp[req_id] + 204 sizeof(rsp->req_id); 205 src = (uint8_t *)rsp + sizeof(rsp->req_id); 206 memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id)); 207 /* 208 * First copy the rest of the data, then req_id. It is 209 * paired with the barrier when accessing bedata->rsp. 210 */ 211 smp_wmb(); 212 bedata->rsp[req_id].req_id = req_id; 213 } 214 215 done = 1; 216 bedata->ring.rsp_cons++; 217 } 218 219 RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more); 220 if (more) 221 goto again; 222 if (done) 223 wake_up(&bedata->inflight_req); 224 pvcalls_exit(); 225 return IRQ_HANDLED; 226 } 227 228 static void free_active_ring(struct sock_mapping *map); 229 230 static void pvcalls_front_destroy_active(struct pvcalls_bedata *bedata, 231 struct sock_mapping *map) 232 { 233 int i; 234 235 unbind_from_irqhandler(map->active.irq, map); 236 237 if (bedata) { 238 spin_lock(&bedata->socket_lock); 239 if (!list_empty(&map->list)) 240 list_del_init(&map->list); 241 spin_unlock(&bedata->socket_lock); 242 } 243 244 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 245 gnttab_end_foreign_access(map->active.ring->ref[i], NULL); 246 gnttab_end_foreign_access(map->active.ref, NULL); 247 free_active_ring(map); 248 } 249 250 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata, 251 struct sock_mapping *map) 252 { 253 pvcalls_front_destroy_active(bedata, map); 254 255 kfree(map); 256 } 257 258 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map) 259 { 260 struct sock_mapping *map = sock_map; 261 262 if (map == NULL) 263 return IRQ_HANDLED; 264 265 wake_up_interruptible(&map->active.inflight_conn_req); 266 267 return IRQ_HANDLED; 268 } 269 270 int pvcalls_front_socket(struct socket *sock) 271 { 272 struct pvcalls_bedata *bedata; 273 struct sock_mapping *map = NULL; 274 struct xen_pvcalls_request *req; 275 int notify, req_id, ret; 276 277 /* 278 * PVCalls only supports domain AF_INET, 279 * type SOCK_STREAM and protocol 0 sockets for now. 280 * 281 * Check socket type here, AF_INET and protocol checks are done 282 * by the caller. 283 */ 284 if (sock->type != SOCK_STREAM) 285 return -EOPNOTSUPP; 286 287 pvcalls_enter(); 288 if (!pvcalls_front_dev) { 289 pvcalls_exit(); 290 return -EACCES; 291 } 292 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 293 294 map = kzalloc(sizeof(*map), GFP_KERNEL); 295 if (map == NULL) { 296 pvcalls_exit(); 297 return -ENOMEM; 298 } 299 300 spin_lock(&bedata->socket_lock); 301 302 ret = get_request(bedata, &req_id); 303 if (ret < 0) { 304 kfree(map); 305 spin_unlock(&bedata->socket_lock); 306 pvcalls_exit(); 307 return ret; 308 } 309 310 /* 311 * sock->sk->sk_send_head is not used for ip sockets: reuse the 312 * field to store a pointer to the struct sock_mapping 313 * corresponding to the socket. This way, we can easily get the 314 * struct sock_mapping from the struct socket. 315 */ 316 sock->sk->sk_send_head = (void *)map; 317 list_add_tail(&map->list, &bedata->socket_mappings); 318 319 req = RING_GET_REQUEST(&bedata->ring, req_id); 320 req->req_id = req_id; 321 req->cmd = PVCALLS_SOCKET; 322 req->u.socket.id = (uintptr_t) map; 323 req->u.socket.domain = AF_INET; 324 req->u.socket.type = SOCK_STREAM; 325 req->u.socket.protocol = IPPROTO_IP; 326 327 bedata->ring.req_prod_pvt++; 328 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 329 spin_unlock(&bedata->socket_lock); 330 if (notify) 331 notify_remote_via_irq(bedata->irq); 332 333 wait_event(bedata->inflight_req, 334 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 335 336 /* read req_id, then the content */ 337 smp_rmb(); 338 ret = bedata->rsp[req_id].ret; 339 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 340 341 pvcalls_exit(); 342 return ret; 343 } 344 EXPORT_SYMBOL_GPL(pvcalls_front_socket); 345 346 static void free_active_ring(struct sock_mapping *map) 347 { 348 if (!map->active.ring) 349 return; 350 351 free_pages_exact(map->active.data.in, 352 PAGE_SIZE << map->active.ring->ring_order); 353 free_page((unsigned long)map->active.ring); 354 } 355 356 static int alloc_active_ring(struct sock_mapping *map) 357 { 358 void *bytes; 359 360 map->active.ring = (struct pvcalls_data_intf *) 361 get_zeroed_page(GFP_KERNEL); 362 if (!map->active.ring) 363 goto out; 364 365 map->active.ring->ring_order = PVCALLS_RING_ORDER; 366 bytes = alloc_pages_exact(PAGE_SIZE << PVCALLS_RING_ORDER, 367 GFP_KERNEL | __GFP_ZERO); 368 if (!bytes) 369 goto out; 370 371 map->active.data.in = bytes; 372 map->active.data.out = bytes + 373 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 374 375 return 0; 376 377 out: 378 free_active_ring(map); 379 return -ENOMEM; 380 } 381 382 static int create_active(struct sock_mapping *map, evtchn_port_t *evtchn) 383 { 384 void *bytes; 385 int ret, irq = -1, i; 386 387 *evtchn = 0; 388 init_waitqueue_head(&map->active.inflight_conn_req); 389 390 bytes = map->active.data.in; 391 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 392 map->active.ring->ref[i] = gnttab_grant_foreign_access( 393 pvcalls_front_dev->otherend_id, 394 pfn_to_gfn(virt_to_pfn(bytes) + i), 0); 395 396 map->active.ref = gnttab_grant_foreign_access( 397 pvcalls_front_dev->otherend_id, 398 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); 399 400 ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); 401 if (ret) 402 goto out_error; 403 irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler, 404 0, "pvcalls-frontend", map); 405 if (irq < 0) { 406 ret = irq; 407 goto out_error; 408 } 409 410 map->active.irq = irq; 411 map->active_socket = true; 412 mutex_init(&map->active.in_mutex); 413 mutex_init(&map->active.out_mutex); 414 415 return 0; 416 417 out_error: 418 if (*evtchn > 0) 419 xenbus_free_evtchn(pvcalls_front_dev, *evtchn); 420 return ret; 421 } 422 423 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, 424 int addr_len, int flags) 425 { 426 struct pvcalls_bedata *bedata; 427 struct sock_mapping *map = NULL; 428 struct xen_pvcalls_request *req; 429 int notify, req_id, ret; 430 evtchn_port_t evtchn; 431 432 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 433 return -EOPNOTSUPP; 434 435 map = pvcalls_enter_sock(sock); 436 if (IS_ERR(map)) 437 return PTR_ERR(map); 438 439 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 440 ret = alloc_active_ring(map); 441 if (ret < 0) { 442 pvcalls_exit_sock(sock); 443 return ret; 444 } 445 ret = create_active(map, &evtchn); 446 if (ret < 0) { 447 free_active_ring(map); 448 pvcalls_exit_sock(sock); 449 return ret; 450 } 451 452 spin_lock(&bedata->socket_lock); 453 ret = get_request(bedata, &req_id); 454 if (ret < 0) { 455 spin_unlock(&bedata->socket_lock); 456 pvcalls_front_destroy_active(NULL, map); 457 pvcalls_exit_sock(sock); 458 return ret; 459 } 460 461 req = RING_GET_REQUEST(&bedata->ring, req_id); 462 req->req_id = req_id; 463 req->cmd = PVCALLS_CONNECT; 464 req->u.connect.id = (uintptr_t)map; 465 req->u.connect.len = addr_len; 466 req->u.connect.flags = flags; 467 req->u.connect.ref = map->active.ref; 468 req->u.connect.evtchn = evtchn; 469 memcpy(req->u.connect.addr, addr, sizeof(*addr)); 470 471 map->sock = sock; 472 473 bedata->ring.req_prod_pvt++; 474 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 475 spin_unlock(&bedata->socket_lock); 476 477 if (notify) 478 notify_remote_via_irq(bedata->irq); 479 480 wait_event(bedata->inflight_req, 481 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 482 483 /* read req_id, then the content */ 484 smp_rmb(); 485 ret = bedata->rsp[req_id].ret; 486 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 487 pvcalls_exit_sock(sock); 488 return ret; 489 } 490 EXPORT_SYMBOL_GPL(pvcalls_front_connect); 491 492 static int __write_ring(struct pvcalls_data_intf *intf, 493 struct pvcalls_data *data, 494 struct iov_iter *msg_iter, 495 int len) 496 { 497 RING_IDX cons, prod, size, masked_prod, masked_cons; 498 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 499 int32_t error; 500 501 error = intf->out_error; 502 if (error < 0) 503 return error; 504 cons = intf->out_cons; 505 prod = intf->out_prod; 506 /* read indexes before continuing */ 507 virt_mb(); 508 509 size = pvcalls_queued(prod, cons, array_size); 510 if (size > array_size) 511 return -EINVAL; 512 if (size == array_size) 513 return 0; 514 if (len > array_size - size) 515 len = array_size - size; 516 517 masked_prod = pvcalls_mask(prod, array_size); 518 masked_cons = pvcalls_mask(cons, array_size); 519 520 if (masked_prod < masked_cons) { 521 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 522 } else { 523 if (len > array_size - masked_prod) { 524 int ret = copy_from_iter(data->out + masked_prod, 525 array_size - masked_prod, msg_iter); 526 if (ret != array_size - masked_prod) { 527 len = ret; 528 goto out; 529 } 530 len = ret + copy_from_iter(data->out, len - ret, msg_iter); 531 } else { 532 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 533 } 534 } 535 out: 536 /* write to ring before updating pointer */ 537 virt_wmb(); 538 intf->out_prod += len; 539 540 return len; 541 } 542 543 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg, 544 size_t len) 545 { 546 struct sock_mapping *map; 547 int sent, tot_sent = 0; 548 int count = 0, flags; 549 550 flags = msg->msg_flags; 551 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 552 return -EOPNOTSUPP; 553 554 map = pvcalls_enter_sock(sock); 555 if (IS_ERR(map)) 556 return PTR_ERR(map); 557 558 mutex_lock(&map->active.out_mutex); 559 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 560 mutex_unlock(&map->active.out_mutex); 561 pvcalls_exit_sock(sock); 562 return -EAGAIN; 563 } 564 if (len > INT_MAX) 565 len = INT_MAX; 566 567 again: 568 count++; 569 sent = __write_ring(map->active.ring, 570 &map->active.data, &msg->msg_iter, 571 len); 572 if (sent > 0) { 573 len -= sent; 574 tot_sent += sent; 575 notify_remote_via_irq(map->active.irq); 576 } 577 if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN) 578 goto again; 579 if (sent < 0) 580 tot_sent = sent; 581 582 mutex_unlock(&map->active.out_mutex); 583 pvcalls_exit_sock(sock); 584 return tot_sent; 585 } 586 EXPORT_SYMBOL_GPL(pvcalls_front_sendmsg); 587 588 static int __read_ring(struct pvcalls_data_intf *intf, 589 struct pvcalls_data *data, 590 struct iov_iter *msg_iter, 591 size_t len, int flags) 592 { 593 RING_IDX cons, prod, size, masked_prod, masked_cons; 594 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 595 int32_t error; 596 597 cons = intf->in_cons; 598 prod = intf->in_prod; 599 error = intf->in_error; 600 /* get pointers before reading from the ring */ 601 virt_rmb(); 602 603 size = pvcalls_queued(prod, cons, array_size); 604 masked_prod = pvcalls_mask(prod, array_size); 605 masked_cons = pvcalls_mask(cons, array_size); 606 607 if (size == 0) 608 return error ?: size; 609 610 if (len > size) 611 len = size; 612 613 if (masked_prod > masked_cons) { 614 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 615 } else { 616 if (len > (array_size - masked_cons)) { 617 int ret = copy_to_iter(data->in + masked_cons, 618 array_size - masked_cons, msg_iter); 619 if (ret != array_size - masked_cons) { 620 len = ret; 621 goto out; 622 } 623 len = ret + copy_to_iter(data->in, len - ret, msg_iter); 624 } else { 625 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 626 } 627 } 628 out: 629 /* read data from the ring before increasing the index */ 630 virt_mb(); 631 if (!(flags & MSG_PEEK)) 632 intf->in_cons += len; 633 634 return len; 635 } 636 637 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 638 int flags) 639 { 640 int ret; 641 struct sock_mapping *map; 642 643 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 644 return -EOPNOTSUPP; 645 646 map = pvcalls_enter_sock(sock); 647 if (IS_ERR(map)) 648 return PTR_ERR(map); 649 650 mutex_lock(&map->active.in_mutex); 651 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 652 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 653 654 while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) { 655 wait_event_interruptible(map->active.inflight_conn_req, 656 pvcalls_front_read_todo(map)); 657 } 658 ret = __read_ring(map->active.ring, &map->active.data, 659 &msg->msg_iter, len, flags); 660 661 if (ret > 0) 662 notify_remote_via_irq(map->active.irq); 663 if (ret == 0) 664 ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0; 665 if (ret == -ENOTCONN) 666 ret = 0; 667 668 mutex_unlock(&map->active.in_mutex); 669 pvcalls_exit_sock(sock); 670 return ret; 671 } 672 EXPORT_SYMBOL_GPL(pvcalls_front_recvmsg); 673 674 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len) 675 { 676 struct pvcalls_bedata *bedata; 677 struct sock_mapping *map = NULL; 678 struct xen_pvcalls_request *req; 679 int notify, req_id, ret; 680 681 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 682 return -EOPNOTSUPP; 683 684 map = pvcalls_enter_sock(sock); 685 if (IS_ERR(map)) 686 return PTR_ERR(map); 687 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 688 689 spin_lock(&bedata->socket_lock); 690 ret = get_request(bedata, &req_id); 691 if (ret < 0) { 692 spin_unlock(&bedata->socket_lock); 693 pvcalls_exit_sock(sock); 694 return ret; 695 } 696 req = RING_GET_REQUEST(&bedata->ring, req_id); 697 req->req_id = req_id; 698 map->sock = sock; 699 req->cmd = PVCALLS_BIND; 700 req->u.bind.id = (uintptr_t)map; 701 memcpy(req->u.bind.addr, addr, sizeof(*addr)); 702 req->u.bind.len = addr_len; 703 704 init_waitqueue_head(&map->passive.inflight_accept_req); 705 706 map->active_socket = false; 707 708 bedata->ring.req_prod_pvt++; 709 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 710 spin_unlock(&bedata->socket_lock); 711 if (notify) 712 notify_remote_via_irq(bedata->irq); 713 714 wait_event(bedata->inflight_req, 715 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 716 717 /* read req_id, then the content */ 718 smp_rmb(); 719 ret = bedata->rsp[req_id].ret; 720 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 721 722 map->passive.status = PVCALLS_STATUS_BIND; 723 pvcalls_exit_sock(sock); 724 return 0; 725 } 726 EXPORT_SYMBOL_GPL(pvcalls_front_bind); 727 728 int pvcalls_front_listen(struct socket *sock, int backlog) 729 { 730 struct pvcalls_bedata *bedata; 731 struct sock_mapping *map; 732 struct xen_pvcalls_request *req; 733 int notify, req_id, ret; 734 735 map = pvcalls_enter_sock(sock); 736 if (IS_ERR(map)) 737 return PTR_ERR(map); 738 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 739 740 if (map->passive.status != PVCALLS_STATUS_BIND) { 741 pvcalls_exit_sock(sock); 742 return -EOPNOTSUPP; 743 } 744 745 spin_lock(&bedata->socket_lock); 746 ret = get_request(bedata, &req_id); 747 if (ret < 0) { 748 spin_unlock(&bedata->socket_lock); 749 pvcalls_exit_sock(sock); 750 return ret; 751 } 752 req = RING_GET_REQUEST(&bedata->ring, req_id); 753 req->req_id = req_id; 754 req->cmd = PVCALLS_LISTEN; 755 req->u.listen.id = (uintptr_t) map; 756 req->u.listen.backlog = backlog; 757 758 bedata->ring.req_prod_pvt++; 759 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 760 spin_unlock(&bedata->socket_lock); 761 if (notify) 762 notify_remote_via_irq(bedata->irq); 763 764 wait_event(bedata->inflight_req, 765 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 766 767 /* read req_id, then the content */ 768 smp_rmb(); 769 ret = bedata->rsp[req_id].ret; 770 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 771 772 map->passive.status = PVCALLS_STATUS_LISTEN; 773 pvcalls_exit_sock(sock); 774 return ret; 775 } 776 EXPORT_SYMBOL_GPL(pvcalls_front_listen); 777 778 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, 779 struct proto_accept_arg *arg) 780 { 781 struct pvcalls_bedata *bedata; 782 struct sock_mapping *map; 783 struct sock_mapping *map2 = NULL; 784 struct xen_pvcalls_request *req; 785 int notify, req_id, ret, nonblock; 786 evtchn_port_t evtchn; 787 788 map = pvcalls_enter_sock(sock); 789 if (IS_ERR(map)) 790 return PTR_ERR(map); 791 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 792 793 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 794 pvcalls_exit_sock(sock); 795 return -EINVAL; 796 } 797 798 nonblock = arg->flags & SOCK_NONBLOCK; 799 /* 800 * Backend only supports 1 inflight accept request, will return 801 * errors for the others 802 */ 803 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 804 (void *)&map->passive.flags)) { 805 req_id = READ_ONCE(map->passive.inflight_req_id); 806 if (req_id != PVCALLS_INVALID_ID && 807 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) { 808 map2 = map->passive.accept_map; 809 goto received; 810 } 811 if (nonblock) { 812 pvcalls_exit_sock(sock); 813 return -EAGAIN; 814 } 815 if (wait_event_interruptible(map->passive.inflight_accept_req, 816 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 817 (void *)&map->passive.flags))) { 818 pvcalls_exit_sock(sock); 819 return -EINTR; 820 } 821 } 822 823 map2 = kzalloc(sizeof(*map2), GFP_KERNEL); 824 if (map2 == NULL) { 825 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 826 (void *)&map->passive.flags); 827 pvcalls_exit_sock(sock); 828 return -ENOMEM; 829 } 830 ret = alloc_active_ring(map2); 831 if (ret < 0) { 832 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 833 (void *)&map->passive.flags); 834 kfree(map2); 835 pvcalls_exit_sock(sock); 836 return ret; 837 } 838 ret = create_active(map2, &evtchn); 839 if (ret < 0) { 840 free_active_ring(map2); 841 kfree(map2); 842 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 843 (void *)&map->passive.flags); 844 pvcalls_exit_sock(sock); 845 return ret; 846 } 847 848 spin_lock(&bedata->socket_lock); 849 ret = get_request(bedata, &req_id); 850 if (ret < 0) { 851 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 852 (void *)&map->passive.flags); 853 spin_unlock(&bedata->socket_lock); 854 pvcalls_front_free_map(bedata, map2); 855 pvcalls_exit_sock(sock); 856 return ret; 857 } 858 859 list_add_tail(&map2->list, &bedata->socket_mappings); 860 861 req = RING_GET_REQUEST(&bedata->ring, req_id); 862 req->req_id = req_id; 863 req->cmd = PVCALLS_ACCEPT; 864 req->u.accept.id = (uintptr_t) map; 865 req->u.accept.ref = map2->active.ref; 866 req->u.accept.id_new = (uintptr_t) map2; 867 req->u.accept.evtchn = evtchn; 868 map->passive.accept_map = map2; 869 870 bedata->ring.req_prod_pvt++; 871 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 872 spin_unlock(&bedata->socket_lock); 873 if (notify) 874 notify_remote_via_irq(bedata->irq); 875 /* We could check if we have received a response before returning. */ 876 if (nonblock) { 877 WRITE_ONCE(map->passive.inflight_req_id, req_id); 878 pvcalls_exit_sock(sock); 879 return -EAGAIN; 880 } 881 882 if (wait_event_interruptible(bedata->inflight_req, 883 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 884 pvcalls_exit_sock(sock); 885 return -EINTR; 886 } 887 /* read req_id, then the content */ 888 smp_rmb(); 889 890 received: 891 map2->sock = newsock; 892 newsock->sk = sk_alloc(sock_net(sock->sk), PF_INET, GFP_KERNEL, &pvcalls_proto, false); 893 if (!newsock->sk) { 894 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 895 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 896 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 897 (void *)&map->passive.flags); 898 pvcalls_front_free_map(bedata, map2); 899 pvcalls_exit_sock(sock); 900 return -ENOMEM; 901 } 902 newsock->sk->sk_send_head = (void *)map2; 903 904 ret = bedata->rsp[req_id].ret; 905 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 906 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 907 908 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 909 wake_up(&map->passive.inflight_accept_req); 910 911 pvcalls_exit_sock(sock); 912 return ret; 913 } 914 EXPORT_SYMBOL_GPL(pvcalls_front_accept); 915 916 static __poll_t pvcalls_front_poll_passive(struct file *file, 917 struct pvcalls_bedata *bedata, 918 struct sock_mapping *map, 919 poll_table *wait) 920 { 921 int notify, req_id, ret; 922 struct xen_pvcalls_request *req; 923 924 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 925 (void *)&map->passive.flags)) { 926 uint32_t req_id = READ_ONCE(map->passive.inflight_req_id); 927 928 if (req_id != PVCALLS_INVALID_ID && 929 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) 930 return EPOLLIN | EPOLLRDNORM; 931 932 poll_wait(file, &map->passive.inflight_accept_req, wait); 933 return 0; 934 } 935 936 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET, 937 (void *)&map->passive.flags)) 938 return EPOLLIN | EPOLLRDNORM; 939 940 /* 941 * First check RET, then INFLIGHT. No barriers necessary to 942 * ensure execution ordering because of the conditional 943 * instructions creating control dependencies. 944 */ 945 946 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT, 947 (void *)&map->passive.flags)) { 948 poll_wait(file, &bedata->inflight_req, wait); 949 return 0; 950 } 951 952 spin_lock(&bedata->socket_lock); 953 ret = get_request(bedata, &req_id); 954 if (ret < 0) { 955 spin_unlock(&bedata->socket_lock); 956 return ret; 957 } 958 req = RING_GET_REQUEST(&bedata->ring, req_id); 959 req->req_id = req_id; 960 req->cmd = PVCALLS_POLL; 961 req->u.poll.id = (uintptr_t) map; 962 963 bedata->ring.req_prod_pvt++; 964 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 965 spin_unlock(&bedata->socket_lock); 966 if (notify) 967 notify_remote_via_irq(bedata->irq); 968 969 poll_wait(file, &bedata->inflight_req, wait); 970 return 0; 971 } 972 973 static __poll_t pvcalls_front_poll_active(struct file *file, 974 struct pvcalls_bedata *bedata, 975 struct sock_mapping *map, 976 poll_table *wait) 977 { 978 __poll_t mask = 0; 979 int32_t in_error, out_error; 980 struct pvcalls_data_intf *intf = map->active.ring; 981 982 out_error = intf->out_error; 983 in_error = intf->in_error; 984 985 poll_wait(file, &map->active.inflight_conn_req, wait); 986 if (pvcalls_front_write_todo(map)) 987 mask |= EPOLLOUT | EPOLLWRNORM; 988 if (pvcalls_front_read_todo(map)) 989 mask |= EPOLLIN | EPOLLRDNORM; 990 if (in_error != 0 || out_error != 0) 991 mask |= EPOLLERR; 992 993 return mask; 994 } 995 996 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock, 997 poll_table *wait) 998 { 999 struct pvcalls_bedata *bedata; 1000 struct sock_mapping *map; 1001 __poll_t ret; 1002 1003 map = pvcalls_enter_sock(sock); 1004 if (IS_ERR(map)) 1005 return EPOLLNVAL; 1006 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1007 1008 if (map->active_socket) 1009 ret = pvcalls_front_poll_active(file, bedata, map, wait); 1010 else 1011 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 1012 pvcalls_exit_sock(sock); 1013 return ret; 1014 } 1015 EXPORT_SYMBOL_GPL(pvcalls_front_poll); 1016 1017 int pvcalls_front_release(struct socket *sock) 1018 { 1019 struct pvcalls_bedata *bedata; 1020 struct sock_mapping *map; 1021 int req_id, notify, ret; 1022 struct xen_pvcalls_request *req; 1023 1024 if (sock->sk == NULL) 1025 return 0; 1026 1027 map = pvcalls_enter_sock(sock); 1028 if (IS_ERR(map)) { 1029 if (PTR_ERR(map) == -ENOTCONN) 1030 return -EIO; 1031 else 1032 return 0; 1033 } 1034 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1035 1036 spin_lock(&bedata->socket_lock); 1037 ret = get_request(bedata, &req_id); 1038 if (ret < 0) { 1039 spin_unlock(&bedata->socket_lock); 1040 pvcalls_exit_sock(sock); 1041 return ret; 1042 } 1043 sock->sk->sk_send_head = NULL; 1044 1045 req = RING_GET_REQUEST(&bedata->ring, req_id); 1046 req->req_id = req_id; 1047 req->cmd = PVCALLS_RELEASE; 1048 req->u.release.id = (uintptr_t)map; 1049 1050 bedata->ring.req_prod_pvt++; 1051 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 1052 spin_unlock(&bedata->socket_lock); 1053 if (notify) 1054 notify_remote_via_irq(bedata->irq); 1055 1056 wait_event(bedata->inflight_req, 1057 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 1058 1059 if (map->active_socket) { 1060 /* 1061 * Set in_error and wake up inflight_conn_req to force 1062 * recvmsg waiters to exit. 1063 */ 1064 map->active.ring->in_error = -EBADF; 1065 wake_up_interruptible(&map->active.inflight_conn_req); 1066 1067 /* 1068 * We need to make sure that sendmsg/recvmsg on this socket have 1069 * not started before we've cleared sk_send_head here. The 1070 * easiest way to guarantee this is to see that no pvcalls 1071 * (other than us) is in progress on this socket. 1072 */ 1073 while (atomic_read(&map->refcount) > 1) 1074 cpu_relax(); 1075 1076 pvcalls_front_free_map(bedata, map); 1077 } else { 1078 wake_up(&bedata->inflight_req); 1079 wake_up(&map->passive.inflight_accept_req); 1080 1081 while (atomic_read(&map->refcount) > 1) 1082 cpu_relax(); 1083 1084 spin_lock(&bedata->socket_lock); 1085 list_del(&map->list); 1086 spin_unlock(&bedata->socket_lock); 1087 if (READ_ONCE(map->passive.inflight_req_id) != PVCALLS_INVALID_ID && 1088 READ_ONCE(map->passive.inflight_req_id) != 0) { 1089 pvcalls_front_free_map(bedata, 1090 map->passive.accept_map); 1091 } 1092 kfree(map); 1093 } 1094 WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID); 1095 1096 pvcalls_exit(); 1097 return 0; 1098 } 1099 EXPORT_SYMBOL_GPL(pvcalls_front_release); 1100 1101 static const struct xenbus_device_id pvcalls_front_ids[] = { 1102 { "pvcalls" }, 1103 { "" } 1104 }; 1105 1106 static void pvcalls_front_remove(struct xenbus_device *dev) 1107 { 1108 struct pvcalls_bedata *bedata; 1109 struct sock_mapping *map = NULL, *n; 1110 1111 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1112 dev_set_drvdata(&dev->dev, NULL); 1113 pvcalls_front_dev = NULL; 1114 if (bedata->irq >= 0) 1115 unbind_from_irqhandler(bedata->irq, dev); 1116 1117 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1118 map->sock->sk->sk_send_head = NULL; 1119 if (map->active_socket) { 1120 map->active.ring->in_error = -EBADF; 1121 wake_up_interruptible(&map->active.inflight_conn_req); 1122 } 1123 } 1124 1125 smp_mb(); 1126 while (atomic_read(&pvcalls_refcount) > 0) 1127 cpu_relax(); 1128 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1129 if (map->active_socket) { 1130 /* No need to lock, refcount is 0 */ 1131 pvcalls_front_free_map(bedata, map); 1132 } else { 1133 list_del(&map->list); 1134 kfree(map); 1135 } 1136 } 1137 if (bedata->ref != -1) 1138 gnttab_end_foreign_access(bedata->ref, NULL); 1139 kfree(bedata->ring.sring); 1140 kfree(bedata); 1141 xenbus_switch_state(dev, XenbusStateClosed); 1142 } 1143 1144 static int pvcalls_front_probe(struct xenbus_device *dev, 1145 const struct xenbus_device_id *id) 1146 { 1147 int ret = -ENOMEM, i; 1148 evtchn_port_t evtchn; 1149 unsigned int max_page_order, function_calls, len; 1150 char *versions; 1151 grant_ref_t gref_head = 0; 1152 struct xenbus_transaction xbt; 1153 struct pvcalls_bedata *bedata = NULL; 1154 struct xen_pvcalls_sring *sring; 1155 1156 if (pvcalls_front_dev != NULL) { 1157 dev_err(&dev->dev, "only one PV Calls connection supported\n"); 1158 return -EINVAL; 1159 } 1160 1161 versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); 1162 if (IS_ERR(versions)) 1163 return PTR_ERR(versions); 1164 if (!len) 1165 return -EINVAL; 1166 if (strcmp(versions, "1")) { 1167 kfree(versions); 1168 return -EINVAL; 1169 } 1170 kfree(versions); 1171 max_page_order = xenbus_read_unsigned(dev->otherend, 1172 "max-page-order", 0); 1173 if (max_page_order < PVCALLS_RING_ORDER) 1174 return -ENODEV; 1175 function_calls = xenbus_read_unsigned(dev->otherend, 1176 "function-calls", 0); 1177 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */ 1178 if (function_calls != 1) 1179 return -ENODEV; 1180 pr_info("%s max-page-order is %u\n", __func__, max_page_order); 1181 1182 bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL); 1183 if (!bedata) 1184 return -ENOMEM; 1185 1186 dev_set_drvdata(&dev->dev, bedata); 1187 pvcalls_front_dev = dev; 1188 init_waitqueue_head(&bedata->inflight_req); 1189 INIT_LIST_HEAD(&bedata->socket_mappings); 1190 spin_lock_init(&bedata->socket_lock); 1191 bedata->irq = -1; 1192 bedata->ref = -1; 1193 1194 for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++) 1195 bedata->rsp[i].req_id = PVCALLS_INVALID_ID; 1196 1197 sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL | 1198 __GFP_ZERO); 1199 if (!sring) 1200 goto error; 1201 SHARED_RING_INIT(sring); 1202 FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE); 1203 1204 ret = xenbus_alloc_evtchn(dev, &evtchn); 1205 if (ret) 1206 goto error; 1207 1208 bedata->irq = bind_evtchn_to_irqhandler(evtchn, 1209 pvcalls_front_event_handler, 1210 0, "pvcalls-frontend", dev); 1211 if (bedata->irq < 0) { 1212 ret = bedata->irq; 1213 goto error; 1214 } 1215 1216 ret = gnttab_alloc_grant_references(1, &gref_head); 1217 if (ret < 0) 1218 goto error; 1219 ret = gnttab_claim_grant_reference(&gref_head); 1220 if (ret < 0) 1221 goto error; 1222 bedata->ref = ret; 1223 gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id, 1224 virt_to_gfn((void *)sring), 0); 1225 1226 again: 1227 ret = xenbus_transaction_start(&xbt); 1228 if (ret) { 1229 xenbus_dev_fatal(dev, ret, "starting transaction"); 1230 goto error; 1231 } 1232 ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1); 1233 if (ret) 1234 goto error_xenbus; 1235 ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref); 1236 if (ret) 1237 goto error_xenbus; 1238 ret = xenbus_printf(xbt, dev->nodename, "port", "%u", 1239 evtchn); 1240 if (ret) 1241 goto error_xenbus; 1242 ret = xenbus_transaction_end(xbt, 0); 1243 if (ret) { 1244 if (ret == -EAGAIN) 1245 goto again; 1246 xenbus_dev_fatal(dev, ret, "completing transaction"); 1247 goto error; 1248 } 1249 xenbus_switch_state(dev, XenbusStateInitialised); 1250 1251 return 0; 1252 1253 error_xenbus: 1254 xenbus_transaction_end(xbt, 1); 1255 xenbus_dev_fatal(dev, ret, "writing xenstore"); 1256 error: 1257 pvcalls_front_remove(dev); 1258 return ret; 1259 } 1260 1261 static void pvcalls_front_changed(struct xenbus_device *dev, 1262 enum xenbus_state backend_state) 1263 { 1264 switch (backend_state) { 1265 case XenbusStateReconfiguring: 1266 case XenbusStateReconfigured: 1267 case XenbusStateInitialising: 1268 case XenbusStateInitialised: 1269 case XenbusStateUnknown: 1270 break; 1271 1272 case XenbusStateInitWait: 1273 break; 1274 1275 case XenbusStateConnected: 1276 xenbus_switch_state(dev, XenbusStateConnected); 1277 break; 1278 1279 case XenbusStateClosed: 1280 if (dev->state == XenbusStateClosed) 1281 break; 1282 /* Missed the backend's CLOSING state */ 1283 fallthrough; 1284 case XenbusStateClosing: 1285 xenbus_frontend_closed(dev); 1286 break; 1287 } 1288 } 1289 1290 static struct xenbus_driver pvcalls_front_driver = { 1291 .ids = pvcalls_front_ids, 1292 .probe = pvcalls_front_probe, 1293 .remove = pvcalls_front_remove, 1294 .otherend_changed = pvcalls_front_changed, 1295 .not_essential = true, 1296 }; 1297 1298 static int __init pvcalls_frontend_init(void) 1299 { 1300 if (!xen_domain()) 1301 return -ENODEV; 1302 1303 pr_info("Initialising Xen pvcalls frontend driver\n"); 1304 1305 return xenbus_register_frontend(&pvcalls_front_driver); 1306 } 1307 1308 module_init(pvcalls_frontend_init); 1309 1310 MODULE_DESCRIPTION("Xen PV Calls frontend driver"); 1311 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>"); 1312 MODULE_LICENSE("GPL"); 1313