1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2024 Meta 3 4 #include <test_progs.h> 5 #include "network_helpers.h" 6 #include "sock_iter_batch.skel.h" 7 8 #define TEST_NS "sock_iter_batch_netns" 9 10 static const int init_batch_size = 16; 11 static const int nr_soreuse = 4; 12 13 struct iter_out { 14 int idx; 15 __u64 cookie; 16 } __packed; 17 18 struct sock_count { 19 __u64 cookie; 20 int count; 21 }; 22 23 static int insert(__u64 cookie, struct sock_count counts[], int counts_len) 24 { 25 int insert = -1; 26 int i = 0; 27 28 for (; i < counts_len; i++) { 29 if (!counts[i].cookie) { 30 insert = i; 31 } else if (counts[i].cookie == cookie) { 32 insert = i; 33 break; 34 } 35 } 36 if (insert < 0) 37 return insert; 38 39 counts[insert].cookie = cookie; 40 counts[insert].count++; 41 42 return counts[insert].count; 43 } 44 45 static int read_n(int iter_fd, int n, struct sock_count counts[], 46 int counts_len) 47 { 48 struct iter_out out; 49 int nread = 1; 50 int i = 0; 51 52 for (; nread > 0 && (n < 0 || i < n); i++) { 53 nread = read(iter_fd, &out, sizeof(out)); 54 if (!nread || !ASSERT_EQ(nread, sizeof(out), "nread")) 55 break; 56 ASSERT_GE(insert(out.cookie, counts, counts_len), 0, "insert"); 57 } 58 59 ASSERT_TRUE(n < 0 || i == n, "n < 0 || i == n"); 60 61 return i; 62 } 63 64 static __u64 socket_cookie(int fd) 65 { 66 __u64 cookie; 67 socklen_t cookie_len = sizeof(cookie); 68 69 if (!ASSERT_OK(getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie, 70 &cookie_len), "getsockopt(SO_COOKIE)")) 71 return 0; 72 return cookie; 73 } 74 75 static bool was_seen(int fd, struct sock_count counts[], int counts_len) 76 { 77 __u64 cookie = socket_cookie(fd); 78 int i = 0; 79 80 for (; cookie && i < counts_len; i++) 81 if (cookie == counts[i].cookie) 82 return true; 83 84 return false; 85 } 86 87 static int get_seen_socket(int *fds, struct sock_count counts[], int n) 88 { 89 int i = 0; 90 91 for (; i < n; i++) 92 if (was_seen(fds[i], counts, n)) 93 return i; 94 return -1; 95 } 96 97 static int get_nth_socket(int *fds, int fds_len, struct bpf_link *link, int n) 98 { 99 int i, nread, iter_fd; 100 int nth_sock_idx = -1; 101 struct iter_out out; 102 103 iter_fd = bpf_iter_create(bpf_link__fd(link)); 104 if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create")) 105 return -1; 106 107 for (; n >= 0; n--) { 108 nread = read(iter_fd, &out, sizeof(out)); 109 if (!nread || !ASSERT_GE(nread, 1, "nread")) 110 goto done; 111 } 112 113 for (i = 0; i < fds_len && nth_sock_idx < 0; i++) 114 if (fds[i] >= 0 && socket_cookie(fds[i]) == out.cookie) 115 nth_sock_idx = i; 116 done: 117 close(iter_fd); 118 return nth_sock_idx; 119 } 120 121 static int get_seen_count(int fd, struct sock_count counts[], int n) 122 { 123 __u64 cookie = socket_cookie(fd); 124 int count = 0; 125 int i = 0; 126 127 for (; cookie && !count && i < n; i++) 128 if (cookie == counts[i].cookie) 129 count = counts[i].count; 130 131 return count; 132 } 133 134 static void check_n_were_seen_once(int *fds, int fds_len, int n, 135 struct sock_count counts[], int counts_len) 136 { 137 int seen_once = 0; 138 int seen_cnt; 139 int i = 0; 140 141 for (; i < fds_len; i++) { 142 /* Skip any sockets that were closed or that weren't seen 143 * exactly once. 144 */ 145 if (fds[i] < 0) 146 continue; 147 seen_cnt = get_seen_count(fds[i], counts, counts_len); 148 if (seen_cnt && ASSERT_EQ(seen_cnt, 1, "seen_cnt")) 149 seen_once++; 150 } 151 152 ASSERT_EQ(seen_once, n, "seen_once"); 153 } 154 155 static void remove_seen(int family, int sock_type, const char *addr, __u16 port, 156 int *socks, int socks_len, struct sock_count *counts, 157 int counts_len, struct bpf_link *link, int iter_fd) 158 { 159 int close_idx; 160 161 /* Iterate through the first socks_len - 1 sockets. */ 162 read_n(iter_fd, socks_len - 1, counts, counts_len); 163 164 /* Make sure we saw socks_len - 1 sockets exactly once. */ 165 check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, 166 counts_len); 167 168 /* Close a socket we've already seen to remove it from the bucket. */ 169 close_idx = get_seen_socket(socks, counts, counts_len); 170 if (!ASSERT_GE(close_idx, 0, "close_idx")) 171 return; 172 close(socks[close_idx]); 173 socks[close_idx] = -1; 174 175 /* Iterate through the rest of the sockets. */ 176 read_n(iter_fd, -1, counts, counts_len); 177 178 /* Make sure the last socket wasn't skipped and that there were no 179 * repeats. 180 */ 181 check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, 182 counts_len); 183 } 184 185 static void remove_unseen(int family, int sock_type, const char *addr, 186 __u16 port, int *socks, int socks_len, 187 struct sock_count *counts, int counts_len, 188 struct bpf_link *link, int iter_fd) 189 { 190 int close_idx; 191 192 /* Iterate through the first socket. */ 193 read_n(iter_fd, 1, counts, counts_len); 194 195 /* Make sure we saw a socket from fds. */ 196 check_n_were_seen_once(socks, socks_len, 1, counts, counts_len); 197 198 /* Close what would be the next socket in the bucket to exercise the 199 * condition where we need to skip past the first cookie we remembered. 200 */ 201 close_idx = get_nth_socket(socks, socks_len, link, 1); 202 if (!ASSERT_GE(close_idx, 0, "close_idx")) 203 return; 204 close(socks[close_idx]); 205 socks[close_idx] = -1; 206 207 /* Iterate through the rest of the sockets. */ 208 read_n(iter_fd, -1, counts, counts_len); 209 210 /* Make sure the remaining sockets were seen exactly once and that we 211 * didn't repeat the socket that was already seen. 212 */ 213 check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, 214 counts_len); 215 } 216 217 static void remove_all(int family, int sock_type, const char *addr, 218 __u16 port, int *socks, int socks_len, 219 struct sock_count *counts, int counts_len, 220 struct bpf_link *link, int iter_fd) 221 { 222 int close_idx, i; 223 224 /* Iterate through the first socket. */ 225 read_n(iter_fd, 1, counts, counts_len); 226 227 /* Make sure we saw a socket from fds. */ 228 check_n_were_seen_once(socks, socks_len, 1, counts, counts_len); 229 230 /* Close all remaining sockets to exhaust the list of saved cookies and 231 * exit without putting any sockets into the batch on the next read. 232 */ 233 for (i = 0; i < socks_len - 1; i++) { 234 close_idx = get_nth_socket(socks, socks_len, link, 1); 235 if (!ASSERT_GE(close_idx, 0, "close_idx")) 236 return; 237 close(socks[close_idx]); 238 socks[close_idx] = -1; 239 } 240 241 /* Make sure there are no more sockets returned */ 242 ASSERT_EQ(read_n(iter_fd, -1, counts, counts_len), 0, "read_n"); 243 } 244 245 static void add_some(int family, int sock_type, const char *addr, __u16 port, 246 int *socks, int socks_len, struct sock_count *counts, 247 int counts_len, struct bpf_link *link, int iter_fd) 248 { 249 int *new_socks = NULL; 250 251 /* Iterate through the first socks_len - 1 sockets. */ 252 read_n(iter_fd, socks_len - 1, counts, counts_len); 253 254 /* Make sure we saw socks_len - 1 sockets exactly once. */ 255 check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, 256 counts_len); 257 258 /* Double the number of sockets in the bucket. */ 259 new_socks = start_reuseport_server(family, sock_type, addr, port, 0, 260 socks_len); 261 if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server")) 262 goto done; 263 264 /* Iterate through the rest of the sockets. */ 265 read_n(iter_fd, -1, counts, counts_len); 266 267 /* Make sure each of the original sockets was seen exactly once. */ 268 check_n_were_seen_once(socks, socks_len, socks_len, counts, 269 counts_len); 270 done: 271 free_fds(new_socks, socks_len); 272 } 273 274 static void force_realloc(int family, int sock_type, const char *addr, 275 __u16 port, int *socks, int socks_len, 276 struct sock_count *counts, int counts_len, 277 struct bpf_link *link, int iter_fd) 278 { 279 int *new_socks = NULL; 280 281 /* Iterate through the first socket just to initialize the batch. */ 282 read_n(iter_fd, 1, counts, counts_len); 283 284 /* Double the number of sockets in the bucket to force a realloc on the 285 * next read. 286 */ 287 new_socks = start_reuseport_server(family, sock_type, addr, port, 0, 288 socks_len); 289 if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server")) 290 goto done; 291 292 /* Iterate through the rest of the sockets. */ 293 read_n(iter_fd, -1, counts, counts_len); 294 295 /* Make sure each socket from the first set was seen exactly once. */ 296 check_n_were_seen_once(socks, socks_len, socks_len, counts, 297 counts_len); 298 done: 299 free_fds(new_socks, socks_len); 300 } 301 302 struct test_case { 303 void (*test)(int family, int sock_type, const char *addr, __u16 port, 304 int *socks, int socks_len, struct sock_count *counts, 305 int counts_len, struct bpf_link *link, int iter_fd); 306 const char *description; 307 int init_socks; 308 int max_socks; 309 int sock_type; 310 int family; 311 }; 312 313 static struct test_case resume_tests[] = { 314 { 315 .description = "udp: resume after removing a seen socket", 316 .init_socks = nr_soreuse, 317 .max_socks = nr_soreuse, 318 .sock_type = SOCK_DGRAM, 319 .family = AF_INET6, 320 .test = remove_seen, 321 }, 322 { 323 .description = "udp: resume after removing one unseen socket", 324 .init_socks = nr_soreuse, 325 .max_socks = nr_soreuse, 326 .sock_type = SOCK_DGRAM, 327 .family = AF_INET6, 328 .test = remove_unseen, 329 }, 330 { 331 .description = "udp: resume after removing all unseen sockets", 332 .init_socks = nr_soreuse, 333 .max_socks = nr_soreuse, 334 .sock_type = SOCK_DGRAM, 335 .family = AF_INET6, 336 .test = remove_all, 337 }, 338 { 339 .description = "udp: resume after adding a few sockets", 340 .init_socks = nr_soreuse, 341 .max_socks = nr_soreuse, 342 .sock_type = SOCK_DGRAM, 343 /* Use AF_INET so that new sockets are added to the head of the 344 * bucket's list. 345 */ 346 .family = AF_INET, 347 .test = add_some, 348 }, 349 { 350 .description = "udp: force a realloc to occur", 351 .init_socks = init_batch_size, 352 .max_socks = init_batch_size * 2, 353 .sock_type = SOCK_DGRAM, 354 /* Use AF_INET6 so that new sockets are added to the tail of the 355 * bucket's list, needing to be added to the next batch to force 356 * a realloc. 357 */ 358 .family = AF_INET6, 359 .test = force_realloc, 360 }, 361 }; 362 363 static void do_resume_test(struct test_case *tc) 364 { 365 struct sock_iter_batch *skel = NULL; 366 static const __u16 port = 10001; 367 struct bpf_link *link = NULL; 368 struct sock_count *counts; 369 int err, iter_fd = -1; 370 const char *addr; 371 int *fds = NULL; 372 int local_port; 373 374 counts = calloc(tc->max_socks, sizeof(*counts)); 375 if (!ASSERT_OK_PTR(counts, "counts")) 376 goto done; 377 skel = sock_iter_batch__open(); 378 if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open")) 379 goto done; 380 381 /* Prepare a bucket of sockets in the kernel hashtable */ 382 addr = tc->family == AF_INET6 ? "::1" : "127.0.0.1"; 383 fds = start_reuseport_server(tc->family, tc->sock_type, addr, port, 0, 384 tc->init_socks); 385 if (!ASSERT_OK_PTR(fds, "start_reuseport_server")) 386 goto done; 387 local_port = get_socket_local_port(*fds); 388 if (!ASSERT_GE(local_port, 0, "get_socket_local_port")) 389 goto done; 390 skel->rodata->ports[0] = ntohs(local_port); 391 skel->rodata->sf = tc->family; 392 393 err = sock_iter_batch__load(skel); 394 if (!ASSERT_OK(err, "sock_iter_batch__load")) 395 goto done; 396 397 link = bpf_program__attach_iter(tc->sock_type == SOCK_STREAM ? 398 skel->progs.iter_tcp_soreuse : 399 skel->progs.iter_udp_soreuse, 400 NULL); 401 if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) 402 goto done; 403 404 iter_fd = bpf_iter_create(bpf_link__fd(link)); 405 if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create")) 406 goto done; 407 408 tc->test(tc->family, tc->sock_type, addr, port, fds, tc->init_socks, 409 counts, tc->max_socks, link, iter_fd); 410 done: 411 free(counts); 412 free_fds(fds, tc->init_socks); 413 if (iter_fd >= 0) 414 close(iter_fd); 415 bpf_link__destroy(link); 416 sock_iter_batch__destroy(skel); 417 } 418 419 static void do_resume_tests(void) 420 { 421 int i; 422 423 for (i = 0; i < ARRAY_SIZE(resume_tests); i++) { 424 if (test__start_subtest(resume_tests[i].description)) { 425 do_resume_test(&resume_tests[i]); 426 } 427 } 428 } 429 430 static void do_test(int sock_type, bool onebyone) 431 { 432 int err, i, nread, to_read, total_read, iter_fd = -1; 433 struct iter_out outputs[nr_soreuse]; 434 struct bpf_link *link = NULL; 435 struct sock_iter_batch *skel; 436 int first_idx, second_idx; 437 int *fds[2] = {}; 438 439 skel = sock_iter_batch__open(); 440 if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open")) 441 return; 442 443 /* Prepare 2 buckets of sockets in the kernel hashtable */ 444 for (i = 0; i < ARRAY_SIZE(fds); i++) { 445 int local_port; 446 447 fds[i] = start_reuseport_server(AF_INET6, sock_type, "::1", 0, 0, 448 nr_soreuse); 449 if (!ASSERT_OK_PTR(fds[i], "start_reuseport_server")) 450 goto done; 451 local_port = get_socket_local_port(*fds[i]); 452 if (!ASSERT_GE(local_port, 0, "get_socket_local_port")) 453 goto done; 454 skel->rodata->ports[i] = ntohs(local_port); 455 } 456 skel->rodata->sf = AF_INET6; 457 458 err = sock_iter_batch__load(skel); 459 if (!ASSERT_OK(err, "sock_iter_batch__load")) 460 goto done; 461 462 link = bpf_program__attach_iter(sock_type == SOCK_STREAM ? 463 skel->progs.iter_tcp_soreuse : 464 skel->progs.iter_udp_soreuse, 465 NULL); 466 if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) 467 goto done; 468 469 iter_fd = bpf_iter_create(bpf_link__fd(link)); 470 if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create")) 471 goto done; 472 473 /* Test reading a bucket (either from fds[0] or fds[1]). 474 * Only read "nr_soreuse - 1" number of sockets 475 * from a bucket and leave one socket out from 476 * that bucket on purpose. 477 */ 478 to_read = (nr_soreuse - 1) * sizeof(*outputs); 479 total_read = 0; 480 first_idx = -1; 481 do { 482 nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read); 483 if (nread <= 0 || nread % sizeof(*outputs)) 484 break; 485 total_read += nread; 486 487 if (first_idx == -1) 488 first_idx = outputs[0].idx; 489 for (i = 0; i < nread / sizeof(*outputs); i++) 490 ASSERT_EQ(outputs[i].idx, first_idx, "first_idx"); 491 } while (total_read < to_read); 492 ASSERT_EQ(nread, onebyone ? sizeof(*outputs) : to_read, "nread"); 493 ASSERT_EQ(total_read, to_read, "total_read"); 494 495 free_fds(fds[first_idx], nr_soreuse); 496 fds[first_idx] = NULL; 497 498 /* Read the "whole" second bucket */ 499 to_read = nr_soreuse * sizeof(*outputs); 500 total_read = 0; 501 second_idx = !first_idx; 502 do { 503 nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read); 504 if (nread <= 0 || nread % sizeof(*outputs)) 505 break; 506 total_read += nread; 507 508 for (i = 0; i < nread / sizeof(*outputs); i++) 509 ASSERT_EQ(outputs[i].idx, second_idx, "second_idx"); 510 } while (total_read <= to_read); 511 ASSERT_EQ(nread, 0, "nread"); 512 /* Both so_reuseport ports should be in different buckets, so 513 * total_read must equal to the expected to_read. 514 * 515 * For a very unlikely case, both ports collide at the same bucket, 516 * the bucket offset (i.e. 3) will be skipped and it cannot 517 * expect the to_read number of bytes. 518 */ 519 if (skel->bss->bucket[0] != skel->bss->bucket[1]) 520 ASSERT_EQ(total_read, to_read, "total_read"); 521 522 done: 523 for (i = 0; i < ARRAY_SIZE(fds); i++) 524 free_fds(fds[i], nr_soreuse); 525 if (iter_fd < 0) 526 close(iter_fd); 527 bpf_link__destroy(link); 528 sock_iter_batch__destroy(skel); 529 } 530 531 void test_sock_iter_batch(void) 532 { 533 struct nstoken *nstoken = NULL; 534 535 SYS_NOFAIL("ip netns del " TEST_NS); 536 SYS(done, "ip netns add %s", TEST_NS); 537 SYS(done, "ip -net %s link set dev lo up", TEST_NS); 538 539 nstoken = open_netns(TEST_NS); 540 if (!ASSERT_OK_PTR(nstoken, "open_netns")) 541 goto done; 542 543 if (test__start_subtest("tcp")) { 544 do_test(SOCK_STREAM, true); 545 do_test(SOCK_STREAM, false); 546 } 547 if (test__start_subtest("udp")) { 548 do_test(SOCK_DGRAM, true); 549 do_test(SOCK_DGRAM, false); 550 } 551 do_resume_tests(); 552 close_netns(nstoken); 553 554 done: 555 SYS_NOFAIL("ip netns del " TEST_NS); 556 } 557