1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 10 #include <net/inet_common.h> 11 12 static bool tcp_bpf_stream_read(const struct sock *sk) 13 { 14 struct sk_psock *psock; 15 bool empty = true; 16 17 rcu_read_lock(); 18 psock = sk_psock(sk); 19 if (likely(psock)) 20 empty = list_empty(&psock->ingress_msg); 21 rcu_read_unlock(); 22 return !empty; 23 } 24 25 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, 26 int flags, long timeo, int *err) 27 { 28 DEFINE_WAIT_FUNC(wait, woken_wake_function); 29 int ret; 30 31 add_wait_queue(sk_sleep(sk), &wait); 32 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 33 ret = sk_wait_event(sk, &timeo, 34 !list_empty(&psock->ingress_msg) || 35 !skb_queue_empty(&sk->sk_receive_queue), &wait); 36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 37 remove_wait_queue(sk_sleep(sk), &wait); 38 return ret; 39 } 40 41 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, 42 struct msghdr *msg, int len, int flags) 43 { 44 struct iov_iter *iter = &msg->msg_iter; 45 int peek = flags & MSG_PEEK; 46 int i, ret, copied = 0; 47 struct sk_msg *msg_rx; 48 49 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 50 struct sk_msg, list); 51 52 while (copied != len) { 53 struct scatterlist *sge; 54 55 if (unlikely(!msg_rx)) 56 break; 57 58 i = msg_rx->sg.start; 59 do { 60 struct page *page; 61 int copy; 62 63 sge = sk_msg_elem(msg_rx, i); 64 copy = sge->length; 65 page = sg_page(sge); 66 if (copied + copy > len) 67 copy = len - copied; 68 ret = copy_page_to_iter(page, sge->offset, copy, iter); 69 if (ret != copy) { 70 msg_rx->sg.start = i; 71 return -EFAULT; 72 } 73 74 copied += copy; 75 if (likely(!peek)) { 76 sge->offset += copy; 77 sge->length -= copy; 78 sk_mem_uncharge(sk, copy); 79 msg_rx->sg.size -= copy; 80 81 if (!sge->length) { 82 sk_msg_iter_var_next(i); 83 if (!msg_rx->skb) 84 put_page(page); 85 } 86 } else { 87 sk_msg_iter_var_next(i); 88 } 89 90 if (copied == len) 91 break; 92 } while (i != msg_rx->sg.end); 93 94 if (unlikely(peek)) { 95 msg_rx = list_next_entry(msg_rx, list); 96 continue; 97 } 98 99 msg_rx->sg.start = i; 100 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { 101 list_del(&msg_rx->list); 102 if (msg_rx->skb) 103 consume_skb(msg_rx->skb); 104 kfree(msg_rx); 105 } 106 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 107 struct sk_msg, list); 108 } 109 110 return copied; 111 } 112 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); 113 114 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 115 int nonblock, int flags, int *addr_len) 116 { 117 struct sk_psock *psock; 118 int copied, ret; 119 120 if (unlikely(flags & MSG_ERRQUEUE)) 121 return inet_recv_error(sk, msg, len, addr_len); 122 if (!skb_queue_empty(&sk->sk_receive_queue)) 123 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 124 125 psock = sk_psock_get(sk); 126 if (unlikely(!psock)) 127 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 128 lock_sock(sk); 129 msg_bytes_ready: 130 copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); 131 if (!copied) { 132 int data, err = 0; 133 long timeo; 134 135 timeo = sock_rcvtimeo(sk, nonblock); 136 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); 137 if (data) { 138 if (skb_queue_empty(&sk->sk_receive_queue)) 139 goto msg_bytes_ready; 140 release_sock(sk); 141 sk_psock_put(sk, psock); 142 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 143 } 144 if (err) { 145 ret = err; 146 goto out; 147 } 148 } 149 ret = copied; 150 out: 151 release_sock(sk); 152 sk_psock_put(sk, psock); 153 return ret; 154 } 155 156 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 157 struct sk_msg *msg, u32 apply_bytes, int flags) 158 { 159 bool apply = apply_bytes; 160 struct scatterlist *sge; 161 u32 size, copied = 0; 162 struct sk_msg *tmp; 163 int i, ret = 0; 164 165 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 166 if (unlikely(!tmp)) 167 return -ENOMEM; 168 169 lock_sock(sk); 170 tmp->sg.start = msg->sg.start; 171 i = msg->sg.start; 172 do { 173 sge = sk_msg_elem(msg, i); 174 size = (apply && apply_bytes < sge->length) ? 175 apply_bytes : sge->length; 176 if (!sk_wmem_schedule(sk, size)) { 177 if (!copied) 178 ret = -ENOMEM; 179 break; 180 } 181 182 sk_mem_charge(sk, size); 183 sk_msg_xfer(tmp, msg, i, size); 184 copied += size; 185 if (sge->length) 186 get_page(sk_msg_page(tmp, i)); 187 sk_msg_iter_var_next(i); 188 tmp->sg.end = i; 189 if (apply) { 190 apply_bytes -= size; 191 if (!apply_bytes) 192 break; 193 } 194 } while (i != msg->sg.end); 195 196 if (!ret) { 197 msg->sg.start = i; 198 msg->sg.size -= apply_bytes; 199 sk_psock_queue_msg(psock, tmp); 200 sk->sk_data_ready(sk); 201 } else { 202 sk_msg_free(sk, tmp); 203 kfree(tmp); 204 } 205 206 release_sock(sk); 207 return ret; 208 } 209 210 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 211 int flags, bool uncharge) 212 { 213 bool apply = apply_bytes; 214 struct scatterlist *sge; 215 struct page *page; 216 int size, ret = 0; 217 u32 off; 218 219 while (1) { 220 sge = sk_msg_elem(msg, msg->sg.start); 221 size = (apply && apply_bytes < sge->length) ? 222 apply_bytes : sge->length; 223 off = sge->offset; 224 page = sg_page(sge); 225 226 tcp_rate_check_app_limited(sk); 227 retry: 228 ret = do_tcp_sendpages(sk, page, off, size, flags); 229 if (ret <= 0) 230 return ret; 231 if (apply) 232 apply_bytes -= ret; 233 msg->sg.size -= ret; 234 sge->offset += ret; 235 sge->length -= ret; 236 if (uncharge) 237 sk_mem_uncharge(sk, ret); 238 if (ret != size) { 239 size -= ret; 240 off += ret; 241 goto retry; 242 } 243 if (!sge->length) { 244 put_page(page); 245 sk_msg_iter_next(msg, start); 246 sg_init_table(sge, 1); 247 if (msg->sg.start == msg->sg.end) 248 break; 249 } 250 if (apply && !apply_bytes) 251 break; 252 } 253 254 return 0; 255 } 256 257 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 258 u32 apply_bytes, int flags, bool uncharge) 259 { 260 int ret; 261 262 lock_sock(sk); 263 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 264 release_sock(sk); 265 return ret; 266 } 267 268 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, 269 u32 bytes, int flags) 270 { 271 bool ingress = sk_msg_to_ingress(msg); 272 struct sk_psock *psock = sk_psock_get(sk); 273 int ret; 274 275 if (unlikely(!psock)) { 276 sk_msg_free(sk, msg); 277 return 0; 278 } 279 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 280 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 281 sk_psock_put(sk, psock); 282 return ret; 283 } 284 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 285 286 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 287 struct sk_msg *msg, int *copied, int flags) 288 { 289 bool cork = false, enospc = msg->sg.start == msg->sg.end; 290 struct sock *sk_redir; 291 u32 tosend; 292 int ret; 293 294 more_data: 295 if (psock->eval == __SK_NONE) 296 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 297 298 if (msg->cork_bytes && 299 msg->cork_bytes > msg->sg.size && !enospc) { 300 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 301 if (!psock->cork) { 302 psock->cork = kzalloc(sizeof(*psock->cork), 303 GFP_ATOMIC | __GFP_NOWARN); 304 if (!psock->cork) 305 return -ENOMEM; 306 } 307 memcpy(psock->cork, msg, sizeof(*msg)); 308 return 0; 309 } 310 311 tosend = msg->sg.size; 312 if (psock->apply_bytes && psock->apply_bytes < tosend) 313 tosend = psock->apply_bytes; 314 315 switch (psock->eval) { 316 case __SK_PASS: 317 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 318 if (unlikely(ret)) { 319 *copied -= sk_msg_free(sk, msg); 320 break; 321 } 322 sk_msg_apply_bytes(psock, tosend); 323 break; 324 case __SK_REDIRECT: 325 sk_redir = psock->sk_redir; 326 sk_msg_apply_bytes(psock, tosend); 327 if (psock->cork) { 328 cork = true; 329 psock->cork = NULL; 330 } 331 sk_msg_return(sk, msg, tosend); 332 release_sock(sk); 333 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); 334 lock_sock(sk); 335 if (unlikely(ret < 0)) { 336 int free = sk_msg_free_nocharge(sk, msg); 337 338 if (!cork) 339 *copied -= free; 340 } 341 if (cork) { 342 sk_msg_free(sk, msg); 343 kfree(msg); 344 msg = NULL; 345 ret = 0; 346 } 347 break; 348 case __SK_DROP: 349 default: 350 sk_msg_free_partial(sk, msg, tosend); 351 sk_msg_apply_bytes(psock, tosend); 352 *copied -= tosend; 353 return -EACCES; 354 } 355 356 if (likely(!ret)) { 357 if (!psock->apply_bytes) { 358 psock->eval = __SK_NONE; 359 if (psock->sk_redir) { 360 sock_put(psock->sk_redir); 361 psock->sk_redir = NULL; 362 } 363 } 364 if (msg && 365 msg->sg.data[msg->sg.start].page_link && 366 msg->sg.data[msg->sg.start].length) 367 goto more_data; 368 } 369 return ret; 370 } 371 372 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 373 { 374 struct sk_msg tmp, *msg_tx = NULL; 375 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; 376 int copied = 0, err = 0; 377 struct sk_psock *psock; 378 long timeo; 379 380 psock = sk_psock_get(sk); 381 if (unlikely(!psock)) 382 return tcp_sendmsg(sk, msg, size); 383 384 lock_sock(sk); 385 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 386 while (msg_data_left(msg)) { 387 bool enospc = false; 388 u32 copy, osize; 389 390 if (sk->sk_err) { 391 err = -sk->sk_err; 392 goto out_err; 393 } 394 395 copy = msg_data_left(msg); 396 if (!sk_stream_memory_free(sk)) 397 goto wait_for_sndbuf; 398 if (psock->cork) { 399 msg_tx = psock->cork; 400 } else { 401 msg_tx = &tmp; 402 sk_msg_init(msg_tx); 403 } 404 405 osize = msg_tx->sg.size; 406 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 407 if (err) { 408 if (err != -ENOSPC) 409 goto wait_for_memory; 410 enospc = true; 411 copy = msg_tx->sg.size - osize; 412 } 413 414 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 415 copy); 416 if (err < 0) { 417 sk_msg_trim(sk, msg_tx, osize); 418 goto out_err; 419 } 420 421 copied += copy; 422 if (psock->cork_bytes) { 423 if (size > psock->cork_bytes) 424 psock->cork_bytes = 0; 425 else 426 psock->cork_bytes -= size; 427 if (psock->cork_bytes && !enospc) 428 goto out_err; 429 /* All cork bytes are accounted, rerun the prog. */ 430 psock->eval = __SK_NONE; 431 psock->cork_bytes = 0; 432 } 433 434 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 435 if (unlikely(err < 0)) 436 goto out_err; 437 continue; 438 wait_for_sndbuf: 439 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 440 wait_for_memory: 441 err = sk_stream_wait_memory(sk, &timeo); 442 if (err) { 443 if (msg_tx && msg_tx != psock->cork) 444 sk_msg_free(sk, msg_tx); 445 goto out_err; 446 } 447 } 448 out_err: 449 if (err < 0) 450 err = sk_stream_error(sk, msg->msg_flags, err); 451 release_sock(sk); 452 sk_psock_put(sk, psock); 453 return copied ? copied : err; 454 } 455 456 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 457 size_t size, int flags) 458 { 459 struct sk_msg tmp, *msg = NULL; 460 int err = 0, copied = 0; 461 struct sk_psock *psock; 462 bool enospc = false; 463 464 psock = sk_psock_get(sk); 465 if (unlikely(!psock)) 466 return tcp_sendpage(sk, page, offset, size, flags); 467 468 lock_sock(sk); 469 if (psock->cork) { 470 msg = psock->cork; 471 } else { 472 msg = &tmp; 473 sk_msg_init(msg); 474 } 475 476 /* Catch case where ring is full and sendpage is stalled. */ 477 if (unlikely(sk_msg_full(msg))) 478 goto out_err; 479 480 sk_msg_page_add(msg, page, size, offset); 481 sk_mem_charge(sk, size); 482 copied = size; 483 if (sk_msg_full(msg)) 484 enospc = true; 485 if (psock->cork_bytes) { 486 if (size > psock->cork_bytes) 487 psock->cork_bytes = 0; 488 else 489 psock->cork_bytes -= size; 490 if (psock->cork_bytes && !enospc) 491 goto out_err; 492 /* All cork bytes are accounted, rerun the prog. */ 493 psock->eval = __SK_NONE; 494 psock->cork_bytes = 0; 495 } 496 497 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 498 out_err: 499 release_sock(sk); 500 sk_psock_put(sk, psock); 501 return copied ? copied : err; 502 } 503 504 static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) 505 { 506 struct sk_psock_link *link; 507 508 sk_psock_cork_free(psock); 509 __sk_psock_purge_ingress_msg(psock); 510 while ((link = sk_psock_link_pop(psock))) { 511 sk_psock_unlink(sk, link); 512 sk_psock_free_link(link); 513 } 514 } 515 516 static void tcp_bpf_unhash(struct sock *sk) 517 { 518 void (*saved_unhash)(struct sock *sk); 519 struct sk_psock *psock; 520 521 rcu_read_lock(); 522 psock = sk_psock(sk); 523 if (unlikely(!psock)) { 524 rcu_read_unlock(); 525 if (sk->sk_prot->unhash) 526 sk->sk_prot->unhash(sk); 527 return; 528 } 529 530 saved_unhash = psock->saved_unhash; 531 tcp_bpf_remove(sk, psock); 532 rcu_read_unlock(); 533 saved_unhash(sk); 534 } 535 536 static void tcp_bpf_close(struct sock *sk, long timeout) 537 { 538 void (*saved_close)(struct sock *sk, long timeout); 539 struct sk_psock *psock; 540 541 lock_sock(sk); 542 rcu_read_lock(); 543 psock = sk_psock(sk); 544 if (unlikely(!psock)) { 545 rcu_read_unlock(); 546 release_sock(sk); 547 return sk->sk_prot->close(sk, timeout); 548 } 549 550 saved_close = psock->saved_close; 551 tcp_bpf_remove(sk, psock); 552 rcu_read_unlock(); 553 release_sock(sk); 554 saved_close(sk, timeout); 555 } 556 557 enum { 558 TCP_BPF_IPV4, 559 TCP_BPF_IPV6, 560 TCP_BPF_NUM_PROTS, 561 }; 562 563 enum { 564 TCP_BPF_BASE, 565 TCP_BPF_TX, 566 TCP_BPF_NUM_CFGS, 567 }; 568 569 static struct proto *tcpv6_prot_saved __read_mostly; 570 static DEFINE_SPINLOCK(tcpv6_prot_lock); 571 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 572 573 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 574 struct proto *base) 575 { 576 prot[TCP_BPF_BASE] = *base; 577 prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; 578 prot[TCP_BPF_BASE].close = tcp_bpf_close; 579 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 580 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; 581 582 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 583 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 584 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 585 } 586 587 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) 588 { 589 if (sk->sk_family == AF_INET6 && 590 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 591 spin_lock_bh(&tcpv6_prot_lock); 592 if (likely(ops != tcpv6_prot_saved)) { 593 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 594 smp_store_release(&tcpv6_prot_saved, ops); 595 } 596 spin_unlock_bh(&tcpv6_prot_lock); 597 } 598 } 599 600 static int __init tcp_bpf_v4_build_proto(void) 601 { 602 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 603 return 0; 604 } 605 core_initcall(tcp_bpf_v4_build_proto); 606 607 static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) 608 { 609 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 610 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 611 612 sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); 613 } 614 615 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) 616 { 617 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 618 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 619 620 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed 621 * or added requiring sk_prot hook updates. We keep original saved 622 * hooks in this case. 623 */ 624 sk->sk_prot = &tcp_bpf_prots[family][config]; 625 } 626 627 static int tcp_bpf_assert_proto_ops(struct proto *ops) 628 { 629 /* In order to avoid retpoline, we make assumptions when we call 630 * into ops if e.g. a psock is not present. Make sure they are 631 * indeed valid assumptions. 632 */ 633 return ops->recvmsg == tcp_recvmsg && 634 ops->sendmsg == tcp_sendmsg && 635 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 636 } 637 638 void tcp_bpf_reinit(struct sock *sk) 639 { 640 struct sk_psock *psock; 641 642 sock_owned_by_me(sk); 643 644 rcu_read_lock(); 645 psock = sk_psock(sk); 646 tcp_bpf_reinit_sk_prot(sk, psock); 647 rcu_read_unlock(); 648 } 649 650 int tcp_bpf_init(struct sock *sk) 651 { 652 struct proto *ops = READ_ONCE(sk->sk_prot); 653 struct sk_psock *psock; 654 655 sock_owned_by_me(sk); 656 657 rcu_read_lock(); 658 psock = sk_psock(sk); 659 if (unlikely(!psock || psock->sk_proto || 660 tcp_bpf_assert_proto_ops(ops))) { 661 rcu_read_unlock(); 662 return -EINVAL; 663 } 664 tcp_bpf_check_v6_needs_rebuild(sk, ops); 665 tcp_bpf_update_sk_prot(sk, psock); 666 rcu_read_unlock(); 667 return 0; 668 } 669