1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 #include <linux/sock_diag.h> 14 #include <net/udp.h> 15 16 struct bpf_stab { 17 struct bpf_map map; 18 struct sock **sks; 19 struct sk_psock_progs progs; 20 raw_spinlock_t lock; 21 }; 22 23 #define SOCK_CREATE_FLAG_MASK \ 24 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 25 26 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 27 { 28 struct bpf_stab *stab; 29 u64 cost; 30 int err; 31 32 if (!capable(CAP_NET_ADMIN)) 33 return ERR_PTR(-EPERM); 34 if (attr->max_entries == 0 || 35 attr->key_size != 4 || 36 (attr->value_size != sizeof(u32) && 37 attr->value_size != sizeof(u64)) || 38 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 39 return ERR_PTR(-EINVAL); 40 41 stab = kzalloc(sizeof(*stab), GFP_USER); 42 if (!stab) 43 return ERR_PTR(-ENOMEM); 44 45 bpf_map_init_from_attr(&stab->map, attr); 46 raw_spin_lock_init(&stab->lock); 47 48 /* Make sure page count doesn't overflow. */ 49 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 50 err = bpf_map_charge_init(&stab->map.memory, cost); 51 if (err) 52 goto free_stab; 53 54 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 55 sizeof(struct sock *), 56 stab->map.numa_node); 57 if (stab->sks) 58 return &stab->map; 59 err = -ENOMEM; 60 bpf_map_charge_finish(&stab->map.memory); 61 free_stab: 62 kfree(stab); 63 return ERR_PTR(err); 64 } 65 66 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 67 { 68 u32 ufd = attr->target_fd; 69 struct bpf_map *map; 70 struct fd f; 71 int ret; 72 73 f = fdget(ufd); 74 map = __bpf_map_get(f); 75 if (IS_ERR(map)) 76 return PTR_ERR(map); 77 ret = sock_map_prog_update(map, prog, attr->attach_type); 78 fdput(f); 79 return ret; 80 } 81 82 static void sock_map_sk_acquire(struct sock *sk) 83 __acquires(&sk->sk_lock.slock) 84 { 85 lock_sock(sk); 86 preempt_disable(); 87 rcu_read_lock(); 88 } 89 90 static void sock_map_sk_release(struct sock *sk) 91 __releases(&sk->sk_lock.slock) 92 { 93 rcu_read_unlock(); 94 preempt_enable(); 95 release_sock(sk); 96 } 97 98 static void sock_map_add_link(struct sk_psock *psock, 99 struct sk_psock_link *link, 100 struct bpf_map *map, void *link_raw) 101 { 102 link->link_raw = link_raw; 103 link->map = map; 104 spin_lock_bh(&psock->link_lock); 105 list_add_tail(&link->list, &psock->link); 106 spin_unlock_bh(&psock->link_lock); 107 } 108 109 static void sock_map_del_link(struct sock *sk, 110 struct sk_psock *psock, void *link_raw) 111 { 112 struct sk_psock_link *link, *tmp; 113 bool strp_stop = false; 114 115 spin_lock_bh(&psock->link_lock); 116 list_for_each_entry_safe(link, tmp, &psock->link, list) { 117 if (link->link_raw == link_raw) { 118 struct bpf_map *map = link->map; 119 struct bpf_stab *stab = container_of(map, struct bpf_stab, 120 map); 121 if (psock->parser.enabled && stab->progs.skb_parser) 122 strp_stop = true; 123 list_del(&link->list); 124 sk_psock_free_link(link); 125 } 126 } 127 spin_unlock_bh(&psock->link_lock); 128 if (strp_stop) { 129 write_lock_bh(&sk->sk_callback_lock); 130 sk_psock_stop_strp(sk, psock); 131 write_unlock_bh(&sk->sk_callback_lock); 132 } 133 } 134 135 static void sock_map_unref(struct sock *sk, void *link_raw) 136 { 137 struct sk_psock *psock = sk_psock(sk); 138 139 if (likely(psock)) { 140 sock_map_del_link(sk, psock, link_raw); 141 sk_psock_put(sk, psock); 142 } 143 } 144 145 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 146 { 147 struct proto *prot; 148 149 sock_owned_by_me(sk); 150 151 switch (sk->sk_type) { 152 case SOCK_STREAM: 153 prot = tcp_bpf_get_proto(sk, psock); 154 break; 155 156 case SOCK_DGRAM: 157 prot = udp_bpf_get_proto(sk, psock); 158 break; 159 160 default: 161 return -EINVAL; 162 } 163 164 if (IS_ERR(prot)) 165 return PTR_ERR(prot); 166 167 sk_psock_update_proto(sk, psock, prot); 168 return 0; 169 } 170 171 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 172 { 173 struct sk_psock *psock; 174 175 rcu_read_lock(); 176 psock = sk_psock(sk); 177 if (psock) { 178 if (sk->sk_prot->close != sock_map_close) { 179 psock = ERR_PTR(-EBUSY); 180 goto out; 181 } 182 183 if (!refcount_inc_not_zero(&psock->refcnt)) 184 psock = ERR_PTR(-EBUSY); 185 } 186 out: 187 rcu_read_unlock(); 188 return psock; 189 } 190 191 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 192 struct sock *sk) 193 { 194 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 195 struct sk_psock *psock; 196 bool skb_progs; 197 int ret; 198 199 skb_verdict = READ_ONCE(progs->skb_verdict); 200 skb_parser = READ_ONCE(progs->skb_parser); 201 skb_progs = skb_parser && skb_verdict; 202 if (skb_progs) { 203 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 204 if (IS_ERR(skb_verdict)) 205 return PTR_ERR(skb_verdict); 206 skb_parser = bpf_prog_inc_not_zero(skb_parser); 207 if (IS_ERR(skb_parser)) { 208 bpf_prog_put(skb_verdict); 209 return PTR_ERR(skb_parser); 210 } 211 } 212 213 msg_parser = READ_ONCE(progs->msg_parser); 214 if (msg_parser) { 215 msg_parser = bpf_prog_inc_not_zero(msg_parser); 216 if (IS_ERR(msg_parser)) { 217 ret = PTR_ERR(msg_parser); 218 goto out; 219 } 220 } 221 222 psock = sock_map_psock_get_checked(sk); 223 if (IS_ERR(psock)) { 224 ret = PTR_ERR(psock); 225 goto out_progs; 226 } 227 228 if (psock) { 229 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 230 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 231 sk_psock_put(sk, psock); 232 ret = -EBUSY; 233 goto out_progs; 234 } 235 } else { 236 psock = sk_psock_init(sk, map->numa_node); 237 if (!psock) { 238 ret = -ENOMEM; 239 goto out_progs; 240 } 241 } 242 243 if (msg_parser) 244 psock_set_prog(&psock->progs.msg_parser, msg_parser); 245 246 ret = sock_map_init_proto(sk, psock); 247 if (ret < 0) 248 goto out_drop; 249 250 write_lock_bh(&sk->sk_callback_lock); 251 if (skb_progs && !psock->parser.enabled) { 252 ret = sk_psock_init_strp(sk, psock); 253 if (ret) { 254 write_unlock_bh(&sk->sk_callback_lock); 255 goto out_drop; 256 } 257 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 258 psock_set_prog(&psock->progs.skb_parser, skb_parser); 259 sk_psock_start_strp(sk, psock); 260 } 261 write_unlock_bh(&sk->sk_callback_lock); 262 return 0; 263 out_drop: 264 sk_psock_put(sk, psock); 265 out_progs: 266 if (msg_parser) 267 bpf_prog_put(msg_parser); 268 out: 269 if (skb_progs) { 270 bpf_prog_put(skb_verdict); 271 bpf_prog_put(skb_parser); 272 } 273 return ret; 274 } 275 276 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) 277 { 278 struct sk_psock *psock; 279 int ret; 280 281 psock = sock_map_psock_get_checked(sk); 282 if (IS_ERR(psock)) 283 return PTR_ERR(psock); 284 285 if (!psock) { 286 psock = sk_psock_init(sk, map->numa_node); 287 if (!psock) 288 return -ENOMEM; 289 } 290 291 ret = sock_map_init_proto(sk, psock); 292 if (ret < 0) 293 sk_psock_put(sk, psock); 294 return ret; 295 } 296 297 static void sock_map_free(struct bpf_map *map) 298 { 299 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 300 int i; 301 302 /* After the sync no updates or deletes will be in-flight so it 303 * is safe to walk map and remove entries without risking a race 304 * in EEXIST update case. 305 */ 306 synchronize_rcu(); 307 for (i = 0; i < stab->map.max_entries; i++) { 308 struct sock **psk = &stab->sks[i]; 309 struct sock *sk; 310 311 sk = xchg(psk, NULL); 312 if (sk) { 313 lock_sock(sk); 314 rcu_read_lock(); 315 sock_map_unref(sk, psk); 316 rcu_read_unlock(); 317 release_sock(sk); 318 } 319 } 320 321 /* wait for psock readers accessing its map link */ 322 synchronize_rcu(); 323 324 bpf_map_area_free(stab->sks); 325 kfree(stab); 326 } 327 328 static void sock_map_release_progs(struct bpf_map *map) 329 { 330 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 331 } 332 333 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 334 { 335 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 336 337 WARN_ON_ONCE(!rcu_read_lock_held()); 338 339 if (unlikely(key >= map->max_entries)) 340 return NULL; 341 return READ_ONCE(stab->sks[key]); 342 } 343 344 static void *sock_map_lookup(struct bpf_map *map, void *key) 345 { 346 return __sock_map_lookup_elem(map, *(u32 *)key); 347 } 348 349 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 350 { 351 struct sock *sk; 352 353 if (map->value_size != sizeof(u64)) 354 return ERR_PTR(-ENOSPC); 355 356 sk = __sock_map_lookup_elem(map, *(u32 *)key); 357 if (!sk) 358 return ERR_PTR(-ENOENT); 359 360 sock_gen_cookie(sk); 361 return &sk->sk_cookie; 362 } 363 364 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 365 struct sock **psk) 366 { 367 struct sock *sk; 368 int err = 0; 369 370 raw_spin_lock_bh(&stab->lock); 371 sk = *psk; 372 if (!sk_test || sk_test == sk) 373 sk = xchg(psk, NULL); 374 375 if (likely(sk)) 376 sock_map_unref(sk, psk); 377 else 378 err = -EINVAL; 379 380 raw_spin_unlock_bh(&stab->lock); 381 return err; 382 } 383 384 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 385 void *link_raw) 386 { 387 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 388 389 __sock_map_delete(stab, sk, link_raw); 390 } 391 392 static int sock_map_delete_elem(struct bpf_map *map, void *key) 393 { 394 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 395 u32 i = *(u32 *)key; 396 struct sock **psk; 397 398 if (unlikely(i >= map->max_entries)) 399 return -EINVAL; 400 401 psk = &stab->sks[i]; 402 return __sock_map_delete(stab, NULL, psk); 403 } 404 405 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 406 { 407 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 408 u32 i = key ? *(u32 *)key : U32_MAX; 409 u32 *key_next = next; 410 411 if (i == stab->map.max_entries - 1) 412 return -ENOENT; 413 if (i >= stab->map.max_entries) 414 *key_next = 0; 415 else 416 *key_next = i + 1; 417 return 0; 418 } 419 420 static bool sock_map_redirect_allowed(const struct sock *sk) 421 { 422 return sk->sk_state != TCP_LISTEN; 423 } 424 425 static int sock_map_update_common(struct bpf_map *map, u32 idx, 426 struct sock *sk, u64 flags) 427 { 428 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 429 struct sk_psock_link *link; 430 struct sk_psock *psock; 431 struct sock *osk; 432 int ret; 433 434 WARN_ON_ONCE(!rcu_read_lock_held()); 435 if (unlikely(flags > BPF_EXIST)) 436 return -EINVAL; 437 if (unlikely(idx >= map->max_entries)) 438 return -E2BIG; 439 if (inet_csk_has_ulp(sk)) 440 return -EINVAL; 441 442 link = sk_psock_init_link(); 443 if (!link) 444 return -ENOMEM; 445 446 /* Only sockets we can redirect into/from in BPF need to hold 447 * refs to parser/verdict progs and have their sk_data_ready 448 * and sk_write_space callbacks overridden. 449 */ 450 if (sock_map_redirect_allowed(sk)) 451 ret = sock_map_link(map, &stab->progs, sk); 452 else 453 ret = sock_map_link_no_progs(map, sk); 454 if (ret < 0) 455 goto out_free; 456 457 psock = sk_psock(sk); 458 WARN_ON_ONCE(!psock); 459 460 raw_spin_lock_bh(&stab->lock); 461 osk = stab->sks[idx]; 462 if (osk && flags == BPF_NOEXIST) { 463 ret = -EEXIST; 464 goto out_unlock; 465 } else if (!osk && flags == BPF_EXIST) { 466 ret = -ENOENT; 467 goto out_unlock; 468 } 469 470 sock_map_add_link(psock, link, map, &stab->sks[idx]); 471 stab->sks[idx] = sk; 472 if (osk) 473 sock_map_unref(osk, &stab->sks[idx]); 474 raw_spin_unlock_bh(&stab->lock); 475 return 0; 476 out_unlock: 477 raw_spin_unlock_bh(&stab->lock); 478 if (psock) 479 sk_psock_put(sk, psock); 480 out_free: 481 sk_psock_free_link(link); 482 return ret; 483 } 484 485 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 486 { 487 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 488 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 489 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 490 } 491 492 static bool sk_is_tcp(const struct sock *sk) 493 { 494 return sk->sk_type == SOCK_STREAM && 495 sk->sk_protocol == IPPROTO_TCP; 496 } 497 498 static bool sk_is_udp(const struct sock *sk) 499 { 500 return sk->sk_type == SOCK_DGRAM && 501 sk->sk_protocol == IPPROTO_UDP; 502 } 503 504 static bool sock_map_sk_is_suitable(const struct sock *sk) 505 { 506 return sk_is_tcp(sk) || sk_is_udp(sk); 507 } 508 509 static bool sock_map_sk_state_allowed(const struct sock *sk) 510 { 511 if (sk_is_tcp(sk)) 512 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 513 else if (sk_is_udp(sk)) 514 return sk_hashed(sk); 515 516 return false; 517 } 518 519 static int sock_map_update_elem(struct bpf_map *map, void *key, 520 void *value, u64 flags) 521 { 522 u32 idx = *(u32 *)key; 523 struct socket *sock; 524 struct sock *sk; 525 int ret; 526 u64 ufd; 527 528 if (map->value_size == sizeof(u64)) 529 ufd = *(u64 *)value; 530 else 531 ufd = *(u32 *)value; 532 if (ufd > S32_MAX) 533 return -EINVAL; 534 535 sock = sockfd_lookup(ufd, &ret); 536 if (!sock) 537 return ret; 538 sk = sock->sk; 539 if (!sk) { 540 ret = -EINVAL; 541 goto out; 542 } 543 if (!sock_map_sk_is_suitable(sk)) { 544 ret = -EOPNOTSUPP; 545 goto out; 546 } 547 548 sock_map_sk_acquire(sk); 549 if (!sock_map_sk_state_allowed(sk)) 550 ret = -EOPNOTSUPP; 551 else 552 ret = sock_map_update_common(map, idx, sk, flags); 553 sock_map_sk_release(sk); 554 out: 555 fput(sock->file); 556 return ret; 557 } 558 559 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 560 struct bpf_map *, map, void *, key, u64, flags) 561 { 562 WARN_ON_ONCE(!rcu_read_lock_held()); 563 564 if (likely(sock_map_sk_is_suitable(sops->sk) && 565 sock_map_op_okay(sops))) 566 return sock_map_update_common(map, *(u32 *)key, sops->sk, 567 flags); 568 return -EOPNOTSUPP; 569 } 570 571 const struct bpf_func_proto bpf_sock_map_update_proto = { 572 .func = bpf_sock_map_update, 573 .gpl_only = false, 574 .pkt_access = true, 575 .ret_type = RET_INTEGER, 576 .arg1_type = ARG_PTR_TO_CTX, 577 .arg2_type = ARG_CONST_MAP_PTR, 578 .arg3_type = ARG_PTR_TO_MAP_KEY, 579 .arg4_type = ARG_ANYTHING, 580 }; 581 582 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 583 struct bpf_map *, map, u32, key, u64, flags) 584 { 585 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 586 struct sock *sk; 587 588 if (unlikely(flags & ~(BPF_F_INGRESS))) 589 return SK_DROP; 590 591 sk = __sock_map_lookup_elem(map, key); 592 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 593 return SK_DROP; 594 595 tcb->bpf.flags = flags; 596 tcb->bpf.sk_redir = sk; 597 return SK_PASS; 598 } 599 600 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 601 .func = bpf_sk_redirect_map, 602 .gpl_only = false, 603 .ret_type = RET_INTEGER, 604 .arg1_type = ARG_PTR_TO_CTX, 605 .arg2_type = ARG_CONST_MAP_PTR, 606 .arg3_type = ARG_ANYTHING, 607 .arg4_type = ARG_ANYTHING, 608 }; 609 610 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 611 struct bpf_map *, map, u32, key, u64, flags) 612 { 613 struct sock *sk; 614 615 if (unlikely(flags & ~(BPF_F_INGRESS))) 616 return SK_DROP; 617 618 sk = __sock_map_lookup_elem(map, key); 619 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 620 return SK_DROP; 621 622 msg->flags = flags; 623 msg->sk_redir = sk; 624 return SK_PASS; 625 } 626 627 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 628 .func = bpf_msg_redirect_map, 629 .gpl_only = false, 630 .ret_type = RET_INTEGER, 631 .arg1_type = ARG_PTR_TO_CTX, 632 .arg2_type = ARG_CONST_MAP_PTR, 633 .arg3_type = ARG_ANYTHING, 634 .arg4_type = ARG_ANYTHING, 635 }; 636 637 const struct bpf_map_ops sock_map_ops = { 638 .map_alloc = sock_map_alloc, 639 .map_free = sock_map_free, 640 .map_get_next_key = sock_map_get_next_key, 641 .map_lookup_elem_sys_only = sock_map_lookup_sys, 642 .map_update_elem = sock_map_update_elem, 643 .map_delete_elem = sock_map_delete_elem, 644 .map_lookup_elem = sock_map_lookup, 645 .map_release_uref = sock_map_release_progs, 646 .map_check_btf = map_check_no_btf, 647 }; 648 649 struct bpf_htab_elem { 650 struct rcu_head rcu; 651 u32 hash; 652 struct sock *sk; 653 struct hlist_node node; 654 u8 key[]; 655 }; 656 657 struct bpf_htab_bucket { 658 struct hlist_head head; 659 raw_spinlock_t lock; 660 }; 661 662 struct bpf_htab { 663 struct bpf_map map; 664 struct bpf_htab_bucket *buckets; 665 u32 buckets_num; 666 u32 elem_size; 667 struct sk_psock_progs progs; 668 atomic_t count; 669 }; 670 671 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 672 { 673 return jhash(key, len, 0); 674 } 675 676 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 677 u32 hash) 678 { 679 return &htab->buckets[hash & (htab->buckets_num - 1)]; 680 } 681 682 static struct bpf_htab_elem * 683 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 684 u32 key_size) 685 { 686 struct bpf_htab_elem *elem; 687 688 hlist_for_each_entry_rcu(elem, head, node) { 689 if (elem->hash == hash && 690 !memcmp(&elem->key, key, key_size)) 691 return elem; 692 } 693 694 return NULL; 695 } 696 697 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 698 { 699 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 700 u32 key_size = map->key_size, hash; 701 struct bpf_htab_bucket *bucket; 702 struct bpf_htab_elem *elem; 703 704 WARN_ON_ONCE(!rcu_read_lock_held()); 705 706 hash = sock_hash_bucket_hash(key, key_size); 707 bucket = sock_hash_select_bucket(htab, hash); 708 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 709 710 return elem ? elem->sk : NULL; 711 } 712 713 static void sock_hash_free_elem(struct bpf_htab *htab, 714 struct bpf_htab_elem *elem) 715 { 716 atomic_dec(&htab->count); 717 kfree_rcu(elem, rcu); 718 } 719 720 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 721 void *link_raw) 722 { 723 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 724 struct bpf_htab_elem *elem_probe, *elem = link_raw; 725 struct bpf_htab_bucket *bucket; 726 727 WARN_ON_ONCE(!rcu_read_lock_held()); 728 bucket = sock_hash_select_bucket(htab, elem->hash); 729 730 /* elem may be deleted in parallel from the map, but access here 731 * is okay since it's going away only after RCU grace period. 732 * However, we need to check whether it's still present. 733 */ 734 raw_spin_lock_bh(&bucket->lock); 735 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 736 elem->key, map->key_size); 737 if (elem_probe && elem_probe == elem) { 738 hlist_del_rcu(&elem->node); 739 sock_map_unref(elem->sk, elem); 740 sock_hash_free_elem(htab, elem); 741 } 742 raw_spin_unlock_bh(&bucket->lock); 743 } 744 745 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 746 { 747 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 748 u32 hash, key_size = map->key_size; 749 struct bpf_htab_bucket *bucket; 750 struct bpf_htab_elem *elem; 751 int ret = -ENOENT; 752 753 hash = sock_hash_bucket_hash(key, key_size); 754 bucket = sock_hash_select_bucket(htab, hash); 755 756 raw_spin_lock_bh(&bucket->lock); 757 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 758 if (elem) { 759 hlist_del_rcu(&elem->node); 760 sock_map_unref(elem->sk, elem); 761 sock_hash_free_elem(htab, elem); 762 ret = 0; 763 } 764 raw_spin_unlock_bh(&bucket->lock); 765 return ret; 766 } 767 768 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 769 void *key, u32 key_size, 770 u32 hash, struct sock *sk, 771 struct bpf_htab_elem *old) 772 { 773 struct bpf_htab_elem *new; 774 775 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 776 if (!old) { 777 atomic_dec(&htab->count); 778 return ERR_PTR(-E2BIG); 779 } 780 } 781 782 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 783 htab->map.numa_node); 784 if (!new) { 785 atomic_dec(&htab->count); 786 return ERR_PTR(-ENOMEM); 787 } 788 memcpy(new->key, key, key_size); 789 new->sk = sk; 790 new->hash = hash; 791 return new; 792 } 793 794 static int sock_hash_update_common(struct bpf_map *map, void *key, 795 struct sock *sk, u64 flags) 796 { 797 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 798 u32 key_size = map->key_size, hash; 799 struct bpf_htab_elem *elem, *elem_new; 800 struct bpf_htab_bucket *bucket; 801 struct sk_psock_link *link; 802 struct sk_psock *psock; 803 int ret; 804 805 WARN_ON_ONCE(!rcu_read_lock_held()); 806 if (unlikely(flags > BPF_EXIST)) 807 return -EINVAL; 808 if (inet_csk_has_ulp(sk)) 809 return -EINVAL; 810 811 link = sk_psock_init_link(); 812 if (!link) 813 return -ENOMEM; 814 815 /* Only sockets we can redirect into/from in BPF need to hold 816 * refs to parser/verdict progs and have their sk_data_ready 817 * and sk_write_space callbacks overridden. 818 */ 819 if (sock_map_redirect_allowed(sk)) 820 ret = sock_map_link(map, &htab->progs, sk); 821 else 822 ret = sock_map_link_no_progs(map, sk); 823 if (ret < 0) 824 goto out_free; 825 826 psock = sk_psock(sk); 827 WARN_ON_ONCE(!psock); 828 829 hash = sock_hash_bucket_hash(key, key_size); 830 bucket = sock_hash_select_bucket(htab, hash); 831 832 raw_spin_lock_bh(&bucket->lock); 833 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 834 if (elem && flags == BPF_NOEXIST) { 835 ret = -EEXIST; 836 goto out_unlock; 837 } else if (!elem && flags == BPF_EXIST) { 838 ret = -ENOENT; 839 goto out_unlock; 840 } 841 842 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 843 if (IS_ERR(elem_new)) { 844 ret = PTR_ERR(elem_new); 845 goto out_unlock; 846 } 847 848 sock_map_add_link(psock, link, map, elem_new); 849 /* Add new element to the head of the list, so that 850 * concurrent search will find it before old elem. 851 */ 852 hlist_add_head_rcu(&elem_new->node, &bucket->head); 853 if (elem) { 854 hlist_del_rcu(&elem->node); 855 sock_map_unref(elem->sk, elem); 856 sock_hash_free_elem(htab, elem); 857 } 858 raw_spin_unlock_bh(&bucket->lock); 859 return 0; 860 out_unlock: 861 raw_spin_unlock_bh(&bucket->lock); 862 sk_psock_put(sk, psock); 863 out_free: 864 sk_psock_free_link(link); 865 return ret; 866 } 867 868 static int sock_hash_update_elem(struct bpf_map *map, void *key, 869 void *value, u64 flags) 870 { 871 struct socket *sock; 872 struct sock *sk; 873 int ret; 874 u64 ufd; 875 876 if (map->value_size == sizeof(u64)) 877 ufd = *(u64 *)value; 878 else 879 ufd = *(u32 *)value; 880 if (ufd > S32_MAX) 881 return -EINVAL; 882 883 sock = sockfd_lookup(ufd, &ret); 884 if (!sock) 885 return ret; 886 sk = sock->sk; 887 if (!sk) { 888 ret = -EINVAL; 889 goto out; 890 } 891 if (!sock_map_sk_is_suitable(sk)) { 892 ret = -EOPNOTSUPP; 893 goto out; 894 } 895 896 sock_map_sk_acquire(sk); 897 if (!sock_map_sk_state_allowed(sk)) 898 ret = -EOPNOTSUPP; 899 else 900 ret = sock_hash_update_common(map, key, sk, flags); 901 sock_map_sk_release(sk); 902 out: 903 fput(sock->file); 904 return ret; 905 } 906 907 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 908 void *key_next) 909 { 910 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 911 struct bpf_htab_elem *elem, *elem_next; 912 u32 hash, key_size = map->key_size; 913 struct hlist_head *head; 914 int i = 0; 915 916 if (!key) 917 goto find_first_elem; 918 hash = sock_hash_bucket_hash(key, key_size); 919 head = &sock_hash_select_bucket(htab, hash)->head; 920 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 921 if (!elem) 922 goto find_first_elem; 923 924 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 925 struct bpf_htab_elem, node); 926 if (elem_next) { 927 memcpy(key_next, elem_next->key, key_size); 928 return 0; 929 } 930 931 i = hash & (htab->buckets_num - 1); 932 i++; 933 find_first_elem: 934 for (; i < htab->buckets_num; i++) { 935 head = &sock_hash_select_bucket(htab, i)->head; 936 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 937 struct bpf_htab_elem, node); 938 if (elem_next) { 939 memcpy(key_next, elem_next->key, key_size); 940 return 0; 941 } 942 } 943 944 return -ENOENT; 945 } 946 947 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 948 { 949 struct bpf_htab *htab; 950 int i, err; 951 u64 cost; 952 953 if (!capable(CAP_NET_ADMIN)) 954 return ERR_PTR(-EPERM); 955 if (attr->max_entries == 0 || 956 attr->key_size == 0 || 957 (attr->value_size != sizeof(u32) && 958 attr->value_size != sizeof(u64)) || 959 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 960 return ERR_PTR(-EINVAL); 961 if (attr->key_size > MAX_BPF_STACK) 962 return ERR_PTR(-E2BIG); 963 964 htab = kzalloc(sizeof(*htab), GFP_USER); 965 if (!htab) 966 return ERR_PTR(-ENOMEM); 967 968 bpf_map_init_from_attr(&htab->map, attr); 969 970 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 971 htab->elem_size = sizeof(struct bpf_htab_elem) + 972 round_up(htab->map.key_size, 8); 973 if (htab->buckets_num == 0 || 974 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 975 err = -EINVAL; 976 goto free_htab; 977 } 978 979 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 980 (u64) htab->elem_size * htab->map.max_entries; 981 if (cost >= U32_MAX - PAGE_SIZE) { 982 err = -EINVAL; 983 goto free_htab; 984 } 985 986 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 987 sizeof(struct bpf_htab_bucket), 988 htab->map.numa_node); 989 if (!htab->buckets) { 990 err = -ENOMEM; 991 goto free_htab; 992 } 993 994 for (i = 0; i < htab->buckets_num; i++) { 995 INIT_HLIST_HEAD(&htab->buckets[i].head); 996 raw_spin_lock_init(&htab->buckets[i].lock); 997 } 998 999 return &htab->map; 1000 free_htab: 1001 kfree(htab); 1002 return ERR_PTR(err); 1003 } 1004 1005 static void sock_hash_free(struct bpf_map *map) 1006 { 1007 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 1008 struct bpf_htab_bucket *bucket; 1009 struct bpf_htab_elem *elem; 1010 struct hlist_node *node; 1011 int i; 1012 1013 /* After the sync no updates or deletes will be in-flight so it 1014 * is safe to walk map and remove entries without risking a race 1015 * in EEXIST update case. 1016 */ 1017 synchronize_rcu(); 1018 for (i = 0; i < htab->buckets_num; i++) { 1019 bucket = sock_hash_select_bucket(htab, i); 1020 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 1021 hlist_del_rcu(&elem->node); 1022 lock_sock(elem->sk); 1023 rcu_read_lock(); 1024 sock_map_unref(elem->sk, elem); 1025 rcu_read_unlock(); 1026 release_sock(elem->sk); 1027 } 1028 } 1029 1030 /* wait for psock readers accessing its map link */ 1031 synchronize_rcu(); 1032 1033 bpf_map_area_free(htab->buckets); 1034 kfree(htab); 1035 } 1036 1037 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1038 { 1039 struct sock *sk; 1040 1041 if (map->value_size != sizeof(u64)) 1042 return ERR_PTR(-ENOSPC); 1043 1044 sk = __sock_hash_lookup_elem(map, key); 1045 if (!sk) 1046 return ERR_PTR(-ENOENT); 1047 1048 sock_gen_cookie(sk); 1049 return &sk->sk_cookie; 1050 } 1051 1052 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1053 { 1054 return __sock_hash_lookup_elem(map, key); 1055 } 1056 1057 static void sock_hash_release_progs(struct bpf_map *map) 1058 { 1059 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 1060 } 1061 1062 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1063 struct bpf_map *, map, void *, key, u64, flags) 1064 { 1065 WARN_ON_ONCE(!rcu_read_lock_held()); 1066 1067 if (likely(sock_map_sk_is_suitable(sops->sk) && 1068 sock_map_op_okay(sops))) 1069 return sock_hash_update_common(map, key, sops->sk, flags); 1070 return -EOPNOTSUPP; 1071 } 1072 1073 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1074 .func = bpf_sock_hash_update, 1075 .gpl_only = false, 1076 .pkt_access = true, 1077 .ret_type = RET_INTEGER, 1078 .arg1_type = ARG_PTR_TO_CTX, 1079 .arg2_type = ARG_CONST_MAP_PTR, 1080 .arg3_type = ARG_PTR_TO_MAP_KEY, 1081 .arg4_type = ARG_ANYTHING, 1082 }; 1083 1084 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1085 struct bpf_map *, map, void *, key, u64, flags) 1086 { 1087 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 1088 struct sock *sk; 1089 1090 if (unlikely(flags & ~(BPF_F_INGRESS))) 1091 return SK_DROP; 1092 1093 sk = __sock_hash_lookup_elem(map, key); 1094 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1095 return SK_DROP; 1096 1097 tcb->bpf.flags = flags; 1098 tcb->bpf.sk_redir = sk; 1099 return SK_PASS; 1100 } 1101 1102 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1103 .func = bpf_sk_redirect_hash, 1104 .gpl_only = false, 1105 .ret_type = RET_INTEGER, 1106 .arg1_type = ARG_PTR_TO_CTX, 1107 .arg2_type = ARG_CONST_MAP_PTR, 1108 .arg3_type = ARG_PTR_TO_MAP_KEY, 1109 .arg4_type = ARG_ANYTHING, 1110 }; 1111 1112 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1113 struct bpf_map *, map, void *, key, u64, flags) 1114 { 1115 struct sock *sk; 1116 1117 if (unlikely(flags & ~(BPF_F_INGRESS))) 1118 return SK_DROP; 1119 1120 sk = __sock_hash_lookup_elem(map, key); 1121 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1122 return SK_DROP; 1123 1124 msg->flags = flags; 1125 msg->sk_redir = sk; 1126 return SK_PASS; 1127 } 1128 1129 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1130 .func = bpf_msg_redirect_hash, 1131 .gpl_only = false, 1132 .ret_type = RET_INTEGER, 1133 .arg1_type = ARG_PTR_TO_CTX, 1134 .arg2_type = ARG_CONST_MAP_PTR, 1135 .arg3_type = ARG_PTR_TO_MAP_KEY, 1136 .arg4_type = ARG_ANYTHING, 1137 }; 1138 1139 const struct bpf_map_ops sock_hash_ops = { 1140 .map_alloc = sock_hash_alloc, 1141 .map_free = sock_hash_free, 1142 .map_get_next_key = sock_hash_get_next_key, 1143 .map_update_elem = sock_hash_update_elem, 1144 .map_delete_elem = sock_hash_delete_elem, 1145 .map_lookup_elem = sock_hash_lookup, 1146 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1147 .map_release_uref = sock_hash_release_progs, 1148 .map_check_btf = map_check_no_btf, 1149 }; 1150 1151 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1152 { 1153 switch (map->map_type) { 1154 case BPF_MAP_TYPE_SOCKMAP: 1155 return &container_of(map, struct bpf_stab, map)->progs; 1156 case BPF_MAP_TYPE_SOCKHASH: 1157 return &container_of(map, struct bpf_htab, map)->progs; 1158 default: 1159 break; 1160 } 1161 1162 return NULL; 1163 } 1164 1165 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1166 u32 which) 1167 { 1168 struct sk_psock_progs *progs = sock_map_progs(map); 1169 1170 if (!progs) 1171 return -EOPNOTSUPP; 1172 1173 switch (which) { 1174 case BPF_SK_MSG_VERDICT: 1175 psock_set_prog(&progs->msg_parser, prog); 1176 break; 1177 case BPF_SK_SKB_STREAM_PARSER: 1178 psock_set_prog(&progs->skb_parser, prog); 1179 break; 1180 case BPF_SK_SKB_STREAM_VERDICT: 1181 psock_set_prog(&progs->skb_verdict, prog); 1182 break; 1183 default: 1184 return -EOPNOTSUPP; 1185 } 1186 1187 return 0; 1188 } 1189 1190 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1191 { 1192 switch (link->map->map_type) { 1193 case BPF_MAP_TYPE_SOCKMAP: 1194 return sock_map_delete_from_link(link->map, sk, 1195 link->link_raw); 1196 case BPF_MAP_TYPE_SOCKHASH: 1197 return sock_hash_delete_from_link(link->map, sk, 1198 link->link_raw); 1199 default: 1200 break; 1201 } 1202 } 1203 1204 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1205 { 1206 struct sk_psock_link *link; 1207 1208 while ((link = sk_psock_link_pop(psock))) { 1209 sock_map_unlink(sk, link); 1210 sk_psock_free_link(link); 1211 } 1212 } 1213 1214 void sock_map_unhash(struct sock *sk) 1215 { 1216 void (*saved_unhash)(struct sock *sk); 1217 struct sk_psock *psock; 1218 1219 rcu_read_lock(); 1220 psock = sk_psock(sk); 1221 if (unlikely(!psock)) { 1222 rcu_read_unlock(); 1223 if (sk->sk_prot->unhash) 1224 sk->sk_prot->unhash(sk); 1225 return; 1226 } 1227 1228 saved_unhash = psock->saved_unhash; 1229 sock_map_remove_links(sk, psock); 1230 rcu_read_unlock(); 1231 saved_unhash(sk); 1232 } 1233 1234 void sock_map_close(struct sock *sk, long timeout) 1235 { 1236 void (*saved_close)(struct sock *sk, long timeout); 1237 struct sk_psock *psock; 1238 1239 lock_sock(sk); 1240 rcu_read_lock(); 1241 psock = sk_psock(sk); 1242 if (unlikely(!psock)) { 1243 rcu_read_unlock(); 1244 release_sock(sk); 1245 return sk->sk_prot->close(sk, timeout); 1246 } 1247 1248 saved_close = psock->saved_close; 1249 sock_map_remove_links(sk, psock); 1250 rcu_read_unlock(); 1251 release_sock(sk); 1252 saved_close(sk, timeout); 1253 } 1254