1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 10 #include <net/inet_common.h> 11 #include <net/tls.h> 12 13 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 14 struct sk_msg *msg, u32 apply_bytes, int flags) 15 { 16 bool apply = apply_bytes; 17 struct scatterlist *sge; 18 u32 size, copied = 0; 19 struct sk_msg *tmp; 20 int i, ret = 0; 21 22 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 23 if (unlikely(!tmp)) 24 return -ENOMEM; 25 26 lock_sock(sk); 27 tmp->sg.start = msg->sg.start; 28 i = msg->sg.start; 29 do { 30 sge = sk_msg_elem(msg, i); 31 size = (apply && apply_bytes < sge->length) ? 32 apply_bytes : sge->length; 33 if (!sk_wmem_schedule(sk, size)) { 34 if (!copied) 35 ret = -ENOMEM; 36 break; 37 } 38 39 sk_mem_charge(sk, size); 40 sk_msg_xfer(tmp, msg, i, size); 41 copied += size; 42 if (sge->length) 43 get_page(sk_msg_page(tmp, i)); 44 sk_msg_iter_var_next(i); 45 tmp->sg.end = i; 46 if (apply) { 47 apply_bytes -= size; 48 if (!apply_bytes) { 49 if (sge->length) 50 sk_msg_iter_var_prev(i); 51 break; 52 } 53 } 54 } while (i != msg->sg.end); 55 56 if (!ret) { 57 msg->sg.start = i; 58 sk_psock_queue_msg(psock, tmp); 59 sk_psock_data_ready(sk, psock); 60 } else { 61 sk_msg_free(sk, tmp); 62 kfree(tmp); 63 } 64 65 release_sock(sk); 66 return ret; 67 } 68 69 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 70 int flags, bool uncharge) 71 { 72 bool apply = apply_bytes; 73 struct scatterlist *sge; 74 struct page *page; 75 int size, ret = 0; 76 u32 off; 77 78 while (1) { 79 bool has_tx_ulp; 80 81 sge = sk_msg_elem(msg, msg->sg.start); 82 size = (apply && apply_bytes < sge->length) ? 83 apply_bytes : sge->length; 84 off = sge->offset; 85 page = sg_page(sge); 86 87 tcp_rate_check_app_limited(sk); 88 retry: 89 has_tx_ulp = tls_sw_has_ctx_tx(sk); 90 if (has_tx_ulp) { 91 flags |= MSG_SENDPAGE_NOPOLICY; 92 ret = kernel_sendpage_locked(sk, 93 page, off, size, flags); 94 } else { 95 ret = do_tcp_sendpages(sk, page, off, size, flags); 96 } 97 98 if (ret <= 0) 99 return ret; 100 if (apply) 101 apply_bytes -= ret; 102 msg->sg.size -= ret; 103 sge->offset += ret; 104 sge->length -= ret; 105 if (uncharge) 106 sk_mem_uncharge(sk, ret); 107 if (ret != size) { 108 size -= ret; 109 off += ret; 110 goto retry; 111 } 112 if (!sge->length) { 113 put_page(page); 114 sk_msg_iter_next(msg, start); 115 sg_init_table(sge, 1); 116 if (msg->sg.start == msg->sg.end) 117 break; 118 } 119 if (apply && !apply_bytes) 120 break; 121 } 122 123 return 0; 124 } 125 126 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 127 u32 apply_bytes, int flags, bool uncharge) 128 { 129 int ret; 130 131 lock_sock(sk); 132 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 133 release_sock(sk); 134 return ret; 135 } 136 137 int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress, 138 struct sk_msg *msg, u32 bytes, int flags) 139 { 140 struct sk_psock *psock = sk_psock_get(sk); 141 int ret; 142 143 if (unlikely(!psock)) 144 return -EPIPE; 145 146 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 147 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 148 sk_psock_put(sk, psock); 149 return ret; 150 } 151 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 152 153 #ifdef CONFIG_BPF_SYSCALL 154 static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 155 long timeo) 156 { 157 DEFINE_WAIT_FUNC(wait, woken_wake_function); 158 int ret = 0; 159 160 if (sk->sk_shutdown & RCV_SHUTDOWN) 161 return 1; 162 163 if (!timeo) 164 return ret; 165 166 add_wait_queue(sk_sleep(sk), &wait); 167 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 168 ret = sk_wait_event(sk, &timeo, 169 !list_empty(&psock->ingress_msg) || 170 !skb_queue_empty(&sk->sk_receive_queue), &wait); 171 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 172 remove_wait_queue(sk_sleep(sk), &wait); 173 return ret; 174 } 175 176 static int tcp_bpf_recvmsg_parser(struct sock *sk, 177 struct msghdr *msg, 178 size_t len, 179 int flags, 180 int *addr_len) 181 { 182 struct sk_psock *psock; 183 int copied; 184 185 if (unlikely(flags & MSG_ERRQUEUE)) 186 return inet_recv_error(sk, msg, len, addr_len); 187 188 psock = sk_psock_get(sk); 189 if (unlikely(!psock)) 190 return tcp_recvmsg(sk, msg, len, flags, addr_len); 191 192 lock_sock(sk); 193 msg_bytes_ready: 194 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 195 if (!copied) { 196 long timeo; 197 int data; 198 199 if (sock_flag(sk, SOCK_DONE)) 200 goto out; 201 202 if (sk->sk_err) { 203 copied = sock_error(sk); 204 goto out; 205 } 206 207 if (sk->sk_shutdown & RCV_SHUTDOWN) 208 goto out; 209 210 if (sk->sk_state == TCP_CLOSE) { 211 copied = -ENOTCONN; 212 goto out; 213 } 214 215 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 216 if (!timeo) { 217 copied = -EAGAIN; 218 goto out; 219 } 220 221 if (signal_pending(current)) { 222 copied = sock_intr_errno(timeo); 223 goto out; 224 } 225 226 data = tcp_msg_wait_data(sk, psock, timeo); 227 if (data && !sk_psock_queue_empty(psock)) 228 goto msg_bytes_ready; 229 copied = -EAGAIN; 230 } 231 out: 232 release_sock(sk); 233 sk_psock_put(sk, psock); 234 return copied; 235 } 236 237 static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 238 int flags, int *addr_len) 239 { 240 struct sk_psock *psock; 241 int copied, ret; 242 243 if (unlikely(flags & MSG_ERRQUEUE)) 244 return inet_recv_error(sk, msg, len, addr_len); 245 246 psock = sk_psock_get(sk); 247 if (unlikely(!psock)) 248 return tcp_recvmsg(sk, msg, len, flags, addr_len); 249 if (!skb_queue_empty(&sk->sk_receive_queue) && 250 sk_psock_queue_empty(psock)) { 251 sk_psock_put(sk, psock); 252 return tcp_recvmsg(sk, msg, len, flags, addr_len); 253 } 254 lock_sock(sk); 255 msg_bytes_ready: 256 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 257 if (!copied) { 258 long timeo; 259 int data; 260 261 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 262 data = tcp_msg_wait_data(sk, psock, timeo); 263 if (data) { 264 if (!sk_psock_queue_empty(psock)) 265 goto msg_bytes_ready; 266 release_sock(sk); 267 sk_psock_put(sk, psock); 268 return tcp_recvmsg(sk, msg, len, flags, addr_len); 269 } 270 copied = -EAGAIN; 271 } 272 ret = copied; 273 release_sock(sk); 274 sk_psock_put(sk, psock); 275 return ret; 276 } 277 278 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 279 struct sk_msg *msg, int *copied, int flags) 280 { 281 bool cork = false, enospc = sk_msg_full(msg), redir_ingress; 282 struct sock *sk_redir; 283 u32 tosend, origsize, sent, delta = 0; 284 u32 eval; 285 int ret; 286 287 more_data: 288 if (psock->eval == __SK_NONE) { 289 /* Track delta in msg size to add/subtract it on SK_DROP from 290 * returned to user copied size. This ensures user doesn't 291 * get a positive return code with msg_cut_data and SK_DROP 292 * verdict. 293 */ 294 delta = msg->sg.size; 295 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 296 delta -= msg->sg.size; 297 } 298 299 if (msg->cork_bytes && 300 msg->cork_bytes > msg->sg.size && !enospc) { 301 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 302 if (!psock->cork) { 303 psock->cork = kzalloc(sizeof(*psock->cork), 304 GFP_ATOMIC | __GFP_NOWARN); 305 if (!psock->cork) 306 return -ENOMEM; 307 } 308 memcpy(psock->cork, msg, sizeof(*msg)); 309 return 0; 310 } 311 312 tosend = msg->sg.size; 313 if (psock->apply_bytes && psock->apply_bytes < tosend) 314 tosend = psock->apply_bytes; 315 eval = __SK_NONE; 316 317 switch (psock->eval) { 318 case __SK_PASS: 319 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 320 if (unlikely(ret)) { 321 *copied -= sk_msg_free(sk, msg); 322 break; 323 } 324 sk_msg_apply_bytes(psock, tosend); 325 break; 326 case __SK_REDIRECT: 327 redir_ingress = psock->redir_ingress; 328 sk_redir = psock->sk_redir; 329 sk_msg_apply_bytes(psock, tosend); 330 if (!psock->apply_bytes) { 331 /* Clean up before releasing the sock lock. */ 332 eval = psock->eval; 333 psock->eval = __SK_NONE; 334 psock->sk_redir = NULL; 335 } 336 if (psock->cork) { 337 cork = true; 338 psock->cork = NULL; 339 } 340 sk_msg_return(sk, msg, tosend); 341 release_sock(sk); 342 343 origsize = msg->sg.size; 344 ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, 345 msg, tosend, flags); 346 sent = origsize - msg->sg.size; 347 348 if (eval == __SK_REDIRECT) 349 sock_put(sk_redir); 350 351 lock_sock(sk); 352 if (unlikely(ret < 0)) { 353 int free = sk_msg_free_nocharge(sk, msg); 354 355 if (!cork) 356 *copied -= free; 357 } 358 if (cork) { 359 sk_msg_free(sk, msg); 360 kfree(msg); 361 msg = NULL; 362 ret = 0; 363 } 364 break; 365 case __SK_DROP: 366 default: 367 sk_msg_free_partial(sk, msg, tosend); 368 sk_msg_apply_bytes(psock, tosend); 369 *copied -= (tosend + delta); 370 return -EACCES; 371 } 372 373 if (likely(!ret)) { 374 if (!psock->apply_bytes) { 375 psock->eval = __SK_NONE; 376 if (psock->sk_redir) { 377 sock_put(psock->sk_redir); 378 psock->sk_redir = NULL; 379 } 380 } 381 if (msg && 382 msg->sg.data[msg->sg.start].page_link && 383 msg->sg.data[msg->sg.start].length) { 384 if (eval == __SK_REDIRECT) 385 sk_mem_charge(sk, tosend - sent); 386 goto more_data; 387 } 388 } 389 return ret; 390 } 391 392 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 393 { 394 struct sk_msg tmp, *msg_tx = NULL; 395 int copied = 0, err = 0; 396 struct sk_psock *psock; 397 long timeo; 398 int flags; 399 400 /* Don't let internal do_tcp_sendpages() flags through */ 401 flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); 402 flags |= MSG_NO_SHARED_FRAGS; 403 404 psock = sk_psock_get(sk); 405 if (unlikely(!psock)) 406 return tcp_sendmsg(sk, msg, size); 407 408 lock_sock(sk); 409 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 410 while (msg_data_left(msg)) { 411 bool enospc = false; 412 u32 copy, osize; 413 414 if (sk->sk_err) { 415 err = -sk->sk_err; 416 goto out_err; 417 } 418 419 copy = msg_data_left(msg); 420 if (!sk_stream_memory_free(sk)) 421 goto wait_for_sndbuf; 422 if (psock->cork) { 423 msg_tx = psock->cork; 424 } else { 425 msg_tx = &tmp; 426 sk_msg_init(msg_tx); 427 } 428 429 osize = msg_tx->sg.size; 430 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 431 if (err) { 432 if (err != -ENOSPC) 433 goto wait_for_memory; 434 enospc = true; 435 copy = msg_tx->sg.size - osize; 436 } 437 438 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 439 copy); 440 if (err < 0) { 441 sk_msg_trim(sk, msg_tx, osize); 442 goto out_err; 443 } 444 445 copied += copy; 446 if (psock->cork_bytes) { 447 if (size > psock->cork_bytes) 448 psock->cork_bytes = 0; 449 else 450 psock->cork_bytes -= size; 451 if (psock->cork_bytes && !enospc) 452 goto out_err; 453 /* All cork bytes are accounted, rerun the prog. */ 454 psock->eval = __SK_NONE; 455 psock->cork_bytes = 0; 456 } 457 458 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 459 if (unlikely(err < 0)) 460 goto out_err; 461 continue; 462 wait_for_sndbuf: 463 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 464 wait_for_memory: 465 err = sk_stream_wait_memory(sk, &timeo); 466 if (err) { 467 if (msg_tx && msg_tx != psock->cork) 468 sk_msg_free(sk, msg_tx); 469 goto out_err; 470 } 471 } 472 out_err: 473 if (err < 0) 474 err = sk_stream_error(sk, msg->msg_flags, err); 475 release_sock(sk); 476 sk_psock_put(sk, psock); 477 return copied ? copied : err; 478 } 479 480 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 481 size_t size, int flags) 482 { 483 struct sk_msg tmp, *msg = NULL; 484 int err = 0, copied = 0; 485 struct sk_psock *psock; 486 bool enospc = false; 487 488 psock = sk_psock_get(sk); 489 if (unlikely(!psock)) 490 return tcp_sendpage(sk, page, offset, size, flags); 491 492 lock_sock(sk); 493 if (psock->cork) { 494 msg = psock->cork; 495 } else { 496 msg = &tmp; 497 sk_msg_init(msg); 498 } 499 500 /* Catch case where ring is full and sendpage is stalled. */ 501 if (unlikely(sk_msg_full(msg))) 502 goto out_err; 503 504 sk_msg_page_add(msg, page, size, offset); 505 sk_mem_charge(sk, size); 506 copied = size; 507 if (sk_msg_full(msg)) 508 enospc = true; 509 if (psock->cork_bytes) { 510 if (size > psock->cork_bytes) 511 psock->cork_bytes = 0; 512 else 513 psock->cork_bytes -= size; 514 if (psock->cork_bytes && !enospc) 515 goto out_err; 516 /* All cork bytes are accounted, rerun the prog. */ 517 psock->eval = __SK_NONE; 518 psock->cork_bytes = 0; 519 } 520 521 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 522 out_err: 523 release_sock(sk); 524 sk_psock_put(sk, psock); 525 return copied ? copied : err; 526 } 527 528 enum { 529 TCP_BPF_IPV4, 530 TCP_BPF_IPV6, 531 TCP_BPF_NUM_PROTS, 532 }; 533 534 enum { 535 TCP_BPF_BASE, 536 TCP_BPF_TX, 537 TCP_BPF_RX, 538 TCP_BPF_TXRX, 539 TCP_BPF_NUM_CFGS, 540 }; 541 542 static struct proto *tcpv6_prot_saved __read_mostly; 543 static DEFINE_SPINLOCK(tcpv6_prot_lock); 544 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 545 546 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 547 struct proto *base) 548 { 549 prot[TCP_BPF_BASE] = *base; 550 prot[TCP_BPF_BASE].destroy = sock_map_destroy; 551 prot[TCP_BPF_BASE].close = sock_map_close; 552 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 553 prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; 554 555 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 556 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 557 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 558 559 prot[TCP_BPF_RX] = prot[TCP_BPF_BASE]; 560 prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser; 561 562 prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX]; 563 prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser; 564 } 565 566 static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops) 567 { 568 if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 569 spin_lock_bh(&tcpv6_prot_lock); 570 if (likely(ops != tcpv6_prot_saved)) { 571 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 572 smp_store_release(&tcpv6_prot_saved, ops); 573 } 574 spin_unlock_bh(&tcpv6_prot_lock); 575 } 576 } 577 578 static int __init tcp_bpf_v4_build_proto(void) 579 { 580 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 581 return 0; 582 } 583 late_initcall(tcp_bpf_v4_build_proto); 584 585 static int tcp_bpf_assert_proto_ops(struct proto *ops) 586 { 587 /* In order to avoid retpoline, we make assumptions when we call 588 * into ops if e.g. a psock is not present. Make sure they are 589 * indeed valid assumptions. 590 */ 591 return ops->recvmsg == tcp_recvmsg && 592 ops->sendmsg == tcp_sendmsg && 593 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 594 } 595 596 int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 597 { 598 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 599 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 600 601 if (psock->progs.stream_verdict || psock->progs.skb_verdict) { 602 config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX; 603 } 604 605 if (restore) { 606 if (inet_csk_has_ulp(sk)) { 607 /* TLS does not have an unhash proto in SW cases, 608 * but we need to ensure we stop using the sock_map 609 * unhash routine because the associated psock is being 610 * removed. So use the original unhash handler. 611 */ 612 WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash); 613 tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); 614 } else { 615 sk->sk_write_space = psock->saved_write_space; 616 /* Pairs with lockless read in sk_clone_lock() */ 617 sock_replace_proto(sk, psock->sk_proto); 618 } 619 return 0; 620 } 621 622 if (sk->sk_family == AF_INET6) { 623 if (tcp_bpf_assert_proto_ops(psock->sk_proto)) 624 return -EINVAL; 625 626 tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); 627 } 628 629 /* Pairs with lockless read in sk_clone_lock() */ 630 sock_replace_proto(sk, &tcp_bpf_prots[family][config]); 631 return 0; 632 } 633 EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); 634 635 /* If a child got cloned from a listening socket that had tcp_bpf 636 * protocol callbacks installed, we need to restore the callbacks to 637 * the default ones because the child does not inherit the psock state 638 * that tcp_bpf callbacks expect. 639 */ 640 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) 641 { 642 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 643 struct proto *prot = newsk->sk_prot; 644 645 if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE]) 646 newsk->sk_prot = sk->sk_prot_creator; 647 } 648 #endif /* CONFIG_BPF_SYSCALL */ 649