1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 10 #include <net/inet_common.h> 11 #include <net/tls.h> 12 13 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, 14 struct msghdr *msg, int len, int flags) 15 { 16 struct iov_iter *iter = &msg->msg_iter; 17 int peek = flags & MSG_PEEK; 18 int i, ret, copied = 0; 19 struct sk_msg *msg_rx; 20 21 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 22 struct sk_msg, list); 23 24 while (copied != len) { 25 struct scatterlist *sge; 26 27 if (unlikely(!msg_rx)) 28 break; 29 30 i = msg_rx->sg.start; 31 do { 32 struct page *page; 33 int copy; 34 35 sge = sk_msg_elem(msg_rx, i); 36 copy = sge->length; 37 page = sg_page(sge); 38 if (copied + copy > len) 39 copy = len - copied; 40 ret = copy_page_to_iter(page, sge->offset, copy, iter); 41 if (ret != copy) { 42 msg_rx->sg.start = i; 43 return -EFAULT; 44 } 45 46 copied += copy; 47 if (likely(!peek)) { 48 sge->offset += copy; 49 sge->length -= copy; 50 sk_mem_uncharge(sk, copy); 51 msg_rx->sg.size -= copy; 52 53 if (!sge->length) { 54 sk_msg_iter_var_next(i); 55 if (!msg_rx->skb) 56 put_page(page); 57 } 58 } else { 59 sk_msg_iter_var_next(i); 60 } 61 62 if (copied == len) 63 break; 64 } while (i != msg_rx->sg.end); 65 66 if (unlikely(peek)) { 67 msg_rx = list_next_entry(msg_rx, list); 68 continue; 69 } 70 71 msg_rx->sg.start = i; 72 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { 73 list_del(&msg_rx->list); 74 if (msg_rx->skb) 75 consume_skb(msg_rx->skb); 76 kfree(msg_rx); 77 } 78 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 79 struct sk_msg, list); 80 } 81 82 return copied; 83 } 84 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); 85 86 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 87 struct sk_msg *msg, u32 apply_bytes, int flags) 88 { 89 bool apply = apply_bytes; 90 struct scatterlist *sge; 91 u32 size, copied = 0; 92 struct sk_msg *tmp; 93 int i, ret = 0; 94 95 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 96 if (unlikely(!tmp)) 97 return -ENOMEM; 98 99 lock_sock(sk); 100 tmp->sg.start = msg->sg.start; 101 i = msg->sg.start; 102 do { 103 sge = sk_msg_elem(msg, i); 104 size = (apply && apply_bytes < sge->length) ? 105 apply_bytes : sge->length; 106 if (!sk_wmem_schedule(sk, size)) { 107 if (!copied) 108 ret = -ENOMEM; 109 break; 110 } 111 112 sk_mem_charge(sk, size); 113 sk_msg_xfer(tmp, msg, i, size); 114 copied += size; 115 if (sge->length) 116 get_page(sk_msg_page(tmp, i)); 117 sk_msg_iter_var_next(i); 118 tmp->sg.end = i; 119 if (apply) { 120 apply_bytes -= size; 121 if (!apply_bytes) 122 break; 123 } 124 } while (i != msg->sg.end); 125 126 if (!ret) { 127 msg->sg.start = i; 128 msg->sg.size -= apply_bytes; 129 sk_psock_queue_msg(psock, tmp); 130 sk_psock_data_ready(sk, psock); 131 } else { 132 sk_msg_free(sk, tmp); 133 kfree(tmp); 134 } 135 136 release_sock(sk); 137 return ret; 138 } 139 140 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 141 int flags, bool uncharge) 142 { 143 bool apply = apply_bytes; 144 struct scatterlist *sge; 145 struct page *page; 146 int size, ret = 0; 147 u32 off; 148 149 while (1) { 150 bool has_tx_ulp; 151 152 sge = sk_msg_elem(msg, msg->sg.start); 153 size = (apply && apply_bytes < sge->length) ? 154 apply_bytes : sge->length; 155 off = sge->offset; 156 page = sg_page(sge); 157 158 tcp_rate_check_app_limited(sk); 159 retry: 160 has_tx_ulp = tls_sw_has_ctx_tx(sk); 161 if (has_tx_ulp) { 162 flags |= MSG_SENDPAGE_NOPOLICY; 163 ret = kernel_sendpage_locked(sk, 164 page, off, size, flags); 165 } else { 166 ret = do_tcp_sendpages(sk, page, off, size, flags); 167 } 168 169 if (ret <= 0) 170 return ret; 171 if (apply) 172 apply_bytes -= ret; 173 msg->sg.size -= ret; 174 sge->offset += ret; 175 sge->length -= ret; 176 if (uncharge) 177 sk_mem_uncharge(sk, ret); 178 if (ret != size) { 179 size -= ret; 180 off += ret; 181 goto retry; 182 } 183 if (!sge->length) { 184 put_page(page); 185 sk_msg_iter_next(msg, start); 186 sg_init_table(sge, 1); 187 if (msg->sg.start == msg->sg.end) 188 break; 189 } 190 if (apply && !apply_bytes) 191 break; 192 } 193 194 return 0; 195 } 196 197 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 198 u32 apply_bytes, int flags, bool uncharge) 199 { 200 int ret; 201 202 lock_sock(sk); 203 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 204 release_sock(sk); 205 return ret; 206 } 207 208 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, 209 u32 bytes, int flags) 210 { 211 bool ingress = sk_msg_to_ingress(msg); 212 struct sk_psock *psock = sk_psock_get(sk); 213 int ret; 214 215 if (unlikely(!psock)) { 216 sk_msg_free(sk, msg); 217 return 0; 218 } 219 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 220 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 221 sk_psock_put(sk, psock); 222 return ret; 223 } 224 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 225 226 #ifdef CONFIG_BPF_STREAM_PARSER 227 static bool tcp_bpf_stream_read(const struct sock *sk) 228 { 229 struct sk_psock *psock; 230 bool empty = true; 231 232 rcu_read_lock(); 233 psock = sk_psock(sk); 234 if (likely(psock)) 235 empty = list_empty(&psock->ingress_msg); 236 rcu_read_unlock(); 237 return !empty; 238 } 239 240 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, 241 int flags, long timeo, int *err) 242 { 243 DEFINE_WAIT_FUNC(wait, woken_wake_function); 244 int ret = 0; 245 246 if (!timeo) 247 return ret; 248 249 add_wait_queue(sk_sleep(sk), &wait); 250 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 251 ret = sk_wait_event(sk, &timeo, 252 !list_empty(&psock->ingress_msg) || 253 !skb_queue_empty(&sk->sk_receive_queue), &wait); 254 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 255 remove_wait_queue(sk_sleep(sk), &wait); 256 return ret; 257 } 258 259 static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 260 int nonblock, int flags, int *addr_len) 261 { 262 struct sk_psock *psock; 263 int copied, ret; 264 265 psock = sk_psock_get(sk); 266 if (unlikely(!psock)) 267 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 268 if (unlikely(flags & MSG_ERRQUEUE)) 269 return inet_recv_error(sk, msg, len, addr_len); 270 if (!skb_queue_empty(&sk->sk_receive_queue) && 271 sk_psock_queue_empty(psock)) 272 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 273 lock_sock(sk); 274 msg_bytes_ready: 275 copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); 276 if (!copied) { 277 int data, err = 0; 278 long timeo; 279 280 timeo = sock_rcvtimeo(sk, nonblock); 281 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); 282 if (data) { 283 if (!sk_psock_queue_empty(psock)) 284 goto msg_bytes_ready; 285 release_sock(sk); 286 sk_psock_put(sk, psock); 287 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 288 } 289 if (err) { 290 ret = err; 291 goto out; 292 } 293 copied = -EAGAIN; 294 } 295 ret = copied; 296 out: 297 release_sock(sk); 298 sk_psock_put(sk, psock); 299 return ret; 300 } 301 302 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 303 struct sk_msg *msg, int *copied, int flags) 304 { 305 bool cork = false, enospc = sk_msg_full(msg); 306 struct sock *sk_redir; 307 u32 tosend, delta = 0; 308 int ret; 309 310 more_data: 311 if (psock->eval == __SK_NONE) { 312 /* Track delta in msg size to add/subtract it on SK_DROP from 313 * returned to user copied size. This ensures user doesn't 314 * get a positive return code with msg_cut_data and SK_DROP 315 * verdict. 316 */ 317 delta = msg->sg.size; 318 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 319 delta -= msg->sg.size; 320 } 321 322 if (msg->cork_bytes && 323 msg->cork_bytes > msg->sg.size && !enospc) { 324 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 325 if (!psock->cork) { 326 psock->cork = kzalloc(sizeof(*psock->cork), 327 GFP_ATOMIC | __GFP_NOWARN); 328 if (!psock->cork) 329 return -ENOMEM; 330 } 331 memcpy(psock->cork, msg, sizeof(*msg)); 332 return 0; 333 } 334 335 tosend = msg->sg.size; 336 if (psock->apply_bytes && psock->apply_bytes < tosend) 337 tosend = psock->apply_bytes; 338 339 switch (psock->eval) { 340 case __SK_PASS: 341 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 342 if (unlikely(ret)) { 343 *copied -= sk_msg_free(sk, msg); 344 break; 345 } 346 sk_msg_apply_bytes(psock, tosend); 347 break; 348 case __SK_REDIRECT: 349 sk_redir = psock->sk_redir; 350 sk_msg_apply_bytes(psock, tosend); 351 if (psock->cork) { 352 cork = true; 353 psock->cork = NULL; 354 } 355 sk_msg_return(sk, msg, tosend); 356 release_sock(sk); 357 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); 358 lock_sock(sk); 359 if (unlikely(ret < 0)) { 360 int free = sk_msg_free_nocharge(sk, msg); 361 362 if (!cork) 363 *copied -= free; 364 } 365 if (cork) { 366 sk_msg_free(sk, msg); 367 kfree(msg); 368 msg = NULL; 369 ret = 0; 370 } 371 break; 372 case __SK_DROP: 373 default: 374 sk_msg_free_partial(sk, msg, tosend); 375 sk_msg_apply_bytes(psock, tosend); 376 *copied -= (tosend + delta); 377 return -EACCES; 378 } 379 380 if (likely(!ret)) { 381 if (!psock->apply_bytes) { 382 psock->eval = __SK_NONE; 383 if (psock->sk_redir) { 384 sock_put(psock->sk_redir); 385 psock->sk_redir = NULL; 386 } 387 } 388 if (msg && 389 msg->sg.data[msg->sg.start].page_link && 390 msg->sg.data[msg->sg.start].length) 391 goto more_data; 392 } 393 return ret; 394 } 395 396 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 397 { 398 struct sk_msg tmp, *msg_tx = NULL; 399 int copied = 0, err = 0; 400 struct sk_psock *psock; 401 long timeo; 402 int flags; 403 404 /* Don't let internal do_tcp_sendpages() flags through */ 405 flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); 406 flags |= MSG_NO_SHARED_FRAGS; 407 408 psock = sk_psock_get(sk); 409 if (unlikely(!psock)) 410 return tcp_sendmsg(sk, msg, size); 411 412 lock_sock(sk); 413 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 414 while (msg_data_left(msg)) { 415 bool enospc = false; 416 u32 copy, osize; 417 418 if (sk->sk_err) { 419 err = -sk->sk_err; 420 goto out_err; 421 } 422 423 copy = msg_data_left(msg); 424 if (!sk_stream_memory_free(sk)) 425 goto wait_for_sndbuf; 426 if (psock->cork) { 427 msg_tx = psock->cork; 428 } else { 429 msg_tx = &tmp; 430 sk_msg_init(msg_tx); 431 } 432 433 osize = msg_tx->sg.size; 434 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 435 if (err) { 436 if (err != -ENOSPC) 437 goto wait_for_memory; 438 enospc = true; 439 copy = msg_tx->sg.size - osize; 440 } 441 442 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 443 copy); 444 if (err < 0) { 445 sk_msg_trim(sk, msg_tx, osize); 446 goto out_err; 447 } 448 449 copied += copy; 450 if (psock->cork_bytes) { 451 if (size > psock->cork_bytes) 452 psock->cork_bytes = 0; 453 else 454 psock->cork_bytes -= size; 455 if (psock->cork_bytes && !enospc) 456 goto out_err; 457 /* All cork bytes are accounted, rerun the prog. */ 458 psock->eval = __SK_NONE; 459 psock->cork_bytes = 0; 460 } 461 462 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 463 if (unlikely(err < 0)) 464 goto out_err; 465 continue; 466 wait_for_sndbuf: 467 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 468 wait_for_memory: 469 err = sk_stream_wait_memory(sk, &timeo); 470 if (err) { 471 if (msg_tx && msg_tx != psock->cork) 472 sk_msg_free(sk, msg_tx); 473 goto out_err; 474 } 475 } 476 out_err: 477 if (err < 0) 478 err = sk_stream_error(sk, msg->msg_flags, err); 479 release_sock(sk); 480 sk_psock_put(sk, psock); 481 return copied ? copied : err; 482 } 483 484 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 485 size_t size, int flags) 486 { 487 struct sk_msg tmp, *msg = NULL; 488 int err = 0, copied = 0; 489 struct sk_psock *psock; 490 bool enospc = false; 491 492 psock = sk_psock_get(sk); 493 if (unlikely(!psock)) 494 return tcp_sendpage(sk, page, offset, size, flags); 495 496 lock_sock(sk); 497 if (psock->cork) { 498 msg = psock->cork; 499 } else { 500 msg = &tmp; 501 sk_msg_init(msg); 502 } 503 504 /* Catch case where ring is full and sendpage is stalled. */ 505 if (unlikely(sk_msg_full(msg))) 506 goto out_err; 507 508 sk_msg_page_add(msg, page, size, offset); 509 sk_mem_charge(sk, size); 510 copied = size; 511 if (sk_msg_full(msg)) 512 enospc = true; 513 if (psock->cork_bytes) { 514 if (size > psock->cork_bytes) 515 psock->cork_bytes = 0; 516 else 517 psock->cork_bytes -= size; 518 if (psock->cork_bytes && !enospc) 519 goto out_err; 520 /* All cork bytes are accounted, rerun the prog. */ 521 psock->eval = __SK_NONE; 522 psock->cork_bytes = 0; 523 } 524 525 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 526 out_err: 527 release_sock(sk); 528 sk_psock_put(sk, psock); 529 return copied ? copied : err; 530 } 531 532 enum { 533 TCP_BPF_IPV4, 534 TCP_BPF_IPV6, 535 TCP_BPF_NUM_PROTS, 536 }; 537 538 enum { 539 TCP_BPF_BASE, 540 TCP_BPF_TX, 541 TCP_BPF_NUM_CFGS, 542 }; 543 544 static struct proto *tcpv6_prot_saved __read_mostly; 545 static DEFINE_SPINLOCK(tcpv6_prot_lock); 546 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 547 548 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 549 struct proto *base) 550 { 551 prot[TCP_BPF_BASE] = *base; 552 prot[TCP_BPF_BASE].unhash = sock_map_unhash; 553 prot[TCP_BPF_BASE].close = sock_map_close; 554 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 555 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; 556 557 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 558 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 559 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 560 } 561 562 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) 563 { 564 if (sk->sk_family == AF_INET6 && 565 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 566 spin_lock_bh(&tcpv6_prot_lock); 567 if (likely(ops != tcpv6_prot_saved)) { 568 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 569 smp_store_release(&tcpv6_prot_saved, ops); 570 } 571 spin_unlock_bh(&tcpv6_prot_lock); 572 } 573 } 574 575 static int __init tcp_bpf_v4_build_proto(void) 576 { 577 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 578 return 0; 579 } 580 core_initcall(tcp_bpf_v4_build_proto); 581 582 static int tcp_bpf_assert_proto_ops(struct proto *ops) 583 { 584 /* In order to avoid retpoline, we make assumptions when we call 585 * into ops if e.g. a psock is not present. Make sure they are 586 * indeed valid assumptions. 587 */ 588 return ops->recvmsg == tcp_recvmsg && 589 ops->sendmsg == tcp_sendmsg && 590 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 591 } 592 593 struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) 594 { 595 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 596 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 597 598 if (!psock->sk_proto) { 599 struct proto *ops = READ_ONCE(sk->sk_prot); 600 601 if (tcp_bpf_assert_proto_ops(ops)) 602 return ERR_PTR(-EINVAL); 603 604 tcp_bpf_check_v6_needs_rebuild(sk, ops); 605 } 606 607 return &tcp_bpf_prots[family][config]; 608 } 609 610 /* If a child got cloned from a listening socket that had tcp_bpf 611 * protocol callbacks installed, we need to restore the callbacks to 612 * the default ones because the child does not inherit the psock state 613 * that tcp_bpf callbacks expect. 614 */ 615 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) 616 { 617 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 618 struct proto *prot = newsk->sk_prot; 619 620 if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE]) 621 newsk->sk_prot = sk->sk_prot_creator; 622 } 623 #endif /* CONFIG_BPF_STREAM_PARSER */ 624