1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * INET An implementation of the TCP/IP protocol suite for the LINUX 4 * operating system. INET is implemented using the BSD Socket 5 * interface as the means of communication with the user level. 6 * 7 * Generic INET transport hashtables 8 * 9 * Authors: Lotsa people, from code originally in tcp 10 */ 11 12 #include <linux/module.h> 13 #include <linux/random.h> 14 #include <linux/sched.h> 15 #include <linux/slab.h> 16 #include <linux/wait.h> 17 #include <linux/vmalloc.h> 18 #include <linux/memblock.h> 19 20 #include <net/addrconf.h> 21 #include <net/inet_connection_sock.h> 22 #include <net/inet_hashtables.h> 23 #if IS_ENABLED(CONFIG_IPV6) 24 #include <net/inet6_hashtables.h> 25 #endif 26 #include <net/secure_seq.h> 27 #include <net/ip.h> 28 #include <net/tcp.h> 29 #include <net/sock_reuseport.h> 30 31 static u32 inet_ehashfn(const struct net *net, const __be32 laddr, 32 const __u16 lport, const __be32 faddr, 33 const __be16 fport) 34 { 35 static u32 inet_ehash_secret __read_mostly; 36 37 net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret)); 38 39 return __inet_ehashfn(laddr, lport, faddr, fport, 40 inet_ehash_secret + net_hash_mix(net)); 41 } 42 43 /* This function handles inet_sock, but also timewait and request sockets 44 * for IPv4/IPv6. 45 */ 46 static u32 sk_ehashfn(const struct sock *sk) 47 { 48 #if IS_ENABLED(CONFIG_IPV6) 49 if (sk->sk_family == AF_INET6 && 50 !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) 51 return inet6_ehashfn(sock_net(sk), 52 &sk->sk_v6_rcv_saddr, sk->sk_num, 53 &sk->sk_v6_daddr, sk->sk_dport); 54 #endif 55 return inet_ehashfn(sock_net(sk), 56 sk->sk_rcv_saddr, sk->sk_num, 57 sk->sk_daddr, sk->sk_dport); 58 } 59 60 /* 61 * Allocate and initialize a new local port bind bucket. 62 * The bindhash mutex for snum's hash chain must be held here. 63 */ 64 struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, 65 struct net *net, 66 struct inet_bind_hashbucket *head, 67 const unsigned short snum, 68 int l3mdev) 69 { 70 struct inet_bind_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); 71 72 if (tb) { 73 write_pnet(&tb->ib_net, net); 74 tb->l3mdev = l3mdev; 75 tb->port = snum; 76 tb->fastreuse = 0; 77 tb->fastreuseport = 0; 78 INIT_HLIST_HEAD(&tb->owners); 79 hlist_add_head(&tb->node, &head->chain); 80 } 81 return tb; 82 } 83 84 /* 85 * Caller must hold hashbucket lock for this tb with local BH disabled 86 */ 87 void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb) 88 { 89 if (hlist_empty(&tb->owners)) { 90 __hlist_del(&tb->node); 91 kmem_cache_free(cachep, tb); 92 } 93 } 94 95 bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net, 96 unsigned short port, int l3mdev) 97 { 98 return net_eq(ib_net(tb), net) && tb->port == port && 99 tb->l3mdev == l3mdev; 100 } 101 102 static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb, 103 struct net *net, 104 struct inet_bind_hashbucket *head, 105 unsigned short port, int l3mdev, 106 const struct sock *sk) 107 { 108 write_pnet(&tb->ib_net, net); 109 tb->l3mdev = l3mdev; 110 tb->port = port; 111 #if IS_ENABLED(CONFIG_IPV6) 112 tb->family = sk->sk_family; 113 if (sk->sk_family == AF_INET6) 114 tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr; 115 else 116 #endif 117 tb->rcv_saddr = sk->sk_rcv_saddr; 118 INIT_HLIST_HEAD(&tb->owners); 119 INIT_HLIST_HEAD(&tb->deathrow); 120 hlist_add_head(&tb->node, &head->chain); 121 } 122 123 struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep, 124 struct net *net, 125 struct inet_bind_hashbucket *head, 126 unsigned short port, 127 int l3mdev, 128 const struct sock *sk) 129 { 130 struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); 131 132 if (tb) 133 inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk); 134 135 return tb; 136 } 137 138 /* Caller must hold hashbucket lock for this tb with local BH disabled */ 139 void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb) 140 { 141 if (hlist_empty(&tb->owners) && hlist_empty(&tb->deathrow)) { 142 __hlist_del(&tb->node); 143 kmem_cache_free(cachep, tb); 144 } 145 } 146 147 static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2, 148 const struct sock *sk) 149 { 150 #if IS_ENABLED(CONFIG_IPV6) 151 if (sk->sk_family != tb2->family) 152 return false; 153 154 if (sk->sk_family == AF_INET6) 155 return ipv6_addr_equal(&tb2->v6_rcv_saddr, 156 &sk->sk_v6_rcv_saddr); 157 #endif 158 return tb2->rcv_saddr == sk->sk_rcv_saddr; 159 } 160 161 void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, 162 struct inet_bind2_bucket *tb2, unsigned short port) 163 { 164 inet_sk(sk)->inet_num = port; 165 sk_add_bind_node(sk, &tb->owners); 166 inet_csk(sk)->icsk_bind_hash = tb; 167 sk_add_bind2_node(sk, &tb2->owners); 168 inet_csk(sk)->icsk_bind2_hash = tb2; 169 } 170 171 /* 172 * Get rid of any references to a local port held by the given sock. 173 */ 174 static void __inet_put_port(struct sock *sk) 175 { 176 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 177 struct inet_bind_hashbucket *head, *head2; 178 struct net *net = sock_net(sk); 179 struct inet_bind_bucket *tb; 180 int bhash; 181 182 bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size); 183 head = &hashinfo->bhash[bhash]; 184 head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num); 185 186 spin_lock(&head->lock); 187 tb = inet_csk(sk)->icsk_bind_hash; 188 __sk_del_bind_node(sk); 189 inet_csk(sk)->icsk_bind_hash = NULL; 190 inet_sk(sk)->inet_num = 0; 191 inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); 192 193 spin_lock(&head2->lock); 194 if (inet_csk(sk)->icsk_bind2_hash) { 195 struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash; 196 197 __sk_del_bind2_node(sk); 198 inet_csk(sk)->icsk_bind2_hash = NULL; 199 inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); 200 } 201 spin_unlock(&head2->lock); 202 203 spin_unlock(&head->lock); 204 } 205 206 void inet_put_port(struct sock *sk) 207 { 208 local_bh_disable(); 209 __inet_put_port(sk); 210 local_bh_enable(); 211 } 212 EXPORT_SYMBOL(inet_put_port); 213 214 int __inet_inherit_port(const struct sock *sk, struct sock *child) 215 { 216 struct inet_hashinfo *table = tcp_or_dccp_get_hashinfo(sk); 217 unsigned short port = inet_sk(child)->inet_num; 218 struct inet_bind_hashbucket *head, *head2; 219 bool created_inet_bind_bucket = false; 220 struct net *net = sock_net(sk); 221 bool update_fastreuse = false; 222 struct inet_bind2_bucket *tb2; 223 struct inet_bind_bucket *tb; 224 int bhash, l3mdev; 225 226 bhash = inet_bhashfn(net, port, table->bhash_size); 227 head = &table->bhash[bhash]; 228 head2 = inet_bhashfn_portaddr(table, child, net, port); 229 230 spin_lock(&head->lock); 231 spin_lock(&head2->lock); 232 tb = inet_csk(sk)->icsk_bind_hash; 233 tb2 = inet_csk(sk)->icsk_bind2_hash; 234 if (unlikely(!tb || !tb2)) { 235 spin_unlock(&head2->lock); 236 spin_unlock(&head->lock); 237 return -ENOENT; 238 } 239 if (tb->port != port) { 240 l3mdev = inet_sk_bound_l3mdev(sk); 241 242 /* NOTE: using tproxy and redirecting skbs to a proxy 243 * on a different listener port breaks the assumption 244 * that the listener socket's icsk_bind_hash is the same 245 * as that of the child socket. We have to look up or 246 * create a new bind bucket for the child here. */ 247 inet_bind_bucket_for_each(tb, &head->chain) { 248 if (inet_bind_bucket_match(tb, net, port, l3mdev)) 249 break; 250 } 251 if (!tb) { 252 tb = inet_bind_bucket_create(table->bind_bucket_cachep, 253 net, head, port, l3mdev); 254 if (!tb) { 255 spin_unlock(&head2->lock); 256 spin_unlock(&head->lock); 257 return -ENOMEM; 258 } 259 created_inet_bind_bucket = true; 260 } 261 update_fastreuse = true; 262 263 goto bhash2_find; 264 } else if (!inet_bind2_bucket_addr_match(tb2, child)) { 265 l3mdev = inet_sk_bound_l3mdev(sk); 266 267 bhash2_find: 268 tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child); 269 if (!tb2) { 270 tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep, 271 net, head2, port, 272 l3mdev, child); 273 if (!tb2) 274 goto error; 275 } 276 } 277 if (update_fastreuse) 278 inet_csk_update_fastreuse(tb, child); 279 inet_bind_hash(child, tb, tb2, port); 280 spin_unlock(&head2->lock); 281 spin_unlock(&head->lock); 282 283 return 0; 284 285 error: 286 if (created_inet_bind_bucket) 287 inet_bind_bucket_destroy(table->bind_bucket_cachep, tb); 288 spin_unlock(&head2->lock); 289 spin_unlock(&head->lock); 290 return -ENOMEM; 291 } 292 EXPORT_SYMBOL_GPL(__inet_inherit_port); 293 294 static struct inet_listen_hashbucket * 295 inet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk) 296 { 297 u32 hash; 298 299 #if IS_ENABLED(CONFIG_IPV6) 300 if (sk->sk_family == AF_INET6) 301 hash = ipv6_portaddr_hash(sock_net(sk), 302 &sk->sk_v6_rcv_saddr, 303 inet_sk(sk)->inet_num); 304 else 305 #endif 306 hash = ipv4_portaddr_hash(sock_net(sk), 307 inet_sk(sk)->inet_rcv_saddr, 308 inet_sk(sk)->inet_num); 309 return inet_lhash2_bucket(h, hash); 310 } 311 312 static inline int compute_score(struct sock *sk, struct net *net, 313 const unsigned short hnum, const __be32 daddr, 314 const int dif, const int sdif) 315 { 316 int score = -1; 317 318 if (net_eq(sock_net(sk), net) && sk->sk_num == hnum && 319 !ipv6_only_sock(sk)) { 320 if (sk->sk_rcv_saddr != daddr) 321 return -1; 322 323 if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) 324 return -1; 325 score = sk->sk_bound_dev_if ? 2 : 1; 326 327 if (sk->sk_family == PF_INET) 328 score++; 329 if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) 330 score++; 331 } 332 return score; 333 } 334 335 static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk, 336 struct sk_buff *skb, int doff, 337 __be32 saddr, __be16 sport, 338 __be32 daddr, unsigned short hnum) 339 { 340 struct sock *reuse_sk = NULL; 341 u32 phash; 342 343 if (sk->sk_reuseport) { 344 phash = inet_ehashfn(net, daddr, hnum, saddr, sport); 345 reuse_sk = reuseport_select_sock(sk, phash, skb, doff); 346 } 347 return reuse_sk; 348 } 349 350 /* 351 * Here are some nice properties to exploit here. The BSD API 352 * does not allow a listening sock to specify the remote port nor the 353 * remote address for the connection. So always assume those are both 354 * wildcarded during the search since they can never be otherwise. 355 */ 356 357 /* called with rcu_read_lock() : No refcount taken on the socket */ 358 static struct sock *inet_lhash2_lookup(struct net *net, 359 struct inet_listen_hashbucket *ilb2, 360 struct sk_buff *skb, int doff, 361 const __be32 saddr, __be16 sport, 362 const __be32 daddr, const unsigned short hnum, 363 const int dif, const int sdif) 364 { 365 struct sock *sk, *result = NULL; 366 struct hlist_nulls_node *node; 367 int score, hiscore = 0; 368 369 sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) { 370 score = compute_score(sk, net, hnum, daddr, dif, sdif); 371 if (score > hiscore) { 372 result = lookup_reuseport(net, sk, skb, doff, 373 saddr, sport, daddr, hnum); 374 if (result) 375 return result; 376 377 result = sk; 378 hiscore = score; 379 } 380 } 381 382 return result; 383 } 384 385 static inline struct sock *inet_lookup_run_bpf(struct net *net, 386 struct inet_hashinfo *hashinfo, 387 struct sk_buff *skb, int doff, 388 __be32 saddr, __be16 sport, 389 __be32 daddr, u16 hnum, const int dif) 390 { 391 struct sock *sk, *reuse_sk; 392 bool no_reuseport; 393 394 if (hashinfo != net->ipv4.tcp_death_row.hashinfo) 395 return NULL; /* only TCP is supported */ 396 397 no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport, 398 daddr, hnum, dif, &sk); 399 if (no_reuseport || IS_ERR_OR_NULL(sk)) 400 return sk; 401 402 reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum); 403 if (reuse_sk) 404 sk = reuse_sk; 405 return sk; 406 } 407 408 struct sock *__inet_lookup_listener(struct net *net, 409 struct inet_hashinfo *hashinfo, 410 struct sk_buff *skb, int doff, 411 const __be32 saddr, __be16 sport, 412 const __be32 daddr, const unsigned short hnum, 413 const int dif, const int sdif) 414 { 415 struct inet_listen_hashbucket *ilb2; 416 struct sock *result = NULL; 417 unsigned int hash2; 418 419 /* Lookup redirect from BPF */ 420 if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { 421 result = inet_lookup_run_bpf(net, hashinfo, skb, doff, 422 saddr, sport, daddr, hnum, dif); 423 if (result) 424 goto done; 425 } 426 427 hash2 = ipv4_portaddr_hash(net, daddr, hnum); 428 ilb2 = inet_lhash2_bucket(hashinfo, hash2); 429 430 result = inet_lhash2_lookup(net, ilb2, skb, doff, 431 saddr, sport, daddr, hnum, 432 dif, sdif); 433 if (result) 434 goto done; 435 436 /* Lookup lhash2 with INADDR_ANY */ 437 hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum); 438 ilb2 = inet_lhash2_bucket(hashinfo, hash2); 439 440 result = inet_lhash2_lookup(net, ilb2, skb, doff, 441 saddr, sport, htonl(INADDR_ANY), hnum, 442 dif, sdif); 443 done: 444 if (IS_ERR(result)) 445 return NULL; 446 return result; 447 } 448 EXPORT_SYMBOL_GPL(__inet_lookup_listener); 449 450 /* All sockets share common refcount, but have different destructors */ 451 void sock_gen_put(struct sock *sk) 452 { 453 if (!refcount_dec_and_test(&sk->sk_refcnt)) 454 return; 455 456 if (sk->sk_state == TCP_TIME_WAIT) 457 inet_twsk_free(inet_twsk(sk)); 458 else if (sk->sk_state == TCP_NEW_SYN_RECV) 459 reqsk_free(inet_reqsk(sk)); 460 else 461 sk_free(sk); 462 } 463 EXPORT_SYMBOL_GPL(sock_gen_put); 464 465 void sock_edemux(struct sk_buff *skb) 466 { 467 sock_gen_put(skb->sk); 468 } 469 EXPORT_SYMBOL(sock_edemux); 470 471 struct sock *__inet_lookup_established(struct net *net, 472 struct inet_hashinfo *hashinfo, 473 const __be32 saddr, const __be16 sport, 474 const __be32 daddr, const u16 hnum, 475 const int dif, const int sdif) 476 { 477 INET_ADDR_COOKIE(acookie, saddr, daddr); 478 const __portpair ports = INET_COMBINED_PORTS(sport, hnum); 479 struct sock *sk; 480 const struct hlist_nulls_node *node; 481 /* Optimize here for direct hit, only listening connections can 482 * have wildcards anyways. 483 */ 484 unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport); 485 unsigned int slot = hash & hashinfo->ehash_mask; 486 struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; 487 488 begin: 489 sk_nulls_for_each_rcu(sk, node, &head->chain) { 490 if (sk->sk_hash != hash) 491 continue; 492 if (likely(inet_match(net, sk, acookie, ports, dif, sdif))) { 493 if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) 494 goto out; 495 if (unlikely(!inet_match(net, sk, acookie, 496 ports, dif, sdif))) { 497 sock_gen_put(sk); 498 goto begin; 499 } 500 goto found; 501 } 502 } 503 /* 504 * if the nulls value we got at the end of this lookup is 505 * not the expected one, we must restart lookup. 506 * We probably met an item that was moved to another chain. 507 */ 508 if (get_nulls_value(node) != slot) 509 goto begin; 510 out: 511 sk = NULL; 512 found: 513 return sk; 514 } 515 EXPORT_SYMBOL_GPL(__inet_lookup_established); 516 517 /* called with local bh disabled */ 518 static int __inet_check_established(struct inet_timewait_death_row *death_row, 519 struct sock *sk, __u16 lport, 520 struct inet_timewait_sock **twp) 521 { 522 struct inet_hashinfo *hinfo = death_row->hashinfo; 523 struct inet_sock *inet = inet_sk(sk); 524 __be32 daddr = inet->inet_rcv_saddr; 525 __be32 saddr = inet->inet_daddr; 526 int dif = sk->sk_bound_dev_if; 527 struct net *net = sock_net(sk); 528 int sdif = l3mdev_master_ifindex_by_index(net, dif); 529 INET_ADDR_COOKIE(acookie, saddr, daddr); 530 const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); 531 unsigned int hash = inet_ehashfn(net, daddr, lport, 532 saddr, inet->inet_dport); 533 struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); 534 spinlock_t *lock = inet_ehash_lockp(hinfo, hash); 535 struct sock *sk2; 536 const struct hlist_nulls_node *node; 537 struct inet_timewait_sock *tw = NULL; 538 539 spin_lock(lock); 540 541 sk_nulls_for_each(sk2, node, &head->chain) { 542 if (sk2->sk_hash != hash) 543 continue; 544 545 if (likely(inet_match(net, sk2, acookie, ports, dif, sdif))) { 546 if (sk2->sk_state == TCP_TIME_WAIT) { 547 tw = inet_twsk(sk2); 548 if (twsk_unique(sk, sk2, twp)) 549 break; 550 } 551 goto not_unique; 552 } 553 } 554 555 /* Must record num and sport now. Otherwise we will see 556 * in hash table socket with a funny identity. 557 */ 558 inet->inet_num = lport; 559 inet->inet_sport = htons(lport); 560 sk->sk_hash = hash; 561 WARN_ON(!sk_unhashed(sk)); 562 __sk_nulls_add_node_rcu(sk, &head->chain); 563 if (tw) { 564 sk_nulls_del_node_init_rcu((struct sock *)tw); 565 __NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED); 566 } 567 spin_unlock(lock); 568 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); 569 570 if (twp) { 571 *twp = tw; 572 } else if (tw) { 573 /* Silly. Should hash-dance instead... */ 574 inet_twsk_deschedule_put(tw); 575 } 576 return 0; 577 578 not_unique: 579 spin_unlock(lock); 580 return -EADDRNOTAVAIL; 581 } 582 583 static u64 inet_sk_port_offset(const struct sock *sk) 584 { 585 const struct inet_sock *inet = inet_sk(sk); 586 587 return secure_ipv4_port_ephemeral(inet->inet_rcv_saddr, 588 inet->inet_daddr, 589 inet->inet_dport); 590 } 591 592 /* Searches for an exsiting socket in the ehash bucket list. 593 * Returns true if found, false otherwise. 594 */ 595 static bool inet_ehash_lookup_by_sk(struct sock *sk, 596 struct hlist_nulls_head *list) 597 { 598 const __portpair ports = INET_COMBINED_PORTS(sk->sk_dport, sk->sk_num); 599 const int sdif = sk->sk_bound_dev_if; 600 const int dif = sk->sk_bound_dev_if; 601 const struct hlist_nulls_node *node; 602 struct net *net = sock_net(sk); 603 struct sock *esk; 604 605 INET_ADDR_COOKIE(acookie, sk->sk_daddr, sk->sk_rcv_saddr); 606 607 sk_nulls_for_each_rcu(esk, node, list) { 608 if (esk->sk_hash != sk->sk_hash) 609 continue; 610 if (sk->sk_family == AF_INET) { 611 if (unlikely(inet_match(net, esk, acookie, 612 ports, dif, sdif))) { 613 return true; 614 } 615 } 616 #if IS_ENABLED(CONFIG_IPV6) 617 else if (sk->sk_family == AF_INET6) { 618 if (unlikely(inet6_match(net, esk, 619 &sk->sk_v6_daddr, 620 &sk->sk_v6_rcv_saddr, 621 ports, dif, sdif))) { 622 return true; 623 } 624 } 625 #endif 626 } 627 return false; 628 } 629 630 /* Insert a socket into ehash, and eventually remove another one 631 * (The another one can be a SYN_RECV or TIMEWAIT) 632 * If an existing socket already exists, socket sk is not inserted, 633 * and sets found_dup_sk parameter to true. 634 */ 635 bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) 636 { 637 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 638 struct inet_ehash_bucket *head; 639 struct hlist_nulls_head *list; 640 spinlock_t *lock; 641 bool ret = true; 642 643 WARN_ON_ONCE(!sk_unhashed(sk)); 644 645 sk->sk_hash = sk_ehashfn(sk); 646 head = inet_ehash_bucket(hashinfo, sk->sk_hash); 647 list = &head->chain; 648 lock = inet_ehash_lockp(hashinfo, sk->sk_hash); 649 650 spin_lock(lock); 651 if (osk) { 652 WARN_ON_ONCE(sk->sk_hash != osk->sk_hash); 653 ret = sk_hashed(osk); 654 if (ret) { 655 /* Before deleting the node, we insert a new one to make 656 * sure that the look-up-sk process would not miss either 657 * of them and that at least one node would exist in ehash 658 * table all the time. Otherwise there's a tiny chance 659 * that lookup process could find nothing in ehash table. 660 */ 661 __sk_nulls_add_node_tail_rcu(sk, list); 662 sk_nulls_del_node_init_rcu(osk); 663 } 664 goto unlock; 665 } 666 if (found_dup_sk) { 667 *found_dup_sk = inet_ehash_lookup_by_sk(sk, list); 668 if (*found_dup_sk) 669 ret = false; 670 } 671 672 if (ret) 673 __sk_nulls_add_node_rcu(sk, list); 674 675 unlock: 676 spin_unlock(lock); 677 678 return ret; 679 } 680 681 bool inet_ehash_nolisten(struct sock *sk, struct sock *osk, bool *found_dup_sk) 682 { 683 bool ok = inet_ehash_insert(sk, osk, found_dup_sk); 684 685 if (ok) { 686 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); 687 } else { 688 this_cpu_inc(*sk->sk_prot->orphan_count); 689 inet_sk_set_state(sk, TCP_CLOSE); 690 sock_set_flag(sk, SOCK_DEAD); 691 inet_csk_destroy_sock(sk); 692 } 693 return ok; 694 } 695 EXPORT_SYMBOL_GPL(inet_ehash_nolisten); 696 697 static int inet_reuseport_add_sock(struct sock *sk, 698 struct inet_listen_hashbucket *ilb) 699 { 700 struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash; 701 const struct hlist_nulls_node *node; 702 struct sock *sk2; 703 kuid_t uid = sock_i_uid(sk); 704 705 sk_nulls_for_each_rcu(sk2, node, &ilb->nulls_head) { 706 if (sk2 != sk && 707 sk2->sk_family == sk->sk_family && 708 ipv6_only_sock(sk2) == ipv6_only_sock(sk) && 709 sk2->sk_bound_dev_if == sk->sk_bound_dev_if && 710 inet_csk(sk2)->icsk_bind_hash == tb && 711 sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && 712 inet_rcv_saddr_equal(sk, sk2, false)) 713 return reuseport_add_sock(sk, sk2, 714 inet_rcv_saddr_any(sk)); 715 } 716 717 return reuseport_alloc(sk, inet_rcv_saddr_any(sk)); 718 } 719 720 int __inet_hash(struct sock *sk, struct sock *osk) 721 { 722 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 723 struct inet_listen_hashbucket *ilb2; 724 int err = 0; 725 726 if (sk->sk_state != TCP_LISTEN) { 727 local_bh_disable(); 728 inet_ehash_nolisten(sk, osk, NULL); 729 local_bh_enable(); 730 return 0; 731 } 732 WARN_ON(!sk_unhashed(sk)); 733 ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); 734 735 spin_lock(&ilb2->lock); 736 if (sk->sk_reuseport) { 737 err = inet_reuseport_add_sock(sk, ilb2); 738 if (err) 739 goto unlock; 740 } 741 if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && 742 sk->sk_family == AF_INET6) 743 __sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head); 744 else 745 __sk_nulls_add_node_rcu(sk, &ilb2->nulls_head); 746 sock_set_flag(sk, SOCK_RCU_FREE); 747 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); 748 unlock: 749 spin_unlock(&ilb2->lock); 750 751 return err; 752 } 753 EXPORT_SYMBOL(__inet_hash); 754 755 int inet_hash(struct sock *sk) 756 { 757 int err = 0; 758 759 if (sk->sk_state != TCP_CLOSE) 760 err = __inet_hash(sk, NULL); 761 762 return err; 763 } 764 EXPORT_SYMBOL_GPL(inet_hash); 765 766 void inet_unhash(struct sock *sk) 767 { 768 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 769 770 if (sk_unhashed(sk)) 771 return; 772 773 if (sk->sk_state == TCP_LISTEN) { 774 struct inet_listen_hashbucket *ilb2; 775 776 ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); 777 /* Don't disable bottom halves while acquiring the lock to 778 * avoid circular locking dependency on PREEMPT_RT. 779 */ 780 spin_lock(&ilb2->lock); 781 if (sk_unhashed(sk)) { 782 spin_unlock(&ilb2->lock); 783 return; 784 } 785 786 if (rcu_access_pointer(sk->sk_reuseport_cb)) 787 reuseport_stop_listen_sock(sk); 788 789 __sk_nulls_del_node_init_rcu(sk); 790 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); 791 spin_unlock(&ilb2->lock); 792 } else { 793 spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); 794 795 spin_lock_bh(lock); 796 if (sk_unhashed(sk)) { 797 spin_unlock_bh(lock); 798 return; 799 } 800 __sk_nulls_del_node_init_rcu(sk); 801 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); 802 spin_unlock_bh(lock); 803 } 804 } 805 EXPORT_SYMBOL_GPL(inet_unhash); 806 807 static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb, 808 const struct net *net, unsigned short port, 809 int l3mdev, const struct sock *sk) 810 { 811 #if IS_ENABLED(CONFIG_IPV6) 812 if (sk->sk_family != tb->family) 813 return false; 814 815 if (sk->sk_family == AF_INET6) 816 return net_eq(ib2_net(tb), net) && tb->port == port && 817 tb->l3mdev == l3mdev && 818 ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr); 819 else 820 #endif 821 return net_eq(ib2_net(tb), net) && tb->port == port && 822 tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr; 823 } 824 825 bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net, 826 unsigned short port, int l3mdev, const struct sock *sk) 827 { 828 #if IS_ENABLED(CONFIG_IPV6) 829 if (sk->sk_family != tb->family) { 830 if (sk->sk_family == AF_INET) 831 return net_eq(ib2_net(tb), net) && tb->port == port && 832 tb->l3mdev == l3mdev && 833 ipv6_addr_any(&tb->v6_rcv_saddr); 834 835 return false; 836 } 837 838 if (sk->sk_family == AF_INET6) 839 return net_eq(ib2_net(tb), net) && tb->port == port && 840 tb->l3mdev == l3mdev && 841 ipv6_addr_any(&tb->v6_rcv_saddr); 842 else 843 #endif 844 return net_eq(ib2_net(tb), net) && tb->port == port && 845 tb->l3mdev == l3mdev && tb->rcv_saddr == 0; 846 } 847 848 /* The socket's bhash2 hashbucket spinlock must be held when this is called */ 849 struct inet_bind2_bucket * 850 inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net, 851 unsigned short port, int l3mdev, const struct sock *sk) 852 { 853 struct inet_bind2_bucket *bhash2 = NULL; 854 855 inet_bind_bucket_for_each(bhash2, &head->chain) 856 if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk)) 857 break; 858 859 return bhash2; 860 } 861 862 struct inet_bind_hashbucket * 863 inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port) 864 { 865 struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); 866 u32 hash; 867 868 #if IS_ENABLED(CONFIG_IPV6) 869 if (sk->sk_family == AF_INET6) 870 hash = ipv6_portaddr_hash(net, &in6addr_any, port); 871 else 872 #endif 873 hash = ipv4_portaddr_hash(net, 0, port); 874 875 return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)]; 876 } 877 878 static void inet_update_saddr(struct sock *sk, void *saddr, int family) 879 { 880 if (family == AF_INET) { 881 inet_sk(sk)->inet_saddr = *(__be32 *)saddr; 882 sk_rcv_saddr_set(sk, inet_sk(sk)->inet_saddr); 883 } 884 #if IS_ENABLED(CONFIG_IPV6) 885 else { 886 sk->sk_v6_rcv_saddr = *(struct in6_addr *)saddr; 887 } 888 #endif 889 } 890 891 static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, bool reset) 892 { 893 struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); 894 struct inet_bind_hashbucket *head, *head2; 895 struct inet_bind2_bucket *tb2, *new_tb2; 896 int l3mdev = inet_sk_bound_l3mdev(sk); 897 int port = inet_sk(sk)->inet_num; 898 struct net *net = sock_net(sk); 899 int bhash; 900 901 if (!inet_csk(sk)->icsk_bind2_hash) { 902 /* Not bind()ed before. */ 903 if (reset) 904 inet_reset_saddr(sk); 905 else 906 inet_update_saddr(sk, saddr, family); 907 908 return 0; 909 } 910 911 /* Allocate a bind2 bucket ahead of time to avoid permanently putting 912 * the bhash2 table in an inconsistent state if a new tb2 bucket 913 * allocation fails. 914 */ 915 new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC); 916 if (!new_tb2) { 917 if (reset) { 918 /* The (INADDR_ANY, port) bucket might have already 919 * been freed, then we cannot fixup icsk_bind2_hash, 920 * so we give up and unlink sk from bhash/bhash2 not 921 * to leave inconsistency in bhash2. 922 */ 923 inet_put_port(sk); 924 inet_reset_saddr(sk); 925 } 926 927 return -ENOMEM; 928 } 929 930 bhash = inet_bhashfn(net, port, hinfo->bhash_size); 931 head = &hinfo->bhash[bhash]; 932 head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); 933 934 /* If we change saddr locklessly, another thread 935 * iterating over bhash might see corrupted address. 936 */ 937 spin_lock_bh(&head->lock); 938 939 spin_lock(&head2->lock); 940 __sk_del_bind2_node(sk); 941 inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, inet_csk(sk)->icsk_bind2_hash); 942 spin_unlock(&head2->lock); 943 944 if (reset) 945 inet_reset_saddr(sk); 946 else 947 inet_update_saddr(sk, saddr, family); 948 949 head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); 950 951 spin_lock(&head2->lock); 952 tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); 953 if (!tb2) { 954 tb2 = new_tb2; 955 inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk); 956 } 957 sk_add_bind2_node(sk, &tb2->owners); 958 inet_csk(sk)->icsk_bind2_hash = tb2; 959 spin_unlock(&head2->lock); 960 961 spin_unlock_bh(&head->lock); 962 963 if (tb2 != new_tb2) 964 kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2); 965 966 return 0; 967 } 968 969 int inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family) 970 { 971 return __inet_bhash2_update_saddr(sk, saddr, family, false); 972 } 973 EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr); 974 975 void inet_bhash2_reset_saddr(struct sock *sk) 976 { 977 if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) 978 __inet_bhash2_update_saddr(sk, NULL, 0, true); 979 } 980 EXPORT_SYMBOL_GPL(inet_bhash2_reset_saddr); 981 982 /* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm 983 * Note that we use 32bit integers (vs RFC 'short integers') 984 * because 2^16 is not a multiple of num_ephemeral and this 985 * property might be used by clever attacker. 986 * 987 * RFC claims using TABLE_LENGTH=10 buckets gives an improvement, though 988 * attacks were since demonstrated, thus we use 65536 by default instead 989 * to really give more isolation and privacy, at the expense of 256kB 990 * of kernel memory. 991 */ 992 #define INET_TABLE_PERTURB_SIZE (1 << CONFIG_INET_TABLE_PERTURB_ORDER) 993 static u32 *table_perturb; 994 995 int __inet_hash_connect(struct inet_timewait_death_row *death_row, 996 struct sock *sk, u64 port_offset, 997 int (*check_established)(struct inet_timewait_death_row *, 998 struct sock *, __u16, struct inet_timewait_sock **)) 999 { 1000 struct inet_hashinfo *hinfo = death_row->hashinfo; 1001 struct inet_bind_hashbucket *head, *head2; 1002 struct inet_timewait_sock *tw = NULL; 1003 int port = inet_sk(sk)->inet_num; 1004 struct net *net = sock_net(sk); 1005 struct inet_bind2_bucket *tb2; 1006 struct inet_bind_bucket *tb; 1007 bool tb_created = false; 1008 u32 remaining, offset; 1009 int ret, i, low, high; 1010 int l3mdev; 1011 u32 index; 1012 1013 if (port) { 1014 local_bh_disable(); 1015 ret = check_established(death_row, sk, port, NULL); 1016 local_bh_enable(); 1017 return ret; 1018 } 1019 1020 l3mdev = inet_sk_bound_l3mdev(sk); 1021 1022 inet_sk_get_local_port_range(sk, &low, &high); 1023 high++; /* [32768, 60999] -> [32768, 61000[ */ 1024 remaining = high - low; 1025 if (likely(remaining > 1)) 1026 remaining &= ~1U; 1027 1028 get_random_sleepable_once(table_perturb, 1029 INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb)); 1030 index = port_offset & (INET_TABLE_PERTURB_SIZE - 1); 1031 1032 offset = READ_ONCE(table_perturb[index]) + (port_offset >> 32); 1033 offset %= remaining; 1034 1035 /* In first pass we try ports of @low parity. 1036 * inet_csk_get_port() does the opposite choice. 1037 */ 1038 offset &= ~1U; 1039 other_parity_scan: 1040 port = low + offset; 1041 for (i = 0; i < remaining; i += 2, port += 2) { 1042 if (unlikely(port >= high)) 1043 port -= remaining; 1044 if (inet_is_local_reserved_port(net, port)) 1045 continue; 1046 head = &hinfo->bhash[inet_bhashfn(net, port, 1047 hinfo->bhash_size)]; 1048 spin_lock_bh(&head->lock); 1049 1050 /* Does not bother with rcv_saddr checks, because 1051 * the established check is already unique enough. 1052 */ 1053 inet_bind_bucket_for_each(tb, &head->chain) { 1054 if (inet_bind_bucket_match(tb, net, port, l3mdev)) { 1055 if (tb->fastreuse >= 0 || 1056 tb->fastreuseport >= 0) 1057 goto next_port; 1058 WARN_ON(hlist_empty(&tb->owners)); 1059 if (!check_established(death_row, sk, 1060 port, &tw)) 1061 goto ok; 1062 goto next_port; 1063 } 1064 } 1065 1066 tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, 1067 net, head, port, l3mdev); 1068 if (!tb) { 1069 spin_unlock_bh(&head->lock); 1070 return -ENOMEM; 1071 } 1072 tb_created = true; 1073 tb->fastreuse = -1; 1074 tb->fastreuseport = -1; 1075 goto ok; 1076 next_port: 1077 spin_unlock_bh(&head->lock); 1078 cond_resched(); 1079 } 1080 1081 offset++; 1082 if ((offset & 1) && remaining > 1) 1083 goto other_parity_scan; 1084 1085 return -EADDRNOTAVAIL; 1086 1087 ok: 1088 /* Find the corresponding tb2 bucket since we need to 1089 * add the socket to the bhash2 table as well 1090 */ 1091 head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); 1092 spin_lock(&head2->lock); 1093 1094 tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); 1095 if (!tb2) { 1096 tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net, 1097 head2, port, l3mdev, sk); 1098 if (!tb2) 1099 goto error; 1100 } 1101 1102 /* Here we want to add a little bit of randomness to the next source 1103 * port that will be chosen. We use a max() with a random here so that 1104 * on low contention the randomness is maximal and on high contention 1105 * it may be inexistent. 1106 */ 1107 i = max_t(int, i, get_random_u32_below(8) * 2); 1108 WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2); 1109 1110 /* Head lock still held and bh's disabled */ 1111 inet_bind_hash(sk, tb, tb2, port); 1112 1113 if (sk_unhashed(sk)) { 1114 inet_sk(sk)->inet_sport = htons(port); 1115 inet_ehash_nolisten(sk, (struct sock *)tw, NULL); 1116 } 1117 if (tw) 1118 inet_twsk_bind_unhash(tw, hinfo); 1119 1120 spin_unlock(&head2->lock); 1121 spin_unlock(&head->lock); 1122 1123 if (tw) 1124 inet_twsk_deschedule_put(tw); 1125 local_bh_enable(); 1126 return 0; 1127 1128 error: 1129 spin_unlock(&head2->lock); 1130 if (tb_created) 1131 inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); 1132 spin_unlock_bh(&head->lock); 1133 return -ENOMEM; 1134 } 1135 1136 /* 1137 * Bind a port for a connect operation and hash it. 1138 */ 1139 int inet_hash_connect(struct inet_timewait_death_row *death_row, 1140 struct sock *sk) 1141 { 1142 u64 port_offset = 0; 1143 1144 if (!inet_sk(sk)->inet_num) 1145 port_offset = inet_sk_port_offset(sk); 1146 return __inet_hash_connect(death_row, sk, port_offset, 1147 __inet_check_established); 1148 } 1149 EXPORT_SYMBOL_GPL(inet_hash_connect); 1150 1151 static void init_hashinfo_lhash2(struct inet_hashinfo *h) 1152 { 1153 int i; 1154 1155 for (i = 0; i <= h->lhash2_mask; i++) { 1156 spin_lock_init(&h->lhash2[i].lock); 1157 INIT_HLIST_NULLS_HEAD(&h->lhash2[i].nulls_head, 1158 i + LISTENING_NULLS_BASE); 1159 } 1160 } 1161 1162 void __init inet_hashinfo2_init(struct inet_hashinfo *h, const char *name, 1163 unsigned long numentries, int scale, 1164 unsigned long low_limit, 1165 unsigned long high_limit) 1166 { 1167 h->lhash2 = alloc_large_system_hash(name, 1168 sizeof(*h->lhash2), 1169 numentries, 1170 scale, 1171 0, 1172 NULL, 1173 &h->lhash2_mask, 1174 low_limit, 1175 high_limit); 1176 init_hashinfo_lhash2(h); 1177 1178 /* this one is used for source ports of outgoing connections */ 1179 table_perturb = alloc_large_system_hash("Table-perturb", 1180 sizeof(*table_perturb), 1181 INET_TABLE_PERTURB_SIZE, 1182 0, 0, NULL, NULL, 1183 INET_TABLE_PERTURB_SIZE, 1184 INET_TABLE_PERTURB_SIZE); 1185 } 1186 1187 int inet_hashinfo2_init_mod(struct inet_hashinfo *h) 1188 { 1189 h->lhash2 = kmalloc_array(INET_LHTABLE_SIZE, sizeof(*h->lhash2), GFP_KERNEL); 1190 if (!h->lhash2) 1191 return -ENOMEM; 1192 1193 h->lhash2_mask = INET_LHTABLE_SIZE - 1; 1194 /* INET_LHTABLE_SIZE must be a power of 2 */ 1195 BUG_ON(INET_LHTABLE_SIZE & h->lhash2_mask); 1196 1197 init_hashinfo_lhash2(h); 1198 return 0; 1199 } 1200 EXPORT_SYMBOL_GPL(inet_hashinfo2_init_mod); 1201 1202 int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo) 1203 { 1204 unsigned int locksz = sizeof(spinlock_t); 1205 unsigned int i, nblocks = 1; 1206 1207 if (locksz != 0) { 1208 /* allocate 2 cache lines or at least one spinlock per cpu */ 1209 nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U); 1210 nblocks = roundup_pow_of_two(nblocks * num_possible_cpus()); 1211 1212 /* no more locks than number of hash buckets */ 1213 nblocks = min(nblocks, hashinfo->ehash_mask + 1); 1214 1215 hashinfo->ehash_locks = kvmalloc_array(nblocks, locksz, GFP_KERNEL); 1216 if (!hashinfo->ehash_locks) 1217 return -ENOMEM; 1218 1219 for (i = 0; i < nblocks; i++) 1220 spin_lock_init(&hashinfo->ehash_locks[i]); 1221 } 1222 hashinfo->ehash_locks_mask = nblocks - 1; 1223 return 0; 1224 } 1225 EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc); 1226 1227 struct inet_hashinfo *inet_pernet_hashinfo_alloc(struct inet_hashinfo *hashinfo, 1228 unsigned int ehash_entries) 1229 { 1230 struct inet_hashinfo *new_hashinfo; 1231 int i; 1232 1233 new_hashinfo = kmemdup(hashinfo, sizeof(*hashinfo), GFP_KERNEL); 1234 if (!new_hashinfo) 1235 goto err; 1236 1237 new_hashinfo->ehash = vmalloc_huge(ehash_entries * sizeof(struct inet_ehash_bucket), 1238 GFP_KERNEL_ACCOUNT); 1239 if (!new_hashinfo->ehash) 1240 goto free_hashinfo; 1241 1242 new_hashinfo->ehash_mask = ehash_entries - 1; 1243 1244 if (inet_ehash_locks_alloc(new_hashinfo)) 1245 goto free_ehash; 1246 1247 for (i = 0; i < ehash_entries; i++) 1248 INIT_HLIST_NULLS_HEAD(&new_hashinfo->ehash[i].chain, i); 1249 1250 new_hashinfo->pernet = true; 1251 1252 return new_hashinfo; 1253 1254 free_ehash: 1255 vfree(new_hashinfo->ehash); 1256 free_hashinfo: 1257 kfree(new_hashinfo); 1258 err: 1259 return NULL; 1260 } 1261 EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_alloc); 1262 1263 void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo) 1264 { 1265 if (!hashinfo->pernet) 1266 return; 1267 1268 inet_ehash_locks_free(hashinfo); 1269 vfree(hashinfo->ehash); 1270 kfree(hashinfo); 1271 } 1272 EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_free); 1273