1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2020, Tessares SA. */
3 /* Copyright (c) 2022, SUSE. */
4
5 #include <linux/const.h>
6 #include <netinet/in.h>
7 #include <test_progs.h>
8 #include "cgroup_helpers.h"
9 #include "network_helpers.h"
10 #include "mptcp_sock.skel.h"
11 #include "mptcpify.skel.h"
12
13 #define NS_TEST "mptcp_ns"
14
15 #ifndef IPPROTO_MPTCP
16 #define IPPROTO_MPTCP 262
17 #endif
18
19 #ifndef SOL_MPTCP
20 #define SOL_MPTCP 284
21 #endif
22 #ifndef MPTCP_INFO
23 #define MPTCP_INFO 1
24 #endif
25 #ifndef MPTCP_INFO_FLAG_FALLBACK
26 #define MPTCP_INFO_FLAG_FALLBACK _BITUL(0)
27 #endif
28 #ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED
29 #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1)
30 #endif
31
32 #ifndef TCP_CA_NAME_MAX
33 #define TCP_CA_NAME_MAX 16
34 #endif
35
36 struct __mptcp_info {
37 __u8 mptcpi_subflows;
38 __u8 mptcpi_add_addr_signal;
39 __u8 mptcpi_add_addr_accepted;
40 __u8 mptcpi_subflows_max;
41 __u8 mptcpi_add_addr_signal_max;
42 __u8 mptcpi_add_addr_accepted_max;
43 __u32 mptcpi_flags;
44 __u32 mptcpi_token;
45 __u64 mptcpi_write_seq;
46 __u64 mptcpi_snd_una;
47 __u64 mptcpi_rcv_nxt;
48 __u8 mptcpi_local_addr_used;
49 __u8 mptcpi_local_addr_max;
50 __u8 mptcpi_csum_enabled;
51 __u32 mptcpi_retransmits;
52 __u64 mptcpi_bytes_retrans;
53 __u64 mptcpi_bytes_sent;
54 __u64 mptcpi_bytes_received;
55 __u64 mptcpi_bytes_acked;
56 };
57
58 struct mptcp_storage {
59 __u32 invoked;
60 __u32 is_mptcp;
61 struct sock *sk;
62 __u32 token;
63 struct sock *first;
64 char ca_name[TCP_CA_NAME_MAX];
65 };
66
create_netns(void)67 static struct nstoken *create_netns(void)
68 {
69 SYS(fail, "ip netns add %s", NS_TEST);
70 SYS(fail, "ip -net %s link set dev lo up", NS_TEST);
71
72 return open_netns(NS_TEST);
73 fail:
74 return NULL;
75 }
76
cleanup_netns(struct nstoken * nstoken)77 static void cleanup_netns(struct nstoken *nstoken)
78 {
79 if (nstoken)
80 close_netns(nstoken);
81
82 SYS_NOFAIL("ip netns del %s", NS_TEST);
83 }
84
start_mptcp_server(int family,const char * addr_str,__u16 port,int timeout_ms)85 static int start_mptcp_server(int family, const char *addr_str, __u16 port,
86 int timeout_ms)
87 {
88 struct network_helper_opts opts = {
89 .timeout_ms = timeout_ms,
90 .proto = IPPROTO_MPTCP,
91 };
92
93 return start_server_str(family, SOCK_STREAM, addr_str, port, &opts);
94 }
95
verify_tsk(int map_fd,int client_fd)96 static int verify_tsk(int map_fd, int client_fd)
97 {
98 int err, cfd = client_fd;
99 struct mptcp_storage val;
100
101 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
102 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
103 return err;
104
105 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
106 err++;
107
108 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
109 err++;
110
111 return err;
112 }
113
get_msk_ca_name(char ca_name[])114 static void get_msk_ca_name(char ca_name[])
115 {
116 size_t len;
117 int fd;
118
119 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
120 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
121 return;
122
123 len = read(fd, ca_name, TCP_CA_NAME_MAX);
124 if (!ASSERT_GT(len, 0, "failed to read ca_name"))
125 goto err;
126
127 if (len > 0 && ca_name[len - 1] == '\n')
128 ca_name[len - 1] = '\0';
129
130 err:
131 close(fd);
132 }
133
verify_msk(int map_fd,int client_fd,__u32 token)134 static int verify_msk(int map_fd, int client_fd, __u32 token)
135 {
136 char ca_name[TCP_CA_NAME_MAX];
137 int err, cfd = client_fd;
138 struct mptcp_storage val;
139
140 if (!ASSERT_GT(token, 0, "invalid token"))
141 return -1;
142
143 get_msk_ca_name(ca_name);
144
145 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
146 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
147 return err;
148
149 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
150 err++;
151
152 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
153 err++;
154
155 if (!ASSERT_EQ(val.token, token, "unexpected token"))
156 err++;
157
158 if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
159 err++;
160
161 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
162 err++;
163
164 return err;
165 }
166
run_test(int cgroup_fd,int server_fd,bool is_mptcp)167 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
168 {
169 int client_fd, prog_fd, map_fd, err;
170 struct mptcp_sock *sock_skel;
171
172 sock_skel = mptcp_sock__open_and_load();
173 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
174 return libbpf_get_error(sock_skel);
175
176 err = mptcp_sock__attach(sock_skel);
177 if (!ASSERT_OK(err, "skel_attach"))
178 goto out;
179
180 prog_fd = bpf_program__fd(sock_skel->progs._sockops);
181 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
182 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
183 if (!ASSERT_OK(err, "bpf_prog_attach"))
184 goto out;
185
186 client_fd = connect_to_fd(server_fd, 0);
187 if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
188 err = -EIO;
189 goto out;
190 }
191
192 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
193 verify_tsk(map_fd, client_fd);
194
195 close(client_fd);
196
197 out:
198 mptcp_sock__destroy(sock_skel);
199 return err;
200 }
201
test_base(void)202 static void test_base(void)
203 {
204 struct nstoken *nstoken = NULL;
205 int server_fd, cgroup_fd;
206
207 cgroup_fd = test__join_cgroup("/mptcp");
208 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
209 return;
210
211 nstoken = create_netns();
212 if (!ASSERT_OK_PTR(nstoken, "create_netns"))
213 goto fail;
214
215 /* without MPTCP */
216 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
217 if (!ASSERT_GE(server_fd, 0, "start_server"))
218 goto with_mptcp;
219
220 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
221
222 close(server_fd);
223
224 with_mptcp:
225 /* with MPTCP */
226 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
227 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
228 goto fail;
229
230 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
231
232 close(server_fd);
233
234 fail:
235 cleanup_netns(nstoken);
236 close(cgroup_fd);
237 }
238
send_byte(int fd)239 static void send_byte(int fd)
240 {
241 char b = 0x55;
242
243 ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte");
244 }
245
verify_mptcpify(int server_fd,int client_fd)246 static int verify_mptcpify(int server_fd, int client_fd)
247 {
248 struct __mptcp_info info;
249 socklen_t optlen;
250 int protocol;
251 int err = 0;
252
253 optlen = sizeof(protocol);
254 if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen),
255 "getsockopt(SOL_PROTOCOL)"))
256 return -1;
257
258 if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP"))
259 err++;
260
261 optlen = sizeof(info);
262 if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen),
263 "getsockopt(MPTCP_INFO)"))
264 return -1;
265
266 if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags"))
267 err++;
268 if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK,
269 "MPTCP fallback"))
270 err++;
271 if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED,
272 "no remote key received"))
273 err++;
274
275 return err;
276 }
277
run_mptcpify(int cgroup_fd)278 static int run_mptcpify(int cgroup_fd)
279 {
280 int server_fd, client_fd, err = 0;
281 struct mptcpify *mptcpify_skel;
282
283 mptcpify_skel = mptcpify__open_and_load();
284 if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load"))
285 return libbpf_get_error(mptcpify_skel);
286
287 mptcpify_skel->bss->pid = getpid();
288
289 err = mptcpify__attach(mptcpify_skel);
290 if (!ASSERT_OK(err, "skel_attach"))
291 goto out;
292
293 /* without MPTCP */
294 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
295 if (!ASSERT_GE(server_fd, 0, "start_server")) {
296 err = -EIO;
297 goto out;
298 }
299
300 client_fd = connect_to_fd(server_fd, 0);
301 if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
302 err = -EIO;
303 goto close_server;
304 }
305
306 send_byte(client_fd);
307
308 err = verify_mptcpify(server_fd, client_fd);
309
310 close(client_fd);
311 close_server:
312 close(server_fd);
313 out:
314 mptcpify__destroy(mptcpify_skel);
315 return err;
316 }
317
test_mptcpify(void)318 static void test_mptcpify(void)
319 {
320 struct nstoken *nstoken = NULL;
321 int cgroup_fd;
322
323 cgroup_fd = test__join_cgroup("/mptcpify");
324 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
325 return;
326
327 nstoken = create_netns();
328 if (!ASSERT_OK_PTR(nstoken, "create_netns"))
329 goto fail;
330
331 ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify");
332
333 fail:
334 cleanup_netns(nstoken);
335 close(cgroup_fd);
336 }
337
test_mptcp(void)338 void test_mptcp(void)
339 {
340 if (test__start_subtest("base"))
341 test_base();
342 if (test__start_subtest("mptcpify"))
343 test_mptcpify();
344 }
345