1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2020, Tessares SA. */
3 /* Copyright (c) 2022, SUSE. */
4
5 #include <linux/const.h>
6 #include <netinet/in.h>
7 #include <test_progs.h>
8 #include <unistd.h>
9 #include <errno.h>
10 #include "cgroup_helpers.h"
11 #include "network_helpers.h"
12 #include "mptcp_sock.skel.h"
13 #include "mptcpify.skel.h"
14 #include "mptcp_subflow.skel.h"
15 #include "mptcp_sockmap.skel.h"
16
17 #define NS_TEST "mptcp_ns"
18 #define ADDR_1 "10.0.1.1"
19 #define ADDR_2 "10.0.1.2"
20 #define PORT_1 10001
21
22 #ifndef IPPROTO_MPTCP
23 #define IPPROTO_MPTCP 262
24 #endif
25
26 #ifndef SOL_MPTCP
27 #define SOL_MPTCP 284
28 #endif
29 #ifndef MPTCP_INFO
30 #define MPTCP_INFO 1
31 #endif
32 #ifndef MPTCP_INFO_FLAG_FALLBACK
33 #define MPTCP_INFO_FLAG_FALLBACK _BITUL(0)
34 #endif
35 #ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED
36 #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1)
37 #endif
38
39 #ifndef TCP_CA_NAME_MAX
40 #define TCP_CA_NAME_MAX 16
41 #endif
42
43 struct __mptcp_info {
44 __u8 mptcpi_subflows;
45 __u8 mptcpi_add_addr_signal;
46 __u8 mptcpi_add_addr_accepted;
47 __u8 mptcpi_subflows_max;
48 __u8 mptcpi_add_addr_signal_max;
49 __u8 mptcpi_add_addr_accepted_max;
50 __u32 mptcpi_flags;
51 __u32 mptcpi_token;
52 __u64 mptcpi_write_seq;
53 __u64 mptcpi_snd_una;
54 __u64 mptcpi_rcv_nxt;
55 __u8 mptcpi_local_addr_used;
56 __u8 mptcpi_local_addr_max;
57 __u8 mptcpi_csum_enabled;
58 __u32 mptcpi_retransmits;
59 __u64 mptcpi_bytes_retrans;
60 __u64 mptcpi_bytes_sent;
61 __u64 mptcpi_bytes_received;
62 __u64 mptcpi_bytes_acked;
63 };
64
65 struct mptcp_storage {
66 __u32 invoked;
67 __u32 is_mptcp;
68 struct sock *sk;
69 __u32 token;
70 struct sock *first;
71 char ca_name[TCP_CA_NAME_MAX];
72 };
73
start_mptcp_server(int family,const char * addr_str,__u16 port,int timeout_ms)74 static int start_mptcp_server(int family, const char *addr_str, __u16 port,
75 int timeout_ms)
76 {
77 struct network_helper_opts opts = {
78 .timeout_ms = timeout_ms,
79 .proto = IPPROTO_MPTCP,
80 };
81
82 return start_server_str(family, SOCK_STREAM, addr_str, port, &opts);
83 }
84
verify_tsk(int map_fd,int client_fd)85 static int verify_tsk(int map_fd, int client_fd)
86 {
87 int err, cfd = client_fd;
88 struct mptcp_storage val;
89
90 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
91 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
92 return err;
93
94 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
95 err++;
96
97 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
98 err++;
99
100 return err;
101 }
102
get_msk_ca_name(char ca_name[])103 static void get_msk_ca_name(char ca_name[])
104 {
105 size_t len;
106 int fd;
107
108 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
109 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
110 return;
111
112 len = read(fd, ca_name, TCP_CA_NAME_MAX);
113 if (!ASSERT_GT(len, 0, "failed to read ca_name"))
114 goto err;
115
116 if (len > 0 && ca_name[len - 1] == '\n')
117 ca_name[len - 1] = '\0';
118
119 err:
120 close(fd);
121 }
122
verify_msk(int map_fd,int client_fd,__u32 token)123 static int verify_msk(int map_fd, int client_fd, __u32 token)
124 {
125 char ca_name[TCP_CA_NAME_MAX];
126 int err, cfd = client_fd;
127 struct mptcp_storage val;
128
129 if (!ASSERT_GT(token, 0, "invalid token"))
130 return -1;
131
132 get_msk_ca_name(ca_name);
133
134 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
135 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
136 return err;
137
138 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
139 err++;
140
141 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
142 err++;
143
144 if (!ASSERT_EQ(val.token, token, "unexpected token"))
145 err++;
146
147 if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
148 err++;
149
150 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
151 err++;
152
153 return err;
154 }
155
run_test(int cgroup_fd,int server_fd,bool is_mptcp)156 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
157 {
158 int client_fd, prog_fd, map_fd, err;
159 struct mptcp_sock *sock_skel;
160
161 sock_skel = mptcp_sock__open_and_load();
162 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
163 return libbpf_get_error(sock_skel);
164
165 err = mptcp_sock__attach(sock_skel);
166 if (!ASSERT_OK(err, "skel_attach"))
167 goto out;
168
169 prog_fd = bpf_program__fd(sock_skel->progs._sockops);
170 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
171 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
172 if (!ASSERT_OK(err, "bpf_prog_attach"))
173 goto out;
174
175 client_fd = connect_to_fd(server_fd, 0);
176 if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
177 err = -EIO;
178 goto out;
179 }
180
181 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
182 verify_tsk(map_fd, client_fd);
183
184 close(client_fd);
185
186 out:
187 mptcp_sock__destroy(sock_skel);
188 return err;
189 }
190
test_base(void)191 static void test_base(void)
192 {
193 struct netns_obj *netns = NULL;
194 int server_fd, cgroup_fd;
195
196 cgroup_fd = test__join_cgroup("/mptcp");
197 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
198 return;
199
200 netns = netns_new(NS_TEST, true);
201 if (!ASSERT_OK_PTR(netns, "netns_new"))
202 goto fail;
203
204 /* without MPTCP */
205 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
206 if (!ASSERT_GE(server_fd, 0, "start_server"))
207 goto with_mptcp;
208
209 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
210
211 close(server_fd);
212
213 with_mptcp:
214 /* with MPTCP */
215 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
216 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
217 goto fail;
218
219 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
220
221 close(server_fd);
222
223 fail:
224 netns_free(netns);
225 close(cgroup_fd);
226 }
227
send_byte(int fd)228 static void send_byte(int fd)
229 {
230 char b = 0x55;
231
232 ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte");
233 }
234
verify_mptcpify(int server_fd,int client_fd)235 static int verify_mptcpify(int server_fd, int client_fd)
236 {
237 struct __mptcp_info info;
238 socklen_t optlen;
239 int protocol;
240 int err = 0;
241
242 optlen = sizeof(protocol);
243 if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen),
244 "getsockopt(SOL_PROTOCOL)"))
245 return -1;
246
247 if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP"))
248 err++;
249
250 optlen = sizeof(info);
251 if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen),
252 "getsockopt(MPTCP_INFO)"))
253 return -1;
254
255 if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags"))
256 err++;
257 if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK,
258 "MPTCP fallback"))
259 err++;
260 if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED,
261 "no remote key received"))
262 err++;
263
264 return err;
265 }
266
run_mptcpify(int cgroup_fd)267 static int run_mptcpify(int cgroup_fd)
268 {
269 int server_fd, client_fd, err = 0;
270 struct mptcpify *mptcpify_skel;
271
272 mptcpify_skel = mptcpify__open_and_load();
273 if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load"))
274 return libbpf_get_error(mptcpify_skel);
275
276 mptcpify_skel->bss->pid = getpid();
277
278 err = mptcpify__attach(mptcpify_skel);
279 if (!ASSERT_OK(err, "skel_attach"))
280 goto out;
281
282 /* without MPTCP */
283 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
284 if (!ASSERT_GE(server_fd, 0, "start_server")) {
285 err = -EIO;
286 goto out;
287 }
288
289 client_fd = connect_to_fd(server_fd, 0);
290 if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
291 err = -EIO;
292 goto close_server;
293 }
294
295 send_byte(client_fd);
296
297 err = verify_mptcpify(server_fd, client_fd);
298
299 close(client_fd);
300 close_server:
301 close(server_fd);
302 out:
303 mptcpify__destroy(mptcpify_skel);
304 return err;
305 }
306
test_mptcpify(void)307 static void test_mptcpify(void)
308 {
309 struct netns_obj *netns = NULL;
310 int cgroup_fd;
311
312 cgroup_fd = test__join_cgroup("/mptcpify");
313 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
314 return;
315
316 netns = netns_new(NS_TEST, true);
317 if (!ASSERT_OK_PTR(netns, "netns_new"))
318 goto fail;
319
320 ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify");
321
322 fail:
323 netns_free(netns);
324 close(cgroup_fd);
325 }
326
endpoint_init(char * flags)327 static int endpoint_init(char *flags)
328 {
329 SYS(fail, "ip -net %s link add veth1 type veth peer name veth2", NS_TEST);
330 SYS(fail, "ip -net %s addr add %s/24 dev veth1", NS_TEST, ADDR_1);
331 SYS(fail, "ip -net %s link set dev veth1 up", NS_TEST);
332 SYS(fail, "ip -net %s addr add %s/24 dev veth2", NS_TEST, ADDR_2);
333 SYS(fail, "ip -net %s link set dev veth2 up", NS_TEST);
334 if (SYS_NOFAIL("ip -net %s mptcp endpoint add %s %s", NS_TEST, ADDR_2, flags)) {
335 printf("'ip mptcp' not supported, skip this test.\n");
336 test__skip();
337 goto fail;
338 }
339
340 return 0;
341 fail:
342 return -1;
343 }
344
wait_for_new_subflows(int fd)345 static void wait_for_new_subflows(int fd)
346 {
347 socklen_t len;
348 u8 subflows;
349 int err, i;
350
351 len = sizeof(subflows);
352 /* Wait max 5 sec for new subflows to be created */
353 for (i = 0; i < 50; i++) {
354 err = getsockopt(fd, SOL_MPTCP, MPTCP_INFO, &subflows, &len);
355 if (!err && subflows > 0)
356 break;
357
358 usleep(100000); /* 0.1s */
359 }
360 }
361
run_subflow(void)362 static void run_subflow(void)
363 {
364 int server_fd, client_fd, err;
365 char new[TCP_CA_NAME_MAX];
366 char cc[TCP_CA_NAME_MAX];
367 unsigned int mark;
368 socklen_t len;
369
370 server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, 0);
371 if (!ASSERT_OK_FD(server_fd, "start_mptcp_server"))
372 return;
373
374 client_fd = connect_to_fd(server_fd, 0);
375 if (!ASSERT_OK_FD(client_fd, "connect_to_fd"))
376 goto close_server;
377
378 send_byte(client_fd);
379 wait_for_new_subflows(client_fd);
380
381 len = sizeof(mark);
382 err = getsockopt(client_fd, SOL_SOCKET, SO_MARK, &mark, &len);
383 if (ASSERT_OK(err, "getsockopt(client_fd, SO_MARK)"))
384 ASSERT_EQ(mark, 0, "mark");
385
386 len = sizeof(new);
387 err = getsockopt(client_fd, SOL_TCP, TCP_CONGESTION, new, &len);
388 if (ASSERT_OK(err, "getsockopt(client_fd, TCP_CONGESTION)")) {
389 get_msk_ca_name(cc);
390 ASSERT_STREQ(new, cc, "cc");
391 }
392
393 close(client_fd);
394 close_server:
395 close(server_fd);
396 }
397
test_subflow(void)398 static void test_subflow(void)
399 {
400 struct mptcp_subflow *skel;
401 struct netns_obj *netns;
402 int cgroup_fd;
403
404 cgroup_fd = test__join_cgroup("/mptcp_subflow");
405 if (!ASSERT_OK_FD(cgroup_fd, "join_cgroup: mptcp_subflow"))
406 return;
407
408 skel = mptcp_subflow__open_and_load();
409 if (!ASSERT_OK_PTR(skel, "skel_open_load: mptcp_subflow"))
410 goto close_cgroup;
411
412 skel->bss->pid = getpid();
413
414 skel->links.mptcp_subflow =
415 bpf_program__attach_cgroup(skel->progs.mptcp_subflow, cgroup_fd);
416 if (!ASSERT_OK_PTR(skel->links.mptcp_subflow, "attach mptcp_subflow"))
417 goto skel_destroy;
418
419 skel->links._getsockopt_subflow =
420 bpf_program__attach_cgroup(skel->progs._getsockopt_subflow, cgroup_fd);
421 if (!ASSERT_OK_PTR(skel->links._getsockopt_subflow, "attach _getsockopt_subflow"))
422 goto skel_destroy;
423
424 netns = netns_new(NS_TEST, true);
425 if (!ASSERT_OK_PTR(netns, "netns_new: mptcp_subflow"))
426 goto skel_destroy;
427
428 if (endpoint_init("subflow") < 0)
429 goto close_netns;
430
431 run_subflow();
432
433 close_netns:
434 netns_free(netns);
435 skel_destroy:
436 mptcp_subflow__destroy(skel);
437 close_cgroup:
438 close(cgroup_fd);
439 }
440
441 /* Test sockmap on MPTCP server handling non-mp-capable clients. */
test_sockmap_with_mptcp_fallback(struct mptcp_sockmap * skel)442 static void test_sockmap_with_mptcp_fallback(struct mptcp_sockmap *skel)
443 {
444 int listen_fd = -1, client_fd1 = -1, client_fd2 = -1;
445 int server_fd1 = -1, server_fd2 = -1, sent, recvd;
446 char snd[9] = "123456789";
447 char rcv[10];
448
449 /* start server with MPTCP enabled */
450 listen_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
451 if (!ASSERT_OK_FD(listen_fd, "sockmap-fb:start_mptcp_server"))
452 return;
453
454 skel->bss->trace_port = ntohs(get_socket_local_port(listen_fd));
455 skel->bss->sk_index = 0;
456 /* create client without MPTCP enabled */
457 client_fd1 = connect_to_fd_opts(listen_fd, NULL);
458 if (!ASSERT_OK_FD(client_fd1, "sockmap-fb:connect_to_fd"))
459 goto end;
460
461 server_fd1 = accept(listen_fd, NULL, 0);
462 skel->bss->sk_index = 1;
463 client_fd2 = connect_to_fd_opts(listen_fd, NULL);
464 if (!ASSERT_OK_FD(client_fd2, "sockmap-fb:connect_to_fd"))
465 goto end;
466
467 server_fd2 = accept(listen_fd, NULL, 0);
468 /* test normal redirect behavior: data sent by client_fd1 can be
469 * received by client_fd2
470 */
471 skel->bss->redirect_idx = 1;
472 sent = send(client_fd1, snd, sizeof(snd), 0);
473 if (!ASSERT_EQ(sent, sizeof(snd), "sockmap-fb:send(client_fd1)"))
474 goto end;
475
476 /* try to recv more bytes to avoid truncation check */
477 recvd = recv(client_fd2, rcv, sizeof(rcv), 0);
478 if (!ASSERT_EQ(recvd, sizeof(snd), "sockmap-fb:recv(client_fd2)"))
479 goto end;
480
481 end:
482 if (client_fd1 >= 0)
483 close(client_fd1);
484 if (client_fd2 >= 0)
485 close(client_fd2);
486 if (server_fd1 >= 0)
487 close(server_fd1);
488 if (server_fd2 >= 0)
489 close(server_fd2);
490 close(listen_fd);
491 }
492
493 /* Test sockmap rejection of MPTCP sockets - both server and client sides. */
test_sockmap_reject_mptcp(struct mptcp_sockmap * skel)494 static void test_sockmap_reject_mptcp(struct mptcp_sockmap *skel)
495 {
496 int listen_fd = -1, server_fd = -1, client_fd1 = -1;
497 int err, zero = 0;
498
499 /* start server with MPTCP enabled */
500 listen_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
501 if (!ASSERT_OK_FD(listen_fd, "start_mptcp_server"))
502 return;
503
504 skel->bss->trace_port = ntohs(get_socket_local_port(listen_fd));
505 skel->bss->sk_index = 0;
506 /* create client with MPTCP enabled */
507 client_fd1 = connect_to_fd(listen_fd, 0);
508 if (!ASSERT_OK_FD(client_fd1, "connect_to_fd client_fd1"))
509 goto end;
510
511 /* bpf_sock_map_update() called from sockops should reject MPTCP sk */
512 if (!ASSERT_EQ(skel->bss->helper_ret, -EOPNOTSUPP, "should reject"))
513 goto end;
514
515 server_fd = accept(listen_fd, NULL, 0);
516 err = bpf_map_update_elem(bpf_map__fd(skel->maps.sock_map),
517 &zero, &server_fd, BPF_NOEXIST);
518 if (!ASSERT_EQ(err, -EOPNOTSUPP, "server should be disallowed"))
519 goto end;
520
521 /* MPTCP client should also be disallowed */
522 err = bpf_map_update_elem(bpf_map__fd(skel->maps.sock_map),
523 &zero, &client_fd1, BPF_NOEXIST);
524 if (!ASSERT_EQ(err, -EOPNOTSUPP, "client should be disallowed"))
525 goto end;
526 end:
527 if (client_fd1 >= 0)
528 close(client_fd1);
529 if (server_fd >= 0)
530 close(server_fd);
531 close(listen_fd);
532 }
533
test_mptcp_sockmap(void)534 static void test_mptcp_sockmap(void)
535 {
536 struct mptcp_sockmap *skel;
537 struct netns_obj *netns;
538 int cgroup_fd, err;
539
540 cgroup_fd = test__join_cgroup("/mptcp_sockmap");
541 if (!ASSERT_OK_FD(cgroup_fd, "join_cgroup: mptcp_sockmap"))
542 return;
543
544 skel = mptcp_sockmap__open_and_load();
545 if (!ASSERT_OK_PTR(skel, "skel_open_load: mptcp_sockmap"))
546 goto close_cgroup;
547
548 skel->links.mptcp_sockmap_inject =
549 bpf_program__attach_cgroup(skel->progs.mptcp_sockmap_inject, cgroup_fd);
550 if (!ASSERT_OK_PTR(skel->links.mptcp_sockmap_inject, "attach sockmap"))
551 goto skel_destroy;
552
553 err = bpf_prog_attach(bpf_program__fd(skel->progs.mptcp_sockmap_redirect),
554 bpf_map__fd(skel->maps.sock_map),
555 BPF_SK_SKB_STREAM_VERDICT, 0);
556 if (!ASSERT_OK(err, "bpf_prog_attach stream verdict"))
557 goto skel_destroy;
558
559 netns = netns_new(NS_TEST, true);
560 if (!ASSERT_OK_PTR(netns, "netns_new: mptcp_sockmap"))
561 goto skel_destroy;
562
563 if (endpoint_init("subflow") < 0)
564 goto close_netns;
565
566 test_sockmap_with_mptcp_fallback(skel);
567 test_sockmap_reject_mptcp(skel);
568
569 close_netns:
570 netns_free(netns);
571 skel_destroy:
572 mptcp_sockmap__destroy(skel);
573 close_cgroup:
574 close(cgroup_fd);
575 }
576
test_mptcp(void)577 void test_mptcp(void)
578 {
579 if (test__start_subtest("base"))
580 test_base();
581 if (test__start_subtest("mptcpify"))
582 test_mptcpify();
583 if (test__start_subtest("subflow"))
584 test_subflow();
585 if (test__start_subtest("sockmap"))
586 test_mptcp_sockmap();
587 }
588