xref: /linux/tools/testing/selftests/bpf/prog_tests/sk_bypass_prot_mem.c (revision 24f171c7e145f43b9f187578e89b0982ce87e54c)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright 2025 Google LLC */
3 
4 #include <test_progs.h>
5 #include "sk_bypass_prot_mem.skel.h"
6 #include "network_helpers.h"
7 
8 #define NR_PAGES	32
9 #define NR_SOCKETS	2
10 #define BUF_TOTAL	(NR_PAGES * 4096 / NR_SOCKETS)
11 #define BUF_SINGLE	1024
12 #define NR_SEND		(BUF_TOTAL / BUF_SINGLE)
13 
14 struct test_case {
15 	char name[8];
16 	int family;
17 	int type;
18 	int (*create_sockets)(struct test_case *test_case, int sk[], int len);
19 	long (*get_memory_allocated)(struct test_case *test_case, struct sk_bypass_prot_mem *skel);
20 };
21 
22 static int tcp_create_sockets(struct test_case *test_case, int sk[], int len)
23 {
24 	int server, i, err = 0;
25 
26 	server = start_server(test_case->family, test_case->type, NULL, 0, 0);
27 	if (!ASSERT_GE(server, 0, "start_server_str"))
28 		return server;
29 
30 	/* Keep for-loop so we can change NR_SOCKETS easily. */
31 	for (i = 0; i < len; i += 2) {
32 		sk[i] = connect_to_fd(server, 0);
33 		if (sk[i] < 0) {
34 			ASSERT_GE(sk[i], 0, "connect_to_fd");
35 			err = sk[i];
36 			break;
37 		}
38 
39 		sk[i + 1] = accept(server, NULL, NULL);
40 		if (sk[i + 1] < 0) {
41 			ASSERT_GE(sk[i + 1], 0, "accept");
42 			err = sk[i + 1];
43 			break;
44 		}
45 	}
46 
47 	close(server);
48 
49 	return err;
50 }
51 
52 static int udp_create_sockets(struct test_case *test_case, int sk[], int len)
53 {
54 	int i, j, err, rcvbuf = BUF_TOTAL;
55 
56 	/* Keep for-loop so we can change NR_SOCKETS easily. */
57 	for (i = 0; i < len; i += 2) {
58 		sk[i] = start_server(test_case->family, test_case->type, NULL, 0, 0);
59 		if (sk[i] < 0) {
60 			ASSERT_GE(sk[i], 0, "start_server");
61 			return sk[i];
62 		}
63 
64 		sk[i + 1] = connect_to_fd(sk[i], 0);
65 		if (sk[i + 1] < 0) {
66 			ASSERT_GE(sk[i + 1], 0, "connect_to_fd");
67 			return sk[i + 1];
68 		}
69 
70 		err = connect_fd_to_fd(sk[i], sk[i + 1], 0);
71 		if (err) {
72 			ASSERT_EQ(err, 0, "connect_fd_to_fd");
73 			return err;
74 		}
75 
76 		for (j = 0; j < 2; j++) {
77 			err = setsockopt(sk[i + j], SOL_SOCKET, SO_RCVBUF, &rcvbuf, sizeof(int));
78 			if (err) {
79 				ASSERT_EQ(err, 0, "setsockopt(SO_RCVBUF)");
80 				return err;
81 			}
82 		}
83 	}
84 
85 	return 0;
86 }
87 
88 static long get_memory_allocated(struct test_case *test_case,
89 				 bool *activated, long *memory_allocated)
90 {
91 	int sk;
92 
93 	*activated = true;
94 
95 	/* AF_INET and AF_INET6 share the same memory_allocated.
96 	 * tcp_init_sock() is called by AF_INET and AF_INET6,
97 	 * but udp_lib_init_sock() is inline.
98 	 */
99 	sk = socket(AF_INET, test_case->type, 0);
100 	if (!ASSERT_GE(sk, 0, "get_memory_allocated"))
101 		return -1;
102 
103 	close(sk);
104 
105 	return *memory_allocated;
106 }
107 
108 static long tcp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel)
109 {
110 	return get_memory_allocated(test_case,
111 				    &skel->bss->tcp_activated,
112 				    &skel->bss->tcp_memory_allocated);
113 }
114 
115 static long udp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel)
116 {
117 	return get_memory_allocated(test_case,
118 				    &skel->bss->udp_activated,
119 				    &skel->bss->udp_memory_allocated);
120 }
121 
122 static int check_bypass(struct test_case *test_case,
123 			struct sk_bypass_prot_mem *skel, bool bypass)
124 {
125 	char buf[BUF_SINGLE] = {};
126 	long memory_allocated[2];
127 	int sk[NR_SOCKETS];
128 	int err, i, j;
129 
130 	for (i = 0; i < ARRAY_SIZE(sk); i++)
131 		sk[i] = -1;
132 
133 	err = test_case->create_sockets(test_case, sk, ARRAY_SIZE(sk));
134 	if (err)
135 		goto close;
136 
137 	memory_allocated[0] = test_case->get_memory_allocated(test_case, skel);
138 
139 	/* allocate pages >= NR_PAGES */
140 	for (i = 0; i < ARRAY_SIZE(sk); i++) {
141 		for (j = 0; j < NR_SEND; j++) {
142 			int bytes = send(sk[i], buf, sizeof(buf), 0);
143 
144 			/* Avoid too noisy logs when something failed. */
145 			if (bytes != sizeof(buf)) {
146 				ASSERT_EQ(bytes, sizeof(buf), "send");
147 				if (bytes < 0) {
148 					err = bytes;
149 					goto drain;
150 				}
151 			}
152 		}
153 	}
154 
155 	memory_allocated[1] = test_case->get_memory_allocated(test_case, skel);
156 
157 	if (bypass)
158 		ASSERT_LE(memory_allocated[1], memory_allocated[0] + 10, "bypass");
159 	else
160 		ASSERT_GT(memory_allocated[1], memory_allocated[0] + NR_PAGES, "no bypass");
161 
162 drain:
163 	if (test_case->type == SOCK_DGRAM) {
164 		/* UDP starts purging sk->sk_receive_queue after one RCU
165 		 * grace period, then udp_memory_allocated goes down,
166 		 * so drain the queue before close().
167 		 */
168 		for (i = 0; i < ARRAY_SIZE(sk); i++) {
169 			for (j = 0; j < NR_SEND; j++) {
170 				int bytes = recv(sk[i], buf, 1, MSG_DONTWAIT | MSG_TRUNC);
171 
172 				if (bytes == sizeof(buf))
173 					continue;
174 				if (bytes != -1 || errno != EAGAIN)
175 					PRINT_FAIL("bytes: %d, errno: %s\n", bytes, strerror(errno));
176 				break;
177 			}
178 		}
179 	}
180 
181 close:
182 	for (i = 0; i < ARRAY_SIZE(sk); i++) {
183 		if (sk[i] < 0)
184 			break;
185 
186 		close(sk[i]);
187 	}
188 
189 	return err;
190 }
191 
192 static void run_test(struct test_case *test_case)
193 {
194 	struct sk_bypass_prot_mem *skel;
195 	struct nstoken *nstoken;
196 	int cgroup, err;
197 
198 	skel = sk_bypass_prot_mem__open_and_load();
199 	if (!ASSERT_OK_PTR(skel, "open_and_load"))
200 		return;
201 
202 	skel->bss->nr_cpus = libbpf_num_possible_cpus();
203 
204 	err = sk_bypass_prot_mem__attach(skel);
205 	if (!ASSERT_OK(err, "attach"))
206 		goto destroy_skel;
207 
208 	cgroup = test__join_cgroup("/sk_bypass_prot_mem");
209 	if (!ASSERT_GE(cgroup, 0, "join_cgroup"))
210 		goto destroy_skel;
211 
212 	err = make_netns("sk_bypass_prot_mem");
213 	if (!ASSERT_EQ(err, 0, "make_netns"))
214 		goto close_cgroup;
215 
216 	nstoken = open_netns("sk_bypass_prot_mem");
217 	if (!ASSERT_OK_PTR(nstoken, "open_netns"))
218 		goto remove_netns;
219 
220 	err = check_bypass(test_case, skel, false);
221 	if (!ASSERT_EQ(err, 0, "test_bypass(false)"))
222 		goto close_netns;
223 
224 	err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "1");
225 	if (!ASSERT_EQ(err, 0, "write_sysctl(1)"))
226 		goto close_netns;
227 
228 	err = check_bypass(test_case, skel, true);
229 	if (!ASSERT_EQ(err, 0, "test_bypass(true by sysctl)"))
230 		goto close_netns;
231 
232 	err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "0");
233 	if (!ASSERT_EQ(err, 0, "write_sysctl(0)"))
234 		goto close_netns;
235 
236 	skel->links.sock_create = bpf_program__attach_cgroup(skel->progs.sock_create, cgroup);
237 	if (!ASSERT_OK_PTR(skel->links.sock_create, "attach_cgroup(sock_create)"))
238 		goto close_netns;
239 
240 	err = check_bypass(test_case, skel, true);
241 	ASSERT_EQ(err, 0, "test_bypass(true by bpf)");
242 
243 close_netns:
244 	close_netns(nstoken);
245 remove_netns:
246 	remove_netns("sk_bypass_prot_mem");
247 close_cgroup:
248 	close(cgroup);
249 destroy_skel:
250 	sk_bypass_prot_mem__destroy(skel);
251 }
252 
253 static struct test_case test_cases[] = {
254 	{
255 		.name = "TCP  ",
256 		.family = AF_INET,
257 		.type = SOCK_STREAM,
258 		.create_sockets = tcp_create_sockets,
259 		.get_memory_allocated = tcp_get_memory_allocated,
260 	},
261 	{
262 		.name = "UDP  ",
263 		.family = AF_INET,
264 		.type = SOCK_DGRAM,
265 		.create_sockets = udp_create_sockets,
266 		.get_memory_allocated = udp_get_memory_allocated,
267 	},
268 	{
269 		.name = "TCPv6",
270 		.family = AF_INET6,
271 		.type = SOCK_STREAM,
272 		.create_sockets = tcp_create_sockets,
273 		.get_memory_allocated = tcp_get_memory_allocated,
274 	},
275 	{
276 		.name = "UDPv6",
277 		.family = AF_INET6,
278 		.type = SOCK_DGRAM,
279 		.create_sockets = udp_create_sockets,
280 		.get_memory_allocated = udp_get_memory_allocated,
281 	},
282 };
283 
284 void serial_test_sk_bypass_prot_mem(void)
285 {
286 	int i;
287 
288 	for (i = 0; i < ARRAY_SIZE(test_cases); i++) {
289 		if (test__start_subtest(test_cases[i].name))
290 			run_test(&test_cases[i]);
291 	}
292 }
293