xref: /linux/tools/testing/selftests/net/tcp_ao/lib/setup.c (revision 9410645520e9b820069761f3450ef6661418e279)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <fcntl.h>
3 #include <pthread.h>
4 #include <sched.h>
5 #include <signal.h>
6 #include "aolib.h"
7 
8 /*
9  * Can't be included in the header: it defines static variables which
10  * will be unique to every object. Let's include it only once here.
11  */
12 #include "../../../kselftest.h"
13 
14 /* Prevent overriding of one thread's output by another */
15 static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER;
16 
__test_msg(const char * buf)17 void __test_msg(const char *buf)
18 {
19 	pthread_mutex_lock(&ksft_print_lock);
20 	ksft_print_msg("%s", buf);
21 	pthread_mutex_unlock(&ksft_print_lock);
22 }
__test_ok(const char * buf)23 void __test_ok(const char *buf)
24 {
25 	pthread_mutex_lock(&ksft_print_lock);
26 	ksft_test_result_pass("%s", buf);
27 	pthread_mutex_unlock(&ksft_print_lock);
28 }
__test_fail(const char * buf)29 void __test_fail(const char *buf)
30 {
31 	pthread_mutex_lock(&ksft_print_lock);
32 	ksft_test_result_fail("%s", buf);
33 	pthread_mutex_unlock(&ksft_print_lock);
34 }
__test_xfail(const char * buf)35 void __test_xfail(const char *buf)
36 {
37 	pthread_mutex_lock(&ksft_print_lock);
38 	ksft_test_result_xfail("%s", buf);
39 	pthread_mutex_unlock(&ksft_print_lock);
40 }
__test_error(const char * buf)41 void __test_error(const char *buf)
42 {
43 	pthread_mutex_lock(&ksft_print_lock);
44 	ksft_test_result_error("%s", buf);
45 	pthread_mutex_unlock(&ksft_print_lock);
46 }
__test_skip(const char * buf)47 void __test_skip(const char *buf)
48 {
49 	pthread_mutex_lock(&ksft_print_lock);
50 	ksft_test_result_skip("%s", buf);
51 	pthread_mutex_unlock(&ksft_print_lock);
52 }
53 
54 static volatile int failed;
55 static volatile int skipped;
56 
test_failed(void)57 void test_failed(void)
58 {
59 	failed = 1;
60 }
61 
test_exit(void)62 static void test_exit(void)
63 {
64 	if (failed) {
65 		ksft_exit_fail();
66 	} else if (skipped) {
67 		/* ksft_exit_skip() is different from ksft_exit_*() */
68 		ksft_print_cnts();
69 		exit(KSFT_SKIP);
70 	} else {
71 		ksft_exit_pass();
72 	}
73 }
74 
75 struct dlist_t {
76 	void (*destruct)(void);
77 	struct dlist_t *next;
78 };
79 static struct dlist_t *destructors_list;
80 
test_add_destructor(void (* d)(void))81 void test_add_destructor(void (*d)(void))
82 {
83 	struct dlist_t *p;
84 
85 	p = malloc(sizeof(struct dlist_t));
86 	if (p == NULL)
87 		test_error("malloc() failed");
88 
89 	p->next = destructors_list;
90 	p->destruct = d;
91 	destructors_list = p;
92 }
93 
94 static void test_destructor(void) __attribute__((destructor));
test_destructor(void)95 static void test_destructor(void)
96 {
97 	while (destructors_list) {
98 		struct dlist_t *p = destructors_list->next;
99 
100 		destructors_list->destruct();
101 		free(destructors_list);
102 		destructors_list = p;
103 	}
104 	test_exit();
105 }
106 
sig_int(int signo)107 static void sig_int(int signo)
108 {
109 	test_error("Caught SIGINT - exiting");
110 }
111 
open_netns(void)112 int open_netns(void)
113 {
114 	const char *netns_path = "/proc/thread-self/ns/net";
115 	int fd;
116 
117 	fd = open(netns_path, O_RDONLY);
118 	if (fd < 0)
119 		test_error("open(%s)", netns_path);
120 	return fd;
121 }
122 
unshare_open_netns(void)123 int unshare_open_netns(void)
124 {
125 	if (unshare(CLONE_NEWNET) != 0)
126 		test_error("unshare()");
127 
128 	return open_netns();
129 }
130 
switch_ns(int fd)131 void switch_ns(int fd)
132 {
133 	if (setns(fd, CLONE_NEWNET))
134 		test_error("setns()");
135 }
136 
switch_save_ns(int new_ns)137 int switch_save_ns(int new_ns)
138 {
139 	int ret = open_netns();
140 
141 	switch_ns(new_ns);
142 	return ret;
143 }
144 
switch_close_ns(int fd)145 void switch_close_ns(int fd)
146 {
147 	if (setns(fd, CLONE_NEWNET))
148 		test_error("setns()");
149 	close(fd);
150 }
151 
152 static int nsfd_outside	= -1;
153 static int nsfd_parent	= -1;
154 static int nsfd_child	= -1;
155 const char veth_name[]	= "ktst-veth";
156 
init_namespaces(void)157 static void init_namespaces(void)
158 {
159 	nsfd_outside = open_netns();
160 	nsfd_parent = unshare_open_netns();
161 	nsfd_child = unshare_open_netns();
162 }
163 
link_init(const char * veth,int family,uint8_t prefix,union tcp_addr addr,union tcp_addr dest)164 static void link_init(const char *veth, int family, uint8_t prefix,
165 		      union tcp_addr addr, union tcp_addr dest)
166 {
167 	if (link_set_up(veth))
168 		test_error("Failed to set link up");
169 	if (ip_addr_add(veth, family, addr, prefix))
170 		test_error("Failed to add ip address");
171 	if (ip_route_add(veth, family, addr, dest))
172 		test_error("Failed to add route");
173 }
174 
175 static unsigned int nr_threads = 1;
176 
177 static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER;
178 static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER;
179 static volatile unsigned int stage_threads[2];
180 static volatile unsigned int stage_nr;
181 
182 /* synchronize all threads in the same stage */
synchronize_threads(void)183 void synchronize_threads(void)
184 {
185 	unsigned int q = stage_nr;
186 
187 	pthread_mutex_lock(&sync_lock);
188 	stage_threads[q]++;
189 	if (stage_threads[q] == nr_threads) {
190 		stage_nr ^= 1;
191 		stage_threads[stage_nr] = 0;
192 		pthread_cond_signal(&sync_cond);
193 	}
194 	while (stage_threads[q] < nr_threads)
195 		pthread_cond_wait(&sync_cond, &sync_lock);
196 	pthread_mutex_unlock(&sync_lock);
197 }
198 
199 __thread union tcp_addr this_ip_addr;
200 __thread union tcp_addr this_ip_dest;
201 int test_family;
202 
203 struct new_pthread_arg {
204 	thread_fn	func;
205 	union tcp_addr	my_ip;
206 	union tcp_addr	dest_ip;
207 };
new_pthread_entry(void * arg)208 static void *new_pthread_entry(void *arg)
209 {
210 	struct new_pthread_arg *p = arg;
211 
212 	this_ip_addr = p->my_ip;
213 	this_ip_dest = p->dest_ip;
214 	p->func(NULL); /* shouldn't return */
215 	exit(KSFT_FAIL);
216 }
217 
__test_skip_all(const char * msg)218 static void __test_skip_all(const char *msg)
219 {
220 	ksft_set_plan(1);
221 	ksft_print_header();
222 	skipped = 1;
223 	test_skip("%s", msg);
224 	exit(KSFT_SKIP);
225 }
226 
__test_init(unsigned int ntests,int family,unsigned int prefix,union tcp_addr addr1,union tcp_addr addr2,thread_fn peer1,thread_fn peer2)227 void __test_init(unsigned int ntests, int family, unsigned int prefix,
228 		 union tcp_addr addr1, union tcp_addr addr2,
229 		 thread_fn peer1, thread_fn peer2)
230 {
231 	struct sigaction sa = {
232 		.sa_handler = sig_int,
233 		.sa_flags = SA_RESTART,
234 	};
235 	time_t seed = time(NULL);
236 
237 	sigemptyset(&sa.sa_mask);
238 	if (sigaction(SIGINT, &sa, NULL))
239 		test_error("Can't set SIGINT handler");
240 
241 	test_family = family;
242 	if (!kernel_config_has(KCONFIG_NET_NS))
243 		__test_skip_all(tests_skip_reason[KCONFIG_NET_NS]);
244 	if (!kernel_config_has(KCONFIG_VETH))
245 		__test_skip_all(tests_skip_reason[KCONFIG_VETH]);
246 	if (!kernel_config_has(KCONFIG_TCP_AO))
247 		__test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]);
248 
249 	ksft_set_plan(ntests);
250 	test_print("rand seed %u", (unsigned int)seed);
251 	srand(seed);
252 
253 	ksft_print_header();
254 	init_namespaces();
255 	test_init_ftrace(nsfd_parent, nsfd_child);
256 
257 	if (add_veth(veth_name, nsfd_parent, nsfd_child))
258 		test_error("Failed to add veth");
259 
260 	switch_ns(nsfd_child);
261 	link_init(veth_name, family, prefix, addr2, addr1);
262 	if (peer2) {
263 		struct new_pthread_arg targ;
264 		pthread_t t;
265 
266 		targ.my_ip = addr2;
267 		targ.dest_ip = addr1;
268 		targ.func = peer2;
269 		nr_threads++;
270 		if (pthread_create(&t, NULL, new_pthread_entry, &targ))
271 			test_error("Failed to create pthread");
272 	}
273 	switch_ns(nsfd_parent);
274 	link_init(veth_name, family, prefix, addr1, addr2);
275 
276 	this_ip_addr = addr1;
277 	this_ip_dest = addr2;
278 	peer1(NULL);
279 	if (failed)
280 		exit(KSFT_FAIL);
281 	else
282 		exit(KSFT_PASS);
283 }
284 
285 /* /proc/sys/net/core/optmem_max artifically limits the amount of memory
286  * that can be allocated with sock_kmalloc() on each socket in the system.
287  * It is not virtualized in v6.7, so it has to written outside test
288  * namespaces. To be nice a test will revert optmem back to the old value.
289  * Keeping it simple without any file lock, which means the tests that
290  * need to set/increase optmem value shouldn't run in parallel.
291  * Also, not re-entrant.
292  * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max")
293  * it is per-namespace, keeping logic for non-virtualized optmem_max
294  * for v6.7, which supports TCP-AO.
295  */
296 static const char *optmem_file = "/proc/sys/net/core/optmem_max";
297 static size_t saved_optmem;
298 static int optmem_ns = -1;
299 
is_optmem_namespaced(void)300 static bool is_optmem_namespaced(void)
301 {
302 	if (optmem_ns == -1) {
303 		int old_ns = switch_save_ns(nsfd_child);
304 
305 		optmem_ns = !access(optmem_file, F_OK);
306 		switch_close_ns(old_ns);
307 	}
308 	return !!optmem_ns;
309 }
310 
test_get_optmem(void)311 size_t test_get_optmem(void)
312 {
313 	int old_ns = 0;
314 	FILE *foptmem;
315 	size_t ret;
316 
317 	if (!is_optmem_namespaced())
318 		old_ns = switch_save_ns(nsfd_outside);
319 	foptmem = fopen(optmem_file, "r");
320 	if (!foptmem)
321 		test_error("failed to open %s", optmem_file);
322 
323 	if (fscanf(foptmem, "%zu", &ret) != 1)
324 		test_error("can't read from %s", optmem_file);
325 	fclose(foptmem);
326 	if (!is_optmem_namespaced())
327 		switch_close_ns(old_ns);
328 	return ret;
329 }
330 
__test_set_optmem(size_t new,size_t * old)331 static void __test_set_optmem(size_t new, size_t *old)
332 {
333 	int old_ns = 0;
334 	FILE *foptmem;
335 
336 	if (old != NULL)
337 		*old = test_get_optmem();
338 
339 	if (!is_optmem_namespaced())
340 		old_ns = switch_save_ns(nsfd_outside);
341 	foptmem = fopen(optmem_file, "w");
342 	if (!foptmem)
343 		test_error("failed to open %s", optmem_file);
344 
345 	if (fprintf(foptmem, "%zu", new) <= 0)
346 		test_error("can't write %zu to %s", new, optmem_file);
347 	fclose(foptmem);
348 	if (!is_optmem_namespaced())
349 		switch_close_ns(old_ns);
350 }
351 
test_revert_optmem(void)352 static void test_revert_optmem(void)
353 {
354 	if (saved_optmem == 0)
355 		return;
356 
357 	__test_set_optmem(saved_optmem, NULL);
358 }
359 
test_set_optmem(size_t value)360 void test_set_optmem(size_t value)
361 {
362 	if (saved_optmem == 0) {
363 		__test_set_optmem(value, &saved_optmem);
364 		test_add_destructor(test_revert_optmem);
365 	} else {
366 		__test_set_optmem(value, NULL);
367 	}
368 }
369