1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 14 struct bpf_stab { 15 struct bpf_map map; 16 struct sock **sks; 17 struct sk_psock_progs progs; 18 raw_spinlock_t lock; 19 }; 20 21 #define SOCK_CREATE_FLAG_MASK \ 22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 23 24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 25 { 26 struct bpf_stab *stab; 27 u64 cost; 28 int err; 29 30 if (!capable(CAP_NET_ADMIN)) 31 return ERR_PTR(-EPERM); 32 if (attr->max_entries == 0 || 33 attr->key_size != 4 || 34 attr->value_size != 4 || 35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 36 return ERR_PTR(-EINVAL); 37 38 stab = kzalloc(sizeof(*stab), GFP_USER); 39 if (!stab) 40 return ERR_PTR(-ENOMEM); 41 42 bpf_map_init_from_attr(&stab->map, attr); 43 raw_spin_lock_init(&stab->lock); 44 45 /* Make sure page count doesn't overflow. */ 46 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 47 err = bpf_map_charge_init(&stab->map.memory, cost); 48 if (err) 49 goto free_stab; 50 51 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (stab->sks) 55 return &stab->map; 56 err = -ENOMEM; 57 bpf_map_charge_finish(&stab->map.memory); 58 free_stab: 59 kfree(stab); 60 return ERR_PTR(err); 61 } 62 63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 64 { 65 u32 ufd = attr->target_fd; 66 struct bpf_map *map; 67 struct fd f; 68 int ret; 69 70 f = fdget(ufd); 71 map = __bpf_map_get(f); 72 if (IS_ERR(map)) 73 return PTR_ERR(map); 74 ret = sock_map_prog_update(map, prog, attr->attach_type); 75 fdput(f); 76 return ret; 77 } 78 79 static void sock_map_sk_acquire(struct sock *sk) 80 __acquires(&sk->sk_lock.slock) 81 { 82 lock_sock(sk); 83 preempt_disable(); 84 rcu_read_lock(); 85 } 86 87 static void sock_map_sk_release(struct sock *sk) 88 __releases(&sk->sk_lock.slock) 89 { 90 rcu_read_unlock(); 91 preempt_enable(); 92 release_sock(sk); 93 } 94 95 static void sock_map_add_link(struct sk_psock *psock, 96 struct sk_psock_link *link, 97 struct bpf_map *map, void *link_raw) 98 { 99 link->link_raw = link_raw; 100 link->map = map; 101 spin_lock_bh(&psock->link_lock); 102 list_add_tail(&link->list, &psock->link); 103 spin_unlock_bh(&psock->link_lock); 104 } 105 106 static void sock_map_del_link(struct sock *sk, 107 struct sk_psock *psock, void *link_raw) 108 { 109 struct sk_psock_link *link, *tmp; 110 bool strp_stop = false; 111 112 spin_lock_bh(&psock->link_lock); 113 list_for_each_entry_safe(link, tmp, &psock->link, list) { 114 if (link->link_raw == link_raw) { 115 struct bpf_map *map = link->map; 116 struct bpf_stab *stab = container_of(map, struct bpf_stab, 117 map); 118 if (psock->parser.enabled && stab->progs.skb_parser) 119 strp_stop = true; 120 list_del(&link->list); 121 sk_psock_free_link(link); 122 } 123 } 124 spin_unlock_bh(&psock->link_lock); 125 if (strp_stop) { 126 write_lock_bh(&sk->sk_callback_lock); 127 sk_psock_stop_strp(sk, psock); 128 write_unlock_bh(&sk->sk_callback_lock); 129 } 130 } 131 132 static void sock_map_unref(struct sock *sk, void *link_raw) 133 { 134 struct sk_psock *psock = sk_psock(sk); 135 136 if (likely(psock)) { 137 sock_map_del_link(sk, psock, link_raw); 138 sk_psock_put(sk, psock); 139 } 140 } 141 142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 143 struct sock *sk) 144 { 145 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 146 bool skb_progs, sk_psock_is_new = false; 147 struct sk_psock *psock; 148 int ret; 149 150 skb_verdict = READ_ONCE(progs->skb_verdict); 151 skb_parser = READ_ONCE(progs->skb_parser); 152 skb_progs = skb_parser && skb_verdict; 153 if (skb_progs) { 154 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 155 if (IS_ERR(skb_verdict)) 156 return PTR_ERR(skb_verdict); 157 skb_parser = bpf_prog_inc_not_zero(skb_parser); 158 if (IS_ERR(skb_parser)) { 159 bpf_prog_put(skb_verdict); 160 return PTR_ERR(skb_parser); 161 } 162 } 163 164 msg_parser = READ_ONCE(progs->msg_parser); 165 if (msg_parser) { 166 msg_parser = bpf_prog_inc_not_zero(msg_parser); 167 if (IS_ERR(msg_parser)) { 168 ret = PTR_ERR(msg_parser); 169 goto out; 170 } 171 } 172 173 psock = sk_psock_get_checked(sk); 174 if (IS_ERR(psock)) { 175 ret = PTR_ERR(psock); 176 goto out_progs; 177 } 178 179 if (psock) { 180 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 181 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 182 sk_psock_put(sk, psock); 183 ret = -EBUSY; 184 goto out_progs; 185 } 186 } else { 187 psock = sk_psock_init(sk, map->numa_node); 188 if (!psock) { 189 ret = -ENOMEM; 190 goto out_progs; 191 } 192 sk_psock_is_new = true; 193 } 194 195 if (msg_parser) 196 psock_set_prog(&psock->progs.msg_parser, msg_parser); 197 if (sk_psock_is_new) { 198 ret = tcp_bpf_init(sk); 199 if (ret < 0) 200 goto out_drop; 201 } else { 202 tcp_bpf_reinit(sk); 203 } 204 205 write_lock_bh(&sk->sk_callback_lock); 206 if (skb_progs && !psock->parser.enabled) { 207 ret = sk_psock_init_strp(sk, psock); 208 if (ret) { 209 write_unlock_bh(&sk->sk_callback_lock); 210 goto out_drop; 211 } 212 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 213 psock_set_prog(&psock->progs.skb_parser, skb_parser); 214 sk_psock_start_strp(sk, psock); 215 } 216 write_unlock_bh(&sk->sk_callback_lock); 217 return 0; 218 out_drop: 219 sk_psock_put(sk, psock); 220 out_progs: 221 if (msg_parser) 222 bpf_prog_put(msg_parser); 223 out: 224 if (skb_progs) { 225 bpf_prog_put(skb_verdict); 226 bpf_prog_put(skb_parser); 227 } 228 return ret; 229 } 230 231 static void sock_map_free(struct bpf_map *map) 232 { 233 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 234 int i; 235 236 synchronize_rcu(); 237 rcu_read_lock(); 238 raw_spin_lock_bh(&stab->lock); 239 for (i = 0; i < stab->map.max_entries; i++) { 240 struct sock **psk = &stab->sks[i]; 241 struct sock *sk; 242 243 sk = xchg(psk, NULL); 244 if (sk) 245 sock_map_unref(sk, psk); 246 } 247 raw_spin_unlock_bh(&stab->lock); 248 rcu_read_unlock(); 249 250 bpf_map_area_free(stab->sks); 251 kfree(stab); 252 } 253 254 static void sock_map_release_progs(struct bpf_map *map) 255 { 256 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 257 } 258 259 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 260 { 261 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 262 263 WARN_ON_ONCE(!rcu_read_lock_held()); 264 265 if (unlikely(key >= map->max_entries)) 266 return NULL; 267 return READ_ONCE(stab->sks[key]); 268 } 269 270 static void *sock_map_lookup(struct bpf_map *map, void *key) 271 { 272 return ERR_PTR(-EOPNOTSUPP); 273 } 274 275 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 276 struct sock **psk) 277 { 278 struct sock *sk; 279 280 raw_spin_lock_bh(&stab->lock); 281 sk = *psk; 282 if (!sk_test || sk_test == sk) 283 *psk = NULL; 284 raw_spin_unlock_bh(&stab->lock); 285 if (unlikely(!sk)) 286 return -EINVAL; 287 sock_map_unref(sk, psk); 288 return 0; 289 } 290 291 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 292 void *link_raw) 293 { 294 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 295 296 __sock_map_delete(stab, sk, link_raw); 297 } 298 299 static int sock_map_delete_elem(struct bpf_map *map, void *key) 300 { 301 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 302 u32 i = *(u32 *)key; 303 struct sock **psk; 304 305 if (unlikely(i >= map->max_entries)) 306 return -EINVAL; 307 308 psk = &stab->sks[i]; 309 return __sock_map_delete(stab, NULL, psk); 310 } 311 312 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 313 { 314 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 315 u32 i = key ? *(u32 *)key : U32_MAX; 316 u32 *key_next = next; 317 318 if (i == stab->map.max_entries - 1) 319 return -ENOENT; 320 if (i >= stab->map.max_entries) 321 *key_next = 0; 322 else 323 *key_next = i + 1; 324 return 0; 325 } 326 327 static int sock_map_update_common(struct bpf_map *map, u32 idx, 328 struct sock *sk, u64 flags) 329 { 330 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 331 struct sk_psock_link *link; 332 struct sk_psock *psock; 333 struct sock *osk; 334 int ret; 335 336 WARN_ON_ONCE(!rcu_read_lock_held()); 337 if (unlikely(flags > BPF_EXIST)) 338 return -EINVAL; 339 if (unlikely(idx >= map->max_entries)) 340 return -E2BIG; 341 342 link = sk_psock_init_link(); 343 if (!link) 344 return -ENOMEM; 345 346 ret = sock_map_link(map, &stab->progs, sk); 347 if (ret < 0) 348 goto out_free; 349 350 psock = sk_psock(sk); 351 WARN_ON_ONCE(!psock); 352 353 raw_spin_lock_bh(&stab->lock); 354 osk = stab->sks[idx]; 355 if (osk && flags == BPF_NOEXIST) { 356 ret = -EEXIST; 357 goto out_unlock; 358 } else if (!osk && flags == BPF_EXIST) { 359 ret = -ENOENT; 360 goto out_unlock; 361 } 362 363 sock_map_add_link(psock, link, map, &stab->sks[idx]); 364 stab->sks[idx] = sk; 365 if (osk) 366 sock_map_unref(osk, &stab->sks[idx]); 367 raw_spin_unlock_bh(&stab->lock); 368 return 0; 369 out_unlock: 370 raw_spin_unlock_bh(&stab->lock); 371 if (psock) 372 sk_psock_put(sk, psock); 373 out_free: 374 sk_psock_free_link(link); 375 return ret; 376 } 377 378 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 379 { 380 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 381 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; 382 } 383 384 static bool sock_map_sk_is_suitable(const struct sock *sk) 385 { 386 return sk->sk_type == SOCK_STREAM && 387 sk->sk_protocol == IPPROTO_TCP; 388 } 389 390 static int sock_map_update_elem(struct bpf_map *map, void *key, 391 void *value, u64 flags) 392 { 393 u32 ufd = *(u32 *)value; 394 u32 idx = *(u32 *)key; 395 struct socket *sock; 396 struct sock *sk; 397 int ret; 398 399 sock = sockfd_lookup(ufd, &ret); 400 if (!sock) 401 return ret; 402 sk = sock->sk; 403 if (!sk) { 404 ret = -EINVAL; 405 goto out; 406 } 407 if (!sock_map_sk_is_suitable(sk) || 408 sk->sk_state != TCP_ESTABLISHED) { 409 ret = -EOPNOTSUPP; 410 goto out; 411 } 412 413 sock_map_sk_acquire(sk); 414 ret = sock_map_update_common(map, idx, sk, flags); 415 sock_map_sk_release(sk); 416 out: 417 fput(sock->file); 418 return ret; 419 } 420 421 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 422 struct bpf_map *, map, void *, key, u64, flags) 423 { 424 WARN_ON_ONCE(!rcu_read_lock_held()); 425 426 if (likely(sock_map_sk_is_suitable(sops->sk) && 427 sock_map_op_okay(sops))) 428 return sock_map_update_common(map, *(u32 *)key, sops->sk, 429 flags); 430 return -EOPNOTSUPP; 431 } 432 433 const struct bpf_func_proto bpf_sock_map_update_proto = { 434 .func = bpf_sock_map_update, 435 .gpl_only = false, 436 .pkt_access = true, 437 .ret_type = RET_INTEGER, 438 .arg1_type = ARG_PTR_TO_CTX, 439 .arg2_type = ARG_CONST_MAP_PTR, 440 .arg3_type = ARG_PTR_TO_MAP_KEY, 441 .arg4_type = ARG_ANYTHING, 442 }; 443 444 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 445 struct bpf_map *, map, u32, key, u64, flags) 446 { 447 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 448 449 if (unlikely(flags & ~(BPF_F_INGRESS))) 450 return SK_DROP; 451 tcb->bpf.flags = flags; 452 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); 453 if (!tcb->bpf.sk_redir) 454 return SK_DROP; 455 return SK_PASS; 456 } 457 458 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 459 .func = bpf_sk_redirect_map, 460 .gpl_only = false, 461 .ret_type = RET_INTEGER, 462 .arg1_type = ARG_PTR_TO_CTX, 463 .arg2_type = ARG_CONST_MAP_PTR, 464 .arg3_type = ARG_ANYTHING, 465 .arg4_type = ARG_ANYTHING, 466 }; 467 468 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 469 struct bpf_map *, map, u32, key, u64, flags) 470 { 471 if (unlikely(flags & ~(BPF_F_INGRESS))) 472 return SK_DROP; 473 msg->flags = flags; 474 msg->sk_redir = __sock_map_lookup_elem(map, key); 475 if (!msg->sk_redir) 476 return SK_DROP; 477 return SK_PASS; 478 } 479 480 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 481 .func = bpf_msg_redirect_map, 482 .gpl_only = false, 483 .ret_type = RET_INTEGER, 484 .arg1_type = ARG_PTR_TO_CTX, 485 .arg2_type = ARG_CONST_MAP_PTR, 486 .arg3_type = ARG_ANYTHING, 487 .arg4_type = ARG_ANYTHING, 488 }; 489 490 const struct bpf_map_ops sock_map_ops = { 491 .map_alloc = sock_map_alloc, 492 .map_free = sock_map_free, 493 .map_get_next_key = sock_map_get_next_key, 494 .map_update_elem = sock_map_update_elem, 495 .map_delete_elem = sock_map_delete_elem, 496 .map_lookup_elem = sock_map_lookup, 497 .map_release_uref = sock_map_release_progs, 498 .map_check_btf = map_check_no_btf, 499 }; 500 501 struct bpf_htab_elem { 502 struct rcu_head rcu; 503 u32 hash; 504 struct sock *sk; 505 struct hlist_node node; 506 u8 key[0]; 507 }; 508 509 struct bpf_htab_bucket { 510 struct hlist_head head; 511 raw_spinlock_t lock; 512 }; 513 514 struct bpf_htab { 515 struct bpf_map map; 516 struct bpf_htab_bucket *buckets; 517 u32 buckets_num; 518 u32 elem_size; 519 struct sk_psock_progs progs; 520 atomic_t count; 521 }; 522 523 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 524 { 525 return jhash(key, len, 0); 526 } 527 528 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 529 u32 hash) 530 { 531 return &htab->buckets[hash & (htab->buckets_num - 1)]; 532 } 533 534 static struct bpf_htab_elem * 535 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 536 u32 key_size) 537 { 538 struct bpf_htab_elem *elem; 539 540 hlist_for_each_entry_rcu(elem, head, node) { 541 if (elem->hash == hash && 542 !memcmp(&elem->key, key, key_size)) 543 return elem; 544 } 545 546 return NULL; 547 } 548 549 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 550 { 551 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 552 u32 key_size = map->key_size, hash; 553 struct bpf_htab_bucket *bucket; 554 struct bpf_htab_elem *elem; 555 556 WARN_ON_ONCE(!rcu_read_lock_held()); 557 558 hash = sock_hash_bucket_hash(key, key_size); 559 bucket = sock_hash_select_bucket(htab, hash); 560 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 561 562 return elem ? elem->sk : NULL; 563 } 564 565 static void sock_hash_free_elem(struct bpf_htab *htab, 566 struct bpf_htab_elem *elem) 567 { 568 atomic_dec(&htab->count); 569 kfree_rcu(elem, rcu); 570 } 571 572 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 573 void *link_raw) 574 { 575 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 576 struct bpf_htab_elem *elem_probe, *elem = link_raw; 577 struct bpf_htab_bucket *bucket; 578 579 WARN_ON_ONCE(!rcu_read_lock_held()); 580 bucket = sock_hash_select_bucket(htab, elem->hash); 581 582 /* elem may be deleted in parallel from the map, but access here 583 * is okay since it's going away only after RCU grace period. 584 * However, we need to check whether it's still present. 585 */ 586 raw_spin_lock_bh(&bucket->lock); 587 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 588 elem->key, map->key_size); 589 if (elem_probe && elem_probe == elem) { 590 hlist_del_rcu(&elem->node); 591 sock_map_unref(elem->sk, elem); 592 sock_hash_free_elem(htab, elem); 593 } 594 raw_spin_unlock_bh(&bucket->lock); 595 } 596 597 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 598 { 599 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 600 u32 hash, key_size = map->key_size; 601 struct bpf_htab_bucket *bucket; 602 struct bpf_htab_elem *elem; 603 int ret = -ENOENT; 604 605 hash = sock_hash_bucket_hash(key, key_size); 606 bucket = sock_hash_select_bucket(htab, hash); 607 608 raw_spin_lock_bh(&bucket->lock); 609 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 610 if (elem) { 611 hlist_del_rcu(&elem->node); 612 sock_map_unref(elem->sk, elem); 613 sock_hash_free_elem(htab, elem); 614 ret = 0; 615 } 616 raw_spin_unlock_bh(&bucket->lock); 617 return ret; 618 } 619 620 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 621 void *key, u32 key_size, 622 u32 hash, struct sock *sk, 623 struct bpf_htab_elem *old) 624 { 625 struct bpf_htab_elem *new; 626 627 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 628 if (!old) { 629 atomic_dec(&htab->count); 630 return ERR_PTR(-E2BIG); 631 } 632 } 633 634 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 635 htab->map.numa_node); 636 if (!new) { 637 atomic_dec(&htab->count); 638 return ERR_PTR(-ENOMEM); 639 } 640 memcpy(new->key, key, key_size); 641 new->sk = sk; 642 new->hash = hash; 643 return new; 644 } 645 646 static int sock_hash_update_common(struct bpf_map *map, void *key, 647 struct sock *sk, u64 flags) 648 { 649 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 650 u32 key_size = map->key_size, hash; 651 struct bpf_htab_elem *elem, *elem_new; 652 struct bpf_htab_bucket *bucket; 653 struct sk_psock_link *link; 654 struct sk_psock *psock; 655 int ret; 656 657 WARN_ON_ONCE(!rcu_read_lock_held()); 658 if (unlikely(flags > BPF_EXIST)) 659 return -EINVAL; 660 661 link = sk_psock_init_link(); 662 if (!link) 663 return -ENOMEM; 664 665 ret = sock_map_link(map, &htab->progs, sk); 666 if (ret < 0) 667 goto out_free; 668 669 psock = sk_psock(sk); 670 WARN_ON_ONCE(!psock); 671 672 hash = sock_hash_bucket_hash(key, key_size); 673 bucket = sock_hash_select_bucket(htab, hash); 674 675 raw_spin_lock_bh(&bucket->lock); 676 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 677 if (elem && flags == BPF_NOEXIST) { 678 ret = -EEXIST; 679 goto out_unlock; 680 } else if (!elem && flags == BPF_EXIST) { 681 ret = -ENOENT; 682 goto out_unlock; 683 } 684 685 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 686 if (IS_ERR(elem_new)) { 687 ret = PTR_ERR(elem_new); 688 goto out_unlock; 689 } 690 691 sock_map_add_link(psock, link, map, elem_new); 692 /* Add new element to the head of the list, so that 693 * concurrent search will find it before old elem. 694 */ 695 hlist_add_head_rcu(&elem_new->node, &bucket->head); 696 if (elem) { 697 hlist_del_rcu(&elem->node); 698 sock_map_unref(elem->sk, elem); 699 sock_hash_free_elem(htab, elem); 700 } 701 raw_spin_unlock_bh(&bucket->lock); 702 return 0; 703 out_unlock: 704 raw_spin_unlock_bh(&bucket->lock); 705 sk_psock_put(sk, psock); 706 out_free: 707 sk_psock_free_link(link); 708 return ret; 709 } 710 711 static int sock_hash_update_elem(struct bpf_map *map, void *key, 712 void *value, u64 flags) 713 { 714 u32 ufd = *(u32 *)value; 715 struct socket *sock; 716 struct sock *sk; 717 int ret; 718 719 sock = sockfd_lookup(ufd, &ret); 720 if (!sock) 721 return ret; 722 sk = sock->sk; 723 if (!sk) { 724 ret = -EINVAL; 725 goto out; 726 } 727 if (!sock_map_sk_is_suitable(sk) || 728 sk->sk_state != TCP_ESTABLISHED) { 729 ret = -EOPNOTSUPP; 730 goto out; 731 } 732 733 sock_map_sk_acquire(sk); 734 ret = sock_hash_update_common(map, key, sk, flags); 735 sock_map_sk_release(sk); 736 out: 737 fput(sock->file); 738 return ret; 739 } 740 741 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 742 void *key_next) 743 { 744 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 745 struct bpf_htab_elem *elem, *elem_next; 746 u32 hash, key_size = map->key_size; 747 struct hlist_head *head; 748 int i = 0; 749 750 if (!key) 751 goto find_first_elem; 752 hash = sock_hash_bucket_hash(key, key_size); 753 head = &sock_hash_select_bucket(htab, hash)->head; 754 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 755 if (!elem) 756 goto find_first_elem; 757 758 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 759 struct bpf_htab_elem, node); 760 if (elem_next) { 761 memcpy(key_next, elem_next->key, key_size); 762 return 0; 763 } 764 765 i = hash & (htab->buckets_num - 1); 766 i++; 767 find_first_elem: 768 for (; i < htab->buckets_num; i++) { 769 head = &sock_hash_select_bucket(htab, i)->head; 770 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 771 struct bpf_htab_elem, node); 772 if (elem_next) { 773 memcpy(key_next, elem_next->key, key_size); 774 return 0; 775 } 776 } 777 778 return -ENOENT; 779 } 780 781 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 782 { 783 struct bpf_htab *htab; 784 int i, err; 785 u64 cost; 786 787 if (!capable(CAP_NET_ADMIN)) 788 return ERR_PTR(-EPERM); 789 if (attr->max_entries == 0 || 790 attr->key_size == 0 || 791 attr->value_size != 4 || 792 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 793 return ERR_PTR(-EINVAL); 794 if (attr->key_size > MAX_BPF_STACK) 795 return ERR_PTR(-E2BIG); 796 797 htab = kzalloc(sizeof(*htab), GFP_USER); 798 if (!htab) 799 return ERR_PTR(-ENOMEM); 800 801 bpf_map_init_from_attr(&htab->map, attr); 802 803 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 804 htab->elem_size = sizeof(struct bpf_htab_elem) + 805 round_up(htab->map.key_size, 8); 806 if (htab->buckets_num == 0 || 807 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 808 err = -EINVAL; 809 goto free_htab; 810 } 811 812 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 813 (u64) htab->elem_size * htab->map.max_entries; 814 if (cost >= U32_MAX - PAGE_SIZE) { 815 err = -EINVAL; 816 goto free_htab; 817 } 818 819 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 820 sizeof(struct bpf_htab_bucket), 821 htab->map.numa_node); 822 if (!htab->buckets) { 823 err = -ENOMEM; 824 goto free_htab; 825 } 826 827 for (i = 0; i < htab->buckets_num; i++) { 828 INIT_HLIST_HEAD(&htab->buckets[i].head); 829 raw_spin_lock_init(&htab->buckets[i].lock); 830 } 831 832 return &htab->map; 833 free_htab: 834 kfree(htab); 835 return ERR_PTR(err); 836 } 837 838 static void sock_hash_free(struct bpf_map *map) 839 { 840 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 841 struct bpf_htab_bucket *bucket; 842 struct bpf_htab_elem *elem; 843 struct hlist_node *node; 844 int i; 845 846 synchronize_rcu(); 847 rcu_read_lock(); 848 for (i = 0; i < htab->buckets_num; i++) { 849 bucket = sock_hash_select_bucket(htab, i); 850 raw_spin_lock_bh(&bucket->lock); 851 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 852 hlist_del_rcu(&elem->node); 853 sock_map_unref(elem->sk, elem); 854 } 855 raw_spin_unlock_bh(&bucket->lock); 856 } 857 rcu_read_unlock(); 858 859 bpf_map_area_free(htab->buckets); 860 kfree(htab); 861 } 862 863 static void sock_hash_release_progs(struct bpf_map *map) 864 { 865 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 866 } 867 868 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 869 struct bpf_map *, map, void *, key, u64, flags) 870 { 871 WARN_ON_ONCE(!rcu_read_lock_held()); 872 873 if (likely(sock_map_sk_is_suitable(sops->sk) && 874 sock_map_op_okay(sops))) 875 return sock_hash_update_common(map, key, sops->sk, flags); 876 return -EOPNOTSUPP; 877 } 878 879 const struct bpf_func_proto bpf_sock_hash_update_proto = { 880 .func = bpf_sock_hash_update, 881 .gpl_only = false, 882 .pkt_access = true, 883 .ret_type = RET_INTEGER, 884 .arg1_type = ARG_PTR_TO_CTX, 885 .arg2_type = ARG_CONST_MAP_PTR, 886 .arg3_type = ARG_PTR_TO_MAP_KEY, 887 .arg4_type = ARG_ANYTHING, 888 }; 889 890 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 891 struct bpf_map *, map, void *, key, u64, flags) 892 { 893 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 894 895 if (unlikely(flags & ~(BPF_F_INGRESS))) 896 return SK_DROP; 897 tcb->bpf.flags = flags; 898 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); 899 if (!tcb->bpf.sk_redir) 900 return SK_DROP; 901 return SK_PASS; 902 } 903 904 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 905 .func = bpf_sk_redirect_hash, 906 .gpl_only = false, 907 .ret_type = RET_INTEGER, 908 .arg1_type = ARG_PTR_TO_CTX, 909 .arg2_type = ARG_CONST_MAP_PTR, 910 .arg3_type = ARG_PTR_TO_MAP_KEY, 911 .arg4_type = ARG_ANYTHING, 912 }; 913 914 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 915 struct bpf_map *, map, void *, key, u64, flags) 916 { 917 if (unlikely(flags & ~(BPF_F_INGRESS))) 918 return SK_DROP; 919 msg->flags = flags; 920 msg->sk_redir = __sock_hash_lookup_elem(map, key); 921 if (!msg->sk_redir) 922 return SK_DROP; 923 return SK_PASS; 924 } 925 926 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 927 .func = bpf_msg_redirect_hash, 928 .gpl_only = false, 929 .ret_type = RET_INTEGER, 930 .arg1_type = ARG_PTR_TO_CTX, 931 .arg2_type = ARG_CONST_MAP_PTR, 932 .arg3_type = ARG_PTR_TO_MAP_KEY, 933 .arg4_type = ARG_ANYTHING, 934 }; 935 936 const struct bpf_map_ops sock_hash_ops = { 937 .map_alloc = sock_hash_alloc, 938 .map_free = sock_hash_free, 939 .map_get_next_key = sock_hash_get_next_key, 940 .map_update_elem = sock_hash_update_elem, 941 .map_delete_elem = sock_hash_delete_elem, 942 .map_lookup_elem = sock_map_lookup, 943 .map_release_uref = sock_hash_release_progs, 944 .map_check_btf = map_check_no_btf, 945 }; 946 947 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 948 { 949 switch (map->map_type) { 950 case BPF_MAP_TYPE_SOCKMAP: 951 return &container_of(map, struct bpf_stab, map)->progs; 952 case BPF_MAP_TYPE_SOCKHASH: 953 return &container_of(map, struct bpf_htab, map)->progs; 954 default: 955 break; 956 } 957 958 return NULL; 959 } 960 961 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 962 u32 which) 963 { 964 struct sk_psock_progs *progs = sock_map_progs(map); 965 966 if (!progs) 967 return -EOPNOTSUPP; 968 969 switch (which) { 970 case BPF_SK_MSG_VERDICT: 971 psock_set_prog(&progs->msg_parser, prog); 972 break; 973 case BPF_SK_SKB_STREAM_PARSER: 974 psock_set_prog(&progs->skb_parser, prog); 975 break; 976 case BPF_SK_SKB_STREAM_VERDICT: 977 psock_set_prog(&progs->skb_verdict, prog); 978 break; 979 default: 980 return -EOPNOTSUPP; 981 } 982 983 return 0; 984 } 985 986 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) 987 { 988 switch (link->map->map_type) { 989 case BPF_MAP_TYPE_SOCKMAP: 990 return sock_map_delete_from_link(link->map, sk, 991 link->link_raw); 992 case BPF_MAP_TYPE_SOCKHASH: 993 return sock_hash_delete_from_link(link->map, sk, 994 link->link_raw); 995 default: 996 break; 997 } 998 } 999