1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020, Tessares SA. */ 3 /* Copyright (c) 2022, SUSE. */ 4 5 #include <linux/const.h> 6 #include <netinet/in.h> 7 #include <test_progs.h> 8 #include "cgroup_helpers.h" 9 #include "network_helpers.h" 10 #include "mptcp_sock.skel.h" 11 #include "mptcpify.skel.h" 12 13 #define NS_TEST "mptcp_ns" 14 15 #ifndef IPPROTO_MPTCP 16 #define IPPROTO_MPTCP 262 17 #endif 18 19 #ifndef SOL_MPTCP 20 #define SOL_MPTCP 284 21 #endif 22 #ifndef MPTCP_INFO 23 #define MPTCP_INFO 1 24 #endif 25 #ifndef MPTCP_INFO_FLAG_FALLBACK 26 #define MPTCP_INFO_FLAG_FALLBACK _BITUL(0) 27 #endif 28 #ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED 29 #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1) 30 #endif 31 32 #ifndef TCP_CA_NAME_MAX 33 #define TCP_CA_NAME_MAX 16 34 #endif 35 36 struct __mptcp_info { 37 __u8 mptcpi_subflows; 38 __u8 mptcpi_add_addr_signal; 39 __u8 mptcpi_add_addr_accepted; 40 __u8 mptcpi_subflows_max; 41 __u8 mptcpi_add_addr_signal_max; 42 __u8 mptcpi_add_addr_accepted_max; 43 __u32 mptcpi_flags; 44 __u32 mptcpi_token; 45 __u64 mptcpi_write_seq; 46 __u64 mptcpi_snd_una; 47 __u64 mptcpi_rcv_nxt; 48 __u8 mptcpi_local_addr_used; 49 __u8 mptcpi_local_addr_max; 50 __u8 mptcpi_csum_enabled; 51 __u32 mptcpi_retransmits; 52 __u64 mptcpi_bytes_retrans; 53 __u64 mptcpi_bytes_sent; 54 __u64 mptcpi_bytes_received; 55 __u64 mptcpi_bytes_acked; 56 }; 57 58 struct mptcp_storage { 59 __u32 invoked; 60 __u32 is_mptcp; 61 struct sock *sk; 62 __u32 token; 63 struct sock *first; 64 char ca_name[TCP_CA_NAME_MAX]; 65 }; 66 67 static struct nstoken *create_netns(void) 68 { 69 SYS(fail, "ip netns add %s", NS_TEST); 70 SYS(fail, "ip -net %s link set dev lo up", NS_TEST); 71 72 return open_netns(NS_TEST); 73 fail: 74 return NULL; 75 } 76 77 static void cleanup_netns(struct nstoken *nstoken) 78 { 79 if (nstoken) 80 close_netns(nstoken); 81 82 SYS_NOFAIL("ip netns del %s", NS_TEST); 83 } 84 85 static int start_mptcp_server(int family, const char *addr_str, __u16 port, 86 int timeout_ms) 87 { 88 struct network_helper_opts opts = { 89 .timeout_ms = timeout_ms, 90 .proto = IPPROTO_MPTCP, 91 }; 92 93 return start_server_str(family, SOCK_STREAM, addr_str, port, &opts); 94 } 95 96 static int verify_tsk(int map_fd, int client_fd) 97 { 98 int err, cfd = client_fd; 99 struct mptcp_storage val; 100 101 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 102 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 103 return err; 104 105 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 106 err++; 107 108 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp")) 109 err++; 110 111 return err; 112 } 113 114 static void get_msk_ca_name(char ca_name[]) 115 { 116 size_t len; 117 int fd; 118 119 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY); 120 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control")) 121 return; 122 123 len = read(fd, ca_name, TCP_CA_NAME_MAX); 124 if (!ASSERT_GT(len, 0, "failed to read ca_name")) 125 goto err; 126 127 if (len > 0 && ca_name[len - 1] == '\n') 128 ca_name[len - 1] = '\0'; 129 130 err: 131 close(fd); 132 } 133 134 static int verify_msk(int map_fd, int client_fd, __u32 token) 135 { 136 char ca_name[TCP_CA_NAME_MAX]; 137 int err, cfd = client_fd; 138 struct mptcp_storage val; 139 140 if (!ASSERT_GT(token, 0, "invalid token")) 141 return -1; 142 143 get_msk_ca_name(ca_name); 144 145 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 146 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 147 return err; 148 149 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 150 err++; 151 152 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp")) 153 err++; 154 155 if (!ASSERT_EQ(val.token, token, "unexpected token")) 156 err++; 157 158 if (!ASSERT_EQ(val.first, val.sk, "unexpected first")) 159 err++; 160 161 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name")) 162 err++; 163 164 return err; 165 } 166 167 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp) 168 { 169 int client_fd, prog_fd, map_fd, err; 170 struct mptcp_sock *sock_skel; 171 172 sock_skel = mptcp_sock__open_and_load(); 173 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load")) 174 return libbpf_get_error(sock_skel); 175 176 err = mptcp_sock__attach(sock_skel); 177 if (!ASSERT_OK(err, "skel_attach")) 178 goto out; 179 180 prog_fd = bpf_program__fd(sock_skel->progs._sockops); 181 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map); 182 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); 183 if (!ASSERT_OK(err, "bpf_prog_attach")) 184 goto out; 185 186 client_fd = connect_to_fd(server_fd, 0); 187 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 188 err = -EIO; 189 goto out; 190 } 191 192 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) : 193 verify_tsk(map_fd, client_fd); 194 195 close(client_fd); 196 197 out: 198 mptcp_sock__destroy(sock_skel); 199 return err; 200 } 201 202 static void test_base(void) 203 { 204 struct nstoken *nstoken = NULL; 205 int server_fd, cgroup_fd; 206 207 cgroup_fd = test__join_cgroup("/mptcp"); 208 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 209 return; 210 211 nstoken = create_netns(); 212 if (!ASSERT_OK_PTR(nstoken, "create_netns")) 213 goto fail; 214 215 /* without MPTCP */ 216 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 217 if (!ASSERT_GE(server_fd, 0, "start_server")) 218 goto with_mptcp; 219 220 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp"); 221 222 close(server_fd); 223 224 with_mptcp: 225 /* with MPTCP */ 226 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0); 227 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) 228 goto fail; 229 230 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp"); 231 232 close(server_fd); 233 234 fail: 235 cleanup_netns(nstoken); 236 close(cgroup_fd); 237 } 238 239 static void send_byte(int fd) 240 { 241 char b = 0x55; 242 243 ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte"); 244 } 245 246 static int verify_mptcpify(int server_fd, int client_fd) 247 { 248 struct __mptcp_info info; 249 socklen_t optlen; 250 int protocol; 251 int err = 0; 252 253 optlen = sizeof(protocol); 254 if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen), 255 "getsockopt(SOL_PROTOCOL)")) 256 return -1; 257 258 if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP")) 259 err++; 260 261 optlen = sizeof(info); 262 if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen), 263 "getsockopt(MPTCP_INFO)")) 264 return -1; 265 266 if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags")) 267 err++; 268 if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK, 269 "MPTCP fallback")) 270 err++; 271 if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED, 272 "no remote key received")) 273 err++; 274 275 return err; 276 } 277 278 static int run_mptcpify(int cgroup_fd) 279 { 280 int server_fd, client_fd, err = 0; 281 struct mptcpify *mptcpify_skel; 282 283 mptcpify_skel = mptcpify__open_and_load(); 284 if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load")) 285 return libbpf_get_error(mptcpify_skel); 286 287 mptcpify_skel->bss->pid = getpid(); 288 289 err = mptcpify__attach(mptcpify_skel); 290 if (!ASSERT_OK(err, "skel_attach")) 291 goto out; 292 293 /* without MPTCP */ 294 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 295 if (!ASSERT_GE(server_fd, 0, "start_server")) { 296 err = -EIO; 297 goto out; 298 } 299 300 client_fd = connect_to_fd(server_fd, 0); 301 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 302 err = -EIO; 303 goto close_server; 304 } 305 306 send_byte(client_fd); 307 308 err = verify_mptcpify(server_fd, client_fd); 309 310 close(client_fd); 311 close_server: 312 close(server_fd); 313 out: 314 mptcpify__destroy(mptcpify_skel); 315 return err; 316 } 317 318 static void test_mptcpify(void) 319 { 320 struct nstoken *nstoken = NULL; 321 int cgroup_fd; 322 323 cgroup_fd = test__join_cgroup("/mptcpify"); 324 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 325 return; 326 327 nstoken = create_netns(); 328 if (!ASSERT_OK_PTR(nstoken, "create_netns")) 329 goto fail; 330 331 ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify"); 332 333 fail: 334 cleanup_netns(nstoken); 335 close(cgroup_fd); 336 } 337 338 void test_mptcp(void) 339 { 340 if (test__start_subtest("base")) 341 test_base(); 342 if (test__start_subtest("mptcpify")) 343 test_mptcpify(); 344 } 345