1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 10 #include <net/inet_common.h> 11 12 static bool tcp_bpf_stream_read(const struct sock *sk) 13 { 14 struct sk_psock *psock; 15 bool empty = true; 16 17 rcu_read_lock(); 18 psock = sk_psock(sk); 19 if (likely(psock)) 20 empty = list_empty(&psock->ingress_msg); 21 rcu_read_unlock(); 22 return !empty; 23 } 24 25 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, 26 int flags, long timeo, int *err) 27 { 28 DEFINE_WAIT_FUNC(wait, woken_wake_function); 29 int ret; 30 31 add_wait_queue(sk_sleep(sk), &wait); 32 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 33 ret = sk_wait_event(sk, &timeo, 34 !list_empty(&psock->ingress_msg) || 35 !skb_queue_empty(&sk->sk_receive_queue), &wait); 36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 37 remove_wait_queue(sk_sleep(sk), &wait); 38 return ret; 39 } 40 41 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, 42 struct msghdr *msg, int len, int flags) 43 { 44 struct iov_iter *iter = &msg->msg_iter; 45 int peek = flags & MSG_PEEK; 46 int i, ret, copied = 0; 47 struct sk_msg *msg_rx; 48 49 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 50 struct sk_msg, list); 51 52 while (copied != len) { 53 struct scatterlist *sge; 54 55 if (unlikely(!msg_rx)) 56 break; 57 58 i = msg_rx->sg.start; 59 do { 60 struct page *page; 61 int copy; 62 63 sge = sk_msg_elem(msg_rx, i); 64 copy = sge->length; 65 page = sg_page(sge); 66 if (copied + copy > len) 67 copy = len - copied; 68 ret = copy_page_to_iter(page, sge->offset, copy, iter); 69 if (ret != copy) { 70 msg_rx->sg.start = i; 71 return -EFAULT; 72 } 73 74 copied += copy; 75 if (likely(!peek)) { 76 sge->offset += copy; 77 sge->length -= copy; 78 sk_mem_uncharge(sk, copy); 79 msg_rx->sg.size -= copy; 80 81 if (!sge->length) { 82 sk_msg_iter_var_next(i); 83 if (!msg_rx->skb) 84 put_page(page); 85 } 86 } else { 87 sk_msg_iter_var_next(i); 88 } 89 90 if (copied == len) 91 break; 92 } while (i != msg_rx->sg.end); 93 94 if (unlikely(peek)) { 95 msg_rx = list_next_entry(msg_rx, list); 96 continue; 97 } 98 99 msg_rx->sg.start = i; 100 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { 101 list_del(&msg_rx->list); 102 if (msg_rx->skb) 103 consume_skb(msg_rx->skb); 104 kfree(msg_rx); 105 } 106 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 107 struct sk_msg, list); 108 } 109 110 return copied; 111 } 112 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); 113 114 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 115 int nonblock, int flags, int *addr_len) 116 { 117 struct sk_psock *psock; 118 int copied, ret; 119 120 if (unlikely(flags & MSG_ERRQUEUE)) 121 return inet_recv_error(sk, msg, len, addr_len); 122 if (!skb_queue_empty(&sk->sk_receive_queue)) 123 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 124 125 psock = sk_psock_get(sk); 126 if (unlikely(!psock)) 127 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 128 lock_sock(sk); 129 msg_bytes_ready: 130 copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); 131 if (!copied) { 132 int data, err = 0; 133 long timeo; 134 135 timeo = sock_rcvtimeo(sk, nonblock); 136 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); 137 if (data) { 138 if (skb_queue_empty(&sk->sk_receive_queue)) 139 goto msg_bytes_ready; 140 release_sock(sk); 141 sk_psock_put(sk, psock); 142 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 143 } 144 if (err) { 145 ret = err; 146 goto out; 147 } 148 copied = -EAGAIN; 149 } 150 ret = copied; 151 out: 152 release_sock(sk); 153 sk_psock_put(sk, psock); 154 return ret; 155 } 156 157 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 158 struct sk_msg *msg, u32 apply_bytes, int flags) 159 { 160 bool apply = apply_bytes; 161 struct scatterlist *sge; 162 u32 size, copied = 0; 163 struct sk_msg *tmp; 164 int i, ret = 0; 165 166 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 167 if (unlikely(!tmp)) 168 return -ENOMEM; 169 170 lock_sock(sk); 171 tmp->sg.start = msg->sg.start; 172 i = msg->sg.start; 173 do { 174 sge = sk_msg_elem(msg, i); 175 size = (apply && apply_bytes < sge->length) ? 176 apply_bytes : sge->length; 177 if (!sk_wmem_schedule(sk, size)) { 178 if (!copied) 179 ret = -ENOMEM; 180 break; 181 } 182 183 sk_mem_charge(sk, size); 184 sk_msg_xfer(tmp, msg, i, size); 185 copied += size; 186 if (sge->length) 187 get_page(sk_msg_page(tmp, i)); 188 sk_msg_iter_var_next(i); 189 tmp->sg.end = i; 190 if (apply) { 191 apply_bytes -= size; 192 if (!apply_bytes) 193 break; 194 } 195 } while (i != msg->sg.end); 196 197 if (!ret) { 198 msg->sg.start = i; 199 msg->sg.size -= apply_bytes; 200 sk_psock_queue_msg(psock, tmp); 201 sk->sk_data_ready(sk); 202 } else { 203 sk_msg_free(sk, tmp); 204 kfree(tmp); 205 } 206 207 release_sock(sk); 208 return ret; 209 } 210 211 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 212 int flags, bool uncharge) 213 { 214 bool apply = apply_bytes; 215 struct scatterlist *sge; 216 struct page *page; 217 int size, ret = 0; 218 u32 off; 219 220 while (1) { 221 sge = sk_msg_elem(msg, msg->sg.start); 222 size = (apply && apply_bytes < sge->length) ? 223 apply_bytes : sge->length; 224 off = sge->offset; 225 page = sg_page(sge); 226 227 tcp_rate_check_app_limited(sk); 228 retry: 229 ret = do_tcp_sendpages(sk, page, off, size, flags); 230 if (ret <= 0) 231 return ret; 232 if (apply) 233 apply_bytes -= ret; 234 msg->sg.size -= ret; 235 sge->offset += ret; 236 sge->length -= ret; 237 if (uncharge) 238 sk_mem_uncharge(sk, ret); 239 if (ret != size) { 240 size -= ret; 241 off += ret; 242 goto retry; 243 } 244 if (!sge->length) { 245 put_page(page); 246 sk_msg_iter_next(msg, start); 247 sg_init_table(sge, 1); 248 if (msg->sg.start == msg->sg.end) 249 break; 250 } 251 if (apply && !apply_bytes) 252 break; 253 } 254 255 return 0; 256 } 257 258 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 259 u32 apply_bytes, int flags, bool uncharge) 260 { 261 int ret; 262 263 lock_sock(sk); 264 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 265 release_sock(sk); 266 return ret; 267 } 268 269 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, 270 u32 bytes, int flags) 271 { 272 bool ingress = sk_msg_to_ingress(msg); 273 struct sk_psock *psock = sk_psock_get(sk); 274 int ret; 275 276 if (unlikely(!psock)) { 277 sk_msg_free(sk, msg); 278 return 0; 279 } 280 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 281 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 282 sk_psock_put(sk, psock); 283 return ret; 284 } 285 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 286 287 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 288 struct sk_msg *msg, int *copied, int flags) 289 { 290 bool cork = false, enospc = msg->sg.start == msg->sg.end; 291 struct sock *sk_redir; 292 u32 tosend, delta = 0; 293 int ret; 294 295 more_data: 296 if (psock->eval == __SK_NONE) { 297 /* Track delta in msg size to add/subtract it on SK_DROP from 298 * returned to user copied size. This ensures user doesn't 299 * get a positive return code with msg_cut_data and SK_DROP 300 * verdict. 301 */ 302 delta = msg->sg.size; 303 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 304 if (msg->sg.size < delta) 305 delta -= msg->sg.size; 306 else 307 delta = 0; 308 } 309 310 if (msg->cork_bytes && 311 msg->cork_bytes > msg->sg.size && !enospc) { 312 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 313 if (!psock->cork) { 314 psock->cork = kzalloc(sizeof(*psock->cork), 315 GFP_ATOMIC | __GFP_NOWARN); 316 if (!psock->cork) 317 return -ENOMEM; 318 } 319 memcpy(psock->cork, msg, sizeof(*msg)); 320 return 0; 321 } 322 323 tosend = msg->sg.size; 324 if (psock->apply_bytes && psock->apply_bytes < tosend) 325 tosend = psock->apply_bytes; 326 327 switch (psock->eval) { 328 case __SK_PASS: 329 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 330 if (unlikely(ret)) { 331 *copied -= sk_msg_free(sk, msg); 332 break; 333 } 334 sk_msg_apply_bytes(psock, tosend); 335 break; 336 case __SK_REDIRECT: 337 sk_redir = psock->sk_redir; 338 sk_msg_apply_bytes(psock, tosend); 339 if (psock->cork) { 340 cork = true; 341 psock->cork = NULL; 342 } 343 sk_msg_return(sk, msg, tosend); 344 release_sock(sk); 345 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); 346 lock_sock(sk); 347 if (unlikely(ret < 0)) { 348 int free = sk_msg_free_nocharge(sk, msg); 349 350 if (!cork) 351 *copied -= free; 352 } 353 if (cork) { 354 sk_msg_free(sk, msg); 355 kfree(msg); 356 msg = NULL; 357 ret = 0; 358 } 359 break; 360 case __SK_DROP: 361 default: 362 sk_msg_free_partial(sk, msg, tosend); 363 sk_msg_apply_bytes(psock, tosend); 364 *copied -= (tosend + delta); 365 return -EACCES; 366 } 367 368 if (likely(!ret)) { 369 if (!psock->apply_bytes) { 370 psock->eval = __SK_NONE; 371 if (psock->sk_redir) { 372 sock_put(psock->sk_redir); 373 psock->sk_redir = NULL; 374 } 375 } 376 if (msg && 377 msg->sg.data[msg->sg.start].page_link && 378 msg->sg.data[msg->sg.start].length) 379 goto more_data; 380 } 381 return ret; 382 } 383 384 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 385 { 386 struct sk_msg tmp, *msg_tx = NULL; 387 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; 388 int copied = 0, err = 0; 389 struct sk_psock *psock; 390 long timeo; 391 392 psock = sk_psock_get(sk); 393 if (unlikely(!psock)) 394 return tcp_sendmsg(sk, msg, size); 395 396 lock_sock(sk); 397 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 398 while (msg_data_left(msg)) { 399 bool enospc = false; 400 u32 copy, osize; 401 402 if (sk->sk_err) { 403 err = -sk->sk_err; 404 goto out_err; 405 } 406 407 copy = msg_data_left(msg); 408 if (!sk_stream_memory_free(sk)) 409 goto wait_for_sndbuf; 410 if (psock->cork) { 411 msg_tx = psock->cork; 412 } else { 413 msg_tx = &tmp; 414 sk_msg_init(msg_tx); 415 } 416 417 osize = msg_tx->sg.size; 418 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 419 if (err) { 420 if (err != -ENOSPC) 421 goto wait_for_memory; 422 enospc = true; 423 copy = msg_tx->sg.size - osize; 424 } 425 426 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 427 copy); 428 if (err < 0) { 429 sk_msg_trim(sk, msg_tx, osize); 430 goto out_err; 431 } 432 433 copied += copy; 434 if (psock->cork_bytes) { 435 if (size > psock->cork_bytes) 436 psock->cork_bytes = 0; 437 else 438 psock->cork_bytes -= size; 439 if (psock->cork_bytes && !enospc) 440 goto out_err; 441 /* All cork bytes are accounted, rerun the prog. */ 442 psock->eval = __SK_NONE; 443 psock->cork_bytes = 0; 444 } 445 446 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 447 if (unlikely(err < 0)) 448 goto out_err; 449 continue; 450 wait_for_sndbuf: 451 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 452 wait_for_memory: 453 err = sk_stream_wait_memory(sk, &timeo); 454 if (err) { 455 if (msg_tx && msg_tx != psock->cork) 456 sk_msg_free(sk, msg_tx); 457 goto out_err; 458 } 459 } 460 out_err: 461 if (err < 0) 462 err = sk_stream_error(sk, msg->msg_flags, err); 463 release_sock(sk); 464 sk_psock_put(sk, psock); 465 return copied ? copied : err; 466 } 467 468 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 469 size_t size, int flags) 470 { 471 struct sk_msg tmp, *msg = NULL; 472 int err = 0, copied = 0; 473 struct sk_psock *psock; 474 bool enospc = false; 475 476 psock = sk_psock_get(sk); 477 if (unlikely(!psock)) 478 return tcp_sendpage(sk, page, offset, size, flags); 479 480 lock_sock(sk); 481 if (psock->cork) { 482 msg = psock->cork; 483 } else { 484 msg = &tmp; 485 sk_msg_init(msg); 486 } 487 488 /* Catch case where ring is full and sendpage is stalled. */ 489 if (unlikely(sk_msg_full(msg))) 490 goto out_err; 491 492 sk_msg_page_add(msg, page, size, offset); 493 sk_mem_charge(sk, size); 494 copied = size; 495 if (sk_msg_full(msg)) 496 enospc = true; 497 if (psock->cork_bytes) { 498 if (size > psock->cork_bytes) 499 psock->cork_bytes = 0; 500 else 501 psock->cork_bytes -= size; 502 if (psock->cork_bytes && !enospc) 503 goto out_err; 504 /* All cork bytes are accounted, rerun the prog. */ 505 psock->eval = __SK_NONE; 506 psock->cork_bytes = 0; 507 } 508 509 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 510 out_err: 511 release_sock(sk); 512 sk_psock_put(sk, psock); 513 return copied ? copied : err; 514 } 515 516 static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) 517 { 518 struct sk_psock_link *link; 519 520 sk_psock_cork_free(psock); 521 __sk_psock_purge_ingress_msg(psock); 522 while ((link = sk_psock_link_pop(psock))) { 523 sk_psock_unlink(sk, link); 524 sk_psock_free_link(link); 525 } 526 } 527 528 static void tcp_bpf_unhash(struct sock *sk) 529 { 530 void (*saved_unhash)(struct sock *sk); 531 struct sk_psock *psock; 532 533 rcu_read_lock(); 534 psock = sk_psock(sk); 535 if (unlikely(!psock)) { 536 rcu_read_unlock(); 537 if (sk->sk_prot->unhash) 538 sk->sk_prot->unhash(sk); 539 return; 540 } 541 542 saved_unhash = psock->saved_unhash; 543 tcp_bpf_remove(sk, psock); 544 rcu_read_unlock(); 545 saved_unhash(sk); 546 } 547 548 static void tcp_bpf_close(struct sock *sk, long timeout) 549 { 550 void (*saved_close)(struct sock *sk, long timeout); 551 struct sk_psock *psock; 552 553 lock_sock(sk); 554 rcu_read_lock(); 555 psock = sk_psock(sk); 556 if (unlikely(!psock)) { 557 rcu_read_unlock(); 558 release_sock(sk); 559 return sk->sk_prot->close(sk, timeout); 560 } 561 562 saved_close = psock->saved_close; 563 tcp_bpf_remove(sk, psock); 564 rcu_read_unlock(); 565 release_sock(sk); 566 saved_close(sk, timeout); 567 } 568 569 enum { 570 TCP_BPF_IPV4, 571 TCP_BPF_IPV6, 572 TCP_BPF_NUM_PROTS, 573 }; 574 575 enum { 576 TCP_BPF_BASE, 577 TCP_BPF_TX, 578 TCP_BPF_NUM_CFGS, 579 }; 580 581 static struct proto *tcpv6_prot_saved __read_mostly; 582 static DEFINE_SPINLOCK(tcpv6_prot_lock); 583 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 584 585 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 586 struct proto *base) 587 { 588 prot[TCP_BPF_BASE] = *base; 589 prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; 590 prot[TCP_BPF_BASE].close = tcp_bpf_close; 591 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 592 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; 593 594 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 595 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 596 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 597 } 598 599 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) 600 { 601 if (sk->sk_family == AF_INET6 && 602 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 603 spin_lock_bh(&tcpv6_prot_lock); 604 if (likely(ops != tcpv6_prot_saved)) { 605 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 606 smp_store_release(&tcpv6_prot_saved, ops); 607 } 608 spin_unlock_bh(&tcpv6_prot_lock); 609 } 610 } 611 612 static int __init tcp_bpf_v4_build_proto(void) 613 { 614 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 615 return 0; 616 } 617 core_initcall(tcp_bpf_v4_build_proto); 618 619 static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) 620 { 621 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 622 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 623 624 sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); 625 } 626 627 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) 628 { 629 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 630 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 631 632 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed 633 * or added requiring sk_prot hook updates. We keep original saved 634 * hooks in this case. 635 */ 636 sk->sk_prot = &tcp_bpf_prots[family][config]; 637 } 638 639 static int tcp_bpf_assert_proto_ops(struct proto *ops) 640 { 641 /* In order to avoid retpoline, we make assumptions when we call 642 * into ops if e.g. a psock is not present. Make sure they are 643 * indeed valid assumptions. 644 */ 645 return ops->recvmsg == tcp_recvmsg && 646 ops->sendmsg == tcp_sendmsg && 647 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 648 } 649 650 void tcp_bpf_reinit(struct sock *sk) 651 { 652 struct sk_psock *psock; 653 654 sock_owned_by_me(sk); 655 656 rcu_read_lock(); 657 psock = sk_psock(sk); 658 tcp_bpf_reinit_sk_prot(sk, psock); 659 rcu_read_unlock(); 660 } 661 662 int tcp_bpf_init(struct sock *sk) 663 { 664 struct proto *ops = READ_ONCE(sk->sk_prot); 665 struct sk_psock *psock; 666 667 sock_owned_by_me(sk); 668 669 rcu_read_lock(); 670 psock = sk_psock(sk); 671 if (unlikely(!psock || psock->sk_proto || 672 tcp_bpf_assert_proto_ops(ops))) { 673 rcu_read_unlock(); 674 return -EINVAL; 675 } 676 tcp_bpf_check_v6_needs_rebuild(sk, ops); 677 tcp_bpf_update_sk_prot(sk, psock); 678 rcu_read_unlock(); 679 return 0; 680 } 681