1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 14 struct bpf_stab { 15 struct bpf_map map; 16 struct sock **sks; 17 struct sk_psock_progs progs; 18 raw_spinlock_t lock; 19 }; 20 21 #define SOCK_CREATE_FLAG_MASK \ 22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 23 24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 25 { 26 struct bpf_stab *stab; 27 u64 cost; 28 int err; 29 30 if (!capable(CAP_NET_ADMIN)) 31 return ERR_PTR(-EPERM); 32 if (attr->max_entries == 0 || 33 attr->key_size != 4 || 34 attr->value_size != 4 || 35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 36 return ERR_PTR(-EINVAL); 37 38 stab = kzalloc(sizeof(*stab), GFP_USER); 39 if (!stab) 40 return ERR_PTR(-ENOMEM); 41 42 bpf_map_init_from_attr(&stab->map, attr); 43 raw_spin_lock_init(&stab->lock); 44 45 /* Make sure page count doesn't overflow. */ 46 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 47 err = bpf_map_charge_init(&stab->map.memory, cost); 48 if (err) 49 goto free_stab; 50 51 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (stab->sks) 55 return &stab->map; 56 err = -ENOMEM; 57 bpf_map_charge_finish(&stab->map.memory); 58 free_stab: 59 kfree(stab); 60 return ERR_PTR(err); 61 } 62 63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 64 { 65 u32 ufd = attr->target_fd; 66 struct bpf_map *map; 67 struct fd f; 68 int ret; 69 70 f = fdget(ufd); 71 map = __bpf_map_get(f); 72 if (IS_ERR(map)) 73 return PTR_ERR(map); 74 ret = sock_map_prog_update(map, prog, attr->attach_type); 75 fdput(f); 76 return ret; 77 } 78 79 static void sock_map_sk_acquire(struct sock *sk) 80 __acquires(&sk->sk_lock.slock) 81 { 82 lock_sock(sk); 83 preempt_disable(); 84 rcu_read_lock(); 85 } 86 87 static void sock_map_sk_release(struct sock *sk) 88 __releases(&sk->sk_lock.slock) 89 { 90 rcu_read_unlock(); 91 preempt_enable(); 92 release_sock(sk); 93 } 94 95 static void sock_map_add_link(struct sk_psock *psock, 96 struct sk_psock_link *link, 97 struct bpf_map *map, void *link_raw) 98 { 99 link->link_raw = link_raw; 100 link->map = map; 101 spin_lock_bh(&psock->link_lock); 102 list_add_tail(&link->list, &psock->link); 103 spin_unlock_bh(&psock->link_lock); 104 } 105 106 static void sock_map_del_link(struct sock *sk, 107 struct sk_psock *psock, void *link_raw) 108 { 109 struct sk_psock_link *link, *tmp; 110 bool strp_stop = false; 111 112 spin_lock_bh(&psock->link_lock); 113 list_for_each_entry_safe(link, tmp, &psock->link, list) { 114 if (link->link_raw == link_raw) { 115 struct bpf_map *map = link->map; 116 struct bpf_stab *stab = container_of(map, struct bpf_stab, 117 map); 118 if (psock->parser.enabled && stab->progs.skb_parser) 119 strp_stop = true; 120 list_del(&link->list); 121 sk_psock_free_link(link); 122 } 123 } 124 spin_unlock_bh(&psock->link_lock); 125 if (strp_stop) { 126 write_lock_bh(&sk->sk_callback_lock); 127 sk_psock_stop_strp(sk, psock); 128 write_unlock_bh(&sk->sk_callback_lock); 129 } 130 } 131 132 static void sock_map_unref(struct sock *sk, void *link_raw) 133 { 134 struct sk_psock *psock = sk_psock(sk); 135 136 if (likely(psock)) { 137 sock_map_del_link(sk, psock, link_raw); 138 sk_psock_put(sk, psock); 139 } 140 } 141 142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 143 struct sock *sk) 144 { 145 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 146 bool skb_progs, sk_psock_is_new = false; 147 struct sk_psock *psock; 148 int ret; 149 150 skb_verdict = READ_ONCE(progs->skb_verdict); 151 skb_parser = READ_ONCE(progs->skb_parser); 152 skb_progs = skb_parser && skb_verdict; 153 if (skb_progs) { 154 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 155 if (IS_ERR(skb_verdict)) 156 return PTR_ERR(skb_verdict); 157 skb_parser = bpf_prog_inc_not_zero(skb_parser); 158 if (IS_ERR(skb_parser)) { 159 bpf_prog_put(skb_verdict); 160 return PTR_ERR(skb_parser); 161 } 162 } 163 164 msg_parser = READ_ONCE(progs->msg_parser); 165 if (msg_parser) { 166 msg_parser = bpf_prog_inc_not_zero(msg_parser); 167 if (IS_ERR(msg_parser)) { 168 ret = PTR_ERR(msg_parser); 169 goto out; 170 } 171 } 172 173 psock = sk_psock_get_checked(sk); 174 if (IS_ERR(psock)) { 175 ret = PTR_ERR(psock); 176 goto out_progs; 177 } 178 179 if (psock) { 180 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 181 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 182 sk_psock_put(sk, psock); 183 ret = -EBUSY; 184 goto out_progs; 185 } 186 } else { 187 psock = sk_psock_init(sk, map->numa_node); 188 if (!psock) { 189 ret = -ENOMEM; 190 goto out_progs; 191 } 192 sk_psock_is_new = true; 193 } 194 195 if (msg_parser) 196 psock_set_prog(&psock->progs.msg_parser, msg_parser); 197 if (sk_psock_is_new) { 198 ret = tcp_bpf_init(sk); 199 if (ret < 0) 200 goto out_drop; 201 } else { 202 tcp_bpf_reinit(sk); 203 } 204 205 write_lock_bh(&sk->sk_callback_lock); 206 if (skb_progs && !psock->parser.enabled) { 207 ret = sk_psock_init_strp(sk, psock); 208 if (ret) { 209 write_unlock_bh(&sk->sk_callback_lock); 210 goto out_drop; 211 } 212 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 213 psock_set_prog(&psock->progs.skb_parser, skb_parser); 214 sk_psock_start_strp(sk, psock); 215 } 216 write_unlock_bh(&sk->sk_callback_lock); 217 return 0; 218 out_drop: 219 sk_psock_put(sk, psock); 220 out_progs: 221 if (msg_parser) 222 bpf_prog_put(msg_parser); 223 out: 224 if (skb_progs) { 225 bpf_prog_put(skb_verdict); 226 bpf_prog_put(skb_parser); 227 } 228 return ret; 229 } 230 231 static void sock_map_free(struct bpf_map *map) 232 { 233 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 234 int i; 235 236 synchronize_rcu(); 237 rcu_read_lock(); 238 raw_spin_lock_bh(&stab->lock); 239 for (i = 0; i < stab->map.max_entries; i++) { 240 struct sock **psk = &stab->sks[i]; 241 struct sock *sk; 242 243 sk = xchg(psk, NULL); 244 if (sk) 245 sock_map_unref(sk, psk); 246 } 247 raw_spin_unlock_bh(&stab->lock); 248 rcu_read_unlock(); 249 250 synchronize_rcu(); 251 252 bpf_map_area_free(stab->sks); 253 kfree(stab); 254 } 255 256 static void sock_map_release_progs(struct bpf_map *map) 257 { 258 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 259 } 260 261 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 262 { 263 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 264 265 WARN_ON_ONCE(!rcu_read_lock_held()); 266 267 if (unlikely(key >= map->max_entries)) 268 return NULL; 269 return READ_ONCE(stab->sks[key]); 270 } 271 272 static void *sock_map_lookup(struct bpf_map *map, void *key) 273 { 274 return ERR_PTR(-EOPNOTSUPP); 275 } 276 277 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 278 struct sock **psk) 279 { 280 struct sock *sk; 281 int err = 0; 282 283 raw_spin_lock_bh(&stab->lock); 284 sk = *psk; 285 if (!sk_test || sk_test == sk) 286 sk = xchg(psk, NULL); 287 288 if (likely(sk)) 289 sock_map_unref(sk, psk); 290 else 291 err = -EINVAL; 292 293 raw_spin_unlock_bh(&stab->lock); 294 return err; 295 } 296 297 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 298 void *link_raw) 299 { 300 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 301 302 __sock_map_delete(stab, sk, link_raw); 303 } 304 305 static int sock_map_delete_elem(struct bpf_map *map, void *key) 306 { 307 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 308 u32 i = *(u32 *)key; 309 struct sock **psk; 310 311 if (unlikely(i >= map->max_entries)) 312 return -EINVAL; 313 314 psk = &stab->sks[i]; 315 return __sock_map_delete(stab, NULL, psk); 316 } 317 318 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 319 { 320 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 321 u32 i = key ? *(u32 *)key : U32_MAX; 322 u32 *key_next = next; 323 324 if (i == stab->map.max_entries - 1) 325 return -ENOENT; 326 if (i >= stab->map.max_entries) 327 *key_next = 0; 328 else 329 *key_next = i + 1; 330 return 0; 331 } 332 333 static int sock_map_update_common(struct bpf_map *map, u32 idx, 334 struct sock *sk, u64 flags) 335 { 336 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 337 struct inet_connection_sock *icsk = inet_csk(sk); 338 struct sk_psock_link *link; 339 struct sk_psock *psock; 340 struct sock *osk; 341 int ret; 342 343 WARN_ON_ONCE(!rcu_read_lock_held()); 344 if (unlikely(flags > BPF_EXIST)) 345 return -EINVAL; 346 if (unlikely(idx >= map->max_entries)) 347 return -E2BIG; 348 if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data))) 349 return -EINVAL; 350 351 link = sk_psock_init_link(); 352 if (!link) 353 return -ENOMEM; 354 355 ret = sock_map_link(map, &stab->progs, sk); 356 if (ret < 0) 357 goto out_free; 358 359 psock = sk_psock(sk); 360 WARN_ON_ONCE(!psock); 361 362 raw_spin_lock_bh(&stab->lock); 363 osk = stab->sks[idx]; 364 if (osk && flags == BPF_NOEXIST) { 365 ret = -EEXIST; 366 goto out_unlock; 367 } else if (!osk && flags == BPF_EXIST) { 368 ret = -ENOENT; 369 goto out_unlock; 370 } 371 372 sock_map_add_link(psock, link, map, &stab->sks[idx]); 373 stab->sks[idx] = sk; 374 if (osk) 375 sock_map_unref(osk, &stab->sks[idx]); 376 raw_spin_unlock_bh(&stab->lock); 377 return 0; 378 out_unlock: 379 raw_spin_unlock_bh(&stab->lock); 380 if (psock) 381 sk_psock_put(sk, psock); 382 out_free: 383 sk_psock_free_link(link); 384 return ret; 385 } 386 387 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 388 { 389 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 390 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; 391 } 392 393 static bool sock_map_sk_is_suitable(const struct sock *sk) 394 { 395 return sk->sk_type == SOCK_STREAM && 396 sk->sk_protocol == IPPROTO_TCP; 397 } 398 399 static int sock_map_update_elem(struct bpf_map *map, void *key, 400 void *value, u64 flags) 401 { 402 u32 ufd = *(u32 *)value; 403 u32 idx = *(u32 *)key; 404 struct socket *sock; 405 struct sock *sk; 406 int ret; 407 408 sock = sockfd_lookup(ufd, &ret); 409 if (!sock) 410 return ret; 411 sk = sock->sk; 412 if (!sk) { 413 ret = -EINVAL; 414 goto out; 415 } 416 if (!sock_map_sk_is_suitable(sk) || 417 sk->sk_state != TCP_ESTABLISHED) { 418 ret = -EOPNOTSUPP; 419 goto out; 420 } 421 422 sock_map_sk_acquire(sk); 423 ret = sock_map_update_common(map, idx, sk, flags); 424 sock_map_sk_release(sk); 425 out: 426 fput(sock->file); 427 return ret; 428 } 429 430 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 431 struct bpf_map *, map, void *, key, u64, flags) 432 { 433 WARN_ON_ONCE(!rcu_read_lock_held()); 434 435 if (likely(sock_map_sk_is_suitable(sops->sk) && 436 sock_map_op_okay(sops))) 437 return sock_map_update_common(map, *(u32 *)key, sops->sk, 438 flags); 439 return -EOPNOTSUPP; 440 } 441 442 const struct bpf_func_proto bpf_sock_map_update_proto = { 443 .func = bpf_sock_map_update, 444 .gpl_only = false, 445 .pkt_access = true, 446 .ret_type = RET_INTEGER, 447 .arg1_type = ARG_PTR_TO_CTX, 448 .arg2_type = ARG_CONST_MAP_PTR, 449 .arg3_type = ARG_PTR_TO_MAP_KEY, 450 .arg4_type = ARG_ANYTHING, 451 }; 452 453 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 454 struct bpf_map *, map, u32, key, u64, flags) 455 { 456 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 457 458 if (unlikely(flags & ~(BPF_F_INGRESS))) 459 return SK_DROP; 460 tcb->bpf.flags = flags; 461 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); 462 if (!tcb->bpf.sk_redir) 463 return SK_DROP; 464 return SK_PASS; 465 } 466 467 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 468 .func = bpf_sk_redirect_map, 469 .gpl_only = false, 470 .ret_type = RET_INTEGER, 471 .arg1_type = ARG_PTR_TO_CTX, 472 .arg2_type = ARG_CONST_MAP_PTR, 473 .arg3_type = ARG_ANYTHING, 474 .arg4_type = ARG_ANYTHING, 475 }; 476 477 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 478 struct bpf_map *, map, u32, key, u64, flags) 479 { 480 if (unlikely(flags & ~(BPF_F_INGRESS))) 481 return SK_DROP; 482 msg->flags = flags; 483 msg->sk_redir = __sock_map_lookup_elem(map, key); 484 if (!msg->sk_redir) 485 return SK_DROP; 486 return SK_PASS; 487 } 488 489 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 490 .func = bpf_msg_redirect_map, 491 .gpl_only = false, 492 .ret_type = RET_INTEGER, 493 .arg1_type = ARG_PTR_TO_CTX, 494 .arg2_type = ARG_CONST_MAP_PTR, 495 .arg3_type = ARG_ANYTHING, 496 .arg4_type = ARG_ANYTHING, 497 }; 498 499 const struct bpf_map_ops sock_map_ops = { 500 .map_alloc = sock_map_alloc, 501 .map_free = sock_map_free, 502 .map_get_next_key = sock_map_get_next_key, 503 .map_update_elem = sock_map_update_elem, 504 .map_delete_elem = sock_map_delete_elem, 505 .map_lookup_elem = sock_map_lookup, 506 .map_release_uref = sock_map_release_progs, 507 .map_check_btf = map_check_no_btf, 508 }; 509 510 struct bpf_htab_elem { 511 struct rcu_head rcu; 512 u32 hash; 513 struct sock *sk; 514 struct hlist_node node; 515 u8 key[0]; 516 }; 517 518 struct bpf_htab_bucket { 519 struct hlist_head head; 520 raw_spinlock_t lock; 521 }; 522 523 struct bpf_htab { 524 struct bpf_map map; 525 struct bpf_htab_bucket *buckets; 526 u32 buckets_num; 527 u32 elem_size; 528 struct sk_psock_progs progs; 529 atomic_t count; 530 }; 531 532 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 533 { 534 return jhash(key, len, 0); 535 } 536 537 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 538 u32 hash) 539 { 540 return &htab->buckets[hash & (htab->buckets_num - 1)]; 541 } 542 543 static struct bpf_htab_elem * 544 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 545 u32 key_size) 546 { 547 struct bpf_htab_elem *elem; 548 549 hlist_for_each_entry_rcu(elem, head, node) { 550 if (elem->hash == hash && 551 !memcmp(&elem->key, key, key_size)) 552 return elem; 553 } 554 555 return NULL; 556 } 557 558 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 559 { 560 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 561 u32 key_size = map->key_size, hash; 562 struct bpf_htab_bucket *bucket; 563 struct bpf_htab_elem *elem; 564 565 WARN_ON_ONCE(!rcu_read_lock_held()); 566 567 hash = sock_hash_bucket_hash(key, key_size); 568 bucket = sock_hash_select_bucket(htab, hash); 569 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 570 571 return elem ? elem->sk : NULL; 572 } 573 574 static void sock_hash_free_elem(struct bpf_htab *htab, 575 struct bpf_htab_elem *elem) 576 { 577 atomic_dec(&htab->count); 578 kfree_rcu(elem, rcu); 579 } 580 581 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 582 void *link_raw) 583 { 584 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 585 struct bpf_htab_elem *elem_probe, *elem = link_raw; 586 struct bpf_htab_bucket *bucket; 587 588 WARN_ON_ONCE(!rcu_read_lock_held()); 589 bucket = sock_hash_select_bucket(htab, elem->hash); 590 591 /* elem may be deleted in parallel from the map, but access here 592 * is okay since it's going away only after RCU grace period. 593 * However, we need to check whether it's still present. 594 */ 595 raw_spin_lock_bh(&bucket->lock); 596 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 597 elem->key, map->key_size); 598 if (elem_probe && elem_probe == elem) { 599 hlist_del_rcu(&elem->node); 600 sock_map_unref(elem->sk, elem); 601 sock_hash_free_elem(htab, elem); 602 } 603 raw_spin_unlock_bh(&bucket->lock); 604 } 605 606 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 607 { 608 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 609 u32 hash, key_size = map->key_size; 610 struct bpf_htab_bucket *bucket; 611 struct bpf_htab_elem *elem; 612 int ret = -ENOENT; 613 614 hash = sock_hash_bucket_hash(key, key_size); 615 bucket = sock_hash_select_bucket(htab, hash); 616 617 raw_spin_lock_bh(&bucket->lock); 618 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 619 if (elem) { 620 hlist_del_rcu(&elem->node); 621 sock_map_unref(elem->sk, elem); 622 sock_hash_free_elem(htab, elem); 623 ret = 0; 624 } 625 raw_spin_unlock_bh(&bucket->lock); 626 return ret; 627 } 628 629 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 630 void *key, u32 key_size, 631 u32 hash, struct sock *sk, 632 struct bpf_htab_elem *old) 633 { 634 struct bpf_htab_elem *new; 635 636 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 637 if (!old) { 638 atomic_dec(&htab->count); 639 return ERR_PTR(-E2BIG); 640 } 641 } 642 643 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 644 htab->map.numa_node); 645 if (!new) { 646 atomic_dec(&htab->count); 647 return ERR_PTR(-ENOMEM); 648 } 649 memcpy(new->key, key, key_size); 650 new->sk = sk; 651 new->hash = hash; 652 return new; 653 } 654 655 static int sock_hash_update_common(struct bpf_map *map, void *key, 656 struct sock *sk, u64 flags) 657 { 658 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 659 struct inet_connection_sock *icsk = inet_csk(sk); 660 u32 key_size = map->key_size, hash; 661 struct bpf_htab_elem *elem, *elem_new; 662 struct bpf_htab_bucket *bucket; 663 struct sk_psock_link *link; 664 struct sk_psock *psock; 665 int ret; 666 667 WARN_ON_ONCE(!rcu_read_lock_held()); 668 if (unlikely(flags > BPF_EXIST)) 669 return -EINVAL; 670 if (unlikely(icsk->icsk_ulp_data)) 671 return -EINVAL; 672 673 link = sk_psock_init_link(); 674 if (!link) 675 return -ENOMEM; 676 677 ret = sock_map_link(map, &htab->progs, sk); 678 if (ret < 0) 679 goto out_free; 680 681 psock = sk_psock(sk); 682 WARN_ON_ONCE(!psock); 683 684 hash = sock_hash_bucket_hash(key, key_size); 685 bucket = sock_hash_select_bucket(htab, hash); 686 687 raw_spin_lock_bh(&bucket->lock); 688 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 689 if (elem && flags == BPF_NOEXIST) { 690 ret = -EEXIST; 691 goto out_unlock; 692 } else if (!elem && flags == BPF_EXIST) { 693 ret = -ENOENT; 694 goto out_unlock; 695 } 696 697 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 698 if (IS_ERR(elem_new)) { 699 ret = PTR_ERR(elem_new); 700 goto out_unlock; 701 } 702 703 sock_map_add_link(psock, link, map, elem_new); 704 /* Add new element to the head of the list, so that 705 * concurrent search will find it before old elem. 706 */ 707 hlist_add_head_rcu(&elem_new->node, &bucket->head); 708 if (elem) { 709 hlist_del_rcu(&elem->node); 710 sock_map_unref(elem->sk, elem); 711 sock_hash_free_elem(htab, elem); 712 } 713 raw_spin_unlock_bh(&bucket->lock); 714 return 0; 715 out_unlock: 716 raw_spin_unlock_bh(&bucket->lock); 717 sk_psock_put(sk, psock); 718 out_free: 719 sk_psock_free_link(link); 720 return ret; 721 } 722 723 static int sock_hash_update_elem(struct bpf_map *map, void *key, 724 void *value, u64 flags) 725 { 726 u32 ufd = *(u32 *)value; 727 struct socket *sock; 728 struct sock *sk; 729 int ret; 730 731 sock = sockfd_lookup(ufd, &ret); 732 if (!sock) 733 return ret; 734 sk = sock->sk; 735 if (!sk) { 736 ret = -EINVAL; 737 goto out; 738 } 739 if (!sock_map_sk_is_suitable(sk) || 740 sk->sk_state != TCP_ESTABLISHED) { 741 ret = -EOPNOTSUPP; 742 goto out; 743 } 744 745 sock_map_sk_acquire(sk); 746 ret = sock_hash_update_common(map, key, sk, flags); 747 sock_map_sk_release(sk); 748 out: 749 fput(sock->file); 750 return ret; 751 } 752 753 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 754 void *key_next) 755 { 756 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 757 struct bpf_htab_elem *elem, *elem_next; 758 u32 hash, key_size = map->key_size; 759 struct hlist_head *head; 760 int i = 0; 761 762 if (!key) 763 goto find_first_elem; 764 hash = sock_hash_bucket_hash(key, key_size); 765 head = &sock_hash_select_bucket(htab, hash)->head; 766 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 767 if (!elem) 768 goto find_first_elem; 769 770 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 771 struct bpf_htab_elem, node); 772 if (elem_next) { 773 memcpy(key_next, elem_next->key, key_size); 774 return 0; 775 } 776 777 i = hash & (htab->buckets_num - 1); 778 i++; 779 find_first_elem: 780 for (; i < htab->buckets_num; i++) { 781 head = &sock_hash_select_bucket(htab, i)->head; 782 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 783 struct bpf_htab_elem, node); 784 if (elem_next) { 785 memcpy(key_next, elem_next->key, key_size); 786 return 0; 787 } 788 } 789 790 return -ENOENT; 791 } 792 793 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 794 { 795 struct bpf_htab *htab; 796 int i, err; 797 u64 cost; 798 799 if (!capable(CAP_NET_ADMIN)) 800 return ERR_PTR(-EPERM); 801 if (attr->max_entries == 0 || 802 attr->key_size == 0 || 803 attr->value_size != 4 || 804 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 805 return ERR_PTR(-EINVAL); 806 if (attr->key_size > MAX_BPF_STACK) 807 return ERR_PTR(-E2BIG); 808 809 htab = kzalloc(sizeof(*htab), GFP_USER); 810 if (!htab) 811 return ERR_PTR(-ENOMEM); 812 813 bpf_map_init_from_attr(&htab->map, attr); 814 815 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 816 htab->elem_size = sizeof(struct bpf_htab_elem) + 817 round_up(htab->map.key_size, 8); 818 if (htab->buckets_num == 0 || 819 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 820 err = -EINVAL; 821 goto free_htab; 822 } 823 824 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 825 (u64) htab->elem_size * htab->map.max_entries; 826 if (cost >= U32_MAX - PAGE_SIZE) { 827 err = -EINVAL; 828 goto free_htab; 829 } 830 831 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 832 sizeof(struct bpf_htab_bucket), 833 htab->map.numa_node); 834 if (!htab->buckets) { 835 err = -ENOMEM; 836 goto free_htab; 837 } 838 839 for (i = 0; i < htab->buckets_num; i++) { 840 INIT_HLIST_HEAD(&htab->buckets[i].head); 841 raw_spin_lock_init(&htab->buckets[i].lock); 842 } 843 844 return &htab->map; 845 free_htab: 846 kfree(htab); 847 return ERR_PTR(err); 848 } 849 850 static void sock_hash_free(struct bpf_map *map) 851 { 852 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 853 struct bpf_htab_bucket *bucket; 854 struct bpf_htab_elem *elem; 855 struct hlist_node *node; 856 int i; 857 858 synchronize_rcu(); 859 rcu_read_lock(); 860 for (i = 0; i < htab->buckets_num; i++) { 861 bucket = sock_hash_select_bucket(htab, i); 862 raw_spin_lock_bh(&bucket->lock); 863 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 864 hlist_del_rcu(&elem->node); 865 sock_map_unref(elem->sk, elem); 866 } 867 raw_spin_unlock_bh(&bucket->lock); 868 } 869 rcu_read_unlock(); 870 871 bpf_map_area_free(htab->buckets); 872 kfree(htab); 873 } 874 875 static void sock_hash_release_progs(struct bpf_map *map) 876 { 877 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 878 } 879 880 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 881 struct bpf_map *, map, void *, key, u64, flags) 882 { 883 WARN_ON_ONCE(!rcu_read_lock_held()); 884 885 if (likely(sock_map_sk_is_suitable(sops->sk) && 886 sock_map_op_okay(sops))) 887 return sock_hash_update_common(map, key, sops->sk, flags); 888 return -EOPNOTSUPP; 889 } 890 891 const struct bpf_func_proto bpf_sock_hash_update_proto = { 892 .func = bpf_sock_hash_update, 893 .gpl_only = false, 894 .pkt_access = true, 895 .ret_type = RET_INTEGER, 896 .arg1_type = ARG_PTR_TO_CTX, 897 .arg2_type = ARG_CONST_MAP_PTR, 898 .arg3_type = ARG_PTR_TO_MAP_KEY, 899 .arg4_type = ARG_ANYTHING, 900 }; 901 902 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 903 struct bpf_map *, map, void *, key, u64, flags) 904 { 905 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 906 907 if (unlikely(flags & ~(BPF_F_INGRESS))) 908 return SK_DROP; 909 tcb->bpf.flags = flags; 910 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); 911 if (!tcb->bpf.sk_redir) 912 return SK_DROP; 913 return SK_PASS; 914 } 915 916 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 917 .func = bpf_sk_redirect_hash, 918 .gpl_only = false, 919 .ret_type = RET_INTEGER, 920 .arg1_type = ARG_PTR_TO_CTX, 921 .arg2_type = ARG_CONST_MAP_PTR, 922 .arg3_type = ARG_PTR_TO_MAP_KEY, 923 .arg4_type = ARG_ANYTHING, 924 }; 925 926 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 927 struct bpf_map *, map, void *, key, u64, flags) 928 { 929 if (unlikely(flags & ~(BPF_F_INGRESS))) 930 return SK_DROP; 931 msg->flags = flags; 932 msg->sk_redir = __sock_hash_lookup_elem(map, key); 933 if (!msg->sk_redir) 934 return SK_DROP; 935 return SK_PASS; 936 } 937 938 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 939 .func = bpf_msg_redirect_hash, 940 .gpl_only = false, 941 .ret_type = RET_INTEGER, 942 .arg1_type = ARG_PTR_TO_CTX, 943 .arg2_type = ARG_CONST_MAP_PTR, 944 .arg3_type = ARG_PTR_TO_MAP_KEY, 945 .arg4_type = ARG_ANYTHING, 946 }; 947 948 const struct bpf_map_ops sock_hash_ops = { 949 .map_alloc = sock_hash_alloc, 950 .map_free = sock_hash_free, 951 .map_get_next_key = sock_hash_get_next_key, 952 .map_update_elem = sock_hash_update_elem, 953 .map_delete_elem = sock_hash_delete_elem, 954 .map_lookup_elem = sock_map_lookup, 955 .map_release_uref = sock_hash_release_progs, 956 .map_check_btf = map_check_no_btf, 957 }; 958 959 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 960 { 961 switch (map->map_type) { 962 case BPF_MAP_TYPE_SOCKMAP: 963 return &container_of(map, struct bpf_stab, map)->progs; 964 case BPF_MAP_TYPE_SOCKHASH: 965 return &container_of(map, struct bpf_htab, map)->progs; 966 default: 967 break; 968 } 969 970 return NULL; 971 } 972 973 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 974 u32 which) 975 { 976 struct sk_psock_progs *progs = sock_map_progs(map); 977 978 if (!progs) 979 return -EOPNOTSUPP; 980 981 switch (which) { 982 case BPF_SK_MSG_VERDICT: 983 psock_set_prog(&progs->msg_parser, prog); 984 break; 985 case BPF_SK_SKB_STREAM_PARSER: 986 psock_set_prog(&progs->skb_parser, prog); 987 break; 988 case BPF_SK_SKB_STREAM_VERDICT: 989 psock_set_prog(&progs->skb_verdict, prog); 990 break; 991 default: 992 return -EOPNOTSUPP; 993 } 994 995 return 0; 996 } 997 998 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) 999 { 1000 switch (link->map->map_type) { 1001 case BPF_MAP_TYPE_SOCKMAP: 1002 return sock_map_delete_from_link(link->map, sk, 1003 link->link_raw); 1004 case BPF_MAP_TYPE_SOCKHASH: 1005 return sock_hash_delete_from_link(link->map, sk, 1006 link->link_raw); 1007 default: 1008 break; 1009 } 1010 } 1011