1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 10 #include <net/inet_common.h> 11 12 static bool tcp_bpf_stream_read(const struct sock *sk) 13 { 14 struct sk_psock *psock; 15 bool empty = true; 16 17 rcu_read_lock(); 18 psock = sk_psock(sk); 19 if (likely(psock)) 20 empty = list_empty(&psock->ingress_msg); 21 rcu_read_unlock(); 22 return !empty; 23 } 24 25 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, 26 int flags, long timeo, int *err) 27 { 28 DEFINE_WAIT_FUNC(wait, woken_wake_function); 29 int ret; 30 31 add_wait_queue(sk_sleep(sk), &wait); 32 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 33 ret = sk_wait_event(sk, &timeo, 34 !list_empty(&psock->ingress_msg) || 35 !skb_queue_empty(&sk->sk_receive_queue), &wait); 36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 37 remove_wait_queue(sk_sleep(sk), &wait); 38 return ret; 39 } 40 41 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, 42 struct msghdr *msg, int len) 43 { 44 struct iov_iter *iter = &msg->msg_iter; 45 int i, ret, copied = 0; 46 47 while (copied != len) { 48 struct scatterlist *sge; 49 struct sk_msg *msg_rx; 50 51 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 52 struct sk_msg, list); 53 if (unlikely(!msg_rx)) 54 break; 55 56 i = msg_rx->sg.start; 57 do { 58 struct page *page; 59 int copy; 60 61 sge = sk_msg_elem(msg_rx, i); 62 copy = sge->length; 63 page = sg_page(sge); 64 if (copied + copy > len) 65 copy = len - copied; 66 ret = copy_page_to_iter(page, sge->offset, copy, iter); 67 if (ret != copy) { 68 msg_rx->sg.start = i; 69 return -EFAULT; 70 } 71 72 copied += copy; 73 sge->offset += copy; 74 sge->length -= copy; 75 sk_mem_uncharge(sk, copy); 76 if (!sge->length) { 77 i++; 78 if (i == MAX_SKB_FRAGS) 79 i = 0; 80 if (!msg_rx->skb) 81 put_page(page); 82 } 83 84 if (copied == len) 85 break; 86 } while (i != msg_rx->sg.end); 87 88 msg_rx->sg.start = i; 89 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { 90 list_del(&msg_rx->list); 91 if (msg_rx->skb) 92 consume_skb(msg_rx->skb); 93 kfree(msg_rx); 94 } 95 } 96 97 return copied; 98 } 99 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); 100 101 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 102 int nonblock, int flags, int *addr_len) 103 { 104 struct sk_psock *psock; 105 int copied, ret; 106 107 if (unlikely(flags & MSG_ERRQUEUE)) 108 return inet_recv_error(sk, msg, len, addr_len); 109 if (!skb_queue_empty(&sk->sk_receive_queue)) 110 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 111 112 psock = sk_psock_get(sk); 113 if (unlikely(!psock)) 114 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 115 lock_sock(sk); 116 msg_bytes_ready: 117 copied = __tcp_bpf_recvmsg(sk, psock, msg, len); 118 if (!copied) { 119 int data, err = 0; 120 long timeo; 121 122 timeo = sock_rcvtimeo(sk, nonblock); 123 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); 124 if (data) { 125 if (skb_queue_empty(&sk->sk_receive_queue)) 126 goto msg_bytes_ready; 127 release_sock(sk); 128 sk_psock_put(sk, psock); 129 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 130 } 131 if (err) { 132 ret = err; 133 goto out; 134 } 135 } 136 ret = copied; 137 out: 138 release_sock(sk); 139 sk_psock_put(sk, psock); 140 return ret; 141 } 142 143 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 144 struct sk_msg *msg, u32 apply_bytes, int flags) 145 { 146 bool apply = apply_bytes; 147 struct scatterlist *sge; 148 u32 size, copied = 0; 149 struct sk_msg *tmp; 150 int i, ret = 0; 151 152 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 153 if (unlikely(!tmp)) 154 return -ENOMEM; 155 156 lock_sock(sk); 157 tmp->sg.start = msg->sg.start; 158 i = msg->sg.start; 159 do { 160 sge = sk_msg_elem(msg, i); 161 size = (apply && apply_bytes < sge->length) ? 162 apply_bytes : sge->length; 163 if (!sk_wmem_schedule(sk, size)) { 164 if (!copied) 165 ret = -ENOMEM; 166 break; 167 } 168 169 sk_mem_charge(sk, size); 170 sk_msg_xfer(tmp, msg, i, size); 171 copied += size; 172 if (sge->length) 173 get_page(sk_msg_page(tmp, i)); 174 sk_msg_iter_var_next(i); 175 tmp->sg.end = i; 176 if (apply) { 177 apply_bytes -= size; 178 if (!apply_bytes) 179 break; 180 } 181 } while (i != msg->sg.end); 182 183 if (!ret) { 184 msg->sg.start = i; 185 msg->sg.size -= apply_bytes; 186 sk_psock_queue_msg(psock, tmp); 187 sk->sk_data_ready(sk); 188 } else { 189 sk_msg_free(sk, tmp); 190 kfree(tmp); 191 } 192 193 release_sock(sk); 194 return ret; 195 } 196 197 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 198 int flags, bool uncharge) 199 { 200 bool apply = apply_bytes; 201 struct scatterlist *sge; 202 struct page *page; 203 int size, ret = 0; 204 u32 off; 205 206 while (1) { 207 sge = sk_msg_elem(msg, msg->sg.start); 208 size = (apply && apply_bytes < sge->length) ? 209 apply_bytes : sge->length; 210 off = sge->offset; 211 page = sg_page(sge); 212 213 tcp_rate_check_app_limited(sk); 214 retry: 215 ret = do_tcp_sendpages(sk, page, off, size, flags); 216 if (ret <= 0) 217 return ret; 218 if (apply) 219 apply_bytes -= ret; 220 msg->sg.size -= ret; 221 sge->offset += ret; 222 sge->length -= ret; 223 if (uncharge) 224 sk_mem_uncharge(sk, ret); 225 if (ret != size) { 226 size -= ret; 227 off += ret; 228 goto retry; 229 } 230 if (!sge->length) { 231 put_page(page); 232 sk_msg_iter_next(msg, start); 233 sg_init_table(sge, 1); 234 if (msg->sg.start == msg->sg.end) 235 break; 236 } 237 if (apply && !apply_bytes) 238 break; 239 } 240 241 return 0; 242 } 243 244 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 245 u32 apply_bytes, int flags, bool uncharge) 246 { 247 int ret; 248 249 lock_sock(sk); 250 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 251 release_sock(sk); 252 return ret; 253 } 254 255 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, 256 u32 bytes, int flags) 257 { 258 bool ingress = sk_msg_to_ingress(msg); 259 struct sk_psock *psock = sk_psock_get(sk); 260 int ret; 261 262 if (unlikely(!psock)) { 263 sk_msg_free(sk, msg); 264 return 0; 265 } 266 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 267 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 268 sk_psock_put(sk, psock); 269 return ret; 270 } 271 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 272 273 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 274 struct sk_msg *msg, int *copied, int flags) 275 { 276 bool cork = false, enospc = msg->sg.start == msg->sg.end; 277 struct sock *sk_redir; 278 u32 tosend; 279 int ret; 280 281 more_data: 282 if (psock->eval == __SK_NONE) 283 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 284 285 if (msg->cork_bytes && 286 msg->cork_bytes > msg->sg.size && !enospc) { 287 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 288 if (!psock->cork) { 289 psock->cork = kzalloc(sizeof(*psock->cork), 290 GFP_ATOMIC | __GFP_NOWARN); 291 if (!psock->cork) 292 return -ENOMEM; 293 } 294 memcpy(psock->cork, msg, sizeof(*msg)); 295 return 0; 296 } 297 298 tosend = msg->sg.size; 299 if (psock->apply_bytes && psock->apply_bytes < tosend) 300 tosend = psock->apply_bytes; 301 302 switch (psock->eval) { 303 case __SK_PASS: 304 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 305 if (unlikely(ret)) { 306 *copied -= sk_msg_free(sk, msg); 307 break; 308 } 309 sk_msg_apply_bytes(psock, tosend); 310 break; 311 case __SK_REDIRECT: 312 sk_redir = psock->sk_redir; 313 sk_msg_apply_bytes(psock, tosend); 314 if (psock->cork) { 315 cork = true; 316 psock->cork = NULL; 317 } 318 sk_msg_return(sk, msg, tosend); 319 release_sock(sk); 320 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); 321 lock_sock(sk); 322 if (unlikely(ret < 0)) { 323 int free = sk_msg_free_nocharge(sk, msg); 324 325 if (!cork) 326 *copied -= free; 327 } 328 if (cork) { 329 sk_msg_free(sk, msg); 330 kfree(msg); 331 msg = NULL; 332 ret = 0; 333 } 334 break; 335 case __SK_DROP: 336 default: 337 sk_msg_free_partial(sk, msg, tosend); 338 sk_msg_apply_bytes(psock, tosend); 339 *copied -= tosend; 340 return -EACCES; 341 } 342 343 if (likely(!ret)) { 344 if (!psock->apply_bytes) { 345 psock->eval = __SK_NONE; 346 if (psock->sk_redir) { 347 sock_put(psock->sk_redir); 348 psock->sk_redir = NULL; 349 } 350 } 351 if (msg && 352 msg->sg.data[msg->sg.start].page_link && 353 msg->sg.data[msg->sg.start].length) 354 goto more_data; 355 } 356 return ret; 357 } 358 359 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 360 { 361 struct sk_msg tmp, *msg_tx = NULL; 362 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; 363 int copied = 0, err = 0; 364 struct sk_psock *psock; 365 long timeo; 366 367 psock = sk_psock_get(sk); 368 if (unlikely(!psock)) 369 return tcp_sendmsg(sk, msg, size); 370 371 lock_sock(sk); 372 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 373 while (msg_data_left(msg)) { 374 bool enospc = false; 375 u32 copy, osize; 376 377 if (sk->sk_err) { 378 err = -sk->sk_err; 379 goto out_err; 380 } 381 382 copy = msg_data_left(msg); 383 if (!sk_stream_memory_free(sk)) 384 goto wait_for_sndbuf; 385 if (psock->cork) { 386 msg_tx = psock->cork; 387 } else { 388 msg_tx = &tmp; 389 sk_msg_init(msg_tx); 390 } 391 392 osize = msg_tx->sg.size; 393 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 394 if (err) { 395 if (err != -ENOSPC) 396 goto wait_for_memory; 397 enospc = true; 398 copy = msg_tx->sg.size - osize; 399 } 400 401 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 402 copy); 403 if (err < 0) { 404 sk_msg_trim(sk, msg_tx, osize); 405 goto out_err; 406 } 407 408 copied += copy; 409 if (psock->cork_bytes) { 410 if (size > psock->cork_bytes) 411 psock->cork_bytes = 0; 412 else 413 psock->cork_bytes -= size; 414 if (psock->cork_bytes && !enospc) 415 goto out_err; 416 /* All cork bytes are accounted, rerun the prog. */ 417 psock->eval = __SK_NONE; 418 psock->cork_bytes = 0; 419 } 420 421 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 422 if (unlikely(err < 0)) 423 goto out_err; 424 continue; 425 wait_for_sndbuf: 426 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 427 wait_for_memory: 428 err = sk_stream_wait_memory(sk, &timeo); 429 if (err) { 430 if (msg_tx && msg_tx != psock->cork) 431 sk_msg_free(sk, msg_tx); 432 goto out_err; 433 } 434 } 435 out_err: 436 if (err < 0) 437 err = sk_stream_error(sk, msg->msg_flags, err); 438 release_sock(sk); 439 sk_psock_put(sk, psock); 440 return copied ? copied : err; 441 } 442 443 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 444 size_t size, int flags) 445 { 446 struct sk_msg tmp, *msg = NULL; 447 int err = 0, copied = 0; 448 struct sk_psock *psock; 449 bool enospc = false; 450 451 psock = sk_psock_get(sk); 452 if (unlikely(!psock)) 453 return tcp_sendpage(sk, page, offset, size, flags); 454 455 lock_sock(sk); 456 if (psock->cork) { 457 msg = psock->cork; 458 } else { 459 msg = &tmp; 460 sk_msg_init(msg); 461 } 462 463 /* Catch case where ring is full and sendpage is stalled. */ 464 if (unlikely(sk_msg_full(msg))) 465 goto out_err; 466 467 sk_msg_page_add(msg, page, size, offset); 468 sk_mem_charge(sk, size); 469 copied = size; 470 if (sk_msg_full(msg)) 471 enospc = true; 472 if (psock->cork_bytes) { 473 if (size > psock->cork_bytes) 474 psock->cork_bytes = 0; 475 else 476 psock->cork_bytes -= size; 477 if (psock->cork_bytes && !enospc) 478 goto out_err; 479 /* All cork bytes are accounted, rerun the prog. */ 480 psock->eval = __SK_NONE; 481 psock->cork_bytes = 0; 482 } 483 484 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 485 out_err: 486 release_sock(sk); 487 sk_psock_put(sk, psock); 488 return copied ? copied : err; 489 } 490 491 static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) 492 { 493 struct sk_psock_link *link; 494 495 sk_psock_cork_free(psock); 496 __sk_psock_purge_ingress_msg(psock); 497 while ((link = sk_psock_link_pop(psock))) { 498 sk_psock_unlink(sk, link); 499 sk_psock_free_link(link); 500 } 501 } 502 503 static void tcp_bpf_unhash(struct sock *sk) 504 { 505 void (*saved_unhash)(struct sock *sk); 506 struct sk_psock *psock; 507 508 rcu_read_lock(); 509 psock = sk_psock(sk); 510 if (unlikely(!psock)) { 511 rcu_read_unlock(); 512 if (sk->sk_prot->unhash) 513 sk->sk_prot->unhash(sk); 514 return; 515 } 516 517 saved_unhash = psock->saved_unhash; 518 tcp_bpf_remove(sk, psock); 519 rcu_read_unlock(); 520 saved_unhash(sk); 521 } 522 523 static void tcp_bpf_close(struct sock *sk, long timeout) 524 { 525 void (*saved_close)(struct sock *sk, long timeout); 526 struct sk_psock *psock; 527 528 lock_sock(sk); 529 rcu_read_lock(); 530 psock = sk_psock(sk); 531 if (unlikely(!psock)) { 532 rcu_read_unlock(); 533 release_sock(sk); 534 return sk->sk_prot->close(sk, timeout); 535 } 536 537 saved_close = psock->saved_close; 538 tcp_bpf_remove(sk, psock); 539 rcu_read_unlock(); 540 release_sock(sk); 541 saved_close(sk, timeout); 542 } 543 544 enum { 545 TCP_BPF_IPV4, 546 TCP_BPF_IPV6, 547 TCP_BPF_NUM_PROTS, 548 }; 549 550 enum { 551 TCP_BPF_BASE, 552 TCP_BPF_TX, 553 TCP_BPF_NUM_CFGS, 554 }; 555 556 static struct proto *tcpv6_prot_saved __read_mostly; 557 static DEFINE_SPINLOCK(tcpv6_prot_lock); 558 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 559 560 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 561 struct proto *base) 562 { 563 prot[TCP_BPF_BASE] = *base; 564 prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; 565 prot[TCP_BPF_BASE].close = tcp_bpf_close; 566 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 567 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; 568 569 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 570 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 571 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 572 } 573 574 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) 575 { 576 if (sk->sk_family == AF_INET6 && 577 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 578 spin_lock_bh(&tcpv6_prot_lock); 579 if (likely(ops != tcpv6_prot_saved)) { 580 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 581 smp_store_release(&tcpv6_prot_saved, ops); 582 } 583 spin_unlock_bh(&tcpv6_prot_lock); 584 } 585 } 586 587 static int __init tcp_bpf_v4_build_proto(void) 588 { 589 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 590 return 0; 591 } 592 core_initcall(tcp_bpf_v4_build_proto); 593 594 static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) 595 { 596 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 597 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 598 599 sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); 600 } 601 602 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) 603 { 604 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 605 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 606 607 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed 608 * or added requiring sk_prot hook updates. We keep original saved 609 * hooks in this case. 610 */ 611 sk->sk_prot = &tcp_bpf_prots[family][config]; 612 } 613 614 static int tcp_bpf_assert_proto_ops(struct proto *ops) 615 { 616 /* In order to avoid retpoline, we make assumptions when we call 617 * into ops if e.g. a psock is not present. Make sure they are 618 * indeed valid assumptions. 619 */ 620 return ops->recvmsg == tcp_recvmsg && 621 ops->sendmsg == tcp_sendmsg && 622 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 623 } 624 625 void tcp_bpf_reinit(struct sock *sk) 626 { 627 struct sk_psock *psock; 628 629 sock_owned_by_me(sk); 630 631 rcu_read_lock(); 632 psock = sk_psock(sk); 633 tcp_bpf_reinit_sk_prot(sk, psock); 634 rcu_read_unlock(); 635 } 636 637 int tcp_bpf_init(struct sock *sk) 638 { 639 struct proto *ops = READ_ONCE(sk->sk_prot); 640 struct sk_psock *psock; 641 642 sock_owned_by_me(sk); 643 644 rcu_read_lock(); 645 psock = sk_psock(sk); 646 if (unlikely(!psock || psock->sk_proto || 647 tcp_bpf_assert_proto_ops(ops))) { 648 rcu_read_unlock(); 649 return -EINVAL; 650 } 651 tcp_bpf_check_v6_needs_rebuild(sk, ops); 652 tcp_bpf_update_sk_prot(sk, psock); 653 rcu_read_unlock(); 654 return 0; 655 } 656