1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020, Tessares SA. */ 3 /* Copyright (c) 2022, SUSE. */ 4 5 #include <linux/const.h> 6 #include <netinet/in.h> 7 #include <test_progs.h> 8 #include <unistd.h> 9 #include <errno.h> 10 #include "cgroup_helpers.h" 11 #include "network_helpers.h" 12 #include "mptcp_sock.skel.h" 13 #include "mptcpify.skel.h" 14 #include "mptcp_subflow.skel.h" 15 #include "mptcp_sockmap.skel.h" 16 17 #define NS_TEST "mptcp_ns" 18 #define ADDR_1 "10.0.1.1" 19 #define ADDR_2 "10.0.1.2" 20 #define PORT_1 10001 21 22 #ifndef IPPROTO_MPTCP 23 #define IPPROTO_MPTCP 262 24 #endif 25 26 #ifndef SOL_MPTCP 27 #define SOL_MPTCP 284 28 #endif 29 #ifndef MPTCP_INFO 30 #define MPTCP_INFO 1 31 #endif 32 #ifndef MPTCP_INFO_FLAG_FALLBACK 33 #define MPTCP_INFO_FLAG_FALLBACK _BITUL(0) 34 #endif 35 #ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED 36 #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1) 37 #endif 38 39 #ifndef TCP_CA_NAME_MAX 40 #define TCP_CA_NAME_MAX 16 41 #endif 42 43 struct __mptcp_info { 44 __u8 mptcpi_subflows; 45 __u8 mptcpi_add_addr_signal; 46 __u8 mptcpi_add_addr_accepted; 47 __u8 mptcpi_subflows_max; 48 __u8 mptcpi_add_addr_signal_max; 49 __u8 mptcpi_add_addr_accepted_max; 50 __u32 mptcpi_flags; 51 __u32 mptcpi_token; 52 __u64 mptcpi_write_seq; 53 __u64 mptcpi_snd_una; 54 __u64 mptcpi_rcv_nxt; 55 __u8 mptcpi_local_addr_used; 56 __u8 mptcpi_local_addr_max; 57 __u8 mptcpi_csum_enabled; 58 __u32 mptcpi_retransmits; 59 __u64 mptcpi_bytes_retrans; 60 __u64 mptcpi_bytes_sent; 61 __u64 mptcpi_bytes_received; 62 __u64 mptcpi_bytes_acked; 63 }; 64 65 struct mptcp_storage { 66 __u32 invoked; 67 __u32 is_mptcp; 68 struct sock *sk; 69 __u32 token; 70 struct sock *first; 71 char ca_name[TCP_CA_NAME_MAX]; 72 }; 73 74 static int start_mptcp_server(int family, const char *addr_str, __u16 port, 75 int timeout_ms) 76 { 77 struct network_helper_opts opts = { 78 .timeout_ms = timeout_ms, 79 .proto = IPPROTO_MPTCP, 80 }; 81 82 return start_server_str(family, SOCK_STREAM, addr_str, port, &opts); 83 } 84 85 static int verify_tsk(int map_fd, int client_fd) 86 { 87 int err, cfd = client_fd; 88 struct mptcp_storage val; 89 90 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 91 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 92 return err; 93 94 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 95 err++; 96 97 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp")) 98 err++; 99 100 return err; 101 } 102 103 static void get_msk_ca_name(char ca_name[]) 104 { 105 size_t len; 106 int fd; 107 108 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY); 109 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control")) 110 return; 111 112 len = read(fd, ca_name, TCP_CA_NAME_MAX); 113 if (!ASSERT_GT(len, 0, "failed to read ca_name")) 114 goto err; 115 116 if (len > 0 && ca_name[len - 1] == '\n') 117 ca_name[len - 1] = '\0'; 118 119 err: 120 close(fd); 121 } 122 123 static int verify_msk(int map_fd, int client_fd, __u32 token) 124 { 125 char ca_name[TCP_CA_NAME_MAX]; 126 int err, cfd = client_fd; 127 struct mptcp_storage val; 128 129 if (!ASSERT_GT(token, 0, "invalid token")) 130 return -1; 131 132 get_msk_ca_name(ca_name); 133 134 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 135 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 136 return err; 137 138 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 139 err++; 140 141 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp")) 142 err++; 143 144 if (!ASSERT_EQ(val.token, token, "unexpected token")) 145 err++; 146 147 if (!ASSERT_EQ(val.first, val.sk, "unexpected first")) 148 err++; 149 150 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name")) 151 err++; 152 153 return err; 154 } 155 156 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp) 157 { 158 int client_fd, prog_fd, map_fd, err; 159 struct mptcp_sock *sock_skel; 160 161 sock_skel = mptcp_sock__open_and_load(); 162 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load")) 163 return libbpf_get_error(sock_skel); 164 165 err = mptcp_sock__attach(sock_skel); 166 if (!ASSERT_OK(err, "skel_attach")) 167 goto out; 168 169 prog_fd = bpf_program__fd(sock_skel->progs._sockops); 170 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map); 171 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); 172 if (!ASSERT_OK(err, "bpf_prog_attach")) 173 goto out; 174 175 client_fd = connect_to_fd(server_fd, 0); 176 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 177 err = -EIO; 178 goto out; 179 } 180 181 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) : 182 verify_tsk(map_fd, client_fd); 183 184 close(client_fd); 185 186 out: 187 mptcp_sock__destroy(sock_skel); 188 return err; 189 } 190 191 static void test_base(void) 192 { 193 struct netns_obj *netns = NULL; 194 int server_fd, cgroup_fd; 195 196 cgroup_fd = test__join_cgroup("/mptcp"); 197 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 198 return; 199 200 netns = netns_new(NS_TEST, true); 201 if (!ASSERT_OK_PTR(netns, "netns_new")) 202 goto fail; 203 204 /* without MPTCP */ 205 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 206 if (!ASSERT_GE(server_fd, 0, "start_server")) 207 goto with_mptcp; 208 209 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp"); 210 211 close(server_fd); 212 213 with_mptcp: 214 /* with MPTCP */ 215 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0); 216 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) 217 goto fail; 218 219 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp"); 220 221 close(server_fd); 222 223 fail: 224 netns_free(netns); 225 close(cgroup_fd); 226 } 227 228 static void send_byte(int fd) 229 { 230 char b = 0x55; 231 232 ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte"); 233 } 234 235 static int verify_mptcpify(int server_fd, int client_fd) 236 { 237 struct __mptcp_info info; 238 socklen_t optlen; 239 int protocol; 240 int err = 0; 241 242 optlen = sizeof(protocol); 243 if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen), 244 "getsockopt(SOL_PROTOCOL)")) 245 return -1; 246 247 if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP")) 248 err++; 249 250 optlen = sizeof(info); 251 if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen), 252 "getsockopt(MPTCP_INFO)")) 253 return -1; 254 255 if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags")) 256 err++; 257 if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK, 258 "MPTCP fallback")) 259 err++; 260 if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED, 261 "no remote key received")) 262 err++; 263 264 return err; 265 } 266 267 static int run_mptcpify(int cgroup_fd) 268 { 269 int server_fd, client_fd, err = 0; 270 struct mptcpify *mptcpify_skel; 271 272 mptcpify_skel = mptcpify__open_and_load(); 273 if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load")) 274 return libbpf_get_error(mptcpify_skel); 275 276 mptcpify_skel->bss->pid = getpid(); 277 278 err = mptcpify__attach(mptcpify_skel); 279 if (!ASSERT_OK(err, "skel_attach")) 280 goto out; 281 282 /* without MPTCP */ 283 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 284 if (!ASSERT_GE(server_fd, 0, "start_server")) { 285 err = -EIO; 286 goto out; 287 } 288 289 client_fd = connect_to_fd(server_fd, 0); 290 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 291 err = -EIO; 292 goto close_server; 293 } 294 295 send_byte(client_fd); 296 297 err = verify_mptcpify(server_fd, client_fd); 298 299 close(client_fd); 300 close_server: 301 close(server_fd); 302 out: 303 mptcpify__destroy(mptcpify_skel); 304 return err; 305 } 306 307 static void test_mptcpify(void) 308 { 309 struct netns_obj *netns = NULL; 310 int cgroup_fd; 311 312 cgroup_fd = test__join_cgroup("/mptcpify"); 313 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 314 return; 315 316 netns = netns_new(NS_TEST, true); 317 if (!ASSERT_OK_PTR(netns, "netns_new")) 318 goto fail; 319 320 ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify"); 321 322 fail: 323 netns_free(netns); 324 close(cgroup_fd); 325 } 326 327 static int endpoint_init(char *flags) 328 { 329 SYS(fail, "ip -net %s link add veth1 type veth peer name veth2", NS_TEST); 330 SYS(fail, "ip -net %s addr add %s/24 dev veth1", NS_TEST, ADDR_1); 331 SYS(fail, "ip -net %s link set dev veth1 up", NS_TEST); 332 SYS(fail, "ip -net %s addr add %s/24 dev veth2", NS_TEST, ADDR_2); 333 SYS(fail, "ip -net %s link set dev veth2 up", NS_TEST); 334 if (SYS_NOFAIL("ip -net %s mptcp endpoint add %s %s", NS_TEST, ADDR_2, flags)) { 335 printf("'ip mptcp' not supported, skip this test.\n"); 336 test__skip(); 337 goto fail; 338 } 339 340 return 0; 341 fail: 342 return -1; 343 } 344 345 static void wait_for_new_subflows(int fd) 346 { 347 socklen_t len; 348 u8 subflows; 349 int err, i; 350 351 len = sizeof(subflows); 352 /* Wait max 5 sec for new subflows to be created */ 353 for (i = 0; i < 50; i++) { 354 err = getsockopt(fd, SOL_MPTCP, MPTCP_INFO, &subflows, &len); 355 if (!err && subflows > 0) 356 break; 357 358 usleep(100000); /* 0.1s */ 359 } 360 } 361 362 static void run_subflow(void) 363 { 364 int server_fd, client_fd, err; 365 char new[TCP_CA_NAME_MAX]; 366 char cc[TCP_CA_NAME_MAX]; 367 unsigned int mark; 368 socklen_t len; 369 370 server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, 0); 371 if (!ASSERT_OK_FD(server_fd, "start_mptcp_server")) 372 return; 373 374 client_fd = connect_to_fd(server_fd, 0); 375 if (!ASSERT_OK_FD(client_fd, "connect_to_fd")) 376 goto close_server; 377 378 send_byte(client_fd); 379 wait_for_new_subflows(client_fd); 380 381 len = sizeof(mark); 382 err = getsockopt(client_fd, SOL_SOCKET, SO_MARK, &mark, &len); 383 if (ASSERT_OK(err, "getsockopt(client_fd, SO_MARK)")) 384 ASSERT_EQ(mark, 0, "mark"); 385 386 len = sizeof(new); 387 err = getsockopt(client_fd, SOL_TCP, TCP_CONGESTION, new, &len); 388 if (ASSERT_OK(err, "getsockopt(client_fd, TCP_CONGESTION)")) { 389 get_msk_ca_name(cc); 390 ASSERT_STREQ(new, cc, "cc"); 391 } 392 393 close(client_fd); 394 close_server: 395 close(server_fd); 396 } 397 398 static void test_subflow(void) 399 { 400 struct mptcp_subflow *skel; 401 struct netns_obj *netns; 402 int cgroup_fd; 403 404 cgroup_fd = test__join_cgroup("/mptcp_subflow"); 405 if (!ASSERT_OK_FD(cgroup_fd, "join_cgroup: mptcp_subflow")) 406 return; 407 408 skel = mptcp_subflow__open_and_load(); 409 if (!ASSERT_OK_PTR(skel, "skel_open_load: mptcp_subflow")) 410 goto close_cgroup; 411 412 skel->bss->pid = getpid(); 413 414 skel->links.mptcp_subflow = 415 bpf_program__attach_cgroup(skel->progs.mptcp_subflow, cgroup_fd); 416 if (!ASSERT_OK_PTR(skel->links.mptcp_subflow, "attach mptcp_subflow")) 417 goto skel_destroy; 418 419 skel->links._getsockopt_subflow = 420 bpf_program__attach_cgroup(skel->progs._getsockopt_subflow, cgroup_fd); 421 if (!ASSERT_OK_PTR(skel->links._getsockopt_subflow, "attach _getsockopt_subflow")) 422 goto skel_destroy; 423 424 netns = netns_new(NS_TEST, true); 425 if (!ASSERT_OK_PTR(netns, "netns_new: mptcp_subflow")) 426 goto skel_destroy; 427 428 if (endpoint_init("subflow") < 0) 429 goto close_netns; 430 431 run_subflow(); 432 433 close_netns: 434 netns_free(netns); 435 skel_destroy: 436 mptcp_subflow__destroy(skel); 437 close_cgroup: 438 close(cgroup_fd); 439 } 440 441 /* Test sockmap on MPTCP server handling non-mp-capable clients. */ 442 static void test_sockmap_with_mptcp_fallback(struct mptcp_sockmap *skel) 443 { 444 int listen_fd = -1, client_fd1 = -1, client_fd2 = -1; 445 int server_fd1 = -1, server_fd2 = -1, sent, recvd; 446 char snd[9] = "123456789"; 447 char rcv[10]; 448 449 /* start server with MPTCP enabled */ 450 listen_fd = start_mptcp_server(AF_INET, NULL, 0, 0); 451 if (!ASSERT_OK_FD(listen_fd, "sockmap-fb:start_mptcp_server")) 452 return; 453 454 skel->bss->trace_port = ntohs(get_socket_local_port(listen_fd)); 455 skel->bss->sk_index = 0; 456 /* create client without MPTCP enabled */ 457 client_fd1 = connect_to_fd_opts(listen_fd, NULL); 458 if (!ASSERT_OK_FD(client_fd1, "sockmap-fb:connect_to_fd")) 459 goto end; 460 461 server_fd1 = accept(listen_fd, NULL, 0); 462 skel->bss->sk_index = 1; 463 client_fd2 = connect_to_fd_opts(listen_fd, NULL); 464 if (!ASSERT_OK_FD(client_fd2, "sockmap-fb:connect_to_fd")) 465 goto end; 466 467 server_fd2 = accept(listen_fd, NULL, 0); 468 /* test normal redirect behavior: data sent by client_fd1 can be 469 * received by client_fd2 470 */ 471 skel->bss->redirect_idx = 1; 472 sent = send(client_fd1, snd, sizeof(snd), 0); 473 if (!ASSERT_EQ(sent, sizeof(snd), "sockmap-fb:send(client_fd1)")) 474 goto end; 475 476 /* try to recv more bytes to avoid truncation check */ 477 recvd = recv(client_fd2, rcv, sizeof(rcv), 0); 478 if (!ASSERT_EQ(recvd, sizeof(snd), "sockmap-fb:recv(client_fd2)")) 479 goto end; 480 481 end: 482 if (client_fd1 >= 0) 483 close(client_fd1); 484 if (client_fd2 >= 0) 485 close(client_fd2); 486 if (server_fd1 >= 0) 487 close(server_fd1); 488 if (server_fd2 >= 0) 489 close(server_fd2); 490 close(listen_fd); 491 } 492 493 /* Test sockmap rejection of MPTCP sockets - both server and client sides. */ 494 static void test_sockmap_reject_mptcp(struct mptcp_sockmap *skel) 495 { 496 int listen_fd = -1, server_fd = -1, client_fd1 = -1; 497 int err, zero = 0; 498 499 /* start server with MPTCP enabled */ 500 listen_fd = start_mptcp_server(AF_INET, NULL, 0, 0); 501 if (!ASSERT_OK_FD(listen_fd, "start_mptcp_server")) 502 return; 503 504 skel->bss->trace_port = ntohs(get_socket_local_port(listen_fd)); 505 skel->bss->sk_index = 0; 506 /* create client with MPTCP enabled */ 507 client_fd1 = connect_to_fd(listen_fd, 0); 508 if (!ASSERT_OK_FD(client_fd1, "connect_to_fd client_fd1")) 509 goto end; 510 511 /* bpf_sock_map_update() called from sockops should reject MPTCP sk */ 512 if (!ASSERT_EQ(skel->bss->helper_ret, -EOPNOTSUPP, "should reject")) 513 goto end; 514 515 server_fd = accept(listen_fd, NULL, 0); 516 err = bpf_map_update_elem(bpf_map__fd(skel->maps.sock_map), 517 &zero, &server_fd, BPF_NOEXIST); 518 if (!ASSERT_EQ(err, -EOPNOTSUPP, "server should be disallowed")) 519 goto end; 520 521 /* MPTCP client should also be disallowed */ 522 err = bpf_map_update_elem(bpf_map__fd(skel->maps.sock_map), 523 &zero, &client_fd1, BPF_NOEXIST); 524 if (!ASSERT_EQ(err, -EOPNOTSUPP, "client should be disallowed")) 525 goto end; 526 end: 527 if (client_fd1 >= 0) 528 close(client_fd1); 529 if (server_fd >= 0) 530 close(server_fd); 531 close(listen_fd); 532 } 533 534 static void test_mptcp_sockmap(void) 535 { 536 struct mptcp_sockmap *skel; 537 struct netns_obj *netns; 538 int cgroup_fd, err; 539 540 cgroup_fd = test__join_cgroup("/mptcp_sockmap"); 541 if (!ASSERT_OK_FD(cgroup_fd, "join_cgroup: mptcp_sockmap")) 542 return; 543 544 skel = mptcp_sockmap__open_and_load(); 545 if (!ASSERT_OK_PTR(skel, "skel_open_load: mptcp_sockmap")) 546 goto close_cgroup; 547 548 skel->links.mptcp_sockmap_inject = 549 bpf_program__attach_cgroup(skel->progs.mptcp_sockmap_inject, cgroup_fd); 550 if (!ASSERT_OK_PTR(skel->links.mptcp_sockmap_inject, "attach sockmap")) 551 goto skel_destroy; 552 553 err = bpf_prog_attach(bpf_program__fd(skel->progs.mptcp_sockmap_redirect), 554 bpf_map__fd(skel->maps.sock_map), 555 BPF_SK_SKB_STREAM_VERDICT, 0); 556 if (!ASSERT_OK(err, "bpf_prog_attach stream verdict")) 557 goto skel_destroy; 558 559 netns = netns_new(NS_TEST, true); 560 if (!ASSERT_OK_PTR(netns, "netns_new: mptcp_sockmap")) 561 goto skel_destroy; 562 563 if (endpoint_init("subflow") < 0) 564 goto close_netns; 565 566 test_sockmap_with_mptcp_fallback(skel); 567 test_sockmap_reject_mptcp(skel); 568 569 close_netns: 570 netns_free(netns); 571 skel_destroy: 572 mptcp_sockmap__destroy(skel); 573 close_cgroup: 574 close(cgroup_fd); 575 } 576 577 void test_mptcp(void) 578 { 579 if (test__start_subtest("base")) 580 test_base(); 581 if (test__start_subtest("mptcpify")) 582 test_mptcpify(); 583 if (test__start_subtest("subflow")) 584 test_subflow(); 585 if (test__start_subtest("sockmap")) 586 test_mptcp_sockmap(); 587 } 588