1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/btf_ids.h> 6 #include <linux/filter.h> 7 #include <linux/errno.h> 8 #include <linux/file.h> 9 #include <linux/net.h> 10 #include <linux/workqueue.h> 11 #include <linux/skmsg.h> 12 #include <linux/list.h> 13 #include <linux/jhash.h> 14 #include <linux/sock_diag.h> 15 #include <net/udp.h> 16 17 struct bpf_stab { 18 struct bpf_map map; 19 struct sock **sks; 20 struct sk_psock_progs progs; 21 raw_spinlock_t lock; 22 }; 23 24 #define SOCK_CREATE_FLAG_MASK \ 25 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 26 27 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 28 struct bpf_prog *old, u32 which); 29 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map); 30 31 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 32 { 33 struct bpf_stab *stab; 34 35 if (!capable(CAP_NET_ADMIN)) 36 return ERR_PTR(-EPERM); 37 if (attr->max_entries == 0 || 38 attr->key_size != 4 || 39 (attr->value_size != sizeof(u32) && 40 attr->value_size != sizeof(u64)) || 41 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 42 return ERR_PTR(-EINVAL); 43 44 stab = kzalloc(sizeof(*stab), GFP_USER | __GFP_ACCOUNT); 45 if (!stab) 46 return ERR_PTR(-ENOMEM); 47 48 bpf_map_init_from_attr(&stab->map, attr); 49 raw_spin_lock_init(&stab->lock); 50 51 stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (!stab->sks) { 55 kfree(stab); 56 return ERR_PTR(-ENOMEM); 57 } 58 59 return &stab->map; 60 } 61 62 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 63 { 64 u32 ufd = attr->target_fd; 65 struct bpf_map *map; 66 struct fd f; 67 int ret; 68 69 if (attr->attach_flags || attr->replace_bpf_fd) 70 return -EINVAL; 71 72 f = fdget(ufd); 73 map = __bpf_map_get(f); 74 if (IS_ERR(map)) 75 return PTR_ERR(map); 76 ret = sock_map_prog_update(map, prog, NULL, attr->attach_type); 77 fdput(f); 78 return ret; 79 } 80 81 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) 82 { 83 u32 ufd = attr->target_fd; 84 struct bpf_prog *prog; 85 struct bpf_map *map; 86 struct fd f; 87 int ret; 88 89 if (attr->attach_flags || attr->replace_bpf_fd) 90 return -EINVAL; 91 92 f = fdget(ufd); 93 map = __bpf_map_get(f); 94 if (IS_ERR(map)) 95 return PTR_ERR(map); 96 97 prog = bpf_prog_get(attr->attach_bpf_fd); 98 if (IS_ERR(prog)) { 99 ret = PTR_ERR(prog); 100 goto put_map; 101 } 102 103 if (prog->type != ptype) { 104 ret = -EINVAL; 105 goto put_prog; 106 } 107 108 ret = sock_map_prog_update(map, NULL, prog, attr->attach_type); 109 put_prog: 110 bpf_prog_put(prog); 111 put_map: 112 fdput(f); 113 return ret; 114 } 115 116 static void sock_map_sk_acquire(struct sock *sk) 117 __acquires(&sk->sk_lock.slock) 118 { 119 lock_sock(sk); 120 preempt_disable(); 121 rcu_read_lock(); 122 } 123 124 static void sock_map_sk_release(struct sock *sk) 125 __releases(&sk->sk_lock.slock) 126 { 127 rcu_read_unlock(); 128 preempt_enable(); 129 release_sock(sk); 130 } 131 132 static void sock_map_add_link(struct sk_psock *psock, 133 struct sk_psock_link *link, 134 struct bpf_map *map, void *link_raw) 135 { 136 link->link_raw = link_raw; 137 link->map = map; 138 spin_lock_bh(&psock->link_lock); 139 list_add_tail(&link->list, &psock->link); 140 spin_unlock_bh(&psock->link_lock); 141 } 142 143 static void sock_map_del_link(struct sock *sk, 144 struct sk_psock *psock, void *link_raw) 145 { 146 bool strp_stop = false, verdict_stop = false; 147 struct sk_psock_link *link, *tmp; 148 149 spin_lock_bh(&psock->link_lock); 150 list_for_each_entry_safe(link, tmp, &psock->link, list) { 151 if (link->link_raw == link_raw) { 152 struct bpf_map *map = link->map; 153 struct bpf_stab *stab = container_of(map, struct bpf_stab, 154 map); 155 if (psock->saved_data_ready && stab->progs.stream_parser) 156 strp_stop = true; 157 if (psock->saved_data_ready && stab->progs.stream_verdict) 158 verdict_stop = true; 159 if (psock->saved_data_ready && stab->progs.skb_verdict) 160 verdict_stop = true; 161 list_del(&link->list); 162 sk_psock_free_link(link); 163 } 164 } 165 spin_unlock_bh(&psock->link_lock); 166 if (strp_stop || verdict_stop) { 167 write_lock_bh(&sk->sk_callback_lock); 168 if (strp_stop) 169 sk_psock_stop_strp(sk, psock); 170 if (verdict_stop) 171 sk_psock_stop_verdict(sk, psock); 172 173 if (psock->psock_update_sk_prot) 174 psock->psock_update_sk_prot(sk, psock, false); 175 write_unlock_bh(&sk->sk_callback_lock); 176 } 177 } 178 179 static void sock_map_unref(struct sock *sk, void *link_raw) 180 { 181 struct sk_psock *psock = sk_psock(sk); 182 183 if (likely(psock)) { 184 sock_map_del_link(sk, psock, link_raw); 185 sk_psock_put(sk, psock); 186 } 187 } 188 189 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 190 { 191 if (!sk->sk_prot->psock_update_sk_prot) 192 return -EINVAL; 193 psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot; 194 return sk->sk_prot->psock_update_sk_prot(sk, psock, false); 195 } 196 197 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 198 { 199 struct sk_psock *psock; 200 201 rcu_read_lock(); 202 psock = sk_psock(sk); 203 if (psock) { 204 if (sk->sk_prot->close != sock_map_close) { 205 psock = ERR_PTR(-EBUSY); 206 goto out; 207 } 208 209 if (!refcount_inc_not_zero(&psock->refcnt)) 210 psock = ERR_PTR(-EBUSY); 211 } 212 out: 213 rcu_read_unlock(); 214 return psock; 215 } 216 217 static int sock_map_link(struct bpf_map *map, struct sock *sk) 218 { 219 struct sk_psock_progs *progs = sock_map_progs(map); 220 struct bpf_prog *stream_verdict = NULL; 221 struct bpf_prog *stream_parser = NULL; 222 struct bpf_prog *skb_verdict = NULL; 223 struct bpf_prog *msg_parser = NULL; 224 struct sk_psock *psock; 225 int ret; 226 227 stream_verdict = READ_ONCE(progs->stream_verdict); 228 if (stream_verdict) { 229 stream_verdict = bpf_prog_inc_not_zero(stream_verdict); 230 if (IS_ERR(stream_verdict)) 231 return PTR_ERR(stream_verdict); 232 } 233 234 stream_parser = READ_ONCE(progs->stream_parser); 235 if (stream_parser) { 236 stream_parser = bpf_prog_inc_not_zero(stream_parser); 237 if (IS_ERR(stream_parser)) { 238 ret = PTR_ERR(stream_parser); 239 goto out_put_stream_verdict; 240 } 241 } 242 243 msg_parser = READ_ONCE(progs->msg_parser); 244 if (msg_parser) { 245 msg_parser = bpf_prog_inc_not_zero(msg_parser); 246 if (IS_ERR(msg_parser)) { 247 ret = PTR_ERR(msg_parser); 248 goto out_put_stream_parser; 249 } 250 } 251 252 skb_verdict = READ_ONCE(progs->skb_verdict); 253 if (skb_verdict) { 254 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 255 if (IS_ERR(skb_verdict)) { 256 ret = PTR_ERR(skb_verdict); 257 goto out_put_msg_parser; 258 } 259 } 260 261 psock = sock_map_psock_get_checked(sk); 262 if (IS_ERR(psock)) { 263 ret = PTR_ERR(psock); 264 goto out_progs; 265 } 266 267 if (psock) { 268 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 269 (stream_parser && READ_ONCE(psock->progs.stream_parser)) || 270 (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) || 271 (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) || 272 (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) || 273 (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) { 274 sk_psock_put(sk, psock); 275 ret = -EBUSY; 276 goto out_progs; 277 } 278 } else { 279 psock = sk_psock_init(sk, map->numa_node); 280 if (IS_ERR(psock)) { 281 ret = PTR_ERR(psock); 282 goto out_progs; 283 } 284 } 285 286 if (msg_parser) 287 psock_set_prog(&psock->progs.msg_parser, msg_parser); 288 if (stream_parser) 289 psock_set_prog(&psock->progs.stream_parser, stream_parser); 290 if (stream_verdict) 291 psock_set_prog(&psock->progs.stream_verdict, stream_verdict); 292 if (skb_verdict) 293 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 294 295 ret = sock_map_init_proto(sk, psock); 296 if (ret < 0) 297 goto out_drop; 298 299 write_lock_bh(&sk->sk_callback_lock); 300 if (stream_parser && stream_verdict && !psock->saved_data_ready) { 301 ret = sk_psock_init_strp(sk, psock); 302 if (ret) 303 goto out_unlock_drop; 304 sk_psock_start_strp(sk, psock); 305 } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) { 306 sk_psock_start_verdict(sk,psock); 307 } else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) { 308 sk_psock_start_verdict(sk, psock); 309 } 310 write_unlock_bh(&sk->sk_callback_lock); 311 return 0; 312 out_unlock_drop: 313 write_unlock_bh(&sk->sk_callback_lock); 314 out_drop: 315 sk_psock_put(sk, psock); 316 out_progs: 317 if (skb_verdict) 318 bpf_prog_put(skb_verdict); 319 out_put_msg_parser: 320 if (msg_parser) 321 bpf_prog_put(msg_parser); 322 out_put_stream_parser: 323 if (stream_parser) 324 bpf_prog_put(stream_parser); 325 out_put_stream_verdict: 326 if (stream_verdict) 327 bpf_prog_put(stream_verdict); 328 return ret; 329 } 330 331 static void sock_map_free(struct bpf_map *map) 332 { 333 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 334 int i; 335 336 /* After the sync no updates or deletes will be in-flight so it 337 * is safe to walk map and remove entries without risking a race 338 * in EEXIST update case. 339 */ 340 synchronize_rcu(); 341 for (i = 0; i < stab->map.max_entries; i++) { 342 struct sock **psk = &stab->sks[i]; 343 struct sock *sk; 344 345 sk = xchg(psk, NULL); 346 if (sk) { 347 lock_sock(sk); 348 rcu_read_lock(); 349 sock_map_unref(sk, psk); 350 rcu_read_unlock(); 351 release_sock(sk); 352 } 353 } 354 355 /* wait for psock readers accessing its map link */ 356 synchronize_rcu(); 357 358 bpf_map_area_free(stab->sks); 359 kfree(stab); 360 } 361 362 static void sock_map_release_progs(struct bpf_map *map) 363 { 364 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 365 } 366 367 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 368 { 369 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 370 371 WARN_ON_ONCE(!rcu_read_lock_held()); 372 373 if (unlikely(key >= map->max_entries)) 374 return NULL; 375 return READ_ONCE(stab->sks[key]); 376 } 377 378 static void *sock_map_lookup(struct bpf_map *map, void *key) 379 { 380 struct sock *sk; 381 382 sk = __sock_map_lookup_elem(map, *(u32 *)key); 383 if (!sk) 384 return NULL; 385 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 386 return NULL; 387 return sk; 388 } 389 390 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 391 { 392 struct sock *sk; 393 394 if (map->value_size != sizeof(u64)) 395 return ERR_PTR(-ENOSPC); 396 397 sk = __sock_map_lookup_elem(map, *(u32 *)key); 398 if (!sk) 399 return ERR_PTR(-ENOENT); 400 401 __sock_gen_cookie(sk); 402 return &sk->sk_cookie; 403 } 404 405 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 406 struct sock **psk) 407 { 408 struct sock *sk; 409 int err = 0; 410 411 raw_spin_lock_bh(&stab->lock); 412 sk = *psk; 413 if (!sk_test || sk_test == sk) 414 sk = xchg(psk, NULL); 415 416 if (likely(sk)) 417 sock_map_unref(sk, psk); 418 else 419 err = -EINVAL; 420 421 raw_spin_unlock_bh(&stab->lock); 422 return err; 423 } 424 425 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 426 void *link_raw) 427 { 428 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 429 430 __sock_map_delete(stab, sk, link_raw); 431 } 432 433 static int sock_map_delete_elem(struct bpf_map *map, void *key) 434 { 435 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 436 u32 i = *(u32 *)key; 437 struct sock **psk; 438 439 if (unlikely(i >= map->max_entries)) 440 return -EINVAL; 441 442 psk = &stab->sks[i]; 443 return __sock_map_delete(stab, NULL, psk); 444 } 445 446 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 447 { 448 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 449 u32 i = key ? *(u32 *)key : U32_MAX; 450 u32 *key_next = next; 451 452 if (i == stab->map.max_entries - 1) 453 return -ENOENT; 454 if (i >= stab->map.max_entries) 455 *key_next = 0; 456 else 457 *key_next = i + 1; 458 return 0; 459 } 460 461 static int sock_map_update_common(struct bpf_map *map, u32 idx, 462 struct sock *sk, u64 flags) 463 { 464 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 465 struct sk_psock_link *link; 466 struct sk_psock *psock; 467 struct sock *osk; 468 int ret; 469 470 WARN_ON_ONCE(!rcu_read_lock_held()); 471 if (unlikely(flags > BPF_EXIST)) 472 return -EINVAL; 473 if (unlikely(idx >= map->max_entries)) 474 return -E2BIG; 475 476 link = sk_psock_init_link(); 477 if (!link) 478 return -ENOMEM; 479 480 ret = sock_map_link(map, sk); 481 if (ret < 0) 482 goto out_free; 483 484 psock = sk_psock(sk); 485 WARN_ON_ONCE(!psock); 486 487 raw_spin_lock_bh(&stab->lock); 488 osk = stab->sks[idx]; 489 if (osk && flags == BPF_NOEXIST) { 490 ret = -EEXIST; 491 goto out_unlock; 492 } else if (!osk && flags == BPF_EXIST) { 493 ret = -ENOENT; 494 goto out_unlock; 495 } 496 497 sock_map_add_link(psock, link, map, &stab->sks[idx]); 498 stab->sks[idx] = sk; 499 if (osk) 500 sock_map_unref(osk, &stab->sks[idx]); 501 raw_spin_unlock_bh(&stab->lock); 502 return 0; 503 out_unlock: 504 raw_spin_unlock_bh(&stab->lock); 505 if (psock) 506 sk_psock_put(sk, psock); 507 out_free: 508 sk_psock_free_link(link); 509 return ret; 510 } 511 512 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 513 { 514 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 515 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 516 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 517 } 518 519 static bool sock_map_redirect_allowed(const struct sock *sk) 520 { 521 if (sk_is_tcp(sk)) 522 return sk->sk_state != TCP_LISTEN; 523 else 524 return sk->sk_state == TCP_ESTABLISHED; 525 } 526 527 static bool sock_map_sk_is_suitable(const struct sock *sk) 528 { 529 return !!sk->sk_prot->psock_update_sk_prot; 530 } 531 532 static bool sock_map_sk_state_allowed(const struct sock *sk) 533 { 534 if (sk_is_tcp(sk)) 535 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 536 return true; 537 } 538 539 static int sock_hash_update_common(struct bpf_map *map, void *key, 540 struct sock *sk, u64 flags); 541 542 int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value, 543 u64 flags) 544 { 545 struct socket *sock; 546 struct sock *sk; 547 int ret; 548 u64 ufd; 549 550 if (map->value_size == sizeof(u64)) 551 ufd = *(u64 *)value; 552 else 553 ufd = *(u32 *)value; 554 if (ufd > S32_MAX) 555 return -EINVAL; 556 557 sock = sockfd_lookup(ufd, &ret); 558 if (!sock) 559 return ret; 560 sk = sock->sk; 561 if (!sk) { 562 ret = -EINVAL; 563 goto out; 564 } 565 if (!sock_map_sk_is_suitable(sk)) { 566 ret = -EOPNOTSUPP; 567 goto out; 568 } 569 570 sock_map_sk_acquire(sk); 571 if (!sock_map_sk_state_allowed(sk)) 572 ret = -EOPNOTSUPP; 573 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 574 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 575 else 576 ret = sock_hash_update_common(map, key, sk, flags); 577 sock_map_sk_release(sk); 578 out: 579 sockfd_put(sock); 580 return ret; 581 } 582 583 static int sock_map_update_elem(struct bpf_map *map, void *key, 584 void *value, u64 flags) 585 { 586 struct sock *sk = (struct sock *)value; 587 int ret; 588 589 if (unlikely(!sk || !sk_fullsock(sk))) 590 return -EINVAL; 591 592 if (!sock_map_sk_is_suitable(sk)) 593 return -EOPNOTSUPP; 594 595 local_bh_disable(); 596 bh_lock_sock(sk); 597 if (!sock_map_sk_state_allowed(sk)) 598 ret = -EOPNOTSUPP; 599 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 600 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 601 else 602 ret = sock_hash_update_common(map, key, sk, flags); 603 bh_unlock_sock(sk); 604 local_bh_enable(); 605 return ret; 606 } 607 608 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 609 struct bpf_map *, map, void *, key, u64, flags) 610 { 611 WARN_ON_ONCE(!rcu_read_lock_held()); 612 613 if (likely(sock_map_sk_is_suitable(sops->sk) && 614 sock_map_op_okay(sops))) 615 return sock_map_update_common(map, *(u32 *)key, sops->sk, 616 flags); 617 return -EOPNOTSUPP; 618 } 619 620 const struct bpf_func_proto bpf_sock_map_update_proto = { 621 .func = bpf_sock_map_update, 622 .gpl_only = false, 623 .pkt_access = true, 624 .ret_type = RET_INTEGER, 625 .arg1_type = ARG_PTR_TO_CTX, 626 .arg2_type = ARG_CONST_MAP_PTR, 627 .arg3_type = ARG_PTR_TO_MAP_KEY, 628 .arg4_type = ARG_ANYTHING, 629 }; 630 631 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 632 struct bpf_map *, map, u32, key, u64, flags) 633 { 634 struct sock *sk; 635 636 if (unlikely(flags & ~(BPF_F_INGRESS))) 637 return SK_DROP; 638 639 sk = __sock_map_lookup_elem(map, key); 640 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 641 return SK_DROP; 642 643 skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); 644 return SK_PASS; 645 } 646 647 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 648 .func = bpf_sk_redirect_map, 649 .gpl_only = false, 650 .ret_type = RET_INTEGER, 651 .arg1_type = ARG_PTR_TO_CTX, 652 .arg2_type = ARG_CONST_MAP_PTR, 653 .arg3_type = ARG_ANYTHING, 654 .arg4_type = ARG_ANYTHING, 655 }; 656 657 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 658 struct bpf_map *, map, u32, key, u64, flags) 659 { 660 struct sock *sk; 661 662 if (unlikely(flags & ~(BPF_F_INGRESS))) 663 return SK_DROP; 664 665 sk = __sock_map_lookup_elem(map, key); 666 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 667 return SK_DROP; 668 669 msg->flags = flags; 670 msg->sk_redir = sk; 671 return SK_PASS; 672 } 673 674 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 675 .func = bpf_msg_redirect_map, 676 .gpl_only = false, 677 .ret_type = RET_INTEGER, 678 .arg1_type = ARG_PTR_TO_CTX, 679 .arg2_type = ARG_CONST_MAP_PTR, 680 .arg3_type = ARG_ANYTHING, 681 .arg4_type = ARG_ANYTHING, 682 }; 683 684 struct sock_map_seq_info { 685 struct bpf_map *map; 686 struct sock *sk; 687 u32 index; 688 }; 689 690 struct bpf_iter__sockmap { 691 __bpf_md_ptr(struct bpf_iter_meta *, meta); 692 __bpf_md_ptr(struct bpf_map *, map); 693 __bpf_md_ptr(void *, key); 694 __bpf_md_ptr(struct sock *, sk); 695 }; 696 697 DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta, 698 struct bpf_map *map, void *key, 699 struct sock *sk) 700 701 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info) 702 { 703 if (unlikely(info->index >= info->map->max_entries)) 704 return NULL; 705 706 info->sk = __sock_map_lookup_elem(info->map, info->index); 707 708 /* can't return sk directly, since that might be NULL */ 709 return info; 710 } 711 712 static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos) 713 __acquires(rcu) 714 { 715 struct sock_map_seq_info *info = seq->private; 716 717 if (*pos == 0) 718 ++*pos; 719 720 /* pairs with sock_map_seq_stop */ 721 rcu_read_lock(); 722 return sock_map_seq_lookup_elem(info); 723 } 724 725 static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos) 726 __must_hold(rcu) 727 { 728 struct sock_map_seq_info *info = seq->private; 729 730 ++*pos; 731 ++info->index; 732 733 return sock_map_seq_lookup_elem(info); 734 } 735 736 static int sock_map_seq_show(struct seq_file *seq, void *v) 737 __must_hold(rcu) 738 { 739 struct sock_map_seq_info *info = seq->private; 740 struct bpf_iter__sockmap ctx = {}; 741 struct bpf_iter_meta meta; 742 struct bpf_prog *prog; 743 744 meta.seq = seq; 745 prog = bpf_iter_get_info(&meta, !v); 746 if (!prog) 747 return 0; 748 749 ctx.meta = &meta; 750 ctx.map = info->map; 751 if (v) { 752 ctx.key = &info->index; 753 ctx.sk = info->sk; 754 } 755 756 return bpf_iter_run_prog(prog, &ctx); 757 } 758 759 static void sock_map_seq_stop(struct seq_file *seq, void *v) 760 __releases(rcu) 761 { 762 if (!v) 763 (void)sock_map_seq_show(seq, NULL); 764 765 /* pairs with sock_map_seq_start */ 766 rcu_read_unlock(); 767 } 768 769 static const struct seq_operations sock_map_seq_ops = { 770 .start = sock_map_seq_start, 771 .next = sock_map_seq_next, 772 .stop = sock_map_seq_stop, 773 .show = sock_map_seq_show, 774 }; 775 776 static int sock_map_init_seq_private(void *priv_data, 777 struct bpf_iter_aux_info *aux) 778 { 779 struct sock_map_seq_info *info = priv_data; 780 781 info->map = aux->map; 782 return 0; 783 } 784 785 static const struct bpf_iter_seq_info sock_map_iter_seq_info = { 786 .seq_ops = &sock_map_seq_ops, 787 .init_seq_private = sock_map_init_seq_private, 788 .seq_priv_size = sizeof(struct sock_map_seq_info), 789 }; 790 791 static int sock_map_btf_id; 792 const struct bpf_map_ops sock_map_ops = { 793 .map_meta_equal = bpf_map_meta_equal, 794 .map_alloc = sock_map_alloc, 795 .map_free = sock_map_free, 796 .map_get_next_key = sock_map_get_next_key, 797 .map_lookup_elem_sys_only = sock_map_lookup_sys, 798 .map_update_elem = sock_map_update_elem, 799 .map_delete_elem = sock_map_delete_elem, 800 .map_lookup_elem = sock_map_lookup, 801 .map_release_uref = sock_map_release_progs, 802 .map_check_btf = map_check_no_btf, 803 .map_btf_name = "bpf_stab", 804 .map_btf_id = &sock_map_btf_id, 805 .iter_seq_info = &sock_map_iter_seq_info, 806 }; 807 808 struct bpf_shtab_elem { 809 struct rcu_head rcu; 810 u32 hash; 811 struct sock *sk; 812 struct hlist_node node; 813 u8 key[]; 814 }; 815 816 struct bpf_shtab_bucket { 817 struct hlist_head head; 818 raw_spinlock_t lock; 819 }; 820 821 struct bpf_shtab { 822 struct bpf_map map; 823 struct bpf_shtab_bucket *buckets; 824 u32 buckets_num; 825 u32 elem_size; 826 struct sk_psock_progs progs; 827 atomic_t count; 828 }; 829 830 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 831 { 832 return jhash(key, len, 0); 833 } 834 835 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, 836 u32 hash) 837 { 838 return &htab->buckets[hash & (htab->buckets_num - 1)]; 839 } 840 841 static struct bpf_shtab_elem * 842 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 843 u32 key_size) 844 { 845 struct bpf_shtab_elem *elem; 846 847 hlist_for_each_entry_rcu(elem, head, node) { 848 if (elem->hash == hash && 849 !memcmp(&elem->key, key, key_size)) 850 return elem; 851 } 852 853 return NULL; 854 } 855 856 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 857 { 858 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 859 u32 key_size = map->key_size, hash; 860 struct bpf_shtab_bucket *bucket; 861 struct bpf_shtab_elem *elem; 862 863 WARN_ON_ONCE(!rcu_read_lock_held()); 864 865 hash = sock_hash_bucket_hash(key, key_size); 866 bucket = sock_hash_select_bucket(htab, hash); 867 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 868 869 return elem ? elem->sk : NULL; 870 } 871 872 static void sock_hash_free_elem(struct bpf_shtab *htab, 873 struct bpf_shtab_elem *elem) 874 { 875 atomic_dec(&htab->count); 876 kfree_rcu(elem, rcu); 877 } 878 879 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 880 void *link_raw) 881 { 882 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 883 struct bpf_shtab_elem *elem_probe, *elem = link_raw; 884 struct bpf_shtab_bucket *bucket; 885 886 WARN_ON_ONCE(!rcu_read_lock_held()); 887 bucket = sock_hash_select_bucket(htab, elem->hash); 888 889 /* elem may be deleted in parallel from the map, but access here 890 * is okay since it's going away only after RCU grace period. 891 * However, we need to check whether it's still present. 892 */ 893 raw_spin_lock_bh(&bucket->lock); 894 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 895 elem->key, map->key_size); 896 if (elem_probe && elem_probe == elem) { 897 hlist_del_rcu(&elem->node); 898 sock_map_unref(elem->sk, elem); 899 sock_hash_free_elem(htab, elem); 900 } 901 raw_spin_unlock_bh(&bucket->lock); 902 } 903 904 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 905 { 906 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 907 u32 hash, key_size = map->key_size; 908 struct bpf_shtab_bucket *bucket; 909 struct bpf_shtab_elem *elem; 910 int ret = -ENOENT; 911 912 hash = sock_hash_bucket_hash(key, key_size); 913 bucket = sock_hash_select_bucket(htab, hash); 914 915 raw_spin_lock_bh(&bucket->lock); 916 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 917 if (elem) { 918 hlist_del_rcu(&elem->node); 919 sock_map_unref(elem->sk, elem); 920 sock_hash_free_elem(htab, elem); 921 ret = 0; 922 } 923 raw_spin_unlock_bh(&bucket->lock); 924 return ret; 925 } 926 927 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, 928 void *key, u32 key_size, 929 u32 hash, struct sock *sk, 930 struct bpf_shtab_elem *old) 931 { 932 struct bpf_shtab_elem *new; 933 934 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 935 if (!old) { 936 atomic_dec(&htab->count); 937 return ERR_PTR(-E2BIG); 938 } 939 } 940 941 new = bpf_map_kmalloc_node(&htab->map, htab->elem_size, 942 GFP_ATOMIC | __GFP_NOWARN, 943 htab->map.numa_node); 944 if (!new) { 945 atomic_dec(&htab->count); 946 return ERR_PTR(-ENOMEM); 947 } 948 memcpy(new->key, key, key_size); 949 new->sk = sk; 950 new->hash = hash; 951 return new; 952 } 953 954 static int sock_hash_update_common(struct bpf_map *map, void *key, 955 struct sock *sk, u64 flags) 956 { 957 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 958 u32 key_size = map->key_size, hash; 959 struct bpf_shtab_elem *elem, *elem_new; 960 struct bpf_shtab_bucket *bucket; 961 struct sk_psock_link *link; 962 struct sk_psock *psock; 963 int ret; 964 965 WARN_ON_ONCE(!rcu_read_lock_held()); 966 if (unlikely(flags > BPF_EXIST)) 967 return -EINVAL; 968 969 link = sk_psock_init_link(); 970 if (!link) 971 return -ENOMEM; 972 973 ret = sock_map_link(map, sk); 974 if (ret < 0) 975 goto out_free; 976 977 psock = sk_psock(sk); 978 WARN_ON_ONCE(!psock); 979 980 hash = sock_hash_bucket_hash(key, key_size); 981 bucket = sock_hash_select_bucket(htab, hash); 982 983 raw_spin_lock_bh(&bucket->lock); 984 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 985 if (elem && flags == BPF_NOEXIST) { 986 ret = -EEXIST; 987 goto out_unlock; 988 } else if (!elem && flags == BPF_EXIST) { 989 ret = -ENOENT; 990 goto out_unlock; 991 } 992 993 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 994 if (IS_ERR(elem_new)) { 995 ret = PTR_ERR(elem_new); 996 goto out_unlock; 997 } 998 999 sock_map_add_link(psock, link, map, elem_new); 1000 /* Add new element to the head of the list, so that 1001 * concurrent search will find it before old elem. 1002 */ 1003 hlist_add_head_rcu(&elem_new->node, &bucket->head); 1004 if (elem) { 1005 hlist_del_rcu(&elem->node); 1006 sock_map_unref(elem->sk, elem); 1007 sock_hash_free_elem(htab, elem); 1008 } 1009 raw_spin_unlock_bh(&bucket->lock); 1010 return 0; 1011 out_unlock: 1012 raw_spin_unlock_bh(&bucket->lock); 1013 sk_psock_put(sk, psock); 1014 out_free: 1015 sk_psock_free_link(link); 1016 return ret; 1017 } 1018 1019 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 1020 void *key_next) 1021 { 1022 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1023 struct bpf_shtab_elem *elem, *elem_next; 1024 u32 hash, key_size = map->key_size; 1025 struct hlist_head *head; 1026 int i = 0; 1027 1028 if (!key) 1029 goto find_first_elem; 1030 hash = sock_hash_bucket_hash(key, key_size); 1031 head = &sock_hash_select_bucket(htab, hash)->head; 1032 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 1033 if (!elem) 1034 goto find_first_elem; 1035 1036 elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), 1037 struct bpf_shtab_elem, node); 1038 if (elem_next) { 1039 memcpy(key_next, elem_next->key, key_size); 1040 return 0; 1041 } 1042 1043 i = hash & (htab->buckets_num - 1); 1044 i++; 1045 find_first_elem: 1046 for (; i < htab->buckets_num; i++) { 1047 head = &sock_hash_select_bucket(htab, i)->head; 1048 elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)), 1049 struct bpf_shtab_elem, node); 1050 if (elem_next) { 1051 memcpy(key_next, elem_next->key, key_size); 1052 return 0; 1053 } 1054 } 1055 1056 return -ENOENT; 1057 } 1058 1059 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 1060 { 1061 struct bpf_shtab *htab; 1062 int i, err; 1063 1064 if (!capable(CAP_NET_ADMIN)) 1065 return ERR_PTR(-EPERM); 1066 if (attr->max_entries == 0 || 1067 attr->key_size == 0 || 1068 (attr->value_size != sizeof(u32) && 1069 attr->value_size != sizeof(u64)) || 1070 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 1071 return ERR_PTR(-EINVAL); 1072 if (attr->key_size > MAX_BPF_STACK) 1073 return ERR_PTR(-E2BIG); 1074 1075 htab = kzalloc(sizeof(*htab), GFP_USER | __GFP_ACCOUNT); 1076 if (!htab) 1077 return ERR_PTR(-ENOMEM); 1078 1079 bpf_map_init_from_attr(&htab->map, attr); 1080 1081 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 1082 htab->elem_size = sizeof(struct bpf_shtab_elem) + 1083 round_up(htab->map.key_size, 8); 1084 if (htab->buckets_num == 0 || 1085 htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) { 1086 err = -EINVAL; 1087 goto free_htab; 1088 } 1089 1090 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 1091 sizeof(struct bpf_shtab_bucket), 1092 htab->map.numa_node); 1093 if (!htab->buckets) { 1094 err = -ENOMEM; 1095 goto free_htab; 1096 } 1097 1098 for (i = 0; i < htab->buckets_num; i++) { 1099 INIT_HLIST_HEAD(&htab->buckets[i].head); 1100 raw_spin_lock_init(&htab->buckets[i].lock); 1101 } 1102 1103 return &htab->map; 1104 free_htab: 1105 kfree(htab); 1106 return ERR_PTR(err); 1107 } 1108 1109 static void sock_hash_free(struct bpf_map *map) 1110 { 1111 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1112 struct bpf_shtab_bucket *bucket; 1113 struct hlist_head unlink_list; 1114 struct bpf_shtab_elem *elem; 1115 struct hlist_node *node; 1116 int i; 1117 1118 /* After the sync no updates or deletes will be in-flight so it 1119 * is safe to walk map and remove entries without risking a race 1120 * in EEXIST update case. 1121 */ 1122 synchronize_rcu(); 1123 for (i = 0; i < htab->buckets_num; i++) { 1124 bucket = sock_hash_select_bucket(htab, i); 1125 1126 /* We are racing with sock_hash_delete_from_link to 1127 * enter the spin-lock critical section. Every socket on 1128 * the list is still linked to sockhash. Since link 1129 * exists, psock exists and holds a ref to socket. That 1130 * lets us to grab a socket ref too. 1131 */ 1132 raw_spin_lock_bh(&bucket->lock); 1133 hlist_for_each_entry(elem, &bucket->head, node) 1134 sock_hold(elem->sk); 1135 hlist_move_list(&bucket->head, &unlink_list); 1136 raw_spin_unlock_bh(&bucket->lock); 1137 1138 /* Process removed entries out of atomic context to 1139 * block for socket lock before deleting the psock's 1140 * link to sockhash. 1141 */ 1142 hlist_for_each_entry_safe(elem, node, &unlink_list, node) { 1143 hlist_del(&elem->node); 1144 lock_sock(elem->sk); 1145 rcu_read_lock(); 1146 sock_map_unref(elem->sk, elem); 1147 rcu_read_unlock(); 1148 release_sock(elem->sk); 1149 sock_put(elem->sk); 1150 sock_hash_free_elem(htab, elem); 1151 } 1152 } 1153 1154 /* wait for psock readers accessing its map link */ 1155 synchronize_rcu(); 1156 1157 bpf_map_area_free(htab->buckets); 1158 kfree(htab); 1159 } 1160 1161 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1162 { 1163 struct sock *sk; 1164 1165 if (map->value_size != sizeof(u64)) 1166 return ERR_PTR(-ENOSPC); 1167 1168 sk = __sock_hash_lookup_elem(map, key); 1169 if (!sk) 1170 return ERR_PTR(-ENOENT); 1171 1172 __sock_gen_cookie(sk); 1173 return &sk->sk_cookie; 1174 } 1175 1176 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1177 { 1178 struct sock *sk; 1179 1180 sk = __sock_hash_lookup_elem(map, key); 1181 if (!sk) 1182 return NULL; 1183 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 1184 return NULL; 1185 return sk; 1186 } 1187 1188 static void sock_hash_release_progs(struct bpf_map *map) 1189 { 1190 psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs); 1191 } 1192 1193 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1194 struct bpf_map *, map, void *, key, u64, flags) 1195 { 1196 WARN_ON_ONCE(!rcu_read_lock_held()); 1197 1198 if (likely(sock_map_sk_is_suitable(sops->sk) && 1199 sock_map_op_okay(sops))) 1200 return sock_hash_update_common(map, key, sops->sk, flags); 1201 return -EOPNOTSUPP; 1202 } 1203 1204 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1205 .func = bpf_sock_hash_update, 1206 .gpl_only = false, 1207 .pkt_access = true, 1208 .ret_type = RET_INTEGER, 1209 .arg1_type = ARG_PTR_TO_CTX, 1210 .arg2_type = ARG_CONST_MAP_PTR, 1211 .arg3_type = ARG_PTR_TO_MAP_KEY, 1212 .arg4_type = ARG_ANYTHING, 1213 }; 1214 1215 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1216 struct bpf_map *, map, void *, key, u64, flags) 1217 { 1218 struct sock *sk; 1219 1220 if (unlikely(flags & ~(BPF_F_INGRESS))) 1221 return SK_DROP; 1222 1223 sk = __sock_hash_lookup_elem(map, key); 1224 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1225 return SK_DROP; 1226 1227 skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); 1228 return SK_PASS; 1229 } 1230 1231 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1232 .func = bpf_sk_redirect_hash, 1233 .gpl_only = false, 1234 .ret_type = RET_INTEGER, 1235 .arg1_type = ARG_PTR_TO_CTX, 1236 .arg2_type = ARG_CONST_MAP_PTR, 1237 .arg3_type = ARG_PTR_TO_MAP_KEY, 1238 .arg4_type = ARG_ANYTHING, 1239 }; 1240 1241 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1242 struct bpf_map *, map, void *, key, u64, flags) 1243 { 1244 struct sock *sk; 1245 1246 if (unlikely(flags & ~(BPF_F_INGRESS))) 1247 return SK_DROP; 1248 1249 sk = __sock_hash_lookup_elem(map, key); 1250 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1251 return SK_DROP; 1252 1253 msg->flags = flags; 1254 msg->sk_redir = sk; 1255 return SK_PASS; 1256 } 1257 1258 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1259 .func = bpf_msg_redirect_hash, 1260 .gpl_only = false, 1261 .ret_type = RET_INTEGER, 1262 .arg1_type = ARG_PTR_TO_CTX, 1263 .arg2_type = ARG_CONST_MAP_PTR, 1264 .arg3_type = ARG_PTR_TO_MAP_KEY, 1265 .arg4_type = ARG_ANYTHING, 1266 }; 1267 1268 struct sock_hash_seq_info { 1269 struct bpf_map *map; 1270 struct bpf_shtab *htab; 1271 u32 bucket_id; 1272 }; 1273 1274 static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, 1275 struct bpf_shtab_elem *prev_elem) 1276 { 1277 const struct bpf_shtab *htab = info->htab; 1278 struct bpf_shtab_bucket *bucket; 1279 struct bpf_shtab_elem *elem; 1280 struct hlist_node *node; 1281 1282 /* try to find next elem in the same bucket */ 1283 if (prev_elem) { 1284 node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); 1285 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1286 if (elem) 1287 return elem; 1288 1289 /* no more elements, continue in the next bucket */ 1290 info->bucket_id++; 1291 } 1292 1293 for (; info->bucket_id < htab->buckets_num; info->bucket_id++) { 1294 bucket = &htab->buckets[info->bucket_id]; 1295 node = rcu_dereference(hlist_first_rcu(&bucket->head)); 1296 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1297 if (elem) 1298 return elem; 1299 } 1300 1301 return NULL; 1302 } 1303 1304 static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos) 1305 __acquires(rcu) 1306 { 1307 struct sock_hash_seq_info *info = seq->private; 1308 1309 if (*pos == 0) 1310 ++*pos; 1311 1312 /* pairs with sock_hash_seq_stop */ 1313 rcu_read_lock(); 1314 return sock_hash_seq_find_next(info, NULL); 1315 } 1316 1317 static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) 1318 __must_hold(rcu) 1319 { 1320 struct sock_hash_seq_info *info = seq->private; 1321 1322 ++*pos; 1323 return sock_hash_seq_find_next(info, v); 1324 } 1325 1326 static int sock_hash_seq_show(struct seq_file *seq, void *v) 1327 __must_hold(rcu) 1328 { 1329 struct sock_hash_seq_info *info = seq->private; 1330 struct bpf_iter__sockmap ctx = {}; 1331 struct bpf_shtab_elem *elem = v; 1332 struct bpf_iter_meta meta; 1333 struct bpf_prog *prog; 1334 1335 meta.seq = seq; 1336 prog = bpf_iter_get_info(&meta, !elem); 1337 if (!prog) 1338 return 0; 1339 1340 ctx.meta = &meta; 1341 ctx.map = info->map; 1342 if (elem) { 1343 ctx.key = elem->key; 1344 ctx.sk = elem->sk; 1345 } 1346 1347 return bpf_iter_run_prog(prog, &ctx); 1348 } 1349 1350 static void sock_hash_seq_stop(struct seq_file *seq, void *v) 1351 __releases(rcu) 1352 { 1353 if (!v) 1354 (void)sock_hash_seq_show(seq, NULL); 1355 1356 /* pairs with sock_hash_seq_start */ 1357 rcu_read_unlock(); 1358 } 1359 1360 static const struct seq_operations sock_hash_seq_ops = { 1361 .start = sock_hash_seq_start, 1362 .next = sock_hash_seq_next, 1363 .stop = sock_hash_seq_stop, 1364 .show = sock_hash_seq_show, 1365 }; 1366 1367 static int sock_hash_init_seq_private(void *priv_data, 1368 struct bpf_iter_aux_info *aux) 1369 { 1370 struct sock_hash_seq_info *info = priv_data; 1371 1372 info->map = aux->map; 1373 info->htab = container_of(aux->map, struct bpf_shtab, map); 1374 return 0; 1375 } 1376 1377 static const struct bpf_iter_seq_info sock_hash_iter_seq_info = { 1378 .seq_ops = &sock_hash_seq_ops, 1379 .init_seq_private = sock_hash_init_seq_private, 1380 .seq_priv_size = sizeof(struct sock_hash_seq_info), 1381 }; 1382 1383 static int sock_hash_map_btf_id; 1384 const struct bpf_map_ops sock_hash_ops = { 1385 .map_meta_equal = bpf_map_meta_equal, 1386 .map_alloc = sock_hash_alloc, 1387 .map_free = sock_hash_free, 1388 .map_get_next_key = sock_hash_get_next_key, 1389 .map_update_elem = sock_map_update_elem, 1390 .map_delete_elem = sock_hash_delete_elem, 1391 .map_lookup_elem = sock_hash_lookup, 1392 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1393 .map_release_uref = sock_hash_release_progs, 1394 .map_check_btf = map_check_no_btf, 1395 .map_btf_name = "bpf_shtab", 1396 .map_btf_id = &sock_hash_map_btf_id, 1397 .iter_seq_info = &sock_hash_iter_seq_info, 1398 }; 1399 1400 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1401 { 1402 switch (map->map_type) { 1403 case BPF_MAP_TYPE_SOCKMAP: 1404 return &container_of(map, struct bpf_stab, map)->progs; 1405 case BPF_MAP_TYPE_SOCKHASH: 1406 return &container_of(map, struct bpf_shtab, map)->progs; 1407 default: 1408 break; 1409 } 1410 1411 return NULL; 1412 } 1413 1414 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1415 struct bpf_prog *old, u32 which) 1416 { 1417 struct sk_psock_progs *progs = sock_map_progs(map); 1418 struct bpf_prog **pprog; 1419 1420 if (!progs) 1421 return -EOPNOTSUPP; 1422 1423 switch (which) { 1424 case BPF_SK_MSG_VERDICT: 1425 pprog = &progs->msg_parser; 1426 break; 1427 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) 1428 case BPF_SK_SKB_STREAM_PARSER: 1429 pprog = &progs->stream_parser; 1430 break; 1431 #endif 1432 case BPF_SK_SKB_STREAM_VERDICT: 1433 if (progs->skb_verdict) 1434 return -EBUSY; 1435 pprog = &progs->stream_verdict; 1436 break; 1437 case BPF_SK_SKB_VERDICT: 1438 if (progs->stream_verdict) 1439 return -EBUSY; 1440 pprog = &progs->skb_verdict; 1441 break; 1442 default: 1443 return -EOPNOTSUPP; 1444 } 1445 1446 if (old) 1447 return psock_replace_prog(pprog, prog, old); 1448 1449 psock_set_prog(pprog, prog); 1450 return 0; 1451 } 1452 1453 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1454 { 1455 switch (link->map->map_type) { 1456 case BPF_MAP_TYPE_SOCKMAP: 1457 return sock_map_delete_from_link(link->map, sk, 1458 link->link_raw); 1459 case BPF_MAP_TYPE_SOCKHASH: 1460 return sock_hash_delete_from_link(link->map, sk, 1461 link->link_raw); 1462 default: 1463 break; 1464 } 1465 } 1466 1467 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1468 { 1469 struct sk_psock_link *link; 1470 1471 while ((link = sk_psock_link_pop(psock))) { 1472 sock_map_unlink(sk, link); 1473 sk_psock_free_link(link); 1474 } 1475 } 1476 1477 void sock_map_unhash(struct sock *sk) 1478 { 1479 void (*saved_unhash)(struct sock *sk); 1480 struct sk_psock *psock; 1481 1482 rcu_read_lock(); 1483 psock = sk_psock(sk); 1484 if (unlikely(!psock)) { 1485 rcu_read_unlock(); 1486 if (sk->sk_prot->unhash) 1487 sk->sk_prot->unhash(sk); 1488 return; 1489 } 1490 1491 saved_unhash = psock->saved_unhash; 1492 sock_map_remove_links(sk, psock); 1493 rcu_read_unlock(); 1494 saved_unhash(sk); 1495 } 1496 EXPORT_SYMBOL_GPL(sock_map_unhash); 1497 1498 void sock_map_close(struct sock *sk, long timeout) 1499 { 1500 void (*saved_close)(struct sock *sk, long timeout); 1501 struct sk_psock *psock; 1502 1503 lock_sock(sk); 1504 rcu_read_lock(); 1505 psock = sk_psock_get(sk); 1506 if (unlikely(!psock)) { 1507 rcu_read_unlock(); 1508 release_sock(sk); 1509 return sk->sk_prot->close(sk, timeout); 1510 } 1511 1512 saved_close = psock->saved_close; 1513 sock_map_remove_links(sk, psock); 1514 rcu_read_unlock(); 1515 sk_psock_stop(psock, true); 1516 sk_psock_put(sk, psock); 1517 release_sock(sk); 1518 saved_close(sk, timeout); 1519 } 1520 EXPORT_SYMBOL_GPL(sock_map_close); 1521 1522 static int sock_map_iter_attach_target(struct bpf_prog *prog, 1523 union bpf_iter_link_info *linfo, 1524 struct bpf_iter_aux_info *aux) 1525 { 1526 struct bpf_map *map; 1527 int err = -EINVAL; 1528 1529 if (!linfo->map.map_fd) 1530 return -EBADF; 1531 1532 map = bpf_map_get_with_uref(linfo->map.map_fd); 1533 if (IS_ERR(map)) 1534 return PTR_ERR(map); 1535 1536 if (map->map_type != BPF_MAP_TYPE_SOCKMAP && 1537 map->map_type != BPF_MAP_TYPE_SOCKHASH) 1538 goto put_map; 1539 1540 if (prog->aux->max_rdonly_access > map->key_size) { 1541 err = -EACCES; 1542 goto put_map; 1543 } 1544 1545 aux->map = map; 1546 return 0; 1547 1548 put_map: 1549 bpf_map_put_with_uref(map); 1550 return err; 1551 } 1552 1553 static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux) 1554 { 1555 bpf_map_put_with_uref(aux->map); 1556 } 1557 1558 static struct bpf_iter_reg sock_map_iter_reg = { 1559 .target = "sockmap", 1560 .attach_target = sock_map_iter_attach_target, 1561 .detach_target = sock_map_iter_detach_target, 1562 .show_fdinfo = bpf_iter_map_show_fdinfo, 1563 .fill_link_info = bpf_iter_map_fill_link_info, 1564 .ctx_arg_info_size = 2, 1565 .ctx_arg_info = { 1566 { offsetof(struct bpf_iter__sockmap, key), 1567 PTR_TO_RDONLY_BUF_OR_NULL }, 1568 { offsetof(struct bpf_iter__sockmap, sk), 1569 PTR_TO_BTF_ID_OR_NULL }, 1570 }, 1571 }; 1572 1573 static int __init bpf_sockmap_iter_init(void) 1574 { 1575 sock_map_iter_reg.ctx_arg_info[1].btf_id = 1576 btf_sock_ids[BTF_SOCK_TYPE_SOCK]; 1577 return bpf_iter_reg_target(&sock_map_iter_reg); 1578 } 1579 late_initcall(bpf_sockmap_iter_init); 1580