xref: /linux/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c (revision fcab107abe1ab5be9dbe874baa722372da8f4f73)
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2024 Meta
3 
4 #include <test_progs.h>
5 #include "network_helpers.h"
6 #include "sock_iter_batch.skel.h"
7 
8 #define TEST_NS "sock_iter_batch_netns"
9 
10 static const int init_batch_size = 16;
11 static const int nr_soreuse = 4;
12 
13 struct iter_out {
14 	int idx;
15 	__u64 cookie;
16 } __packed;
17 
18 struct sock_count {
19 	__u64 cookie;
20 	int count;
21 };
22 
23 static int insert(__u64 cookie, struct sock_count counts[], int counts_len)
24 {
25 	int insert = -1;
26 	int i = 0;
27 
28 	for (; i < counts_len; i++) {
29 		if (!counts[i].cookie) {
30 			insert = i;
31 		} else if (counts[i].cookie == cookie) {
32 			insert = i;
33 			break;
34 		}
35 	}
36 	if (insert < 0)
37 		return insert;
38 
39 	counts[insert].cookie = cookie;
40 	counts[insert].count++;
41 
42 	return counts[insert].count;
43 }
44 
45 static int read_n(int iter_fd, int n, struct sock_count counts[],
46 		  int counts_len)
47 {
48 	struct iter_out out;
49 	int nread = 1;
50 	int i = 0;
51 
52 	for (; nread > 0 && (n < 0 || i < n); i++) {
53 		nread = read(iter_fd, &out, sizeof(out));
54 		if (!nread || !ASSERT_EQ(nread, sizeof(out), "nread"))
55 			break;
56 		ASSERT_GE(insert(out.cookie, counts, counts_len), 0, "insert");
57 	}
58 
59 	ASSERT_TRUE(n < 0 || i == n, "n < 0 || i == n");
60 
61 	return i;
62 }
63 
64 static __u64 socket_cookie(int fd)
65 {
66 	__u64 cookie;
67 	socklen_t cookie_len = sizeof(cookie);
68 
69 	if (!ASSERT_OK(getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie,
70 				  &cookie_len), "getsockopt(SO_COOKIE)"))
71 		return 0;
72 	return cookie;
73 }
74 
75 static bool was_seen(int fd, struct sock_count counts[], int counts_len)
76 {
77 	__u64 cookie = socket_cookie(fd);
78 	int i = 0;
79 
80 	for (; cookie && i < counts_len; i++)
81 		if (cookie == counts[i].cookie)
82 			return true;
83 
84 	return false;
85 }
86 
87 static int get_seen_socket(int *fds, struct sock_count counts[], int n)
88 {
89 	int i = 0;
90 
91 	for (; i < n; i++)
92 		if (was_seen(fds[i], counts, n))
93 			return i;
94 	return -1;
95 }
96 
97 static int get_nth_socket(int *fds, int fds_len, struct bpf_link *link, int n)
98 {
99 	int i, nread, iter_fd;
100 	int nth_sock_idx = -1;
101 	struct iter_out out;
102 
103 	iter_fd = bpf_iter_create(bpf_link__fd(link));
104 	if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create"))
105 		return -1;
106 
107 	for (; n >= 0; n--) {
108 		nread = read(iter_fd, &out, sizeof(out));
109 		if (!nread || !ASSERT_GE(nread, 1, "nread"))
110 			goto done;
111 	}
112 
113 	for (i = 0; i < fds_len && nth_sock_idx < 0; i++)
114 		if (fds[i] >= 0 && socket_cookie(fds[i]) == out.cookie)
115 			nth_sock_idx = i;
116 done:
117 	close(iter_fd);
118 	return nth_sock_idx;
119 }
120 
121 static int get_seen_count(int fd, struct sock_count counts[], int n)
122 {
123 	__u64 cookie = socket_cookie(fd);
124 	int count = 0;
125 	int i = 0;
126 
127 	for (; cookie && !count && i < n; i++)
128 		if (cookie == counts[i].cookie)
129 			count = counts[i].count;
130 
131 	return count;
132 }
133 
134 static void check_n_were_seen_once(int *fds, int fds_len, int n,
135 				   struct sock_count counts[], int counts_len)
136 {
137 	int seen_once = 0;
138 	int seen_cnt;
139 	int i = 0;
140 
141 	for (; i < fds_len; i++) {
142 		/* Skip any sockets that were closed or that weren't seen
143 		 * exactly once.
144 		 */
145 		if (fds[i] < 0)
146 			continue;
147 		seen_cnt = get_seen_count(fds[i], counts, counts_len);
148 		if (seen_cnt && ASSERT_EQ(seen_cnt, 1, "seen_cnt"))
149 			seen_once++;
150 	}
151 
152 	ASSERT_EQ(seen_once, n, "seen_once");
153 }
154 
155 static void remove_seen(int family, int sock_type, const char *addr, __u16 port,
156 			int *socks, int socks_len, struct sock_count *counts,
157 			int counts_len, struct bpf_link *link, int iter_fd)
158 {
159 	int close_idx;
160 
161 	/* Iterate through the first socks_len - 1 sockets. */
162 	read_n(iter_fd, socks_len - 1, counts, counts_len);
163 
164 	/* Make sure we saw socks_len - 1 sockets exactly once. */
165 	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
166 			       counts_len);
167 
168 	/* Close a socket we've already seen to remove it from the bucket. */
169 	close_idx = get_seen_socket(socks, counts, counts_len);
170 	if (!ASSERT_GE(close_idx, 0, "close_idx"))
171 		return;
172 	close(socks[close_idx]);
173 	socks[close_idx] = -1;
174 
175 	/* Iterate through the rest of the sockets. */
176 	read_n(iter_fd, -1, counts, counts_len);
177 
178 	/* Make sure the last socket wasn't skipped and that there were no
179 	 * repeats.
180 	 */
181 	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
182 			       counts_len);
183 }
184 
185 static void remove_unseen(int family, int sock_type, const char *addr,
186 			  __u16 port, int *socks, int socks_len,
187 			  struct sock_count *counts, int counts_len,
188 			  struct bpf_link *link, int iter_fd)
189 {
190 	int close_idx;
191 
192 	/* Iterate through the first socket. */
193 	read_n(iter_fd, 1, counts, counts_len);
194 
195 	/* Make sure we saw a socket from fds. */
196 	check_n_were_seen_once(socks, socks_len, 1, counts, counts_len);
197 
198 	/* Close what would be the next socket in the bucket to exercise the
199 	 * condition where we need to skip past the first cookie we remembered.
200 	 */
201 	close_idx = get_nth_socket(socks, socks_len, link, 1);
202 	if (!ASSERT_GE(close_idx, 0, "close_idx"))
203 		return;
204 	close(socks[close_idx]);
205 	socks[close_idx] = -1;
206 
207 	/* Iterate through the rest of the sockets. */
208 	read_n(iter_fd, -1, counts, counts_len);
209 
210 	/* Make sure the remaining sockets were seen exactly once and that we
211 	 * didn't repeat the socket that was already seen.
212 	 */
213 	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
214 			       counts_len);
215 }
216 
217 static void remove_all(int family, int sock_type, const char *addr,
218 		       __u16 port, int *socks, int socks_len,
219 		       struct sock_count *counts, int counts_len,
220 		       struct bpf_link *link, int iter_fd)
221 {
222 	int close_idx, i;
223 
224 	/* Iterate through the first socket. */
225 	read_n(iter_fd, 1, counts, counts_len);
226 
227 	/* Make sure we saw a socket from fds. */
228 	check_n_were_seen_once(socks, socks_len, 1, counts, counts_len);
229 
230 	/* Close all remaining sockets to exhaust the list of saved cookies and
231 	 * exit without putting any sockets into the batch on the next read.
232 	 */
233 	for (i = 0; i < socks_len - 1; i++) {
234 		close_idx = get_nth_socket(socks, socks_len, link, 1);
235 		if (!ASSERT_GE(close_idx, 0, "close_idx"))
236 			return;
237 		close(socks[close_idx]);
238 		socks[close_idx] = -1;
239 	}
240 
241 	/* Make sure there are no more sockets returned */
242 	ASSERT_EQ(read_n(iter_fd, -1, counts, counts_len), 0, "read_n");
243 }
244 
245 static void add_some(int family, int sock_type, const char *addr, __u16 port,
246 		     int *socks, int socks_len, struct sock_count *counts,
247 		     int counts_len, struct bpf_link *link, int iter_fd)
248 {
249 	int *new_socks = NULL;
250 
251 	/* Iterate through the first socks_len - 1 sockets. */
252 	read_n(iter_fd, socks_len - 1, counts, counts_len);
253 
254 	/* Make sure we saw socks_len - 1 sockets exactly once. */
255 	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
256 			       counts_len);
257 
258 	/* Double the number of sockets in the bucket. */
259 	new_socks = start_reuseport_server(family, sock_type, addr, port, 0,
260 					   socks_len);
261 	if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server"))
262 		goto done;
263 
264 	/* Iterate through the rest of the sockets. */
265 	read_n(iter_fd, -1, counts, counts_len);
266 
267 	/* Make sure each of the original sockets was seen exactly once. */
268 	check_n_were_seen_once(socks, socks_len, socks_len, counts,
269 			       counts_len);
270 done:
271 	free_fds(new_socks, socks_len);
272 }
273 
274 static void force_realloc(int family, int sock_type, const char *addr,
275 			  __u16 port, int *socks, int socks_len,
276 			  struct sock_count *counts, int counts_len,
277 			  struct bpf_link *link, int iter_fd)
278 {
279 	int *new_socks = NULL;
280 
281 	/* Iterate through the first socket just to initialize the batch. */
282 	read_n(iter_fd, 1, counts, counts_len);
283 
284 	/* Double the number of sockets in the bucket to force a realloc on the
285 	 * next read.
286 	 */
287 	new_socks = start_reuseport_server(family, sock_type, addr, port, 0,
288 					   socks_len);
289 	if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server"))
290 		goto done;
291 
292 	/* Iterate through the rest of the sockets. */
293 	read_n(iter_fd, -1, counts, counts_len);
294 
295 	/* Make sure each socket from the first set was seen exactly once. */
296 	check_n_were_seen_once(socks, socks_len, socks_len, counts,
297 			       counts_len);
298 done:
299 	free_fds(new_socks, socks_len);
300 }
301 
302 struct test_case {
303 	void (*test)(int family, int sock_type, const char *addr, __u16 port,
304 		     int *socks, int socks_len, struct sock_count *counts,
305 		     int counts_len, struct bpf_link *link, int iter_fd);
306 	const char *description;
307 	int init_socks;
308 	int max_socks;
309 	int sock_type;
310 	int family;
311 };
312 
313 static struct test_case resume_tests[] = {
314 	{
315 		.description = "udp: resume after removing a seen socket",
316 		.init_socks = nr_soreuse,
317 		.max_socks = nr_soreuse,
318 		.sock_type = SOCK_DGRAM,
319 		.family = AF_INET6,
320 		.test = remove_seen,
321 	},
322 	{
323 		.description = "udp: resume after removing one unseen socket",
324 		.init_socks = nr_soreuse,
325 		.max_socks = nr_soreuse,
326 		.sock_type = SOCK_DGRAM,
327 		.family = AF_INET6,
328 		.test = remove_unseen,
329 	},
330 	{
331 		.description = "udp: resume after removing all unseen sockets",
332 		.init_socks = nr_soreuse,
333 		.max_socks = nr_soreuse,
334 		.sock_type = SOCK_DGRAM,
335 		.family = AF_INET6,
336 		.test = remove_all,
337 	},
338 	{
339 		.description = "udp: resume after adding a few sockets",
340 		.init_socks = nr_soreuse,
341 		.max_socks = nr_soreuse,
342 		.sock_type = SOCK_DGRAM,
343 		/* Use AF_INET so that new sockets are added to the head of the
344 		 * bucket's list.
345 		 */
346 		.family = AF_INET,
347 		.test = add_some,
348 	},
349 	{
350 		.description = "udp: force a realloc to occur",
351 		.init_socks = init_batch_size,
352 		.max_socks = init_batch_size * 2,
353 		.sock_type = SOCK_DGRAM,
354 		/* Use AF_INET6 so that new sockets are added to the tail of the
355 		 * bucket's list, needing to be added to the next batch to force
356 		 * a realloc.
357 		 */
358 		.family = AF_INET6,
359 		.test = force_realloc,
360 	},
361 };
362 
363 static void do_resume_test(struct test_case *tc)
364 {
365 	struct sock_iter_batch *skel = NULL;
366 	static const __u16 port = 10001;
367 	struct bpf_link *link = NULL;
368 	struct sock_count *counts;
369 	int err, iter_fd = -1;
370 	const char *addr;
371 	int *fds = NULL;
372 	int local_port;
373 
374 	counts = calloc(tc->max_socks, sizeof(*counts));
375 	if (!ASSERT_OK_PTR(counts, "counts"))
376 		goto done;
377 	skel = sock_iter_batch__open();
378 	if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open"))
379 		goto done;
380 
381 	/* Prepare a bucket of sockets in the kernel hashtable */
382 	addr = tc->family == AF_INET6 ? "::1" : "127.0.0.1";
383 	fds = start_reuseport_server(tc->family, tc->sock_type, addr, port, 0,
384 				     tc->init_socks);
385 	if (!ASSERT_OK_PTR(fds, "start_reuseport_server"))
386 		goto done;
387 	local_port = get_socket_local_port(*fds);
388 	if (!ASSERT_GE(local_port, 0, "get_socket_local_port"))
389 		goto done;
390 	skel->rodata->ports[0] = ntohs(local_port);
391 	skel->rodata->sf = tc->family;
392 
393 	err = sock_iter_batch__load(skel);
394 	if (!ASSERT_OK(err, "sock_iter_batch__load"))
395 		goto done;
396 
397 	link = bpf_program__attach_iter(tc->sock_type == SOCK_STREAM ?
398 					skel->progs.iter_tcp_soreuse :
399 					skel->progs.iter_udp_soreuse,
400 					NULL);
401 	if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter"))
402 		goto done;
403 
404 	iter_fd = bpf_iter_create(bpf_link__fd(link));
405 	if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create"))
406 		goto done;
407 
408 	tc->test(tc->family, tc->sock_type, addr, port, fds, tc->init_socks,
409 		 counts, tc->max_socks, link, iter_fd);
410 done:
411 	free(counts);
412 	free_fds(fds, tc->init_socks);
413 	if (iter_fd >= 0)
414 		close(iter_fd);
415 	bpf_link__destroy(link);
416 	sock_iter_batch__destroy(skel);
417 }
418 
419 static void do_resume_tests(void)
420 {
421 	int i;
422 
423 	for (i = 0; i < ARRAY_SIZE(resume_tests); i++) {
424 		if (test__start_subtest(resume_tests[i].description)) {
425 			do_resume_test(&resume_tests[i]);
426 		}
427 	}
428 }
429 
430 static void do_test(int sock_type, bool onebyone)
431 {
432 	int err, i, nread, to_read, total_read, iter_fd = -1;
433 	struct iter_out outputs[nr_soreuse];
434 	struct bpf_link *link = NULL;
435 	struct sock_iter_batch *skel;
436 	int first_idx, second_idx;
437 	int *fds[2] = {};
438 
439 	skel = sock_iter_batch__open();
440 	if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open"))
441 		return;
442 
443 	/* Prepare 2 buckets of sockets in the kernel hashtable */
444 	for (i = 0; i < ARRAY_SIZE(fds); i++) {
445 		int local_port;
446 
447 		fds[i] = start_reuseport_server(AF_INET6, sock_type, "::1", 0, 0,
448 						nr_soreuse);
449 		if (!ASSERT_OK_PTR(fds[i], "start_reuseport_server"))
450 			goto done;
451 		local_port = get_socket_local_port(*fds[i]);
452 		if (!ASSERT_GE(local_port, 0, "get_socket_local_port"))
453 			goto done;
454 		skel->rodata->ports[i] = ntohs(local_port);
455 	}
456 	skel->rodata->sf = AF_INET6;
457 
458 	err = sock_iter_batch__load(skel);
459 	if (!ASSERT_OK(err, "sock_iter_batch__load"))
460 		goto done;
461 
462 	link = bpf_program__attach_iter(sock_type == SOCK_STREAM ?
463 					skel->progs.iter_tcp_soreuse :
464 					skel->progs.iter_udp_soreuse,
465 					NULL);
466 	if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter"))
467 		goto done;
468 
469 	iter_fd = bpf_iter_create(bpf_link__fd(link));
470 	if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create"))
471 		goto done;
472 
473 	/* Test reading a bucket (either from fds[0] or fds[1]).
474 	 * Only read "nr_soreuse - 1" number of sockets
475 	 * from a bucket and leave one socket out from
476 	 * that bucket on purpose.
477 	 */
478 	to_read = (nr_soreuse - 1) * sizeof(*outputs);
479 	total_read = 0;
480 	first_idx = -1;
481 	do {
482 		nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read);
483 		if (nread <= 0 || nread % sizeof(*outputs))
484 			break;
485 		total_read += nread;
486 
487 		if (first_idx == -1)
488 			first_idx = outputs[0].idx;
489 		for (i = 0; i < nread / sizeof(*outputs); i++)
490 			ASSERT_EQ(outputs[i].idx, first_idx, "first_idx");
491 	} while (total_read < to_read);
492 	ASSERT_EQ(nread, onebyone ? sizeof(*outputs) : to_read, "nread");
493 	ASSERT_EQ(total_read, to_read, "total_read");
494 
495 	free_fds(fds[first_idx], nr_soreuse);
496 	fds[first_idx] = NULL;
497 
498 	/* Read the "whole" second bucket */
499 	to_read = nr_soreuse * sizeof(*outputs);
500 	total_read = 0;
501 	second_idx = !first_idx;
502 	do {
503 		nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read);
504 		if (nread <= 0 || nread % sizeof(*outputs))
505 			break;
506 		total_read += nread;
507 
508 		for (i = 0; i < nread / sizeof(*outputs); i++)
509 			ASSERT_EQ(outputs[i].idx, second_idx, "second_idx");
510 	} while (total_read <= to_read);
511 	ASSERT_EQ(nread, 0, "nread");
512 	/* Both so_reuseport ports should be in different buckets, so
513 	 * total_read must equal to the expected to_read.
514 	 *
515 	 * For a very unlikely case, both ports collide at the same bucket,
516 	 * the bucket offset (i.e. 3) will be skipped and it cannot
517 	 * expect the to_read number of bytes.
518 	 */
519 	if (skel->bss->bucket[0] != skel->bss->bucket[1])
520 		ASSERT_EQ(total_read, to_read, "total_read");
521 
522 done:
523 	for (i = 0; i < ARRAY_SIZE(fds); i++)
524 		free_fds(fds[i], nr_soreuse);
525 	if (iter_fd < 0)
526 		close(iter_fd);
527 	bpf_link__destroy(link);
528 	sock_iter_batch__destroy(skel);
529 }
530 
531 void test_sock_iter_batch(void)
532 {
533 	struct nstoken *nstoken = NULL;
534 
535 	SYS_NOFAIL("ip netns del " TEST_NS);
536 	SYS(done, "ip netns add %s", TEST_NS);
537 	SYS(done, "ip -net %s link set dev lo up", TEST_NS);
538 
539 	nstoken = open_netns(TEST_NS);
540 	if (!ASSERT_OK_PTR(nstoken, "open_netns"))
541 		goto done;
542 
543 	if (test__start_subtest("tcp")) {
544 		do_test(SOCK_STREAM, true);
545 		do_test(SOCK_STREAM, false);
546 	}
547 	if (test__start_subtest("udp")) {
548 		do_test(SOCK_DGRAM, true);
549 		do_test(SOCK_DGRAM, false);
550 	}
551 	do_resume_tests();
552 	close_netns(nstoken);
553 
554 done:
555 	SYS_NOFAIL("ip netns del " TEST_NS);
556 }
557