1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2020, Red Hat, Inc. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/inet.h> 10 #include <linux/kernel.h> 11 #include <net/tcp.h> 12 #include <net/netns/generic.h> 13 #include <net/mptcp.h> 14 #include <net/genetlink.h> 15 #include <uapi/linux/mptcp.h> 16 17 #include "protocol.h" 18 #include "mib.h" 19 20 /* forward declaration */ 21 static struct genl_family mptcp_genl_family; 22 23 static int pm_nl_pernet_id; 24 25 struct mptcp_pm_addr_entry { 26 struct list_head list; 27 struct mptcp_addr_info addr; 28 struct rcu_head rcu; 29 }; 30 31 struct mptcp_pm_add_entry { 32 struct list_head list; 33 struct mptcp_addr_info addr; 34 struct timer_list add_timer; 35 struct mptcp_sock *sock; 36 u8 retrans_times; 37 }; 38 39 struct pm_nl_pernet { 40 /* protects pernet updates */ 41 spinlock_t lock; 42 struct list_head local_addr_list; 43 unsigned int addrs; 44 unsigned int add_addr_signal_max; 45 unsigned int add_addr_accept_max; 46 unsigned int local_addr_max; 47 unsigned int subflows_max; 48 unsigned int next_id; 49 }; 50 51 #define MPTCP_PM_ADDR_MAX 8 52 #define ADD_ADDR_RETRANS_MAX 3 53 54 static bool addresses_equal(const struct mptcp_addr_info *a, 55 struct mptcp_addr_info *b, bool use_port) 56 { 57 bool addr_equals = false; 58 59 if (a->family != b->family) 60 return false; 61 62 if (a->family == AF_INET) 63 addr_equals = a->addr.s_addr == b->addr.s_addr; 64 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 65 else 66 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6); 67 #endif 68 69 if (!addr_equals) 70 return false; 71 if (!use_port) 72 return true; 73 74 return a->port == b->port; 75 } 76 77 static bool address_zero(const struct mptcp_addr_info *addr) 78 { 79 struct mptcp_addr_info zero; 80 81 memset(&zero, 0, sizeof(zero)); 82 zero.family = addr->family; 83 84 return addresses_equal(addr, &zero, false); 85 } 86 87 static void local_address(const struct sock_common *skc, 88 struct mptcp_addr_info *addr) 89 { 90 addr->port = 0; 91 addr->family = skc->skc_family; 92 if (addr->family == AF_INET) 93 addr->addr.s_addr = skc->skc_rcv_saddr; 94 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 95 else if (addr->family == AF_INET6) 96 addr->addr6 = skc->skc_v6_rcv_saddr; 97 #endif 98 } 99 100 static void remote_address(const struct sock_common *skc, 101 struct mptcp_addr_info *addr) 102 { 103 addr->family = skc->skc_family; 104 addr->port = skc->skc_dport; 105 if (addr->family == AF_INET) 106 addr->addr.s_addr = skc->skc_daddr; 107 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 108 else if (addr->family == AF_INET6) 109 addr->addr6 = skc->skc_v6_daddr; 110 #endif 111 } 112 113 static bool lookup_subflow_by_saddr(const struct list_head *list, 114 struct mptcp_addr_info *saddr) 115 { 116 struct mptcp_subflow_context *subflow; 117 struct mptcp_addr_info cur; 118 struct sock_common *skc; 119 120 list_for_each_entry(subflow, list, node) { 121 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); 122 123 local_address(skc, &cur); 124 if (addresses_equal(&cur, saddr, false)) 125 return true; 126 } 127 128 return false; 129 } 130 131 static struct mptcp_pm_addr_entry * 132 select_local_address(const struct pm_nl_pernet *pernet, 133 struct mptcp_sock *msk) 134 { 135 struct mptcp_pm_addr_entry *entry, *ret = NULL; 136 137 rcu_read_lock(); 138 spin_lock_bh(&msk->join_list_lock); 139 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 140 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) 141 continue; 142 143 /* avoid any address already in use by subflows and 144 * pending join 145 */ 146 if (entry->addr.family == ((struct sock *)msk)->sk_family && 147 !lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) && 148 !lookup_subflow_by_saddr(&msk->join_list, &entry->addr)) { 149 ret = entry; 150 break; 151 } 152 } 153 spin_unlock_bh(&msk->join_list_lock); 154 rcu_read_unlock(); 155 return ret; 156 } 157 158 static struct mptcp_pm_addr_entry * 159 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos) 160 { 161 struct mptcp_pm_addr_entry *entry, *ret = NULL; 162 int i = 0; 163 164 rcu_read_lock(); 165 /* do not keep any additional per socket state, just signal 166 * the address list in order. 167 * Note: removal from the local address list during the msk life-cycle 168 * can lead to additional addresses not being announced. 169 */ 170 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 171 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) 172 continue; 173 if (i++ == pos) { 174 ret = entry; 175 break; 176 } 177 } 178 rcu_read_unlock(); 179 return ret; 180 } 181 182 static void check_work_pending(struct mptcp_sock *msk) 183 { 184 if (msk->pm.add_addr_signaled == msk->pm.add_addr_signal_max && 185 (msk->pm.local_addr_used == msk->pm.local_addr_max || 186 msk->pm.subflows == msk->pm.subflows_max)) 187 WRITE_ONCE(msk->pm.work_pending, false); 188 } 189 190 static struct mptcp_pm_add_entry * 191 lookup_anno_list_by_saddr(struct mptcp_sock *msk, 192 struct mptcp_addr_info *addr) 193 { 194 struct mptcp_pm_add_entry *entry; 195 196 list_for_each_entry(entry, &msk->pm.anno_list, list) { 197 if (addresses_equal(&entry->addr, addr, false)) 198 return entry; 199 } 200 201 return NULL; 202 } 203 204 static void mptcp_pm_add_timer(struct timer_list *timer) 205 { 206 struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer); 207 struct mptcp_sock *msk = entry->sock; 208 struct sock *sk = (struct sock *)msk; 209 210 pr_debug("msk=%p", msk); 211 212 if (!msk) 213 return; 214 215 if (inet_sk_state_load(sk) == TCP_CLOSE) 216 return; 217 218 if (!entry->addr.id) 219 return; 220 221 if (mptcp_pm_should_add_signal(msk)) { 222 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8); 223 goto out; 224 } 225 226 spin_lock_bh(&msk->pm.lock); 227 228 if (!mptcp_pm_should_add_signal(msk)) { 229 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id); 230 mptcp_pm_announce_addr(msk, &entry->addr, false); 231 entry->retrans_times++; 232 } 233 234 if (entry->retrans_times < ADD_ADDR_RETRANS_MAX) 235 sk_reset_timer(sk, timer, 236 jiffies + mptcp_get_add_addr_timeout(sock_net(sk))); 237 238 spin_unlock_bh(&msk->pm.lock); 239 240 out: 241 __sock_put(sk); 242 } 243 244 struct mptcp_pm_add_entry * 245 mptcp_pm_del_add_timer(struct mptcp_sock *msk, 246 struct mptcp_addr_info *addr) 247 { 248 struct mptcp_pm_add_entry *entry; 249 struct sock *sk = (struct sock *)msk; 250 251 spin_lock_bh(&msk->pm.lock); 252 entry = lookup_anno_list_by_saddr(msk, addr); 253 if (entry) 254 entry->retrans_times = ADD_ADDR_RETRANS_MAX; 255 spin_unlock_bh(&msk->pm.lock); 256 257 if (entry) 258 sk_stop_timer_sync(sk, &entry->add_timer); 259 260 return entry; 261 } 262 263 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, 264 struct mptcp_pm_addr_entry *entry) 265 { 266 struct mptcp_pm_add_entry *add_entry = NULL; 267 struct sock *sk = (struct sock *)msk; 268 struct net *net = sock_net(sk); 269 270 if (lookup_anno_list_by_saddr(msk, &entry->addr)) 271 return false; 272 273 add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC); 274 if (!add_entry) 275 return false; 276 277 list_add(&add_entry->list, &msk->pm.anno_list); 278 279 add_entry->addr = entry->addr; 280 add_entry->sock = msk; 281 add_entry->retrans_times = 0; 282 283 timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0); 284 sk_reset_timer(sk, &add_entry->add_timer, 285 jiffies + mptcp_get_add_addr_timeout(net)); 286 287 return true; 288 } 289 290 void mptcp_pm_free_anno_list(struct mptcp_sock *msk) 291 { 292 struct mptcp_pm_add_entry *entry, *tmp; 293 struct sock *sk = (struct sock *)msk; 294 LIST_HEAD(free_list); 295 296 pr_debug("msk=%p", msk); 297 298 spin_lock_bh(&msk->pm.lock); 299 list_splice_init(&msk->pm.anno_list, &free_list); 300 spin_unlock_bh(&msk->pm.lock); 301 302 list_for_each_entry_safe(entry, tmp, &free_list, list) { 303 sk_stop_timer_sync(sk, &entry->add_timer); 304 kfree(entry); 305 } 306 } 307 308 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) 309 { 310 struct mptcp_addr_info remote = { 0 }; 311 struct sock *sk = (struct sock *)msk; 312 struct mptcp_pm_addr_entry *local; 313 struct pm_nl_pernet *pernet; 314 315 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 316 317 pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", 318 msk->pm.local_addr_used, msk->pm.local_addr_max, 319 msk->pm.add_addr_signaled, msk->pm.add_addr_signal_max, 320 msk->pm.subflows, msk->pm.subflows_max); 321 322 /* check first for announce */ 323 if (msk->pm.add_addr_signaled < msk->pm.add_addr_signal_max) { 324 local = select_signal_address(pernet, 325 msk->pm.add_addr_signaled); 326 327 if (local) { 328 if (mptcp_pm_alloc_anno_list(msk, local)) { 329 msk->pm.add_addr_signaled++; 330 mptcp_pm_announce_addr(msk, &local->addr, false); 331 } 332 } else { 333 /* pick failed, avoid fourther attempts later */ 334 msk->pm.local_addr_used = msk->pm.add_addr_signal_max; 335 } 336 337 check_work_pending(msk); 338 } 339 340 /* check if should create a new subflow */ 341 if (msk->pm.local_addr_used < msk->pm.local_addr_max && 342 msk->pm.subflows < msk->pm.subflows_max) { 343 remote_address((struct sock_common *)sk, &remote); 344 345 local = select_local_address(pernet, msk); 346 if (local) { 347 msk->pm.local_addr_used++; 348 msk->pm.subflows++; 349 check_work_pending(msk); 350 spin_unlock_bh(&msk->pm.lock); 351 __mptcp_subflow_connect(sk, &local->addr, &remote); 352 spin_lock_bh(&msk->pm.lock); 353 return; 354 } 355 356 /* lookup failed, avoid fourther attempts later */ 357 msk->pm.local_addr_used = msk->pm.local_addr_max; 358 check_work_pending(msk); 359 } 360 } 361 362 void mptcp_pm_nl_fully_established(struct mptcp_sock *msk) 363 { 364 mptcp_pm_create_subflow_or_signal_addr(msk); 365 } 366 367 void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) 368 { 369 mptcp_pm_create_subflow_or_signal_addr(msk); 370 } 371 372 void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) 373 { 374 struct sock *sk = (struct sock *)msk; 375 struct mptcp_addr_info remote; 376 struct mptcp_addr_info local; 377 378 pr_debug("accepted %d:%d remote family %d", 379 msk->pm.add_addr_accepted, msk->pm.add_addr_accept_max, 380 msk->pm.remote.family); 381 msk->pm.add_addr_accepted++; 382 msk->pm.subflows++; 383 if (msk->pm.add_addr_accepted >= msk->pm.add_addr_accept_max || 384 msk->pm.subflows >= msk->pm.subflows_max) 385 WRITE_ONCE(msk->pm.accept_addr, false); 386 387 /* connect to the specified remote address, using whatever 388 * local address the routing configuration will pick. 389 */ 390 remote = msk->pm.remote; 391 if (!remote.port) 392 remote.port = sk->sk_dport; 393 memset(&local, 0, sizeof(local)); 394 local.family = remote.family; 395 396 spin_unlock_bh(&msk->pm.lock); 397 __mptcp_subflow_connect((struct sock *)msk, &local, &remote); 398 spin_lock_bh(&msk->pm.lock); 399 400 mptcp_pm_announce_addr(msk, &remote, true); 401 } 402 403 void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk) 404 { 405 struct mptcp_subflow_context *subflow, *tmp; 406 struct sock *sk = (struct sock *)msk; 407 408 pr_debug("address rm_id %d", msk->pm.rm_id); 409 410 if (!msk->pm.rm_id) 411 return; 412 413 if (list_empty(&msk->conn_list)) 414 return; 415 416 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 417 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 418 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 419 420 if (msk->pm.rm_id != subflow->remote_id) 421 continue; 422 423 spin_unlock_bh(&msk->pm.lock); 424 mptcp_subflow_shutdown(sk, ssk, how); 425 __mptcp_close_ssk(sk, ssk, subflow); 426 spin_lock_bh(&msk->pm.lock); 427 428 msk->pm.add_addr_accepted--; 429 msk->pm.subflows--; 430 WRITE_ONCE(msk->pm.accept_addr, true); 431 432 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR); 433 434 break; 435 } 436 } 437 438 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id) 439 { 440 struct mptcp_subflow_context *subflow, *tmp; 441 struct sock *sk = (struct sock *)msk; 442 443 pr_debug("subflow rm_id %d", rm_id); 444 445 if (!rm_id) 446 return; 447 448 if (list_empty(&msk->conn_list)) 449 return; 450 451 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 452 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 453 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 454 455 if (rm_id != subflow->local_id) 456 continue; 457 458 spin_unlock_bh(&msk->pm.lock); 459 mptcp_subflow_shutdown(sk, ssk, how); 460 __mptcp_close_ssk(sk, ssk, subflow); 461 spin_lock_bh(&msk->pm.lock); 462 463 msk->pm.local_addr_used--; 464 msk->pm.subflows--; 465 466 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); 467 468 break; 469 } 470 } 471 472 static bool address_use_port(struct mptcp_pm_addr_entry *entry) 473 { 474 return (entry->addr.flags & 475 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) == 476 MPTCP_PM_ADDR_FLAG_SIGNAL; 477 } 478 479 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, 480 struct mptcp_pm_addr_entry *entry) 481 { 482 struct mptcp_pm_addr_entry *cur; 483 int ret = -EINVAL; 484 485 spin_lock_bh(&pernet->lock); 486 /* to keep the code simple, don't do IDR-like allocation for address ID, 487 * just bail when we exceed limits 488 */ 489 if (pernet->next_id > 255) 490 goto out; 491 if (pernet->addrs >= MPTCP_PM_ADDR_MAX) 492 goto out; 493 494 /* do not insert duplicate address, differentiate on port only 495 * singled addresses 496 */ 497 list_for_each_entry(cur, &pernet->local_addr_list, list) { 498 if (addresses_equal(&cur->addr, &entry->addr, 499 address_use_port(entry) && 500 address_use_port(cur))) 501 goto out; 502 } 503 504 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) 505 pernet->add_addr_signal_max++; 506 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) 507 pernet->local_addr_max++; 508 509 entry->addr.id = pernet->next_id++; 510 pernet->addrs++; 511 list_add_tail_rcu(&entry->list, &pernet->local_addr_list); 512 ret = entry->addr.id; 513 514 out: 515 spin_unlock_bh(&pernet->lock); 516 return ret; 517 } 518 519 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) 520 { 521 struct mptcp_pm_addr_entry *entry; 522 struct mptcp_addr_info skc_local; 523 struct mptcp_addr_info msk_local; 524 struct pm_nl_pernet *pernet; 525 int ret = -1; 526 527 if (WARN_ON_ONCE(!msk)) 528 return -1; 529 530 /* The 0 ID mapping is defined by the first subflow, copied into the msk 531 * addr 532 */ 533 local_address((struct sock_common *)msk, &msk_local); 534 local_address((struct sock_common *)skc, &skc_local); 535 if (addresses_equal(&msk_local, &skc_local, false)) 536 return 0; 537 538 if (address_zero(&skc_local)) 539 return 0; 540 541 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 542 543 rcu_read_lock(); 544 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 545 if (addresses_equal(&entry->addr, &skc_local, false)) { 546 ret = entry->addr.id; 547 break; 548 } 549 } 550 rcu_read_unlock(); 551 if (ret >= 0) 552 return ret; 553 554 /* address not found, add to local list */ 555 entry = kmalloc(sizeof(*entry), GFP_ATOMIC); 556 if (!entry) 557 return -ENOMEM; 558 559 entry->addr = skc_local; 560 entry->addr.ifindex = 0; 561 entry->addr.flags = 0; 562 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 563 if (ret < 0) 564 kfree(entry); 565 566 return ret; 567 } 568 569 void mptcp_pm_nl_data_init(struct mptcp_sock *msk) 570 { 571 struct mptcp_pm_data *pm = &msk->pm; 572 struct pm_nl_pernet *pernet; 573 bool subflows; 574 575 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 576 577 pm->add_addr_signal_max = READ_ONCE(pernet->add_addr_signal_max); 578 pm->add_addr_accept_max = READ_ONCE(pernet->add_addr_accept_max); 579 pm->local_addr_max = READ_ONCE(pernet->local_addr_max); 580 pm->subflows_max = READ_ONCE(pernet->subflows_max); 581 subflows = !!pm->subflows_max; 582 WRITE_ONCE(pm->work_pending, (!!pm->local_addr_max && subflows) || 583 !!pm->add_addr_signal_max); 584 WRITE_ONCE(pm->accept_addr, !!pm->add_addr_accept_max && subflows); 585 WRITE_ONCE(pm->accept_subflow, subflows); 586 } 587 588 #define MPTCP_PM_CMD_GRP_OFFSET 0 589 590 static const struct genl_multicast_group mptcp_pm_mcgrps[] = { 591 [MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, }, 592 }; 593 594 static const struct nla_policy 595 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = { 596 [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type = NLA_U16, }, 597 [MPTCP_PM_ADDR_ATTR_ID] = { .type = NLA_U8, }, 598 [MPTCP_PM_ADDR_ATTR_ADDR4] = { .type = NLA_U32, }, 599 [MPTCP_PM_ADDR_ATTR_ADDR6] = 600 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), 601 [MPTCP_PM_ADDR_ATTR_PORT] = { .type = NLA_U16 }, 602 [MPTCP_PM_ADDR_ATTR_FLAGS] = { .type = NLA_U32 }, 603 [MPTCP_PM_ADDR_ATTR_IF_IDX] = { .type = NLA_S32 }, 604 }; 605 606 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = { 607 [MPTCP_PM_ATTR_ADDR] = 608 NLA_POLICY_NESTED(mptcp_pm_addr_policy), 609 [MPTCP_PM_ATTR_RCV_ADD_ADDRS] = { .type = NLA_U32, }, 610 [MPTCP_PM_ATTR_SUBFLOWS] = { .type = NLA_U32, }, 611 }; 612 613 static int mptcp_pm_family_to_addr(int family) 614 { 615 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 616 if (family == AF_INET6) 617 return MPTCP_PM_ADDR_ATTR_ADDR6; 618 #endif 619 return MPTCP_PM_ADDR_ATTR_ADDR4; 620 } 621 622 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, 623 bool require_family, 624 struct mptcp_pm_addr_entry *entry) 625 { 626 struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; 627 int err, addr_addr; 628 629 if (!attr) { 630 GENL_SET_ERR_MSG(info, "missing address info"); 631 return -EINVAL; 632 } 633 634 /* no validation needed - was already done via nested policy */ 635 err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr, 636 mptcp_pm_addr_policy, info->extack); 637 if (err) 638 return err; 639 640 memset(entry, 0, sizeof(*entry)); 641 if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) { 642 if (!require_family) 643 goto skip_family; 644 645 NL_SET_ERR_MSG_ATTR(info->extack, attr, 646 "missing family"); 647 return -EINVAL; 648 } 649 650 entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]); 651 if (entry->addr.family != AF_INET 652 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 653 && entry->addr.family != AF_INET6 654 #endif 655 ) { 656 NL_SET_ERR_MSG_ATTR(info->extack, attr, 657 "unknown address family"); 658 return -EINVAL; 659 } 660 addr_addr = mptcp_pm_family_to_addr(entry->addr.family); 661 if (!tb[addr_addr]) { 662 NL_SET_ERR_MSG_ATTR(info->extack, attr, 663 "missing address data"); 664 return -EINVAL; 665 } 666 667 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 668 if (entry->addr.family == AF_INET6) 669 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]); 670 else 671 #endif 672 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]); 673 674 skip_family: 675 if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) { 676 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]); 677 678 entry->addr.ifindex = val; 679 } 680 681 if (tb[MPTCP_PM_ADDR_ATTR_ID]) 682 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]); 683 684 if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) 685 entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); 686 687 return 0; 688 } 689 690 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) 691 { 692 return net_generic(genl_info_net(info), pm_nl_pernet_id); 693 } 694 695 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) 696 { 697 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 698 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 699 struct mptcp_pm_addr_entry addr, *entry; 700 int ret; 701 702 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 703 if (ret < 0) 704 return ret; 705 706 entry = kmalloc(sizeof(*entry), GFP_KERNEL); 707 if (!entry) { 708 GENL_SET_ERR_MSG(info, "can't allocate addr"); 709 return -ENOMEM; 710 } 711 712 *entry = addr; 713 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 714 if (ret < 0) { 715 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); 716 kfree(entry); 717 return ret; 718 } 719 720 return 0; 721 } 722 723 static struct mptcp_pm_addr_entry * 724 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id) 725 { 726 struct mptcp_pm_addr_entry *entry; 727 728 list_for_each_entry(entry, &pernet->local_addr_list, list) { 729 if (entry->addr.id == id) 730 return entry; 731 } 732 return NULL; 733 } 734 735 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, 736 struct mptcp_addr_info *addr) 737 { 738 struct mptcp_pm_add_entry *entry; 739 740 entry = mptcp_pm_del_add_timer(msk, addr); 741 if (entry) { 742 list_del(&entry->list); 743 kfree(entry); 744 return true; 745 } 746 747 return false; 748 } 749 750 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, 751 struct mptcp_addr_info *addr, 752 bool force) 753 { 754 bool ret; 755 756 ret = remove_anno_list_by_saddr(msk, addr); 757 if (ret || force) { 758 spin_lock_bh(&msk->pm.lock); 759 mptcp_pm_remove_addr(msk, addr->id); 760 spin_unlock_bh(&msk->pm.lock); 761 } 762 return ret; 763 } 764 765 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, 766 struct mptcp_addr_info *addr) 767 { 768 struct mptcp_sock *msk; 769 long s_slot = 0, s_num = 0; 770 771 pr_debug("remove_id=%d", addr->id); 772 773 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 774 struct sock *sk = (struct sock *)msk; 775 bool remove_subflow; 776 777 if (list_empty(&msk->conn_list)) { 778 mptcp_pm_remove_anno_addr(msk, addr, false); 779 goto next; 780 } 781 782 lock_sock(sk); 783 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); 784 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow); 785 if (remove_subflow) 786 mptcp_pm_remove_subflow(msk, addr->id); 787 release_sock(sk); 788 789 next: 790 sock_put(sk); 791 cond_resched(); 792 } 793 794 return 0; 795 } 796 797 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) 798 { 799 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 800 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 801 struct mptcp_pm_addr_entry addr, *entry; 802 int ret; 803 804 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 805 if (ret < 0) 806 return ret; 807 808 spin_lock_bh(&pernet->lock); 809 entry = __lookup_addr_by_id(pernet, addr.addr.id); 810 if (!entry) { 811 GENL_SET_ERR_MSG(info, "address not found"); 812 spin_unlock_bh(&pernet->lock); 813 return -EINVAL; 814 } 815 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) 816 pernet->add_addr_signal_max--; 817 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) 818 pernet->local_addr_max--; 819 820 pernet->addrs--; 821 list_del_rcu(&entry->list); 822 spin_unlock_bh(&pernet->lock); 823 824 mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); 825 kfree_rcu(entry, rcu); 826 827 return ret; 828 } 829 830 static void __flush_addrs(struct pm_nl_pernet *pernet) 831 { 832 while (!list_empty(&pernet->local_addr_list)) { 833 struct mptcp_pm_addr_entry *cur; 834 835 cur = list_entry(pernet->local_addr_list.next, 836 struct mptcp_pm_addr_entry, list); 837 list_del_rcu(&cur->list); 838 kfree_rcu(cur, rcu); 839 } 840 } 841 842 static void __reset_counters(struct pm_nl_pernet *pernet) 843 { 844 pernet->add_addr_signal_max = 0; 845 pernet->add_addr_accept_max = 0; 846 pernet->local_addr_max = 0; 847 pernet->addrs = 0; 848 } 849 850 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info) 851 { 852 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 853 854 spin_lock_bh(&pernet->lock); 855 __flush_addrs(pernet); 856 __reset_counters(pernet); 857 spin_unlock_bh(&pernet->lock); 858 return 0; 859 } 860 861 static int mptcp_nl_fill_addr(struct sk_buff *skb, 862 struct mptcp_pm_addr_entry *entry) 863 { 864 struct mptcp_addr_info *addr = &entry->addr; 865 struct nlattr *attr; 866 867 attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR); 868 if (!attr) 869 return -EMSGSIZE; 870 871 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family)) 872 goto nla_put_failure; 873 if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id)) 874 goto nla_put_failure; 875 if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags)) 876 goto nla_put_failure; 877 if (entry->addr.ifindex && 878 nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex)) 879 goto nla_put_failure; 880 881 if (addr->family == AF_INET && 882 nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4, 883 addr->addr.s_addr)) 884 goto nla_put_failure; 885 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 886 else if (addr->family == AF_INET6 && 887 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6)) 888 goto nla_put_failure; 889 #endif 890 nla_nest_end(skb, attr); 891 return 0; 892 893 nla_put_failure: 894 nla_nest_cancel(skb, attr); 895 return -EMSGSIZE; 896 } 897 898 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info) 899 { 900 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 901 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 902 struct mptcp_pm_addr_entry addr, *entry; 903 struct sk_buff *msg; 904 void *reply; 905 int ret; 906 907 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 908 if (ret < 0) 909 return ret; 910 911 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 912 if (!msg) 913 return -ENOMEM; 914 915 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 916 info->genlhdr->cmd); 917 if (!reply) { 918 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 919 ret = -EMSGSIZE; 920 goto fail; 921 } 922 923 spin_lock_bh(&pernet->lock); 924 entry = __lookup_addr_by_id(pernet, addr.addr.id); 925 if (!entry) { 926 GENL_SET_ERR_MSG(info, "address not found"); 927 ret = -EINVAL; 928 goto unlock_fail; 929 } 930 931 ret = mptcp_nl_fill_addr(msg, entry); 932 if (ret) 933 goto unlock_fail; 934 935 genlmsg_end(msg, reply); 936 ret = genlmsg_reply(msg, info); 937 spin_unlock_bh(&pernet->lock); 938 return ret; 939 940 unlock_fail: 941 spin_unlock_bh(&pernet->lock); 942 943 fail: 944 nlmsg_free(msg); 945 return ret; 946 } 947 948 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg, 949 struct netlink_callback *cb) 950 { 951 struct net *net = sock_net(msg->sk); 952 struct mptcp_pm_addr_entry *entry; 953 struct pm_nl_pernet *pernet; 954 int id = cb->args[0]; 955 void *hdr; 956 957 pernet = net_generic(net, pm_nl_pernet_id); 958 959 spin_lock_bh(&pernet->lock); 960 list_for_each_entry(entry, &pernet->local_addr_list, list) { 961 if (entry->addr.id <= id) 962 continue; 963 964 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid, 965 cb->nlh->nlmsg_seq, &mptcp_genl_family, 966 NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR); 967 if (!hdr) 968 break; 969 970 if (mptcp_nl_fill_addr(msg, entry) < 0) { 971 genlmsg_cancel(msg, hdr); 972 break; 973 } 974 975 id = entry->addr.id; 976 genlmsg_end(msg, hdr); 977 } 978 spin_unlock_bh(&pernet->lock); 979 980 cb->args[0] = id; 981 return msg->len; 982 } 983 984 static int parse_limit(struct genl_info *info, int id, unsigned int *limit) 985 { 986 struct nlattr *attr = info->attrs[id]; 987 988 if (!attr) 989 return 0; 990 991 *limit = nla_get_u32(attr); 992 if (*limit > MPTCP_PM_ADDR_MAX) { 993 GENL_SET_ERR_MSG(info, "limit greater than maximum"); 994 return -EINVAL; 995 } 996 return 0; 997 } 998 999 static int 1000 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info) 1001 { 1002 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1003 unsigned int rcv_addrs, subflows; 1004 int ret; 1005 1006 spin_lock_bh(&pernet->lock); 1007 rcv_addrs = pernet->add_addr_accept_max; 1008 ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs); 1009 if (ret) 1010 goto unlock; 1011 1012 subflows = pernet->subflows_max; 1013 ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows); 1014 if (ret) 1015 goto unlock; 1016 1017 WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs); 1018 WRITE_ONCE(pernet->subflows_max, subflows); 1019 1020 unlock: 1021 spin_unlock_bh(&pernet->lock); 1022 return ret; 1023 } 1024 1025 static int 1026 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info) 1027 { 1028 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1029 struct sk_buff *msg; 1030 void *reply; 1031 1032 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1033 if (!msg) 1034 return -ENOMEM; 1035 1036 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1037 MPTCP_PM_CMD_GET_LIMITS); 1038 if (!reply) 1039 goto fail; 1040 1041 if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS, 1042 READ_ONCE(pernet->add_addr_accept_max))) 1043 goto fail; 1044 1045 if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS, 1046 READ_ONCE(pernet->subflows_max))) 1047 goto fail; 1048 1049 genlmsg_end(msg, reply); 1050 return genlmsg_reply(msg, info); 1051 1052 fail: 1053 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1054 nlmsg_free(msg); 1055 return -EMSGSIZE; 1056 } 1057 1058 static const struct genl_small_ops mptcp_pm_ops[] = { 1059 { 1060 .cmd = MPTCP_PM_CMD_ADD_ADDR, 1061 .doit = mptcp_nl_cmd_add_addr, 1062 .flags = GENL_ADMIN_PERM, 1063 }, 1064 { 1065 .cmd = MPTCP_PM_CMD_DEL_ADDR, 1066 .doit = mptcp_nl_cmd_del_addr, 1067 .flags = GENL_ADMIN_PERM, 1068 }, 1069 { 1070 .cmd = MPTCP_PM_CMD_FLUSH_ADDRS, 1071 .doit = mptcp_nl_cmd_flush_addrs, 1072 .flags = GENL_ADMIN_PERM, 1073 }, 1074 { 1075 .cmd = MPTCP_PM_CMD_GET_ADDR, 1076 .doit = mptcp_nl_cmd_get_addr, 1077 .dumpit = mptcp_nl_cmd_dump_addrs, 1078 }, 1079 { 1080 .cmd = MPTCP_PM_CMD_SET_LIMITS, 1081 .doit = mptcp_nl_cmd_set_limits, 1082 .flags = GENL_ADMIN_PERM, 1083 }, 1084 { 1085 .cmd = MPTCP_PM_CMD_GET_LIMITS, 1086 .doit = mptcp_nl_cmd_get_limits, 1087 }, 1088 }; 1089 1090 static struct genl_family mptcp_genl_family __ro_after_init = { 1091 .name = MPTCP_PM_NAME, 1092 .version = MPTCP_PM_VER, 1093 .maxattr = MPTCP_PM_ATTR_MAX, 1094 .policy = mptcp_pm_policy, 1095 .netnsok = true, 1096 .module = THIS_MODULE, 1097 .small_ops = mptcp_pm_ops, 1098 .n_small_ops = ARRAY_SIZE(mptcp_pm_ops), 1099 .mcgrps = mptcp_pm_mcgrps, 1100 .n_mcgrps = ARRAY_SIZE(mptcp_pm_mcgrps), 1101 }; 1102 1103 static int __net_init pm_nl_init_net(struct net *net) 1104 { 1105 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1106 1107 INIT_LIST_HEAD_RCU(&pernet->local_addr_list); 1108 __reset_counters(pernet); 1109 pernet->next_id = 1; 1110 spin_lock_init(&pernet->lock); 1111 return 0; 1112 } 1113 1114 static void __net_exit pm_nl_exit_net(struct list_head *net_list) 1115 { 1116 struct net *net; 1117 1118 list_for_each_entry(net, net_list, exit_list) { 1119 /* net is removed from namespace list, can't race with 1120 * other modifiers 1121 */ 1122 __flush_addrs(net_generic(net, pm_nl_pernet_id)); 1123 } 1124 } 1125 1126 static struct pernet_operations mptcp_pm_pernet_ops = { 1127 .init = pm_nl_init_net, 1128 .exit_batch = pm_nl_exit_net, 1129 .id = &pm_nl_pernet_id, 1130 .size = sizeof(struct pm_nl_pernet), 1131 }; 1132 1133 void __init mptcp_pm_nl_init(void) 1134 { 1135 if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0) 1136 panic("Failed to register MPTCP PM pernet subsystem.\n"); 1137 1138 if (genl_register_family(&mptcp_genl_family)) 1139 panic("Failed to register MPTCP PM netlink family\n"); 1140 } 1141