1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 14 struct bpf_stab { 15 struct bpf_map map; 16 struct sock **sks; 17 struct sk_psock_progs progs; 18 raw_spinlock_t lock; 19 }; 20 21 #define SOCK_CREATE_FLAG_MASK \ 22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 23 24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 25 { 26 struct bpf_stab *stab; 27 u64 cost; 28 int err; 29 30 if (!capable(CAP_NET_ADMIN)) 31 return ERR_PTR(-EPERM); 32 if (attr->max_entries == 0 || 33 attr->key_size != 4 || 34 attr->value_size != 4 || 35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 36 return ERR_PTR(-EINVAL); 37 38 stab = kzalloc(sizeof(*stab), GFP_USER); 39 if (!stab) 40 return ERR_PTR(-ENOMEM); 41 42 bpf_map_init_from_attr(&stab->map, attr); 43 raw_spin_lock_init(&stab->lock); 44 45 /* Make sure page count doesn't overflow. */ 46 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 47 if (cost >= U32_MAX - PAGE_SIZE) { 48 err = -EINVAL; 49 goto free_stab; 50 } 51 52 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT; 53 err = bpf_map_precharge_memlock(stab->map.pages); 54 if (err) 55 goto free_stab; 56 57 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 58 sizeof(struct sock *), 59 stab->map.numa_node); 60 if (stab->sks) 61 return &stab->map; 62 err = -ENOMEM; 63 free_stab: 64 kfree(stab); 65 return ERR_PTR(err); 66 } 67 68 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 69 { 70 u32 ufd = attr->target_fd; 71 struct bpf_map *map; 72 struct fd f; 73 int ret; 74 75 f = fdget(ufd); 76 map = __bpf_map_get(f); 77 if (IS_ERR(map)) 78 return PTR_ERR(map); 79 ret = sock_map_prog_update(map, prog, attr->attach_type); 80 fdput(f); 81 return ret; 82 } 83 84 static void sock_map_sk_acquire(struct sock *sk) 85 __acquires(&sk->sk_lock.slock) 86 { 87 lock_sock(sk); 88 preempt_disable(); 89 rcu_read_lock(); 90 } 91 92 static void sock_map_sk_release(struct sock *sk) 93 __releases(&sk->sk_lock.slock) 94 { 95 rcu_read_unlock(); 96 preempt_enable(); 97 release_sock(sk); 98 } 99 100 static void sock_map_add_link(struct sk_psock *psock, 101 struct sk_psock_link *link, 102 struct bpf_map *map, void *link_raw) 103 { 104 link->link_raw = link_raw; 105 link->map = map; 106 spin_lock_bh(&psock->link_lock); 107 list_add_tail(&link->list, &psock->link); 108 spin_unlock_bh(&psock->link_lock); 109 } 110 111 static void sock_map_del_link(struct sock *sk, 112 struct sk_psock *psock, void *link_raw) 113 { 114 struct sk_psock_link *link, *tmp; 115 bool strp_stop = false; 116 117 spin_lock_bh(&psock->link_lock); 118 list_for_each_entry_safe(link, tmp, &psock->link, list) { 119 if (link->link_raw == link_raw) { 120 struct bpf_map *map = link->map; 121 struct bpf_stab *stab = container_of(map, struct bpf_stab, 122 map); 123 if (psock->parser.enabled && stab->progs.skb_parser) 124 strp_stop = true; 125 list_del(&link->list); 126 sk_psock_free_link(link); 127 } 128 } 129 spin_unlock_bh(&psock->link_lock); 130 if (strp_stop) { 131 write_lock_bh(&sk->sk_callback_lock); 132 sk_psock_stop_strp(sk, psock); 133 write_unlock_bh(&sk->sk_callback_lock); 134 } 135 } 136 137 static void sock_map_unref(struct sock *sk, void *link_raw) 138 { 139 struct sk_psock *psock = sk_psock(sk); 140 141 if (likely(psock)) { 142 sock_map_del_link(sk, psock, link_raw); 143 sk_psock_put(sk, psock); 144 } 145 } 146 147 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 148 struct sock *sk) 149 { 150 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 151 bool skb_progs, sk_psock_is_new = false; 152 struct sk_psock *psock; 153 int ret; 154 155 skb_verdict = READ_ONCE(progs->skb_verdict); 156 skb_parser = READ_ONCE(progs->skb_parser); 157 skb_progs = skb_parser && skb_verdict; 158 if (skb_progs) { 159 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 160 if (IS_ERR(skb_verdict)) 161 return PTR_ERR(skb_verdict); 162 skb_parser = bpf_prog_inc_not_zero(skb_parser); 163 if (IS_ERR(skb_parser)) { 164 bpf_prog_put(skb_verdict); 165 return PTR_ERR(skb_parser); 166 } 167 } 168 169 msg_parser = READ_ONCE(progs->msg_parser); 170 if (msg_parser) { 171 msg_parser = bpf_prog_inc_not_zero(msg_parser); 172 if (IS_ERR(msg_parser)) { 173 ret = PTR_ERR(msg_parser); 174 goto out; 175 } 176 } 177 178 psock = sk_psock_get(sk); 179 if (psock) { 180 if (!sk_has_psock(sk)) { 181 ret = -EBUSY; 182 goto out_progs; 183 } 184 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 185 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 186 sk_psock_put(sk, psock); 187 ret = -EBUSY; 188 goto out_progs; 189 } 190 } else { 191 psock = sk_psock_init(sk, map->numa_node); 192 if (!psock) { 193 ret = -ENOMEM; 194 goto out_progs; 195 } 196 sk_psock_is_new = true; 197 } 198 199 if (msg_parser) 200 psock_set_prog(&psock->progs.msg_parser, msg_parser); 201 if (sk_psock_is_new) { 202 ret = tcp_bpf_init(sk); 203 if (ret < 0) 204 goto out_drop; 205 } else { 206 tcp_bpf_reinit(sk); 207 } 208 209 write_lock_bh(&sk->sk_callback_lock); 210 if (skb_progs && !psock->parser.enabled) { 211 ret = sk_psock_init_strp(sk, psock); 212 if (ret) { 213 write_unlock_bh(&sk->sk_callback_lock); 214 goto out_drop; 215 } 216 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 217 psock_set_prog(&psock->progs.skb_parser, skb_parser); 218 sk_psock_start_strp(sk, psock); 219 } 220 write_unlock_bh(&sk->sk_callback_lock); 221 return 0; 222 out_drop: 223 sk_psock_put(sk, psock); 224 out_progs: 225 if (msg_parser) 226 bpf_prog_put(msg_parser); 227 out: 228 if (skb_progs) { 229 bpf_prog_put(skb_verdict); 230 bpf_prog_put(skb_parser); 231 } 232 return ret; 233 } 234 235 static void sock_map_free(struct bpf_map *map) 236 { 237 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 238 int i; 239 240 synchronize_rcu(); 241 rcu_read_lock(); 242 raw_spin_lock_bh(&stab->lock); 243 for (i = 0; i < stab->map.max_entries; i++) { 244 struct sock **psk = &stab->sks[i]; 245 struct sock *sk; 246 247 sk = xchg(psk, NULL); 248 if (sk) 249 sock_map_unref(sk, psk); 250 } 251 raw_spin_unlock_bh(&stab->lock); 252 rcu_read_unlock(); 253 254 bpf_map_area_free(stab->sks); 255 kfree(stab); 256 } 257 258 static void sock_map_release_progs(struct bpf_map *map) 259 { 260 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 261 } 262 263 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 264 { 265 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 266 267 WARN_ON_ONCE(!rcu_read_lock_held()); 268 269 if (unlikely(key >= map->max_entries)) 270 return NULL; 271 return READ_ONCE(stab->sks[key]); 272 } 273 274 static void *sock_map_lookup(struct bpf_map *map, void *key) 275 { 276 return ERR_PTR(-EOPNOTSUPP); 277 } 278 279 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 280 struct sock **psk) 281 { 282 struct sock *sk; 283 284 raw_spin_lock_bh(&stab->lock); 285 sk = *psk; 286 if (!sk_test || sk_test == sk) 287 *psk = NULL; 288 raw_spin_unlock_bh(&stab->lock); 289 if (unlikely(!sk)) 290 return -EINVAL; 291 sock_map_unref(sk, psk); 292 return 0; 293 } 294 295 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 296 void *link_raw) 297 { 298 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 299 300 __sock_map_delete(stab, sk, link_raw); 301 } 302 303 static int sock_map_delete_elem(struct bpf_map *map, void *key) 304 { 305 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 306 u32 i = *(u32 *)key; 307 struct sock **psk; 308 309 if (unlikely(i >= map->max_entries)) 310 return -EINVAL; 311 312 psk = &stab->sks[i]; 313 return __sock_map_delete(stab, NULL, psk); 314 } 315 316 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 317 { 318 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 319 u32 i = key ? *(u32 *)key : U32_MAX; 320 u32 *key_next = next; 321 322 if (i == stab->map.max_entries - 1) 323 return -ENOENT; 324 if (i >= stab->map.max_entries) 325 *key_next = 0; 326 else 327 *key_next = i + 1; 328 return 0; 329 } 330 331 static int sock_map_update_common(struct bpf_map *map, u32 idx, 332 struct sock *sk, u64 flags) 333 { 334 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 335 struct sk_psock_link *link; 336 struct sk_psock *psock; 337 struct sock *osk; 338 int ret; 339 340 WARN_ON_ONCE(!rcu_read_lock_held()); 341 if (unlikely(flags > BPF_EXIST)) 342 return -EINVAL; 343 if (unlikely(idx >= map->max_entries)) 344 return -E2BIG; 345 346 link = sk_psock_init_link(); 347 if (!link) 348 return -ENOMEM; 349 350 ret = sock_map_link(map, &stab->progs, sk); 351 if (ret < 0) 352 goto out_free; 353 354 psock = sk_psock(sk); 355 WARN_ON_ONCE(!psock); 356 357 raw_spin_lock_bh(&stab->lock); 358 osk = stab->sks[idx]; 359 if (osk && flags == BPF_NOEXIST) { 360 ret = -EEXIST; 361 goto out_unlock; 362 } else if (!osk && flags == BPF_EXIST) { 363 ret = -ENOENT; 364 goto out_unlock; 365 } 366 367 sock_map_add_link(psock, link, map, &stab->sks[idx]); 368 stab->sks[idx] = sk; 369 if (osk) 370 sock_map_unref(osk, &stab->sks[idx]); 371 raw_spin_unlock_bh(&stab->lock); 372 return 0; 373 out_unlock: 374 raw_spin_unlock_bh(&stab->lock); 375 if (psock) 376 sk_psock_put(sk, psock); 377 out_free: 378 sk_psock_free_link(link); 379 return ret; 380 } 381 382 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 383 { 384 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 385 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; 386 } 387 388 static bool sock_map_sk_is_suitable(const struct sock *sk) 389 { 390 return sk->sk_type == SOCK_STREAM && 391 sk->sk_protocol == IPPROTO_TCP; 392 } 393 394 static int sock_map_update_elem(struct bpf_map *map, void *key, 395 void *value, u64 flags) 396 { 397 u32 ufd = *(u32 *)value; 398 u32 idx = *(u32 *)key; 399 struct socket *sock; 400 struct sock *sk; 401 int ret; 402 403 sock = sockfd_lookup(ufd, &ret); 404 if (!sock) 405 return ret; 406 sk = sock->sk; 407 if (!sk) { 408 ret = -EINVAL; 409 goto out; 410 } 411 if (!sock_map_sk_is_suitable(sk) || 412 sk->sk_state != TCP_ESTABLISHED) { 413 ret = -EOPNOTSUPP; 414 goto out; 415 } 416 417 sock_map_sk_acquire(sk); 418 ret = sock_map_update_common(map, idx, sk, flags); 419 sock_map_sk_release(sk); 420 out: 421 fput(sock->file); 422 return ret; 423 } 424 425 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 426 struct bpf_map *, map, void *, key, u64, flags) 427 { 428 WARN_ON_ONCE(!rcu_read_lock_held()); 429 430 if (likely(sock_map_sk_is_suitable(sops->sk) && 431 sock_map_op_okay(sops))) 432 return sock_map_update_common(map, *(u32 *)key, sops->sk, 433 flags); 434 return -EOPNOTSUPP; 435 } 436 437 const struct bpf_func_proto bpf_sock_map_update_proto = { 438 .func = bpf_sock_map_update, 439 .gpl_only = false, 440 .pkt_access = true, 441 .ret_type = RET_INTEGER, 442 .arg1_type = ARG_PTR_TO_CTX, 443 .arg2_type = ARG_CONST_MAP_PTR, 444 .arg3_type = ARG_PTR_TO_MAP_KEY, 445 .arg4_type = ARG_ANYTHING, 446 }; 447 448 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 449 struct bpf_map *, map, u32, key, u64, flags) 450 { 451 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 452 453 if (unlikely(flags & ~(BPF_F_INGRESS))) 454 return SK_DROP; 455 tcb->bpf.flags = flags; 456 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); 457 if (!tcb->bpf.sk_redir) 458 return SK_DROP; 459 return SK_PASS; 460 } 461 462 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 463 .func = bpf_sk_redirect_map, 464 .gpl_only = false, 465 .ret_type = RET_INTEGER, 466 .arg1_type = ARG_PTR_TO_CTX, 467 .arg2_type = ARG_CONST_MAP_PTR, 468 .arg3_type = ARG_ANYTHING, 469 .arg4_type = ARG_ANYTHING, 470 }; 471 472 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 473 struct bpf_map *, map, u32, key, u64, flags) 474 { 475 if (unlikely(flags & ~(BPF_F_INGRESS))) 476 return SK_DROP; 477 msg->flags = flags; 478 msg->sk_redir = __sock_map_lookup_elem(map, key); 479 if (!msg->sk_redir) 480 return SK_DROP; 481 return SK_PASS; 482 } 483 484 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 485 .func = bpf_msg_redirect_map, 486 .gpl_only = false, 487 .ret_type = RET_INTEGER, 488 .arg1_type = ARG_PTR_TO_CTX, 489 .arg2_type = ARG_CONST_MAP_PTR, 490 .arg3_type = ARG_ANYTHING, 491 .arg4_type = ARG_ANYTHING, 492 }; 493 494 const struct bpf_map_ops sock_map_ops = { 495 .map_alloc = sock_map_alloc, 496 .map_free = sock_map_free, 497 .map_get_next_key = sock_map_get_next_key, 498 .map_update_elem = sock_map_update_elem, 499 .map_delete_elem = sock_map_delete_elem, 500 .map_lookup_elem = sock_map_lookup, 501 .map_release_uref = sock_map_release_progs, 502 .map_check_btf = map_check_no_btf, 503 }; 504 505 struct bpf_htab_elem { 506 struct rcu_head rcu; 507 u32 hash; 508 struct sock *sk; 509 struct hlist_node node; 510 u8 key[0]; 511 }; 512 513 struct bpf_htab_bucket { 514 struct hlist_head head; 515 raw_spinlock_t lock; 516 }; 517 518 struct bpf_htab { 519 struct bpf_map map; 520 struct bpf_htab_bucket *buckets; 521 u32 buckets_num; 522 u32 elem_size; 523 struct sk_psock_progs progs; 524 atomic_t count; 525 }; 526 527 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 528 { 529 return jhash(key, len, 0); 530 } 531 532 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 533 u32 hash) 534 { 535 return &htab->buckets[hash & (htab->buckets_num - 1)]; 536 } 537 538 static struct bpf_htab_elem * 539 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 540 u32 key_size) 541 { 542 struct bpf_htab_elem *elem; 543 544 hlist_for_each_entry_rcu(elem, head, node) { 545 if (elem->hash == hash && 546 !memcmp(&elem->key, key, key_size)) 547 return elem; 548 } 549 550 return NULL; 551 } 552 553 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 554 { 555 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 556 u32 key_size = map->key_size, hash; 557 struct bpf_htab_bucket *bucket; 558 struct bpf_htab_elem *elem; 559 560 WARN_ON_ONCE(!rcu_read_lock_held()); 561 562 hash = sock_hash_bucket_hash(key, key_size); 563 bucket = sock_hash_select_bucket(htab, hash); 564 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 565 566 return elem ? elem->sk : NULL; 567 } 568 569 static void sock_hash_free_elem(struct bpf_htab *htab, 570 struct bpf_htab_elem *elem) 571 { 572 atomic_dec(&htab->count); 573 kfree_rcu(elem, rcu); 574 } 575 576 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 577 void *link_raw) 578 { 579 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 580 struct bpf_htab_elem *elem_probe, *elem = link_raw; 581 struct bpf_htab_bucket *bucket; 582 583 WARN_ON_ONCE(!rcu_read_lock_held()); 584 bucket = sock_hash_select_bucket(htab, elem->hash); 585 586 /* elem may be deleted in parallel from the map, but access here 587 * is okay since it's going away only after RCU grace period. 588 * However, we need to check whether it's still present. 589 */ 590 raw_spin_lock_bh(&bucket->lock); 591 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 592 elem->key, map->key_size); 593 if (elem_probe && elem_probe == elem) { 594 hlist_del_rcu(&elem->node); 595 sock_map_unref(elem->sk, elem); 596 sock_hash_free_elem(htab, elem); 597 } 598 raw_spin_unlock_bh(&bucket->lock); 599 } 600 601 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 602 { 603 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 604 u32 hash, key_size = map->key_size; 605 struct bpf_htab_bucket *bucket; 606 struct bpf_htab_elem *elem; 607 int ret = -ENOENT; 608 609 hash = sock_hash_bucket_hash(key, key_size); 610 bucket = sock_hash_select_bucket(htab, hash); 611 612 raw_spin_lock_bh(&bucket->lock); 613 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 614 if (elem) { 615 hlist_del_rcu(&elem->node); 616 sock_map_unref(elem->sk, elem); 617 sock_hash_free_elem(htab, elem); 618 ret = 0; 619 } 620 raw_spin_unlock_bh(&bucket->lock); 621 return ret; 622 } 623 624 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 625 void *key, u32 key_size, 626 u32 hash, struct sock *sk, 627 struct bpf_htab_elem *old) 628 { 629 struct bpf_htab_elem *new; 630 631 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 632 if (!old) { 633 atomic_dec(&htab->count); 634 return ERR_PTR(-E2BIG); 635 } 636 } 637 638 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 639 htab->map.numa_node); 640 if (!new) { 641 atomic_dec(&htab->count); 642 return ERR_PTR(-ENOMEM); 643 } 644 memcpy(new->key, key, key_size); 645 new->sk = sk; 646 new->hash = hash; 647 return new; 648 } 649 650 static int sock_hash_update_common(struct bpf_map *map, void *key, 651 struct sock *sk, u64 flags) 652 { 653 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 654 u32 key_size = map->key_size, hash; 655 struct bpf_htab_elem *elem, *elem_new; 656 struct bpf_htab_bucket *bucket; 657 struct sk_psock_link *link; 658 struct sk_psock *psock; 659 int ret; 660 661 WARN_ON_ONCE(!rcu_read_lock_held()); 662 if (unlikely(flags > BPF_EXIST)) 663 return -EINVAL; 664 665 link = sk_psock_init_link(); 666 if (!link) 667 return -ENOMEM; 668 669 ret = sock_map_link(map, &htab->progs, sk); 670 if (ret < 0) 671 goto out_free; 672 673 psock = sk_psock(sk); 674 WARN_ON_ONCE(!psock); 675 676 hash = sock_hash_bucket_hash(key, key_size); 677 bucket = sock_hash_select_bucket(htab, hash); 678 679 raw_spin_lock_bh(&bucket->lock); 680 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 681 if (elem && flags == BPF_NOEXIST) { 682 ret = -EEXIST; 683 goto out_unlock; 684 } else if (!elem && flags == BPF_EXIST) { 685 ret = -ENOENT; 686 goto out_unlock; 687 } 688 689 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 690 if (IS_ERR(elem_new)) { 691 ret = PTR_ERR(elem_new); 692 goto out_unlock; 693 } 694 695 sock_map_add_link(psock, link, map, elem_new); 696 /* Add new element to the head of the list, so that 697 * concurrent search will find it before old elem. 698 */ 699 hlist_add_head_rcu(&elem_new->node, &bucket->head); 700 if (elem) { 701 hlist_del_rcu(&elem->node); 702 sock_map_unref(elem->sk, elem); 703 sock_hash_free_elem(htab, elem); 704 } 705 raw_spin_unlock_bh(&bucket->lock); 706 return 0; 707 out_unlock: 708 raw_spin_unlock_bh(&bucket->lock); 709 sk_psock_put(sk, psock); 710 out_free: 711 sk_psock_free_link(link); 712 return ret; 713 } 714 715 static int sock_hash_update_elem(struct bpf_map *map, void *key, 716 void *value, u64 flags) 717 { 718 u32 ufd = *(u32 *)value; 719 struct socket *sock; 720 struct sock *sk; 721 int ret; 722 723 sock = sockfd_lookup(ufd, &ret); 724 if (!sock) 725 return ret; 726 sk = sock->sk; 727 if (!sk) { 728 ret = -EINVAL; 729 goto out; 730 } 731 if (!sock_map_sk_is_suitable(sk) || 732 sk->sk_state != TCP_ESTABLISHED) { 733 ret = -EOPNOTSUPP; 734 goto out; 735 } 736 737 sock_map_sk_acquire(sk); 738 ret = sock_hash_update_common(map, key, sk, flags); 739 sock_map_sk_release(sk); 740 out: 741 fput(sock->file); 742 return ret; 743 } 744 745 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 746 void *key_next) 747 { 748 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 749 struct bpf_htab_elem *elem, *elem_next; 750 u32 hash, key_size = map->key_size; 751 struct hlist_head *head; 752 int i = 0; 753 754 if (!key) 755 goto find_first_elem; 756 hash = sock_hash_bucket_hash(key, key_size); 757 head = &sock_hash_select_bucket(htab, hash)->head; 758 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 759 if (!elem) 760 goto find_first_elem; 761 762 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 763 struct bpf_htab_elem, node); 764 if (elem_next) { 765 memcpy(key_next, elem_next->key, key_size); 766 return 0; 767 } 768 769 i = hash & (htab->buckets_num - 1); 770 i++; 771 find_first_elem: 772 for (; i < htab->buckets_num; i++) { 773 head = &sock_hash_select_bucket(htab, i)->head; 774 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 775 struct bpf_htab_elem, node); 776 if (elem_next) { 777 memcpy(key_next, elem_next->key, key_size); 778 return 0; 779 } 780 } 781 782 return -ENOENT; 783 } 784 785 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 786 { 787 struct bpf_htab *htab; 788 int i, err; 789 u64 cost; 790 791 if (!capable(CAP_NET_ADMIN)) 792 return ERR_PTR(-EPERM); 793 if (attr->max_entries == 0 || 794 attr->key_size == 0 || 795 attr->value_size != 4 || 796 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 797 return ERR_PTR(-EINVAL); 798 if (attr->key_size > MAX_BPF_STACK) 799 return ERR_PTR(-E2BIG); 800 801 htab = kzalloc(sizeof(*htab), GFP_USER); 802 if (!htab) 803 return ERR_PTR(-ENOMEM); 804 805 bpf_map_init_from_attr(&htab->map, attr); 806 807 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 808 htab->elem_size = sizeof(struct bpf_htab_elem) + 809 round_up(htab->map.key_size, 8); 810 if (htab->buckets_num == 0 || 811 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 812 err = -EINVAL; 813 goto free_htab; 814 } 815 816 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 817 (u64) htab->elem_size * htab->map.max_entries; 818 if (cost >= U32_MAX - PAGE_SIZE) { 819 err = -EINVAL; 820 goto free_htab; 821 } 822 823 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 824 sizeof(struct bpf_htab_bucket), 825 htab->map.numa_node); 826 if (!htab->buckets) { 827 err = -ENOMEM; 828 goto free_htab; 829 } 830 831 for (i = 0; i < htab->buckets_num; i++) { 832 INIT_HLIST_HEAD(&htab->buckets[i].head); 833 raw_spin_lock_init(&htab->buckets[i].lock); 834 } 835 836 return &htab->map; 837 free_htab: 838 kfree(htab); 839 return ERR_PTR(err); 840 } 841 842 static void sock_hash_free(struct bpf_map *map) 843 { 844 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 845 struct bpf_htab_bucket *bucket; 846 struct bpf_htab_elem *elem; 847 struct hlist_node *node; 848 int i; 849 850 synchronize_rcu(); 851 rcu_read_lock(); 852 for (i = 0; i < htab->buckets_num; i++) { 853 bucket = sock_hash_select_bucket(htab, i); 854 raw_spin_lock_bh(&bucket->lock); 855 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 856 hlist_del_rcu(&elem->node); 857 sock_map_unref(elem->sk, elem); 858 } 859 raw_spin_unlock_bh(&bucket->lock); 860 } 861 rcu_read_unlock(); 862 863 bpf_map_area_free(htab->buckets); 864 kfree(htab); 865 } 866 867 static void sock_hash_release_progs(struct bpf_map *map) 868 { 869 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 870 } 871 872 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 873 struct bpf_map *, map, void *, key, u64, flags) 874 { 875 WARN_ON_ONCE(!rcu_read_lock_held()); 876 877 if (likely(sock_map_sk_is_suitable(sops->sk) && 878 sock_map_op_okay(sops))) 879 return sock_hash_update_common(map, key, sops->sk, flags); 880 return -EOPNOTSUPP; 881 } 882 883 const struct bpf_func_proto bpf_sock_hash_update_proto = { 884 .func = bpf_sock_hash_update, 885 .gpl_only = false, 886 .pkt_access = true, 887 .ret_type = RET_INTEGER, 888 .arg1_type = ARG_PTR_TO_CTX, 889 .arg2_type = ARG_CONST_MAP_PTR, 890 .arg3_type = ARG_PTR_TO_MAP_KEY, 891 .arg4_type = ARG_ANYTHING, 892 }; 893 894 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 895 struct bpf_map *, map, void *, key, u64, flags) 896 { 897 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 898 899 if (unlikely(flags & ~(BPF_F_INGRESS))) 900 return SK_DROP; 901 tcb->bpf.flags = flags; 902 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); 903 if (!tcb->bpf.sk_redir) 904 return SK_DROP; 905 return SK_PASS; 906 } 907 908 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 909 .func = bpf_sk_redirect_hash, 910 .gpl_only = false, 911 .ret_type = RET_INTEGER, 912 .arg1_type = ARG_PTR_TO_CTX, 913 .arg2_type = ARG_CONST_MAP_PTR, 914 .arg3_type = ARG_PTR_TO_MAP_KEY, 915 .arg4_type = ARG_ANYTHING, 916 }; 917 918 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 919 struct bpf_map *, map, void *, key, u64, flags) 920 { 921 if (unlikely(flags & ~(BPF_F_INGRESS))) 922 return SK_DROP; 923 msg->flags = flags; 924 msg->sk_redir = __sock_hash_lookup_elem(map, key); 925 if (!msg->sk_redir) 926 return SK_DROP; 927 return SK_PASS; 928 } 929 930 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 931 .func = bpf_msg_redirect_hash, 932 .gpl_only = false, 933 .ret_type = RET_INTEGER, 934 .arg1_type = ARG_PTR_TO_CTX, 935 .arg2_type = ARG_CONST_MAP_PTR, 936 .arg3_type = ARG_PTR_TO_MAP_KEY, 937 .arg4_type = ARG_ANYTHING, 938 }; 939 940 const struct bpf_map_ops sock_hash_ops = { 941 .map_alloc = sock_hash_alloc, 942 .map_free = sock_hash_free, 943 .map_get_next_key = sock_hash_get_next_key, 944 .map_update_elem = sock_hash_update_elem, 945 .map_delete_elem = sock_hash_delete_elem, 946 .map_lookup_elem = sock_map_lookup, 947 .map_release_uref = sock_hash_release_progs, 948 .map_check_btf = map_check_no_btf, 949 }; 950 951 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 952 { 953 switch (map->map_type) { 954 case BPF_MAP_TYPE_SOCKMAP: 955 return &container_of(map, struct bpf_stab, map)->progs; 956 case BPF_MAP_TYPE_SOCKHASH: 957 return &container_of(map, struct bpf_htab, map)->progs; 958 default: 959 break; 960 } 961 962 return NULL; 963 } 964 965 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 966 u32 which) 967 { 968 struct sk_psock_progs *progs = sock_map_progs(map); 969 970 if (!progs) 971 return -EOPNOTSUPP; 972 973 switch (which) { 974 case BPF_SK_MSG_VERDICT: 975 psock_set_prog(&progs->msg_parser, prog); 976 break; 977 case BPF_SK_SKB_STREAM_PARSER: 978 psock_set_prog(&progs->skb_parser, prog); 979 break; 980 case BPF_SK_SKB_STREAM_VERDICT: 981 psock_set_prog(&progs->skb_verdict, prog); 982 break; 983 default: 984 return -EOPNOTSUPP; 985 } 986 987 return 0; 988 } 989 990 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) 991 { 992 switch (link->map->map_type) { 993 case BPF_MAP_TYPE_SOCKMAP: 994 return sock_map_delete_from_link(link->map, sk, 995 link->link_raw); 996 case BPF_MAP_TYPE_SOCKHASH: 997 return sock_hash_delete_from_link(link->map, sk, 998 link->link_raw); 999 default: 1000 break; 1001 } 1002 } 1003