1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 #include <linux/sock_diag.h> 14 #include <net/udp.h> 15 16 struct bpf_stab { 17 struct bpf_map map; 18 struct sock **sks; 19 struct sk_psock_progs progs; 20 raw_spinlock_t lock; 21 }; 22 23 #define SOCK_CREATE_FLAG_MASK \ 24 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 25 26 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 27 { 28 struct bpf_stab *stab; 29 u64 cost; 30 int err; 31 32 if (!capable(CAP_NET_ADMIN)) 33 return ERR_PTR(-EPERM); 34 if (attr->max_entries == 0 || 35 attr->key_size != 4 || 36 (attr->value_size != sizeof(u32) && 37 attr->value_size != sizeof(u64)) || 38 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 39 return ERR_PTR(-EINVAL); 40 41 stab = kzalloc(sizeof(*stab), GFP_USER); 42 if (!stab) 43 return ERR_PTR(-ENOMEM); 44 45 bpf_map_init_from_attr(&stab->map, attr); 46 raw_spin_lock_init(&stab->lock); 47 48 /* Make sure page count doesn't overflow. */ 49 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 50 err = bpf_map_charge_init(&stab->map.memory, cost); 51 if (err) 52 goto free_stab; 53 54 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 55 sizeof(struct sock *), 56 stab->map.numa_node); 57 if (stab->sks) 58 return &stab->map; 59 err = -ENOMEM; 60 bpf_map_charge_finish(&stab->map.memory); 61 free_stab: 62 kfree(stab); 63 return ERR_PTR(err); 64 } 65 66 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 67 { 68 u32 ufd = attr->target_fd; 69 struct bpf_map *map; 70 struct fd f; 71 int ret; 72 73 f = fdget(ufd); 74 map = __bpf_map_get(f); 75 if (IS_ERR(map)) 76 return PTR_ERR(map); 77 ret = sock_map_prog_update(map, prog, attr->attach_type); 78 fdput(f); 79 return ret; 80 } 81 82 static void sock_map_sk_acquire(struct sock *sk) 83 __acquires(&sk->sk_lock.slock) 84 { 85 lock_sock(sk); 86 preempt_disable(); 87 rcu_read_lock(); 88 } 89 90 static void sock_map_sk_release(struct sock *sk) 91 __releases(&sk->sk_lock.slock) 92 { 93 rcu_read_unlock(); 94 preempt_enable(); 95 release_sock(sk); 96 } 97 98 static void sock_map_add_link(struct sk_psock *psock, 99 struct sk_psock_link *link, 100 struct bpf_map *map, void *link_raw) 101 { 102 link->link_raw = link_raw; 103 link->map = map; 104 spin_lock_bh(&psock->link_lock); 105 list_add_tail(&link->list, &psock->link); 106 spin_unlock_bh(&psock->link_lock); 107 } 108 109 static void sock_map_del_link(struct sock *sk, 110 struct sk_psock *psock, void *link_raw) 111 { 112 struct sk_psock_link *link, *tmp; 113 bool strp_stop = false; 114 115 spin_lock_bh(&psock->link_lock); 116 list_for_each_entry_safe(link, tmp, &psock->link, list) { 117 if (link->link_raw == link_raw) { 118 struct bpf_map *map = link->map; 119 struct bpf_stab *stab = container_of(map, struct bpf_stab, 120 map); 121 if (psock->parser.enabled && stab->progs.skb_parser) 122 strp_stop = true; 123 list_del(&link->list); 124 sk_psock_free_link(link); 125 } 126 } 127 spin_unlock_bh(&psock->link_lock); 128 if (strp_stop) { 129 write_lock_bh(&sk->sk_callback_lock); 130 sk_psock_stop_strp(sk, psock); 131 write_unlock_bh(&sk->sk_callback_lock); 132 } 133 } 134 135 static void sock_map_unref(struct sock *sk, void *link_raw) 136 { 137 struct sk_psock *psock = sk_psock(sk); 138 139 if (likely(psock)) { 140 sock_map_del_link(sk, psock, link_raw); 141 sk_psock_put(sk, psock); 142 } 143 } 144 145 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 146 { 147 struct proto *prot; 148 149 sock_owned_by_me(sk); 150 151 switch (sk->sk_type) { 152 case SOCK_STREAM: 153 prot = tcp_bpf_get_proto(sk, psock); 154 break; 155 156 case SOCK_DGRAM: 157 prot = udp_bpf_get_proto(sk, psock); 158 break; 159 160 default: 161 return -EINVAL; 162 } 163 164 if (IS_ERR(prot)) 165 return PTR_ERR(prot); 166 167 sk_psock_update_proto(sk, psock, prot); 168 return 0; 169 } 170 171 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 172 { 173 struct sk_psock *psock; 174 175 rcu_read_lock(); 176 psock = sk_psock(sk); 177 if (psock) { 178 if (sk->sk_prot->close != sock_map_close) { 179 psock = ERR_PTR(-EBUSY); 180 goto out; 181 } 182 183 if (!refcount_inc_not_zero(&psock->refcnt)) 184 psock = ERR_PTR(-EBUSY); 185 } 186 out: 187 rcu_read_unlock(); 188 return psock; 189 } 190 191 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 192 struct sock *sk) 193 { 194 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 195 struct sk_psock *psock; 196 bool skb_progs; 197 int ret; 198 199 skb_verdict = READ_ONCE(progs->skb_verdict); 200 skb_parser = READ_ONCE(progs->skb_parser); 201 skb_progs = skb_parser && skb_verdict; 202 if (skb_progs) { 203 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 204 if (IS_ERR(skb_verdict)) 205 return PTR_ERR(skb_verdict); 206 skb_parser = bpf_prog_inc_not_zero(skb_parser); 207 if (IS_ERR(skb_parser)) { 208 bpf_prog_put(skb_verdict); 209 return PTR_ERR(skb_parser); 210 } 211 } 212 213 msg_parser = READ_ONCE(progs->msg_parser); 214 if (msg_parser) { 215 msg_parser = bpf_prog_inc_not_zero(msg_parser); 216 if (IS_ERR(msg_parser)) { 217 ret = PTR_ERR(msg_parser); 218 goto out; 219 } 220 } 221 222 psock = sock_map_psock_get_checked(sk); 223 if (IS_ERR(psock)) { 224 ret = PTR_ERR(psock); 225 goto out_progs; 226 } 227 228 if (psock) { 229 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 230 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 231 sk_psock_put(sk, psock); 232 ret = -EBUSY; 233 goto out_progs; 234 } 235 } else { 236 psock = sk_psock_init(sk, map->numa_node); 237 if (!psock) { 238 ret = -ENOMEM; 239 goto out_progs; 240 } 241 } 242 243 if (msg_parser) 244 psock_set_prog(&psock->progs.msg_parser, msg_parser); 245 246 ret = sock_map_init_proto(sk, psock); 247 if (ret < 0) 248 goto out_drop; 249 250 write_lock_bh(&sk->sk_callback_lock); 251 if (skb_progs && !psock->parser.enabled) { 252 ret = sk_psock_init_strp(sk, psock); 253 if (ret) { 254 write_unlock_bh(&sk->sk_callback_lock); 255 goto out_drop; 256 } 257 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 258 psock_set_prog(&psock->progs.skb_parser, skb_parser); 259 sk_psock_start_strp(sk, psock); 260 } 261 write_unlock_bh(&sk->sk_callback_lock); 262 return 0; 263 out_drop: 264 sk_psock_put(sk, psock); 265 out_progs: 266 if (msg_parser) 267 bpf_prog_put(msg_parser); 268 out: 269 if (skb_progs) { 270 bpf_prog_put(skb_verdict); 271 bpf_prog_put(skb_parser); 272 } 273 return ret; 274 } 275 276 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) 277 { 278 struct sk_psock *psock; 279 int ret; 280 281 psock = sock_map_psock_get_checked(sk); 282 if (IS_ERR(psock)) 283 return PTR_ERR(psock); 284 285 if (!psock) { 286 psock = sk_psock_init(sk, map->numa_node); 287 if (!psock) 288 return -ENOMEM; 289 } 290 291 ret = sock_map_init_proto(sk, psock); 292 if (ret < 0) 293 sk_psock_put(sk, psock); 294 return ret; 295 } 296 297 static void sock_map_free(struct bpf_map *map) 298 { 299 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 300 int i; 301 302 synchronize_rcu(); 303 raw_spin_lock_bh(&stab->lock); 304 for (i = 0; i < stab->map.max_entries; i++) { 305 struct sock **psk = &stab->sks[i]; 306 struct sock *sk; 307 308 sk = xchg(psk, NULL); 309 if (sk) { 310 lock_sock(sk); 311 rcu_read_lock(); 312 sock_map_unref(sk, psk); 313 rcu_read_unlock(); 314 release_sock(sk); 315 } 316 } 317 raw_spin_unlock_bh(&stab->lock); 318 319 /* wait for psock readers accessing its map link */ 320 synchronize_rcu(); 321 322 bpf_map_area_free(stab->sks); 323 kfree(stab); 324 } 325 326 static void sock_map_release_progs(struct bpf_map *map) 327 { 328 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 329 } 330 331 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 332 { 333 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 334 335 WARN_ON_ONCE(!rcu_read_lock_held()); 336 337 if (unlikely(key >= map->max_entries)) 338 return NULL; 339 return READ_ONCE(stab->sks[key]); 340 } 341 342 static void *sock_map_lookup(struct bpf_map *map, void *key) 343 { 344 return __sock_map_lookup_elem(map, *(u32 *)key); 345 } 346 347 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 348 { 349 struct sock *sk; 350 351 if (map->value_size != sizeof(u64)) 352 return ERR_PTR(-ENOSPC); 353 354 sk = __sock_map_lookup_elem(map, *(u32 *)key); 355 if (!sk) 356 return ERR_PTR(-ENOENT); 357 358 sock_gen_cookie(sk); 359 return &sk->sk_cookie; 360 } 361 362 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 363 struct sock **psk) 364 { 365 struct sock *sk; 366 int err = 0; 367 368 raw_spin_lock_bh(&stab->lock); 369 sk = *psk; 370 if (!sk_test || sk_test == sk) 371 sk = xchg(psk, NULL); 372 373 if (likely(sk)) 374 sock_map_unref(sk, psk); 375 else 376 err = -EINVAL; 377 378 raw_spin_unlock_bh(&stab->lock); 379 return err; 380 } 381 382 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 383 void *link_raw) 384 { 385 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 386 387 __sock_map_delete(stab, sk, link_raw); 388 } 389 390 static int sock_map_delete_elem(struct bpf_map *map, void *key) 391 { 392 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 393 u32 i = *(u32 *)key; 394 struct sock **psk; 395 396 if (unlikely(i >= map->max_entries)) 397 return -EINVAL; 398 399 psk = &stab->sks[i]; 400 return __sock_map_delete(stab, NULL, psk); 401 } 402 403 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 404 { 405 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 406 u32 i = key ? *(u32 *)key : U32_MAX; 407 u32 *key_next = next; 408 409 if (i == stab->map.max_entries - 1) 410 return -ENOENT; 411 if (i >= stab->map.max_entries) 412 *key_next = 0; 413 else 414 *key_next = i + 1; 415 return 0; 416 } 417 418 static bool sock_map_redirect_allowed(const struct sock *sk) 419 { 420 return sk->sk_state != TCP_LISTEN; 421 } 422 423 static int sock_map_update_common(struct bpf_map *map, u32 idx, 424 struct sock *sk, u64 flags) 425 { 426 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 427 struct sk_psock_link *link; 428 struct sk_psock *psock; 429 struct sock *osk; 430 int ret; 431 432 WARN_ON_ONCE(!rcu_read_lock_held()); 433 if (unlikely(flags > BPF_EXIST)) 434 return -EINVAL; 435 if (unlikely(idx >= map->max_entries)) 436 return -E2BIG; 437 if (inet_csk_has_ulp(sk)) 438 return -EINVAL; 439 440 link = sk_psock_init_link(); 441 if (!link) 442 return -ENOMEM; 443 444 /* Only sockets we can redirect into/from in BPF need to hold 445 * refs to parser/verdict progs and have their sk_data_ready 446 * and sk_write_space callbacks overridden. 447 */ 448 if (sock_map_redirect_allowed(sk)) 449 ret = sock_map_link(map, &stab->progs, sk); 450 else 451 ret = sock_map_link_no_progs(map, sk); 452 if (ret < 0) 453 goto out_free; 454 455 psock = sk_psock(sk); 456 WARN_ON_ONCE(!psock); 457 458 raw_spin_lock_bh(&stab->lock); 459 osk = stab->sks[idx]; 460 if (osk && flags == BPF_NOEXIST) { 461 ret = -EEXIST; 462 goto out_unlock; 463 } else if (!osk && flags == BPF_EXIST) { 464 ret = -ENOENT; 465 goto out_unlock; 466 } 467 468 sock_map_add_link(psock, link, map, &stab->sks[idx]); 469 stab->sks[idx] = sk; 470 if (osk) 471 sock_map_unref(osk, &stab->sks[idx]); 472 raw_spin_unlock_bh(&stab->lock); 473 return 0; 474 out_unlock: 475 raw_spin_unlock_bh(&stab->lock); 476 if (psock) 477 sk_psock_put(sk, psock); 478 out_free: 479 sk_psock_free_link(link); 480 return ret; 481 } 482 483 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 484 { 485 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 486 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 487 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 488 } 489 490 static bool sk_is_tcp(const struct sock *sk) 491 { 492 return sk->sk_type == SOCK_STREAM && 493 sk->sk_protocol == IPPROTO_TCP; 494 } 495 496 static bool sk_is_udp(const struct sock *sk) 497 { 498 return sk->sk_type == SOCK_DGRAM && 499 sk->sk_protocol == IPPROTO_UDP; 500 } 501 502 static bool sock_map_sk_is_suitable(const struct sock *sk) 503 { 504 return sk_is_tcp(sk) || sk_is_udp(sk); 505 } 506 507 static bool sock_map_sk_state_allowed(const struct sock *sk) 508 { 509 if (sk_is_tcp(sk)) 510 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 511 else if (sk_is_udp(sk)) 512 return sk_hashed(sk); 513 514 return false; 515 } 516 517 static int sock_map_update_elem(struct bpf_map *map, void *key, 518 void *value, u64 flags) 519 { 520 u32 idx = *(u32 *)key; 521 struct socket *sock; 522 struct sock *sk; 523 int ret; 524 u64 ufd; 525 526 if (map->value_size == sizeof(u64)) 527 ufd = *(u64 *)value; 528 else 529 ufd = *(u32 *)value; 530 if (ufd > S32_MAX) 531 return -EINVAL; 532 533 sock = sockfd_lookup(ufd, &ret); 534 if (!sock) 535 return ret; 536 sk = sock->sk; 537 if (!sk) { 538 ret = -EINVAL; 539 goto out; 540 } 541 if (!sock_map_sk_is_suitable(sk)) { 542 ret = -EOPNOTSUPP; 543 goto out; 544 } 545 546 sock_map_sk_acquire(sk); 547 if (!sock_map_sk_state_allowed(sk)) 548 ret = -EOPNOTSUPP; 549 else 550 ret = sock_map_update_common(map, idx, sk, flags); 551 sock_map_sk_release(sk); 552 out: 553 fput(sock->file); 554 return ret; 555 } 556 557 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 558 struct bpf_map *, map, void *, key, u64, flags) 559 { 560 WARN_ON_ONCE(!rcu_read_lock_held()); 561 562 if (likely(sock_map_sk_is_suitable(sops->sk) && 563 sock_map_op_okay(sops))) 564 return sock_map_update_common(map, *(u32 *)key, sops->sk, 565 flags); 566 return -EOPNOTSUPP; 567 } 568 569 const struct bpf_func_proto bpf_sock_map_update_proto = { 570 .func = bpf_sock_map_update, 571 .gpl_only = false, 572 .pkt_access = true, 573 .ret_type = RET_INTEGER, 574 .arg1_type = ARG_PTR_TO_CTX, 575 .arg2_type = ARG_CONST_MAP_PTR, 576 .arg3_type = ARG_PTR_TO_MAP_KEY, 577 .arg4_type = ARG_ANYTHING, 578 }; 579 580 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 581 struct bpf_map *, map, u32, key, u64, flags) 582 { 583 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 584 struct sock *sk; 585 586 if (unlikely(flags & ~(BPF_F_INGRESS))) 587 return SK_DROP; 588 589 sk = __sock_map_lookup_elem(map, key); 590 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 591 return SK_DROP; 592 593 tcb->bpf.flags = flags; 594 tcb->bpf.sk_redir = sk; 595 return SK_PASS; 596 } 597 598 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 599 .func = bpf_sk_redirect_map, 600 .gpl_only = false, 601 .ret_type = RET_INTEGER, 602 .arg1_type = ARG_PTR_TO_CTX, 603 .arg2_type = ARG_CONST_MAP_PTR, 604 .arg3_type = ARG_ANYTHING, 605 .arg4_type = ARG_ANYTHING, 606 }; 607 608 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 609 struct bpf_map *, map, u32, key, u64, flags) 610 { 611 struct sock *sk; 612 613 if (unlikely(flags & ~(BPF_F_INGRESS))) 614 return SK_DROP; 615 616 sk = __sock_map_lookup_elem(map, key); 617 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 618 return SK_DROP; 619 620 msg->flags = flags; 621 msg->sk_redir = sk; 622 return SK_PASS; 623 } 624 625 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 626 .func = bpf_msg_redirect_map, 627 .gpl_only = false, 628 .ret_type = RET_INTEGER, 629 .arg1_type = ARG_PTR_TO_CTX, 630 .arg2_type = ARG_CONST_MAP_PTR, 631 .arg3_type = ARG_ANYTHING, 632 .arg4_type = ARG_ANYTHING, 633 }; 634 635 const struct bpf_map_ops sock_map_ops = { 636 .map_alloc = sock_map_alloc, 637 .map_free = sock_map_free, 638 .map_get_next_key = sock_map_get_next_key, 639 .map_lookup_elem_sys_only = sock_map_lookup_sys, 640 .map_update_elem = sock_map_update_elem, 641 .map_delete_elem = sock_map_delete_elem, 642 .map_lookup_elem = sock_map_lookup, 643 .map_release_uref = sock_map_release_progs, 644 .map_check_btf = map_check_no_btf, 645 }; 646 647 struct bpf_htab_elem { 648 struct rcu_head rcu; 649 u32 hash; 650 struct sock *sk; 651 struct hlist_node node; 652 u8 key[]; 653 }; 654 655 struct bpf_htab_bucket { 656 struct hlist_head head; 657 raw_spinlock_t lock; 658 }; 659 660 struct bpf_htab { 661 struct bpf_map map; 662 struct bpf_htab_bucket *buckets; 663 u32 buckets_num; 664 u32 elem_size; 665 struct sk_psock_progs progs; 666 atomic_t count; 667 }; 668 669 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 670 { 671 return jhash(key, len, 0); 672 } 673 674 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 675 u32 hash) 676 { 677 return &htab->buckets[hash & (htab->buckets_num - 1)]; 678 } 679 680 static struct bpf_htab_elem * 681 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 682 u32 key_size) 683 { 684 struct bpf_htab_elem *elem; 685 686 hlist_for_each_entry_rcu(elem, head, node) { 687 if (elem->hash == hash && 688 !memcmp(&elem->key, key, key_size)) 689 return elem; 690 } 691 692 return NULL; 693 } 694 695 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 696 { 697 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 698 u32 key_size = map->key_size, hash; 699 struct bpf_htab_bucket *bucket; 700 struct bpf_htab_elem *elem; 701 702 WARN_ON_ONCE(!rcu_read_lock_held()); 703 704 hash = sock_hash_bucket_hash(key, key_size); 705 bucket = sock_hash_select_bucket(htab, hash); 706 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 707 708 return elem ? elem->sk : NULL; 709 } 710 711 static void sock_hash_free_elem(struct bpf_htab *htab, 712 struct bpf_htab_elem *elem) 713 { 714 atomic_dec(&htab->count); 715 kfree_rcu(elem, rcu); 716 } 717 718 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 719 void *link_raw) 720 { 721 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 722 struct bpf_htab_elem *elem_probe, *elem = link_raw; 723 struct bpf_htab_bucket *bucket; 724 725 WARN_ON_ONCE(!rcu_read_lock_held()); 726 bucket = sock_hash_select_bucket(htab, elem->hash); 727 728 /* elem may be deleted in parallel from the map, but access here 729 * is okay since it's going away only after RCU grace period. 730 * However, we need to check whether it's still present. 731 */ 732 raw_spin_lock_bh(&bucket->lock); 733 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 734 elem->key, map->key_size); 735 if (elem_probe && elem_probe == elem) { 736 hlist_del_rcu(&elem->node); 737 sock_map_unref(elem->sk, elem); 738 sock_hash_free_elem(htab, elem); 739 } 740 raw_spin_unlock_bh(&bucket->lock); 741 } 742 743 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 744 { 745 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 746 u32 hash, key_size = map->key_size; 747 struct bpf_htab_bucket *bucket; 748 struct bpf_htab_elem *elem; 749 int ret = -ENOENT; 750 751 hash = sock_hash_bucket_hash(key, key_size); 752 bucket = sock_hash_select_bucket(htab, hash); 753 754 raw_spin_lock_bh(&bucket->lock); 755 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 756 if (elem) { 757 hlist_del_rcu(&elem->node); 758 sock_map_unref(elem->sk, elem); 759 sock_hash_free_elem(htab, elem); 760 ret = 0; 761 } 762 raw_spin_unlock_bh(&bucket->lock); 763 return ret; 764 } 765 766 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 767 void *key, u32 key_size, 768 u32 hash, struct sock *sk, 769 struct bpf_htab_elem *old) 770 { 771 struct bpf_htab_elem *new; 772 773 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 774 if (!old) { 775 atomic_dec(&htab->count); 776 return ERR_PTR(-E2BIG); 777 } 778 } 779 780 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 781 htab->map.numa_node); 782 if (!new) { 783 atomic_dec(&htab->count); 784 return ERR_PTR(-ENOMEM); 785 } 786 memcpy(new->key, key, key_size); 787 new->sk = sk; 788 new->hash = hash; 789 return new; 790 } 791 792 static int sock_hash_update_common(struct bpf_map *map, void *key, 793 struct sock *sk, u64 flags) 794 { 795 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 796 u32 key_size = map->key_size, hash; 797 struct bpf_htab_elem *elem, *elem_new; 798 struct bpf_htab_bucket *bucket; 799 struct sk_psock_link *link; 800 struct sk_psock *psock; 801 int ret; 802 803 WARN_ON_ONCE(!rcu_read_lock_held()); 804 if (unlikely(flags > BPF_EXIST)) 805 return -EINVAL; 806 if (inet_csk_has_ulp(sk)) 807 return -EINVAL; 808 809 link = sk_psock_init_link(); 810 if (!link) 811 return -ENOMEM; 812 813 /* Only sockets we can redirect into/from in BPF need to hold 814 * refs to parser/verdict progs and have their sk_data_ready 815 * and sk_write_space callbacks overridden. 816 */ 817 if (sock_map_redirect_allowed(sk)) 818 ret = sock_map_link(map, &htab->progs, sk); 819 else 820 ret = sock_map_link_no_progs(map, sk); 821 if (ret < 0) 822 goto out_free; 823 824 psock = sk_psock(sk); 825 WARN_ON_ONCE(!psock); 826 827 hash = sock_hash_bucket_hash(key, key_size); 828 bucket = sock_hash_select_bucket(htab, hash); 829 830 raw_spin_lock_bh(&bucket->lock); 831 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 832 if (elem && flags == BPF_NOEXIST) { 833 ret = -EEXIST; 834 goto out_unlock; 835 } else if (!elem && flags == BPF_EXIST) { 836 ret = -ENOENT; 837 goto out_unlock; 838 } 839 840 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 841 if (IS_ERR(elem_new)) { 842 ret = PTR_ERR(elem_new); 843 goto out_unlock; 844 } 845 846 sock_map_add_link(psock, link, map, elem_new); 847 /* Add new element to the head of the list, so that 848 * concurrent search will find it before old elem. 849 */ 850 hlist_add_head_rcu(&elem_new->node, &bucket->head); 851 if (elem) { 852 hlist_del_rcu(&elem->node); 853 sock_map_unref(elem->sk, elem); 854 sock_hash_free_elem(htab, elem); 855 } 856 raw_spin_unlock_bh(&bucket->lock); 857 return 0; 858 out_unlock: 859 raw_spin_unlock_bh(&bucket->lock); 860 sk_psock_put(sk, psock); 861 out_free: 862 sk_psock_free_link(link); 863 return ret; 864 } 865 866 static int sock_hash_update_elem(struct bpf_map *map, void *key, 867 void *value, u64 flags) 868 { 869 struct socket *sock; 870 struct sock *sk; 871 int ret; 872 u64 ufd; 873 874 if (map->value_size == sizeof(u64)) 875 ufd = *(u64 *)value; 876 else 877 ufd = *(u32 *)value; 878 if (ufd > S32_MAX) 879 return -EINVAL; 880 881 sock = sockfd_lookup(ufd, &ret); 882 if (!sock) 883 return ret; 884 sk = sock->sk; 885 if (!sk) { 886 ret = -EINVAL; 887 goto out; 888 } 889 if (!sock_map_sk_is_suitable(sk)) { 890 ret = -EOPNOTSUPP; 891 goto out; 892 } 893 894 sock_map_sk_acquire(sk); 895 if (!sock_map_sk_state_allowed(sk)) 896 ret = -EOPNOTSUPP; 897 else 898 ret = sock_hash_update_common(map, key, sk, flags); 899 sock_map_sk_release(sk); 900 out: 901 fput(sock->file); 902 return ret; 903 } 904 905 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 906 void *key_next) 907 { 908 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 909 struct bpf_htab_elem *elem, *elem_next; 910 u32 hash, key_size = map->key_size; 911 struct hlist_head *head; 912 int i = 0; 913 914 if (!key) 915 goto find_first_elem; 916 hash = sock_hash_bucket_hash(key, key_size); 917 head = &sock_hash_select_bucket(htab, hash)->head; 918 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 919 if (!elem) 920 goto find_first_elem; 921 922 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 923 struct bpf_htab_elem, node); 924 if (elem_next) { 925 memcpy(key_next, elem_next->key, key_size); 926 return 0; 927 } 928 929 i = hash & (htab->buckets_num - 1); 930 i++; 931 find_first_elem: 932 for (; i < htab->buckets_num; i++) { 933 head = &sock_hash_select_bucket(htab, i)->head; 934 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 935 struct bpf_htab_elem, node); 936 if (elem_next) { 937 memcpy(key_next, elem_next->key, key_size); 938 return 0; 939 } 940 } 941 942 return -ENOENT; 943 } 944 945 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 946 { 947 struct bpf_htab *htab; 948 int i, err; 949 u64 cost; 950 951 if (!capable(CAP_NET_ADMIN)) 952 return ERR_PTR(-EPERM); 953 if (attr->max_entries == 0 || 954 attr->key_size == 0 || 955 (attr->value_size != sizeof(u32) && 956 attr->value_size != sizeof(u64)) || 957 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 958 return ERR_PTR(-EINVAL); 959 if (attr->key_size > MAX_BPF_STACK) 960 return ERR_PTR(-E2BIG); 961 962 htab = kzalloc(sizeof(*htab), GFP_USER); 963 if (!htab) 964 return ERR_PTR(-ENOMEM); 965 966 bpf_map_init_from_attr(&htab->map, attr); 967 968 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 969 htab->elem_size = sizeof(struct bpf_htab_elem) + 970 round_up(htab->map.key_size, 8); 971 if (htab->buckets_num == 0 || 972 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 973 err = -EINVAL; 974 goto free_htab; 975 } 976 977 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 978 (u64) htab->elem_size * htab->map.max_entries; 979 if (cost >= U32_MAX - PAGE_SIZE) { 980 err = -EINVAL; 981 goto free_htab; 982 } 983 984 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 985 sizeof(struct bpf_htab_bucket), 986 htab->map.numa_node); 987 if (!htab->buckets) { 988 err = -ENOMEM; 989 goto free_htab; 990 } 991 992 for (i = 0; i < htab->buckets_num; i++) { 993 INIT_HLIST_HEAD(&htab->buckets[i].head); 994 raw_spin_lock_init(&htab->buckets[i].lock); 995 } 996 997 return &htab->map; 998 free_htab: 999 kfree(htab); 1000 return ERR_PTR(err); 1001 } 1002 1003 static void sock_hash_free(struct bpf_map *map) 1004 { 1005 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 1006 struct bpf_htab_bucket *bucket; 1007 struct bpf_htab_elem *elem; 1008 struct hlist_node *node; 1009 int i; 1010 1011 synchronize_rcu(); 1012 for (i = 0; i < htab->buckets_num; i++) { 1013 bucket = sock_hash_select_bucket(htab, i); 1014 raw_spin_lock_bh(&bucket->lock); 1015 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 1016 hlist_del_rcu(&elem->node); 1017 lock_sock(elem->sk); 1018 rcu_read_lock(); 1019 sock_map_unref(elem->sk, elem); 1020 rcu_read_unlock(); 1021 release_sock(elem->sk); 1022 } 1023 raw_spin_unlock_bh(&bucket->lock); 1024 } 1025 1026 /* wait for psock readers accessing its map link */ 1027 synchronize_rcu(); 1028 1029 bpf_map_area_free(htab->buckets); 1030 kfree(htab); 1031 } 1032 1033 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1034 { 1035 struct sock *sk; 1036 1037 if (map->value_size != sizeof(u64)) 1038 return ERR_PTR(-ENOSPC); 1039 1040 sk = __sock_hash_lookup_elem(map, key); 1041 if (!sk) 1042 return ERR_PTR(-ENOENT); 1043 1044 sock_gen_cookie(sk); 1045 return &sk->sk_cookie; 1046 } 1047 1048 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1049 { 1050 return __sock_hash_lookup_elem(map, key); 1051 } 1052 1053 static void sock_hash_release_progs(struct bpf_map *map) 1054 { 1055 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 1056 } 1057 1058 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1059 struct bpf_map *, map, void *, key, u64, flags) 1060 { 1061 WARN_ON_ONCE(!rcu_read_lock_held()); 1062 1063 if (likely(sock_map_sk_is_suitable(sops->sk) && 1064 sock_map_op_okay(sops))) 1065 return sock_hash_update_common(map, key, sops->sk, flags); 1066 return -EOPNOTSUPP; 1067 } 1068 1069 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1070 .func = bpf_sock_hash_update, 1071 .gpl_only = false, 1072 .pkt_access = true, 1073 .ret_type = RET_INTEGER, 1074 .arg1_type = ARG_PTR_TO_CTX, 1075 .arg2_type = ARG_CONST_MAP_PTR, 1076 .arg3_type = ARG_PTR_TO_MAP_KEY, 1077 .arg4_type = ARG_ANYTHING, 1078 }; 1079 1080 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1081 struct bpf_map *, map, void *, key, u64, flags) 1082 { 1083 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 1084 struct sock *sk; 1085 1086 if (unlikely(flags & ~(BPF_F_INGRESS))) 1087 return SK_DROP; 1088 1089 sk = __sock_hash_lookup_elem(map, key); 1090 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1091 return SK_DROP; 1092 1093 tcb->bpf.flags = flags; 1094 tcb->bpf.sk_redir = sk; 1095 return SK_PASS; 1096 } 1097 1098 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1099 .func = bpf_sk_redirect_hash, 1100 .gpl_only = false, 1101 .ret_type = RET_INTEGER, 1102 .arg1_type = ARG_PTR_TO_CTX, 1103 .arg2_type = ARG_CONST_MAP_PTR, 1104 .arg3_type = ARG_PTR_TO_MAP_KEY, 1105 .arg4_type = ARG_ANYTHING, 1106 }; 1107 1108 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1109 struct bpf_map *, map, void *, key, u64, flags) 1110 { 1111 struct sock *sk; 1112 1113 if (unlikely(flags & ~(BPF_F_INGRESS))) 1114 return SK_DROP; 1115 1116 sk = __sock_hash_lookup_elem(map, key); 1117 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1118 return SK_DROP; 1119 1120 msg->flags = flags; 1121 msg->sk_redir = sk; 1122 return SK_PASS; 1123 } 1124 1125 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1126 .func = bpf_msg_redirect_hash, 1127 .gpl_only = false, 1128 .ret_type = RET_INTEGER, 1129 .arg1_type = ARG_PTR_TO_CTX, 1130 .arg2_type = ARG_CONST_MAP_PTR, 1131 .arg3_type = ARG_PTR_TO_MAP_KEY, 1132 .arg4_type = ARG_ANYTHING, 1133 }; 1134 1135 const struct bpf_map_ops sock_hash_ops = { 1136 .map_alloc = sock_hash_alloc, 1137 .map_free = sock_hash_free, 1138 .map_get_next_key = sock_hash_get_next_key, 1139 .map_update_elem = sock_hash_update_elem, 1140 .map_delete_elem = sock_hash_delete_elem, 1141 .map_lookup_elem = sock_hash_lookup, 1142 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1143 .map_release_uref = sock_hash_release_progs, 1144 .map_check_btf = map_check_no_btf, 1145 }; 1146 1147 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1148 { 1149 switch (map->map_type) { 1150 case BPF_MAP_TYPE_SOCKMAP: 1151 return &container_of(map, struct bpf_stab, map)->progs; 1152 case BPF_MAP_TYPE_SOCKHASH: 1153 return &container_of(map, struct bpf_htab, map)->progs; 1154 default: 1155 break; 1156 } 1157 1158 return NULL; 1159 } 1160 1161 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1162 u32 which) 1163 { 1164 struct sk_psock_progs *progs = sock_map_progs(map); 1165 1166 if (!progs) 1167 return -EOPNOTSUPP; 1168 1169 switch (which) { 1170 case BPF_SK_MSG_VERDICT: 1171 psock_set_prog(&progs->msg_parser, prog); 1172 break; 1173 case BPF_SK_SKB_STREAM_PARSER: 1174 psock_set_prog(&progs->skb_parser, prog); 1175 break; 1176 case BPF_SK_SKB_STREAM_VERDICT: 1177 psock_set_prog(&progs->skb_verdict, prog); 1178 break; 1179 default: 1180 return -EOPNOTSUPP; 1181 } 1182 1183 return 0; 1184 } 1185 1186 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1187 { 1188 switch (link->map->map_type) { 1189 case BPF_MAP_TYPE_SOCKMAP: 1190 return sock_map_delete_from_link(link->map, sk, 1191 link->link_raw); 1192 case BPF_MAP_TYPE_SOCKHASH: 1193 return sock_hash_delete_from_link(link->map, sk, 1194 link->link_raw); 1195 default: 1196 break; 1197 } 1198 } 1199 1200 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1201 { 1202 struct sk_psock_link *link; 1203 1204 while ((link = sk_psock_link_pop(psock))) { 1205 sock_map_unlink(sk, link); 1206 sk_psock_free_link(link); 1207 } 1208 } 1209 1210 void sock_map_unhash(struct sock *sk) 1211 { 1212 void (*saved_unhash)(struct sock *sk); 1213 struct sk_psock *psock; 1214 1215 rcu_read_lock(); 1216 psock = sk_psock(sk); 1217 if (unlikely(!psock)) { 1218 rcu_read_unlock(); 1219 if (sk->sk_prot->unhash) 1220 sk->sk_prot->unhash(sk); 1221 return; 1222 } 1223 1224 saved_unhash = psock->saved_unhash; 1225 sock_map_remove_links(sk, psock); 1226 rcu_read_unlock(); 1227 saved_unhash(sk); 1228 } 1229 1230 void sock_map_close(struct sock *sk, long timeout) 1231 { 1232 void (*saved_close)(struct sock *sk, long timeout); 1233 struct sk_psock *psock; 1234 1235 lock_sock(sk); 1236 rcu_read_lock(); 1237 psock = sk_psock(sk); 1238 if (unlikely(!psock)) { 1239 rcu_read_unlock(); 1240 release_sock(sk); 1241 return sk->sk_prot->close(sk, timeout); 1242 } 1243 1244 saved_close = psock->saved_close; 1245 sock_map_remove_links(sk, psock); 1246 rcu_read_unlock(); 1247 release_sock(sk); 1248 saved_close(sk, timeout); 1249 } 1250