1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2020, Red Hat, Inc. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/inet.h> 10 #include <linux/kernel.h> 11 #include <net/tcp.h> 12 #include <net/netns/generic.h> 13 #include <net/mptcp.h> 14 #include <net/genetlink.h> 15 #include <uapi/linux/mptcp.h> 16 17 #include "protocol.h" 18 #include "mib.h" 19 20 /* forward declaration */ 21 static struct genl_family mptcp_genl_family; 22 23 static int pm_nl_pernet_id; 24 25 struct mptcp_pm_addr_entry { 26 struct list_head list; 27 struct mptcp_addr_info addr; 28 struct rcu_head rcu; 29 struct socket *lsk; 30 }; 31 32 struct mptcp_pm_add_entry { 33 struct list_head list; 34 struct mptcp_addr_info addr; 35 struct timer_list add_timer; 36 struct mptcp_sock *sock; 37 u8 retrans_times; 38 }; 39 40 #define MAX_ADDR_ID 255 41 #define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG) 42 43 struct pm_nl_pernet { 44 /* protects pernet updates */ 45 spinlock_t lock; 46 struct list_head local_addr_list; 47 unsigned int addrs; 48 unsigned int add_addr_signal_max; 49 unsigned int add_addr_accept_max; 50 unsigned int local_addr_max; 51 unsigned int subflows_max; 52 unsigned int next_id; 53 unsigned long id_bitmap[BITMAP_SZ]; 54 }; 55 56 #define MPTCP_PM_ADDR_MAX 8 57 #define ADD_ADDR_RETRANS_MAX 3 58 59 static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk); 60 61 static bool addresses_equal(const struct mptcp_addr_info *a, 62 struct mptcp_addr_info *b, bool use_port) 63 { 64 bool addr_equals = false; 65 66 if (a->family == b->family) { 67 if (a->family == AF_INET) 68 addr_equals = a->addr.s_addr == b->addr.s_addr; 69 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 70 else 71 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6); 72 } else if (a->family == AF_INET) { 73 if (ipv6_addr_v4mapped(&b->addr6)) 74 addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3]; 75 } else if (b->family == AF_INET) { 76 if (ipv6_addr_v4mapped(&a->addr6)) 77 addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr; 78 #endif 79 } 80 81 if (!addr_equals) 82 return false; 83 if (!use_port) 84 return true; 85 86 return a->port == b->port; 87 } 88 89 static bool address_zero(const struct mptcp_addr_info *addr) 90 { 91 struct mptcp_addr_info zero; 92 93 memset(&zero, 0, sizeof(zero)); 94 zero.family = addr->family; 95 96 return addresses_equal(addr, &zero, true); 97 } 98 99 static void local_address(const struct sock_common *skc, 100 struct mptcp_addr_info *addr) 101 { 102 addr->family = skc->skc_family; 103 addr->port = htons(skc->skc_num); 104 if (addr->family == AF_INET) 105 addr->addr.s_addr = skc->skc_rcv_saddr; 106 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 107 else if (addr->family == AF_INET6) 108 addr->addr6 = skc->skc_v6_rcv_saddr; 109 #endif 110 } 111 112 static void remote_address(const struct sock_common *skc, 113 struct mptcp_addr_info *addr) 114 { 115 addr->family = skc->skc_family; 116 addr->port = skc->skc_dport; 117 if (addr->family == AF_INET) 118 addr->addr.s_addr = skc->skc_daddr; 119 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 120 else if (addr->family == AF_INET6) 121 addr->addr6 = skc->skc_v6_daddr; 122 #endif 123 } 124 125 static bool lookup_subflow_by_saddr(const struct list_head *list, 126 struct mptcp_addr_info *saddr) 127 { 128 struct mptcp_subflow_context *subflow; 129 struct mptcp_addr_info cur; 130 struct sock_common *skc; 131 132 list_for_each_entry(subflow, list, node) { 133 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); 134 135 local_address(skc, &cur); 136 if (addresses_equal(&cur, saddr, saddr->port)) 137 return true; 138 } 139 140 return false; 141 } 142 143 static struct mptcp_pm_addr_entry * 144 select_local_address(const struct pm_nl_pernet *pernet, 145 struct mptcp_sock *msk) 146 { 147 struct mptcp_pm_addr_entry *entry, *ret = NULL; 148 struct sock *sk = (struct sock *)msk; 149 150 msk_owned_by_me(msk); 151 152 rcu_read_lock(); 153 __mptcp_flush_join_list(msk); 154 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 155 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) 156 continue; 157 158 if (entry->addr.family != sk->sk_family) { 159 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 160 if ((entry->addr.family == AF_INET && 161 !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) || 162 (sk->sk_family == AF_INET && 163 !ipv6_addr_v4mapped(&entry->addr.addr6))) 164 #endif 165 continue; 166 } 167 168 /* avoid any address already in use by subflows and 169 * pending join 170 */ 171 if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) { 172 ret = entry; 173 break; 174 } 175 } 176 rcu_read_unlock(); 177 return ret; 178 } 179 180 static struct mptcp_pm_addr_entry * 181 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos) 182 { 183 struct mptcp_pm_addr_entry *entry, *ret = NULL; 184 int i = 0; 185 186 rcu_read_lock(); 187 /* do not keep any additional per socket state, just signal 188 * the address list in order. 189 * Note: removal from the local address list during the msk life-cycle 190 * can lead to additional addresses not being announced. 191 */ 192 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 193 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) 194 continue; 195 if (i++ == pos) { 196 ret = entry; 197 break; 198 } 199 } 200 rcu_read_unlock(); 201 return ret; 202 } 203 204 unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk) 205 { 206 struct pm_nl_pernet *pernet; 207 208 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 209 return READ_ONCE(pernet->add_addr_signal_max); 210 } 211 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max); 212 213 unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk) 214 { 215 struct pm_nl_pernet *pernet; 216 217 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 218 return READ_ONCE(pernet->add_addr_accept_max); 219 } 220 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max); 221 222 unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk) 223 { 224 struct pm_nl_pernet *pernet; 225 226 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 227 return READ_ONCE(pernet->subflows_max); 228 } 229 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max); 230 231 static unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk) 232 { 233 struct pm_nl_pernet *pernet; 234 235 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 236 return READ_ONCE(pernet->local_addr_max); 237 } 238 239 static void check_work_pending(struct mptcp_sock *msk) 240 { 241 if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) && 242 (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) || 243 msk->pm.subflows == mptcp_pm_get_subflows_max(msk))) 244 WRITE_ONCE(msk->pm.work_pending, false); 245 } 246 247 static struct mptcp_pm_add_entry * 248 lookup_anno_list_by_saddr(struct mptcp_sock *msk, 249 struct mptcp_addr_info *addr) 250 { 251 struct mptcp_pm_add_entry *entry; 252 253 lockdep_assert_held(&msk->pm.lock); 254 255 list_for_each_entry(entry, &msk->pm.anno_list, list) { 256 if (addresses_equal(&entry->addr, addr, true)) 257 return entry; 258 } 259 260 return NULL; 261 } 262 263 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk) 264 { 265 struct mptcp_pm_add_entry *entry; 266 struct mptcp_addr_info saddr; 267 bool ret = false; 268 269 local_address((struct sock_common *)sk, &saddr); 270 271 spin_lock_bh(&msk->pm.lock); 272 list_for_each_entry(entry, &msk->pm.anno_list, list) { 273 if (addresses_equal(&entry->addr, &saddr, true)) { 274 ret = true; 275 goto out; 276 } 277 } 278 279 out: 280 spin_unlock_bh(&msk->pm.lock); 281 return ret; 282 } 283 284 static void mptcp_pm_add_timer(struct timer_list *timer) 285 { 286 struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer); 287 struct mptcp_sock *msk = entry->sock; 288 struct sock *sk = (struct sock *)msk; 289 290 pr_debug("msk=%p", msk); 291 292 if (!msk) 293 return; 294 295 if (inet_sk_state_load(sk) == TCP_CLOSE) 296 return; 297 298 if (!entry->addr.id) 299 return; 300 301 if (mptcp_pm_should_add_signal(msk)) { 302 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8); 303 goto out; 304 } 305 306 spin_lock_bh(&msk->pm.lock); 307 308 if (!mptcp_pm_should_add_signal(msk)) { 309 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id); 310 mptcp_pm_announce_addr(msk, &entry->addr, false, entry->addr.port); 311 mptcp_pm_add_addr_send_ack(msk); 312 entry->retrans_times++; 313 } 314 315 if (entry->retrans_times < ADD_ADDR_RETRANS_MAX) 316 sk_reset_timer(sk, timer, 317 jiffies + mptcp_get_add_addr_timeout(sock_net(sk))); 318 319 spin_unlock_bh(&msk->pm.lock); 320 321 out: 322 __sock_put(sk); 323 } 324 325 struct mptcp_pm_add_entry * 326 mptcp_pm_del_add_timer(struct mptcp_sock *msk, 327 struct mptcp_addr_info *addr) 328 { 329 struct mptcp_pm_add_entry *entry; 330 struct sock *sk = (struct sock *)msk; 331 332 spin_lock_bh(&msk->pm.lock); 333 entry = lookup_anno_list_by_saddr(msk, addr); 334 if (entry) 335 entry->retrans_times = ADD_ADDR_RETRANS_MAX; 336 spin_unlock_bh(&msk->pm.lock); 337 338 if (entry) 339 sk_stop_timer_sync(sk, &entry->add_timer); 340 341 return entry; 342 } 343 344 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, 345 struct mptcp_pm_addr_entry *entry) 346 { 347 struct mptcp_pm_add_entry *add_entry = NULL; 348 struct sock *sk = (struct sock *)msk; 349 struct net *net = sock_net(sk); 350 351 lockdep_assert_held(&msk->pm.lock); 352 353 if (lookup_anno_list_by_saddr(msk, &entry->addr)) 354 return false; 355 356 add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC); 357 if (!add_entry) 358 return false; 359 360 list_add(&add_entry->list, &msk->pm.anno_list); 361 362 add_entry->addr = entry->addr; 363 add_entry->sock = msk; 364 add_entry->retrans_times = 0; 365 366 timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0); 367 sk_reset_timer(sk, &add_entry->add_timer, 368 jiffies + mptcp_get_add_addr_timeout(net)); 369 370 return true; 371 } 372 373 void mptcp_pm_free_anno_list(struct mptcp_sock *msk) 374 { 375 struct mptcp_pm_add_entry *entry, *tmp; 376 struct sock *sk = (struct sock *)msk; 377 LIST_HEAD(free_list); 378 379 pr_debug("msk=%p", msk); 380 381 spin_lock_bh(&msk->pm.lock); 382 list_splice_init(&msk->pm.anno_list, &free_list); 383 spin_unlock_bh(&msk->pm.lock); 384 385 list_for_each_entry_safe(entry, tmp, &free_list, list) { 386 sk_stop_timer_sync(sk, &entry->add_timer); 387 kfree(entry); 388 } 389 } 390 391 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) 392 { 393 struct sock *sk = (struct sock *)msk; 394 struct mptcp_pm_addr_entry *local; 395 unsigned int add_addr_signal_max; 396 unsigned int local_addr_max; 397 struct pm_nl_pernet *pernet; 398 unsigned int subflows_max; 399 400 pernet = net_generic(sock_net(sk), pm_nl_pernet_id); 401 402 add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk); 403 local_addr_max = mptcp_pm_get_local_addr_max(msk); 404 subflows_max = mptcp_pm_get_subflows_max(msk); 405 406 pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", 407 msk->pm.local_addr_used, local_addr_max, 408 msk->pm.add_addr_signaled, add_addr_signal_max, 409 msk->pm.subflows, subflows_max); 410 411 /* check first for announce */ 412 if (msk->pm.add_addr_signaled < add_addr_signal_max) { 413 local = select_signal_address(pernet, 414 msk->pm.add_addr_signaled); 415 416 if (local) { 417 if (mptcp_pm_alloc_anno_list(msk, local)) { 418 msk->pm.add_addr_signaled++; 419 mptcp_pm_announce_addr(msk, &local->addr, false, local->addr.port); 420 mptcp_pm_nl_add_addr_send_ack(msk); 421 } 422 } else { 423 /* pick failed, avoid fourther attempts later */ 424 msk->pm.local_addr_used = add_addr_signal_max; 425 } 426 427 check_work_pending(msk); 428 } 429 430 /* check if should create a new subflow */ 431 if (msk->pm.local_addr_used < local_addr_max && 432 msk->pm.subflows < subflows_max) { 433 local = select_local_address(pernet, msk); 434 if (local) { 435 struct mptcp_addr_info remote = { 0 }; 436 437 msk->pm.local_addr_used++; 438 msk->pm.subflows++; 439 check_work_pending(msk); 440 remote_address((struct sock_common *)sk, &remote); 441 spin_unlock_bh(&msk->pm.lock); 442 __mptcp_subflow_connect(sk, &local->addr, &remote); 443 spin_lock_bh(&msk->pm.lock); 444 return; 445 } 446 447 /* lookup failed, avoid fourther attempts later */ 448 msk->pm.local_addr_used = local_addr_max; 449 check_work_pending(msk); 450 } 451 } 452 453 static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk) 454 { 455 mptcp_pm_create_subflow_or_signal_addr(msk); 456 } 457 458 static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) 459 { 460 mptcp_pm_create_subflow_or_signal_addr(msk); 461 } 462 463 static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) 464 { 465 struct sock *sk = (struct sock *)msk; 466 unsigned int add_addr_accept_max; 467 struct mptcp_addr_info remote; 468 struct mptcp_addr_info local; 469 unsigned int subflows_max; 470 bool use_port = false; 471 472 add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); 473 subflows_max = mptcp_pm_get_subflows_max(msk); 474 475 pr_debug("accepted %d:%d remote family %d", 476 msk->pm.add_addr_accepted, add_addr_accept_max, 477 msk->pm.remote.family); 478 msk->pm.add_addr_accepted++; 479 msk->pm.subflows++; 480 if (msk->pm.add_addr_accepted >= add_addr_accept_max || 481 msk->pm.subflows >= subflows_max) 482 WRITE_ONCE(msk->pm.accept_addr, false); 483 484 /* connect to the specified remote address, using whatever 485 * local address the routing configuration will pick. 486 */ 487 remote = msk->pm.remote; 488 if (!remote.port) 489 remote.port = sk->sk_dport; 490 else 491 use_port = true; 492 memset(&local, 0, sizeof(local)); 493 local.family = remote.family; 494 495 spin_unlock_bh(&msk->pm.lock); 496 __mptcp_subflow_connect(sk, &local, &remote); 497 spin_lock_bh(&msk->pm.lock); 498 499 mptcp_pm_announce_addr(msk, &remote, true, use_port); 500 mptcp_pm_nl_add_addr_send_ack(msk); 501 } 502 503 static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk) 504 { 505 struct mptcp_subflow_context *subflow; 506 507 msk_owned_by_me(msk); 508 lockdep_assert_held(&msk->pm.lock); 509 510 if (!mptcp_pm_should_add_signal(msk)) 511 return; 512 513 __mptcp_flush_join_list(msk); 514 subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node); 515 if (subflow) { 516 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 517 u8 add_addr; 518 519 spin_unlock_bh(&msk->pm.lock); 520 pr_debug("send ack for add_addr%s%s", 521 mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "", 522 mptcp_pm_should_add_signal_port(msk) ? " [port]" : ""); 523 524 lock_sock(ssk); 525 tcp_send_ack(ssk); 526 release_sock(ssk); 527 spin_lock_bh(&msk->pm.lock); 528 529 add_addr = READ_ONCE(msk->pm.addr_signal); 530 if (mptcp_pm_should_add_signal_ipv6(msk)) 531 add_addr &= ~BIT(MPTCP_ADD_ADDR_IPV6); 532 if (mptcp_pm_should_add_signal_port(msk)) 533 add_addr &= ~BIT(MPTCP_ADD_ADDR_PORT); 534 WRITE_ONCE(msk->pm.addr_signal, add_addr); 535 } 536 } 537 538 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, 539 struct mptcp_addr_info *addr, 540 u8 bkup) 541 { 542 struct mptcp_subflow_context *subflow; 543 544 pr_debug("bkup=%d", bkup); 545 546 mptcp_for_each_subflow(msk, subflow) { 547 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 548 struct sock *sk = (struct sock *)msk; 549 struct mptcp_addr_info local; 550 551 local_address((struct sock_common *)ssk, &local); 552 if (!addresses_equal(&local, addr, addr->port)) 553 continue; 554 555 subflow->backup = bkup; 556 subflow->send_mp_prio = 1; 557 subflow->request_bkup = bkup; 558 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX); 559 560 spin_unlock_bh(&msk->pm.lock); 561 pr_debug("send ack for mp_prio"); 562 lock_sock(ssk); 563 tcp_send_ack(ssk); 564 release_sock(ssk); 565 spin_lock_bh(&msk->pm.lock); 566 567 return 0; 568 } 569 570 return -EINVAL; 571 } 572 573 static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk) 574 { 575 struct mptcp_subflow_context *subflow, *tmp; 576 struct sock *sk = (struct sock *)msk; 577 578 pr_debug("address rm_id %d", msk->pm.rm_id); 579 580 msk_owned_by_me(msk); 581 582 if (!msk->pm.rm_id) 583 return; 584 585 if (list_empty(&msk->conn_list)) 586 return; 587 588 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 589 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 590 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 591 592 if (msk->pm.rm_id != subflow->remote_id) 593 continue; 594 595 spin_unlock_bh(&msk->pm.lock); 596 mptcp_subflow_shutdown(sk, ssk, how); 597 mptcp_close_ssk(sk, ssk, subflow); 598 spin_lock_bh(&msk->pm.lock); 599 600 msk->pm.add_addr_accepted--; 601 msk->pm.subflows--; 602 WRITE_ONCE(msk->pm.accept_addr, true); 603 604 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR); 605 606 break; 607 } 608 } 609 610 void mptcp_pm_nl_work(struct mptcp_sock *msk) 611 { 612 struct mptcp_pm_data *pm = &msk->pm; 613 614 msk_owned_by_me(msk); 615 616 spin_lock_bh(&msk->pm.lock); 617 618 pr_debug("msk=%p status=%x", msk, pm->status); 619 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) { 620 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED); 621 mptcp_pm_nl_add_addr_received(msk); 622 } 623 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) { 624 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK); 625 mptcp_pm_nl_add_addr_send_ack(msk); 626 } 627 if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) { 628 pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED); 629 mptcp_pm_nl_rm_addr_received(msk); 630 } 631 if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) { 632 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED); 633 mptcp_pm_nl_fully_established(msk); 634 } 635 if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) { 636 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED); 637 mptcp_pm_nl_subflow_established(msk); 638 } 639 640 spin_unlock_bh(&msk->pm.lock); 641 } 642 643 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id) 644 { 645 struct mptcp_subflow_context *subflow, *tmp; 646 struct sock *sk = (struct sock *)msk; 647 648 pr_debug("subflow rm_id %d", rm_id); 649 650 msk_owned_by_me(msk); 651 652 if (!rm_id) 653 return; 654 655 if (list_empty(&msk->conn_list)) 656 return; 657 658 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 659 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 660 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 661 662 if (rm_id != subflow->local_id) 663 continue; 664 665 spin_unlock_bh(&msk->pm.lock); 666 mptcp_subflow_shutdown(sk, ssk, how); 667 mptcp_close_ssk(sk, ssk, subflow); 668 spin_lock_bh(&msk->pm.lock); 669 670 msk->pm.local_addr_used--; 671 msk->pm.subflows--; 672 673 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); 674 675 break; 676 } 677 } 678 679 static bool address_use_port(struct mptcp_pm_addr_entry *entry) 680 { 681 return (entry->addr.flags & 682 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) == 683 MPTCP_PM_ADDR_FLAG_SIGNAL; 684 } 685 686 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, 687 struct mptcp_pm_addr_entry *entry) 688 { 689 struct mptcp_pm_addr_entry *cur; 690 unsigned int addr_max; 691 int ret = -EINVAL; 692 693 spin_lock_bh(&pernet->lock); 694 /* to keep the code simple, don't do IDR-like allocation for address ID, 695 * just bail when we exceed limits 696 */ 697 if (pernet->next_id == MAX_ADDR_ID) 698 pernet->next_id = 1; 699 if (pernet->addrs >= MPTCP_PM_ADDR_MAX) 700 goto out; 701 if (test_bit(entry->addr.id, pernet->id_bitmap)) 702 goto out; 703 704 /* do not insert duplicate address, differentiate on port only 705 * singled addresses 706 */ 707 list_for_each_entry(cur, &pernet->local_addr_list, list) { 708 if (addresses_equal(&cur->addr, &entry->addr, 709 address_use_port(entry) && 710 address_use_port(cur))) 711 goto out; 712 } 713 714 if (!entry->addr.id) { 715 find_next: 716 entry->addr.id = find_next_zero_bit(pernet->id_bitmap, 717 MAX_ADDR_ID + 1, 718 pernet->next_id); 719 if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) && 720 pernet->next_id != 1) { 721 pernet->next_id = 1; 722 goto find_next; 723 } 724 } 725 726 if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID) 727 goto out; 728 729 __set_bit(entry->addr.id, pernet->id_bitmap); 730 if (entry->addr.id > pernet->next_id) 731 pernet->next_id = entry->addr.id; 732 733 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 734 addr_max = pernet->add_addr_signal_max; 735 WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1); 736 } 737 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 738 addr_max = pernet->local_addr_max; 739 WRITE_ONCE(pernet->local_addr_max, addr_max + 1); 740 } 741 742 pernet->addrs++; 743 list_add_tail_rcu(&entry->list, &pernet->local_addr_list); 744 ret = entry->addr.id; 745 746 out: 747 spin_unlock_bh(&pernet->lock); 748 return ret; 749 } 750 751 static int mptcp_pm_nl_create_listen_socket(struct sock *sk, 752 struct mptcp_pm_addr_entry *entry) 753 { 754 struct sockaddr_storage addr; 755 struct mptcp_sock *msk; 756 struct socket *ssock; 757 int backlog = 1024; 758 int err; 759 760 err = sock_create_kern(sock_net(sk), entry->addr.family, 761 SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk); 762 if (err) 763 return err; 764 765 msk = mptcp_sk(entry->lsk->sk); 766 if (!msk) { 767 err = -EINVAL; 768 goto out; 769 } 770 771 ssock = __mptcp_nmpc_socket(msk); 772 if (!ssock) { 773 err = -EINVAL; 774 goto out; 775 } 776 777 mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); 778 err = kernel_bind(ssock, (struct sockaddr *)&addr, 779 sizeof(struct sockaddr_in)); 780 if (err) { 781 pr_warn("kernel_bind error, err=%d", err); 782 goto out; 783 } 784 785 err = kernel_listen(ssock, backlog); 786 if (err) { 787 pr_warn("kernel_listen error, err=%d", err); 788 goto out; 789 } 790 791 return 0; 792 793 out: 794 sock_release(entry->lsk); 795 return err; 796 } 797 798 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) 799 { 800 struct mptcp_pm_addr_entry *entry; 801 struct mptcp_addr_info skc_local; 802 struct mptcp_addr_info msk_local; 803 struct pm_nl_pernet *pernet; 804 int ret = -1; 805 806 if (WARN_ON_ONCE(!msk)) 807 return -1; 808 809 /* The 0 ID mapping is defined by the first subflow, copied into the msk 810 * addr 811 */ 812 local_address((struct sock_common *)msk, &msk_local); 813 local_address((struct sock_common *)skc, &skc_local); 814 if (addresses_equal(&msk_local, &skc_local, false)) 815 return 0; 816 817 if (address_zero(&skc_local)) 818 return 0; 819 820 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 821 822 rcu_read_lock(); 823 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 824 if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) { 825 ret = entry->addr.id; 826 break; 827 } 828 } 829 rcu_read_unlock(); 830 if (ret >= 0) 831 return ret; 832 833 /* address not found, add to local list */ 834 entry = kmalloc(sizeof(*entry), GFP_ATOMIC); 835 if (!entry) 836 return -ENOMEM; 837 838 entry->addr = skc_local; 839 entry->addr.ifindex = 0; 840 entry->addr.flags = 0; 841 entry->addr.id = 0; 842 entry->addr.port = 0; 843 entry->lsk = NULL; 844 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 845 if (ret < 0) 846 kfree(entry); 847 848 return ret; 849 } 850 851 void mptcp_pm_nl_data_init(struct mptcp_sock *msk) 852 { 853 struct mptcp_pm_data *pm = &msk->pm; 854 bool subflows; 855 856 subflows = !!mptcp_pm_get_subflows_max(msk); 857 WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) || 858 !!mptcp_pm_get_add_addr_signal_max(msk)); 859 WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows); 860 WRITE_ONCE(pm->accept_subflow, subflows); 861 } 862 863 #define MPTCP_PM_CMD_GRP_OFFSET 0 864 #define MPTCP_PM_EV_GRP_OFFSET 1 865 866 static const struct genl_multicast_group mptcp_pm_mcgrps[] = { 867 [MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, }, 868 [MPTCP_PM_EV_GRP_OFFSET] = { .name = MPTCP_PM_EV_GRP_NAME, 869 .flags = GENL_UNS_ADMIN_PERM, 870 }, 871 }; 872 873 static const struct nla_policy 874 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = { 875 [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type = NLA_U16, }, 876 [MPTCP_PM_ADDR_ATTR_ID] = { .type = NLA_U8, }, 877 [MPTCP_PM_ADDR_ATTR_ADDR4] = { .type = NLA_U32, }, 878 [MPTCP_PM_ADDR_ATTR_ADDR6] = 879 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), 880 [MPTCP_PM_ADDR_ATTR_PORT] = { .type = NLA_U16 }, 881 [MPTCP_PM_ADDR_ATTR_FLAGS] = { .type = NLA_U32 }, 882 [MPTCP_PM_ADDR_ATTR_IF_IDX] = { .type = NLA_S32 }, 883 }; 884 885 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = { 886 [MPTCP_PM_ATTR_ADDR] = 887 NLA_POLICY_NESTED(mptcp_pm_addr_policy), 888 [MPTCP_PM_ATTR_RCV_ADD_ADDRS] = { .type = NLA_U32, }, 889 [MPTCP_PM_ATTR_SUBFLOWS] = { .type = NLA_U32, }, 890 }; 891 892 static int mptcp_pm_family_to_addr(int family) 893 { 894 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 895 if (family == AF_INET6) 896 return MPTCP_PM_ADDR_ATTR_ADDR6; 897 #endif 898 return MPTCP_PM_ADDR_ATTR_ADDR4; 899 } 900 901 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, 902 bool require_family, 903 struct mptcp_pm_addr_entry *entry) 904 { 905 struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; 906 int err, addr_addr; 907 908 if (!attr) { 909 GENL_SET_ERR_MSG(info, "missing address info"); 910 return -EINVAL; 911 } 912 913 /* no validation needed - was already done via nested policy */ 914 err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr, 915 mptcp_pm_addr_policy, info->extack); 916 if (err) 917 return err; 918 919 memset(entry, 0, sizeof(*entry)); 920 if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) { 921 if (!require_family) 922 goto skip_family; 923 924 NL_SET_ERR_MSG_ATTR(info->extack, attr, 925 "missing family"); 926 return -EINVAL; 927 } 928 929 entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]); 930 if (entry->addr.family != AF_INET 931 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 932 && entry->addr.family != AF_INET6 933 #endif 934 ) { 935 NL_SET_ERR_MSG_ATTR(info->extack, attr, 936 "unknown address family"); 937 return -EINVAL; 938 } 939 addr_addr = mptcp_pm_family_to_addr(entry->addr.family); 940 if (!tb[addr_addr]) { 941 NL_SET_ERR_MSG_ATTR(info->extack, attr, 942 "missing address data"); 943 return -EINVAL; 944 } 945 946 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 947 if (entry->addr.family == AF_INET6) 948 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]); 949 else 950 #endif 951 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]); 952 953 skip_family: 954 if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) { 955 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]); 956 957 entry->addr.ifindex = val; 958 } 959 960 if (tb[MPTCP_PM_ADDR_ATTR_ID]) 961 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]); 962 963 if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) 964 entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); 965 966 if (tb[MPTCP_PM_ADDR_ATTR_PORT]) 967 entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); 968 969 return 0; 970 } 971 972 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) 973 { 974 return net_generic(genl_info_net(info), pm_nl_pernet_id); 975 } 976 977 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net) 978 { 979 struct mptcp_sock *msk; 980 long s_slot = 0, s_num = 0; 981 982 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 983 struct sock *sk = (struct sock *)msk; 984 985 if (!READ_ONCE(msk->fully_established)) 986 goto next; 987 988 lock_sock(sk); 989 spin_lock_bh(&msk->pm.lock); 990 mptcp_pm_create_subflow_or_signal_addr(msk); 991 spin_unlock_bh(&msk->pm.lock); 992 release_sock(sk); 993 994 next: 995 sock_put(sk); 996 cond_resched(); 997 } 998 999 return 0; 1000 } 1001 1002 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) 1003 { 1004 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1005 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1006 struct mptcp_pm_addr_entry addr, *entry; 1007 int ret; 1008 1009 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 1010 if (ret < 0) 1011 return ret; 1012 1013 entry = kmalloc(sizeof(*entry), GFP_KERNEL); 1014 if (!entry) { 1015 GENL_SET_ERR_MSG(info, "can't allocate addr"); 1016 return -ENOMEM; 1017 } 1018 1019 *entry = addr; 1020 if (entry->addr.port) { 1021 ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); 1022 if (ret) { 1023 GENL_SET_ERR_MSG(info, "create listen socket error"); 1024 kfree(entry); 1025 return ret; 1026 } 1027 } 1028 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 1029 if (ret < 0) { 1030 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); 1031 if (entry->lsk) 1032 sock_release(entry->lsk); 1033 kfree(entry); 1034 return ret; 1035 } 1036 1037 mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); 1038 1039 return 0; 1040 } 1041 1042 static struct mptcp_pm_addr_entry * 1043 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id) 1044 { 1045 struct mptcp_pm_addr_entry *entry; 1046 1047 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1048 if (entry->addr.id == id) 1049 return entry; 1050 } 1051 return NULL; 1052 } 1053 1054 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, 1055 struct mptcp_addr_info *addr) 1056 { 1057 struct mptcp_pm_add_entry *entry; 1058 1059 entry = mptcp_pm_del_add_timer(msk, addr); 1060 if (entry) { 1061 list_del(&entry->list); 1062 kfree(entry); 1063 return true; 1064 } 1065 1066 return false; 1067 } 1068 1069 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, 1070 struct mptcp_addr_info *addr, 1071 bool force) 1072 { 1073 bool ret; 1074 1075 ret = remove_anno_list_by_saddr(msk, addr); 1076 if (ret || force) { 1077 spin_lock_bh(&msk->pm.lock); 1078 mptcp_pm_remove_addr(msk, addr->id); 1079 spin_unlock_bh(&msk->pm.lock); 1080 } 1081 return ret; 1082 } 1083 1084 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, 1085 struct mptcp_addr_info *addr) 1086 { 1087 struct mptcp_sock *msk; 1088 long s_slot = 0, s_num = 0; 1089 1090 pr_debug("remove_id=%d", addr->id); 1091 1092 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1093 struct sock *sk = (struct sock *)msk; 1094 bool remove_subflow; 1095 1096 if (list_empty(&msk->conn_list)) { 1097 mptcp_pm_remove_anno_addr(msk, addr, false); 1098 goto next; 1099 } 1100 1101 lock_sock(sk); 1102 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); 1103 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow); 1104 if (remove_subflow) 1105 mptcp_pm_remove_subflow(msk, addr->id); 1106 release_sock(sk); 1107 1108 next: 1109 sock_put(sk); 1110 cond_resched(); 1111 } 1112 1113 return 0; 1114 } 1115 1116 struct addr_entry_release_work { 1117 struct rcu_work rwork; 1118 struct mptcp_pm_addr_entry *entry; 1119 }; 1120 1121 static void mptcp_pm_release_addr_entry(struct work_struct *work) 1122 { 1123 struct addr_entry_release_work *w; 1124 struct mptcp_pm_addr_entry *entry; 1125 1126 w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork); 1127 entry = w->entry; 1128 if (entry) { 1129 if (entry->lsk) 1130 sock_release(entry->lsk); 1131 kfree(entry); 1132 } 1133 kfree(w); 1134 } 1135 1136 static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry) 1137 { 1138 struct addr_entry_release_work *w; 1139 1140 w = kmalloc(sizeof(*w), GFP_ATOMIC); 1141 if (w) { 1142 INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry); 1143 w->entry = entry; 1144 queue_rcu_work(system_wq, &w->rwork); 1145 } 1146 } 1147 1148 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) 1149 { 1150 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1151 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1152 struct mptcp_pm_addr_entry addr, *entry; 1153 unsigned int addr_max; 1154 int ret; 1155 1156 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1157 if (ret < 0) 1158 return ret; 1159 1160 spin_lock_bh(&pernet->lock); 1161 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1162 if (!entry) { 1163 GENL_SET_ERR_MSG(info, "address not found"); 1164 spin_unlock_bh(&pernet->lock); 1165 return -EINVAL; 1166 } 1167 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 1168 addr_max = pernet->add_addr_signal_max; 1169 WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1); 1170 } 1171 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 1172 addr_max = pernet->local_addr_max; 1173 WRITE_ONCE(pernet->local_addr_max, addr_max - 1); 1174 } 1175 1176 pernet->addrs--; 1177 list_del_rcu(&entry->list); 1178 __clear_bit(entry->addr.id, pernet->id_bitmap); 1179 spin_unlock_bh(&pernet->lock); 1180 1181 mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); 1182 mptcp_pm_free_addr_entry(entry); 1183 1184 return ret; 1185 } 1186 1187 static void __flush_addrs(struct net *net, struct list_head *list) 1188 { 1189 while (!list_empty(list)) { 1190 struct mptcp_pm_addr_entry *cur; 1191 1192 cur = list_entry(list->next, 1193 struct mptcp_pm_addr_entry, list); 1194 mptcp_nl_remove_subflow_and_signal_addr(net, &cur->addr); 1195 list_del_rcu(&cur->list); 1196 mptcp_pm_free_addr_entry(cur); 1197 } 1198 } 1199 1200 static void __reset_counters(struct pm_nl_pernet *pernet) 1201 { 1202 WRITE_ONCE(pernet->add_addr_signal_max, 0); 1203 WRITE_ONCE(pernet->add_addr_accept_max, 0); 1204 WRITE_ONCE(pernet->local_addr_max, 0); 1205 pernet->addrs = 0; 1206 } 1207 1208 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info) 1209 { 1210 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1211 LIST_HEAD(free_list); 1212 1213 spin_lock_bh(&pernet->lock); 1214 list_splice_init(&pernet->local_addr_list, &free_list); 1215 __reset_counters(pernet); 1216 pernet->next_id = 1; 1217 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1218 spin_unlock_bh(&pernet->lock); 1219 __flush_addrs(sock_net(skb->sk), &free_list); 1220 return 0; 1221 } 1222 1223 static int mptcp_nl_fill_addr(struct sk_buff *skb, 1224 struct mptcp_pm_addr_entry *entry) 1225 { 1226 struct mptcp_addr_info *addr = &entry->addr; 1227 struct nlattr *attr; 1228 1229 attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR); 1230 if (!attr) 1231 return -EMSGSIZE; 1232 1233 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family)) 1234 goto nla_put_failure; 1235 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port))) 1236 goto nla_put_failure; 1237 if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id)) 1238 goto nla_put_failure; 1239 if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags)) 1240 goto nla_put_failure; 1241 if (entry->addr.ifindex && 1242 nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex)) 1243 goto nla_put_failure; 1244 1245 if (addr->family == AF_INET && 1246 nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4, 1247 addr->addr.s_addr)) 1248 goto nla_put_failure; 1249 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1250 else if (addr->family == AF_INET6 && 1251 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6)) 1252 goto nla_put_failure; 1253 #endif 1254 nla_nest_end(skb, attr); 1255 return 0; 1256 1257 nla_put_failure: 1258 nla_nest_cancel(skb, attr); 1259 return -EMSGSIZE; 1260 } 1261 1262 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info) 1263 { 1264 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1265 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1266 struct mptcp_pm_addr_entry addr, *entry; 1267 struct sk_buff *msg; 1268 void *reply; 1269 int ret; 1270 1271 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1272 if (ret < 0) 1273 return ret; 1274 1275 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1276 if (!msg) 1277 return -ENOMEM; 1278 1279 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1280 info->genlhdr->cmd); 1281 if (!reply) { 1282 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1283 ret = -EMSGSIZE; 1284 goto fail; 1285 } 1286 1287 spin_lock_bh(&pernet->lock); 1288 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1289 if (!entry) { 1290 GENL_SET_ERR_MSG(info, "address not found"); 1291 ret = -EINVAL; 1292 goto unlock_fail; 1293 } 1294 1295 ret = mptcp_nl_fill_addr(msg, entry); 1296 if (ret) 1297 goto unlock_fail; 1298 1299 genlmsg_end(msg, reply); 1300 ret = genlmsg_reply(msg, info); 1301 spin_unlock_bh(&pernet->lock); 1302 return ret; 1303 1304 unlock_fail: 1305 spin_unlock_bh(&pernet->lock); 1306 1307 fail: 1308 nlmsg_free(msg); 1309 return ret; 1310 } 1311 1312 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg, 1313 struct netlink_callback *cb) 1314 { 1315 struct net *net = sock_net(msg->sk); 1316 struct mptcp_pm_addr_entry *entry; 1317 struct pm_nl_pernet *pernet; 1318 int id = cb->args[0]; 1319 void *hdr; 1320 int i; 1321 1322 pernet = net_generic(net, pm_nl_pernet_id); 1323 1324 spin_lock_bh(&pernet->lock); 1325 for (i = id; i < MAX_ADDR_ID + 1; i++) { 1326 if (test_bit(i, pernet->id_bitmap)) { 1327 entry = __lookup_addr_by_id(pernet, i); 1328 if (!entry) 1329 break; 1330 1331 if (entry->addr.id <= id) 1332 continue; 1333 1334 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid, 1335 cb->nlh->nlmsg_seq, &mptcp_genl_family, 1336 NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR); 1337 if (!hdr) 1338 break; 1339 1340 if (mptcp_nl_fill_addr(msg, entry) < 0) { 1341 genlmsg_cancel(msg, hdr); 1342 break; 1343 } 1344 1345 id = entry->addr.id; 1346 genlmsg_end(msg, hdr); 1347 } 1348 } 1349 spin_unlock_bh(&pernet->lock); 1350 1351 cb->args[0] = id; 1352 return msg->len; 1353 } 1354 1355 static int parse_limit(struct genl_info *info, int id, unsigned int *limit) 1356 { 1357 struct nlattr *attr = info->attrs[id]; 1358 1359 if (!attr) 1360 return 0; 1361 1362 *limit = nla_get_u32(attr); 1363 if (*limit > MPTCP_PM_ADDR_MAX) { 1364 GENL_SET_ERR_MSG(info, "limit greater than maximum"); 1365 return -EINVAL; 1366 } 1367 return 0; 1368 } 1369 1370 static int 1371 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info) 1372 { 1373 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1374 unsigned int rcv_addrs, subflows; 1375 int ret; 1376 1377 spin_lock_bh(&pernet->lock); 1378 rcv_addrs = pernet->add_addr_accept_max; 1379 ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs); 1380 if (ret) 1381 goto unlock; 1382 1383 subflows = pernet->subflows_max; 1384 ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows); 1385 if (ret) 1386 goto unlock; 1387 1388 WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs); 1389 WRITE_ONCE(pernet->subflows_max, subflows); 1390 1391 unlock: 1392 spin_unlock_bh(&pernet->lock); 1393 return ret; 1394 } 1395 1396 static int 1397 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info) 1398 { 1399 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1400 struct sk_buff *msg; 1401 void *reply; 1402 1403 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1404 if (!msg) 1405 return -ENOMEM; 1406 1407 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1408 MPTCP_PM_CMD_GET_LIMITS); 1409 if (!reply) 1410 goto fail; 1411 1412 if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS, 1413 READ_ONCE(pernet->add_addr_accept_max))) 1414 goto fail; 1415 1416 if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS, 1417 READ_ONCE(pernet->subflows_max))) 1418 goto fail; 1419 1420 genlmsg_end(msg, reply); 1421 return genlmsg_reply(msg, info); 1422 1423 fail: 1424 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1425 nlmsg_free(msg); 1426 return -EMSGSIZE; 1427 } 1428 1429 static int mptcp_nl_addr_backup(struct net *net, 1430 struct mptcp_addr_info *addr, 1431 u8 bkup) 1432 { 1433 long s_slot = 0, s_num = 0; 1434 struct mptcp_sock *msk; 1435 int ret = -EINVAL; 1436 1437 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1438 struct sock *sk = (struct sock *)msk; 1439 1440 if (list_empty(&msk->conn_list)) 1441 goto next; 1442 1443 lock_sock(sk); 1444 spin_lock_bh(&msk->pm.lock); 1445 ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup); 1446 spin_unlock_bh(&msk->pm.lock); 1447 release_sock(sk); 1448 1449 next: 1450 sock_put(sk); 1451 cond_resched(); 1452 } 1453 1454 return ret; 1455 } 1456 1457 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) 1458 { 1459 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1460 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1461 struct mptcp_pm_addr_entry addr, *entry; 1462 struct net *net = sock_net(skb->sk); 1463 u8 bkup = 0; 1464 int ret; 1465 1466 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 1467 if (ret < 0) 1468 return ret; 1469 1470 if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) 1471 bkup = 1; 1472 1473 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1474 if (addresses_equal(&entry->addr, &addr.addr, true)) { 1475 ret = mptcp_nl_addr_backup(net, &entry->addr, bkup); 1476 if (ret) 1477 return ret; 1478 1479 if (bkup) 1480 entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP; 1481 else 1482 entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; 1483 } 1484 } 1485 1486 return 0; 1487 } 1488 1489 static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp) 1490 { 1491 genlmsg_multicast_netns(&mptcp_genl_family, net, 1492 nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp); 1493 } 1494 1495 static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk) 1496 { 1497 const struct inet_sock *issk = inet_sk(ssk); 1498 const struct mptcp_subflow_context *sf; 1499 1500 if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family)) 1501 return -EMSGSIZE; 1502 1503 switch (ssk->sk_family) { 1504 case AF_INET: 1505 if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr)) 1506 return -EMSGSIZE; 1507 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr)) 1508 return -EMSGSIZE; 1509 break; 1510 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1511 case AF_INET6: { 1512 const struct ipv6_pinfo *np = inet6_sk(ssk); 1513 1514 if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr)) 1515 return -EMSGSIZE; 1516 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr)) 1517 return -EMSGSIZE; 1518 break; 1519 } 1520 #endif 1521 default: 1522 WARN_ON_ONCE(1); 1523 return -EMSGSIZE; 1524 } 1525 1526 if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport)) 1527 return -EMSGSIZE; 1528 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport)) 1529 return -EMSGSIZE; 1530 1531 sf = mptcp_subflow_ctx(ssk); 1532 if (WARN_ON_ONCE(!sf)) 1533 return -EINVAL; 1534 1535 if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id)) 1536 return -EMSGSIZE; 1537 1538 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id)) 1539 return -EMSGSIZE; 1540 1541 return 0; 1542 } 1543 1544 static int mptcp_event_put_token_and_ssk(struct sk_buff *skb, 1545 const struct mptcp_sock *msk, 1546 const struct sock *ssk) 1547 { 1548 const struct sock *sk = (const struct sock *)msk; 1549 const struct mptcp_subflow_context *sf; 1550 u8 sk_err; 1551 1552 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1553 return -EMSGSIZE; 1554 1555 if (mptcp_event_add_subflow(skb, ssk)) 1556 return -EMSGSIZE; 1557 1558 sf = mptcp_subflow_ctx(ssk); 1559 if (WARN_ON_ONCE(!sf)) 1560 return -EINVAL; 1561 1562 if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup)) 1563 return -EMSGSIZE; 1564 1565 if (ssk->sk_bound_dev_if && 1566 nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if)) 1567 return -EMSGSIZE; 1568 1569 sk_err = ssk->sk_err; 1570 if (sk_err && sk->sk_state == TCP_ESTABLISHED && 1571 nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err)) 1572 return -EMSGSIZE; 1573 1574 return 0; 1575 } 1576 1577 static int mptcp_event_sub_established(struct sk_buff *skb, 1578 const struct mptcp_sock *msk, 1579 const struct sock *ssk) 1580 { 1581 return mptcp_event_put_token_and_ssk(skb, msk, ssk); 1582 } 1583 1584 static int mptcp_event_sub_closed(struct sk_buff *skb, 1585 const struct mptcp_sock *msk, 1586 const struct sock *ssk) 1587 { 1588 if (mptcp_event_put_token_and_ssk(skb, msk, ssk)) 1589 return -EMSGSIZE; 1590 1591 return 0; 1592 } 1593 1594 static int mptcp_event_created(struct sk_buff *skb, 1595 const struct mptcp_sock *msk, 1596 const struct sock *ssk) 1597 { 1598 int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token); 1599 1600 if (err) 1601 return err; 1602 1603 return mptcp_event_add_subflow(skb, ssk); 1604 } 1605 1606 void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id) 1607 { 1608 struct net *net = sock_net((const struct sock *)msk); 1609 struct nlmsghdr *nlh; 1610 struct sk_buff *skb; 1611 1612 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1613 return; 1614 1615 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); 1616 if (!skb) 1617 return; 1618 1619 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED); 1620 if (!nlh) 1621 goto nla_put_failure; 1622 1623 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1624 goto nla_put_failure; 1625 1626 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id)) 1627 goto nla_put_failure; 1628 1629 genlmsg_end(skb, nlh); 1630 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); 1631 return; 1632 1633 nla_put_failure: 1634 kfree_skb(skb); 1635 } 1636 1637 void mptcp_event_addr_announced(const struct mptcp_sock *msk, 1638 const struct mptcp_addr_info *info) 1639 { 1640 struct net *net = sock_net((const struct sock *)msk); 1641 struct nlmsghdr *nlh; 1642 struct sk_buff *skb; 1643 1644 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1645 return; 1646 1647 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); 1648 if (!skb) 1649 return; 1650 1651 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, 1652 MPTCP_EVENT_ANNOUNCED); 1653 if (!nlh) 1654 goto nla_put_failure; 1655 1656 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1657 goto nla_put_failure; 1658 1659 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id)) 1660 goto nla_put_failure; 1661 1662 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port)) 1663 goto nla_put_failure; 1664 1665 switch (info->family) { 1666 case AF_INET: 1667 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr)) 1668 goto nla_put_failure; 1669 break; 1670 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1671 case AF_INET6: 1672 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6)) 1673 goto nla_put_failure; 1674 break; 1675 #endif 1676 default: 1677 WARN_ON_ONCE(1); 1678 goto nla_put_failure; 1679 } 1680 1681 genlmsg_end(skb, nlh); 1682 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); 1683 return; 1684 1685 nla_put_failure: 1686 kfree_skb(skb); 1687 } 1688 1689 void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, 1690 const struct sock *ssk, gfp_t gfp) 1691 { 1692 struct net *net = sock_net((const struct sock *)msk); 1693 struct nlmsghdr *nlh; 1694 struct sk_buff *skb; 1695 1696 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1697 return; 1698 1699 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); 1700 if (!skb) 1701 return; 1702 1703 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type); 1704 if (!nlh) 1705 goto nla_put_failure; 1706 1707 switch (type) { 1708 case MPTCP_EVENT_UNSPEC: 1709 WARN_ON_ONCE(1); 1710 break; 1711 case MPTCP_EVENT_CREATED: 1712 case MPTCP_EVENT_ESTABLISHED: 1713 if (mptcp_event_created(skb, msk, ssk) < 0) 1714 goto nla_put_failure; 1715 break; 1716 case MPTCP_EVENT_CLOSED: 1717 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0) 1718 goto nla_put_failure; 1719 break; 1720 case MPTCP_EVENT_ANNOUNCED: 1721 case MPTCP_EVENT_REMOVED: 1722 /* call mptcp_event_addr_announced()/removed instead */ 1723 WARN_ON_ONCE(1); 1724 break; 1725 case MPTCP_EVENT_SUB_ESTABLISHED: 1726 case MPTCP_EVENT_SUB_PRIORITY: 1727 if (mptcp_event_sub_established(skb, msk, ssk) < 0) 1728 goto nla_put_failure; 1729 break; 1730 case MPTCP_EVENT_SUB_CLOSED: 1731 if (mptcp_event_sub_closed(skb, msk, ssk) < 0) 1732 goto nla_put_failure; 1733 break; 1734 } 1735 1736 genlmsg_end(skb, nlh); 1737 mptcp_nl_mcast_send(net, skb, gfp); 1738 return; 1739 1740 nla_put_failure: 1741 kfree_skb(skb); 1742 } 1743 1744 static const struct genl_small_ops mptcp_pm_ops[] = { 1745 { 1746 .cmd = MPTCP_PM_CMD_ADD_ADDR, 1747 .doit = mptcp_nl_cmd_add_addr, 1748 .flags = GENL_ADMIN_PERM, 1749 }, 1750 { 1751 .cmd = MPTCP_PM_CMD_DEL_ADDR, 1752 .doit = mptcp_nl_cmd_del_addr, 1753 .flags = GENL_ADMIN_PERM, 1754 }, 1755 { 1756 .cmd = MPTCP_PM_CMD_FLUSH_ADDRS, 1757 .doit = mptcp_nl_cmd_flush_addrs, 1758 .flags = GENL_ADMIN_PERM, 1759 }, 1760 { 1761 .cmd = MPTCP_PM_CMD_GET_ADDR, 1762 .doit = mptcp_nl_cmd_get_addr, 1763 .dumpit = mptcp_nl_cmd_dump_addrs, 1764 }, 1765 { 1766 .cmd = MPTCP_PM_CMD_SET_LIMITS, 1767 .doit = mptcp_nl_cmd_set_limits, 1768 .flags = GENL_ADMIN_PERM, 1769 }, 1770 { 1771 .cmd = MPTCP_PM_CMD_GET_LIMITS, 1772 .doit = mptcp_nl_cmd_get_limits, 1773 }, 1774 { 1775 .cmd = MPTCP_PM_CMD_SET_FLAGS, 1776 .doit = mptcp_nl_cmd_set_flags, 1777 .flags = GENL_ADMIN_PERM, 1778 }, 1779 }; 1780 1781 static struct genl_family mptcp_genl_family __ro_after_init = { 1782 .name = MPTCP_PM_NAME, 1783 .version = MPTCP_PM_VER, 1784 .maxattr = MPTCP_PM_ATTR_MAX, 1785 .policy = mptcp_pm_policy, 1786 .netnsok = true, 1787 .module = THIS_MODULE, 1788 .small_ops = mptcp_pm_ops, 1789 .n_small_ops = ARRAY_SIZE(mptcp_pm_ops), 1790 .mcgrps = mptcp_pm_mcgrps, 1791 .n_mcgrps = ARRAY_SIZE(mptcp_pm_mcgrps), 1792 }; 1793 1794 static int __net_init pm_nl_init_net(struct net *net) 1795 { 1796 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1797 1798 INIT_LIST_HEAD_RCU(&pernet->local_addr_list); 1799 __reset_counters(pernet); 1800 pernet->next_id = 1; 1801 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1802 spin_lock_init(&pernet->lock); 1803 return 0; 1804 } 1805 1806 static void __net_exit pm_nl_exit_net(struct list_head *net_list) 1807 { 1808 struct net *net; 1809 1810 list_for_each_entry(net, net_list, exit_list) { 1811 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1812 1813 /* net is removed from namespace list, can't race with 1814 * other modifiers 1815 */ 1816 __flush_addrs(net, &pernet->local_addr_list); 1817 } 1818 } 1819 1820 static struct pernet_operations mptcp_pm_pernet_ops = { 1821 .init = pm_nl_init_net, 1822 .exit_batch = pm_nl_exit_net, 1823 .id = &pm_nl_pernet_id, 1824 .size = sizeof(struct pm_nl_pernet), 1825 }; 1826 1827 void __init mptcp_pm_nl_init(void) 1828 { 1829 if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0) 1830 panic("Failed to register MPTCP PM pernet subsystem.\n"); 1831 1832 if (genl_register_family(&mptcp_genl_family)) 1833 panic("Failed to register MPTCP PM netlink family\n"); 1834 } 1835