1 // SPDX-License-Identifier: GPL-2.0 2 #include <fcntl.h> 3 #include <pthread.h> 4 #include <sched.h> 5 #include <signal.h> 6 #include "aolib.h" 7 8 /* 9 * Can't be included in the header: it defines static variables which 10 * will be unique to every object. Let's include it only once here. 11 */ 12 #include "../../../kselftest.h" 13 14 /* Prevent overriding of one thread's output by another */ 15 static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER; 16 17 void __test_msg(const char *buf) 18 { 19 pthread_mutex_lock(&ksft_print_lock); 20 ksft_print_msg("%s", buf); 21 pthread_mutex_unlock(&ksft_print_lock); 22 } 23 void __test_ok(const char *buf) 24 { 25 pthread_mutex_lock(&ksft_print_lock); 26 ksft_test_result_pass("%s", buf); 27 pthread_mutex_unlock(&ksft_print_lock); 28 } 29 void __test_fail(const char *buf) 30 { 31 pthread_mutex_lock(&ksft_print_lock); 32 ksft_test_result_fail("%s", buf); 33 pthread_mutex_unlock(&ksft_print_lock); 34 } 35 void __test_xfail(const char *buf) 36 { 37 pthread_mutex_lock(&ksft_print_lock); 38 ksft_test_result_xfail("%s", buf); 39 pthread_mutex_unlock(&ksft_print_lock); 40 } 41 void __test_error(const char *buf) 42 { 43 pthread_mutex_lock(&ksft_print_lock); 44 ksft_test_result_error("%s", buf); 45 pthread_mutex_unlock(&ksft_print_lock); 46 } 47 void __test_skip(const char *buf) 48 { 49 pthread_mutex_lock(&ksft_print_lock); 50 ksft_test_result_skip("%s", buf); 51 pthread_mutex_unlock(&ksft_print_lock); 52 } 53 54 static volatile int failed; 55 static volatile int skipped; 56 57 void test_failed(void) 58 { 59 failed = 1; 60 } 61 62 static void test_exit(void) 63 { 64 if (failed) { 65 ksft_exit_fail(); 66 } else if (skipped) { 67 /* ksft_exit_skip() is different from ksft_exit_*() */ 68 ksft_print_cnts(); 69 exit(KSFT_SKIP); 70 } else { 71 ksft_exit_pass(); 72 } 73 } 74 75 struct dlist_t { 76 void (*destruct)(void); 77 struct dlist_t *next; 78 }; 79 static struct dlist_t *destructors_list; 80 81 void test_add_destructor(void (*d)(void)) 82 { 83 struct dlist_t *p; 84 85 p = malloc(sizeof(struct dlist_t)); 86 if (p == NULL) 87 test_error("malloc() failed"); 88 89 p->next = destructors_list; 90 p->destruct = d; 91 destructors_list = p; 92 } 93 94 static void test_destructor(void) __attribute__((destructor)); 95 static void test_destructor(void) 96 { 97 while (destructors_list) { 98 struct dlist_t *p = destructors_list->next; 99 100 destructors_list->destruct(); 101 free(destructors_list); 102 destructors_list = p; 103 } 104 test_exit(); 105 } 106 107 static void sig_int(int signo) 108 { 109 test_error("Caught SIGINT - exiting"); 110 } 111 112 int open_netns(void) 113 { 114 const char *netns_path = "/proc/thread-self/ns/net"; 115 int fd; 116 117 fd = open(netns_path, O_RDONLY); 118 if (fd < 0) 119 test_error("open(%s)", netns_path); 120 return fd; 121 } 122 123 int unshare_open_netns(void) 124 { 125 if (unshare(CLONE_NEWNET) != 0) 126 test_error("unshare()"); 127 128 return open_netns(); 129 } 130 131 void switch_ns(int fd) 132 { 133 if (setns(fd, CLONE_NEWNET)) 134 test_error("setns()"); 135 } 136 137 int switch_save_ns(int new_ns) 138 { 139 int ret = open_netns(); 140 141 switch_ns(new_ns); 142 return ret; 143 } 144 145 void switch_close_ns(int fd) 146 { 147 if (setns(fd, CLONE_NEWNET)) 148 test_error("setns()"); 149 close(fd); 150 } 151 152 static int nsfd_outside = -1; 153 static int nsfd_parent = -1; 154 static int nsfd_child = -1; 155 const char veth_name[] = "ktst-veth"; 156 157 static void init_namespaces(void) 158 { 159 nsfd_outside = open_netns(); 160 nsfd_parent = unshare_open_netns(); 161 nsfd_child = unshare_open_netns(); 162 } 163 164 static void link_init(const char *veth, int family, uint8_t prefix, 165 union tcp_addr addr, union tcp_addr dest) 166 { 167 if (link_set_up(veth)) 168 test_error("Failed to set link up"); 169 if (ip_addr_add(veth, family, addr, prefix)) 170 test_error("Failed to add ip address"); 171 if (ip_route_add(veth, family, addr, dest)) 172 test_error("Failed to add route"); 173 } 174 175 static unsigned int nr_threads = 1; 176 177 static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER; 178 static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER; 179 static volatile unsigned int stage_threads[2]; 180 static volatile unsigned int stage_nr; 181 182 /* synchronize all threads in the same stage */ 183 void synchronize_threads(void) 184 { 185 unsigned int q = stage_nr; 186 187 pthread_mutex_lock(&sync_lock); 188 stage_threads[q]++; 189 if (stage_threads[q] == nr_threads) { 190 stage_nr ^= 1; 191 stage_threads[stage_nr] = 0; 192 pthread_cond_signal(&sync_cond); 193 } 194 while (stage_threads[q] < nr_threads) 195 pthread_cond_wait(&sync_cond, &sync_lock); 196 pthread_mutex_unlock(&sync_lock); 197 } 198 199 __thread union tcp_addr this_ip_addr; 200 __thread union tcp_addr this_ip_dest; 201 int test_family; 202 203 struct new_pthread_arg { 204 thread_fn func; 205 union tcp_addr my_ip; 206 union tcp_addr dest_ip; 207 }; 208 static void *new_pthread_entry(void *arg) 209 { 210 struct new_pthread_arg *p = arg; 211 212 this_ip_addr = p->my_ip; 213 this_ip_dest = p->dest_ip; 214 p->func(NULL); /* shouldn't return */ 215 exit(KSFT_FAIL); 216 } 217 218 static void __test_skip_all(const char *msg) 219 { 220 ksft_set_plan(1); 221 ksft_print_header(); 222 skipped = 1; 223 test_skip("%s", msg); 224 exit(KSFT_SKIP); 225 } 226 227 void __test_init(unsigned int ntests, int family, unsigned int prefix, 228 union tcp_addr addr1, union tcp_addr addr2, 229 thread_fn peer1, thread_fn peer2) 230 { 231 struct sigaction sa = { 232 .sa_handler = sig_int, 233 .sa_flags = SA_RESTART, 234 }; 235 time_t seed = time(NULL); 236 237 sigemptyset(&sa.sa_mask); 238 if (sigaction(SIGINT, &sa, NULL)) 239 test_error("Can't set SIGINT handler"); 240 241 test_family = family; 242 if (!kernel_config_has(KCONFIG_NET_NS)) 243 __test_skip_all(tests_skip_reason[KCONFIG_NET_NS]); 244 if (!kernel_config_has(KCONFIG_VETH)) 245 __test_skip_all(tests_skip_reason[KCONFIG_VETH]); 246 if (!kernel_config_has(KCONFIG_TCP_AO)) 247 __test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]); 248 249 ksft_set_plan(ntests); 250 test_print("rand seed %u", (unsigned int)seed); 251 srand(seed); 252 253 ksft_print_header(); 254 init_namespaces(); 255 test_init_ftrace(nsfd_parent, nsfd_child); 256 257 if (add_veth(veth_name, nsfd_parent, nsfd_child)) 258 test_error("Failed to add veth"); 259 260 switch_ns(nsfd_child); 261 link_init(veth_name, family, prefix, addr2, addr1); 262 if (peer2) { 263 struct new_pthread_arg targ; 264 pthread_t t; 265 266 targ.my_ip = addr2; 267 targ.dest_ip = addr1; 268 targ.func = peer2; 269 nr_threads++; 270 if (pthread_create(&t, NULL, new_pthread_entry, &targ)) 271 test_error("Failed to create pthread"); 272 } 273 switch_ns(nsfd_parent); 274 link_init(veth_name, family, prefix, addr1, addr2); 275 276 this_ip_addr = addr1; 277 this_ip_dest = addr2; 278 peer1(NULL); 279 if (failed) 280 exit(KSFT_FAIL); 281 else 282 exit(KSFT_PASS); 283 } 284 285 /* /proc/sys/net/core/optmem_max artifically limits the amount of memory 286 * that can be allocated with sock_kmalloc() on each socket in the system. 287 * It is not virtualized in v6.7, so it has to written outside test 288 * namespaces. To be nice a test will revert optmem back to the old value. 289 * Keeping it simple without any file lock, which means the tests that 290 * need to set/increase optmem value shouldn't run in parallel. 291 * Also, not re-entrant. 292 * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max") 293 * it is per-namespace, keeping logic for non-virtualized optmem_max 294 * for v6.7, which supports TCP-AO. 295 */ 296 static const char *optmem_file = "/proc/sys/net/core/optmem_max"; 297 static size_t saved_optmem; 298 static int optmem_ns = -1; 299 300 static bool is_optmem_namespaced(void) 301 { 302 if (optmem_ns == -1) { 303 int old_ns = switch_save_ns(nsfd_child); 304 305 optmem_ns = !access(optmem_file, F_OK); 306 switch_close_ns(old_ns); 307 } 308 return !!optmem_ns; 309 } 310 311 size_t test_get_optmem(void) 312 { 313 int old_ns = 0; 314 FILE *foptmem; 315 size_t ret; 316 317 if (!is_optmem_namespaced()) 318 old_ns = switch_save_ns(nsfd_outside); 319 foptmem = fopen(optmem_file, "r"); 320 if (!foptmem) 321 test_error("failed to open %s", optmem_file); 322 323 if (fscanf(foptmem, "%zu", &ret) != 1) 324 test_error("can't read from %s", optmem_file); 325 fclose(foptmem); 326 if (!is_optmem_namespaced()) 327 switch_close_ns(old_ns); 328 return ret; 329 } 330 331 static void __test_set_optmem(size_t new, size_t *old) 332 { 333 int old_ns = 0; 334 FILE *foptmem; 335 336 if (old != NULL) 337 *old = test_get_optmem(); 338 339 if (!is_optmem_namespaced()) 340 old_ns = switch_save_ns(nsfd_outside); 341 foptmem = fopen(optmem_file, "w"); 342 if (!foptmem) 343 test_error("failed to open %s", optmem_file); 344 345 if (fprintf(foptmem, "%zu", new) <= 0) 346 test_error("can't write %zu to %s", new, optmem_file); 347 fclose(foptmem); 348 if (!is_optmem_namespaced()) 349 switch_close_ns(old_ns); 350 } 351 352 static void test_revert_optmem(void) 353 { 354 if (saved_optmem == 0) 355 return; 356 357 __test_set_optmem(saved_optmem, NULL); 358 } 359 360 void test_set_optmem(size_t value) 361 { 362 if (saved_optmem == 0) { 363 __test_set_optmem(value, &saved_optmem); 364 test_add_destructor(test_revert_optmem); 365 } else { 366 __test_set_optmem(value, NULL); 367 } 368 } 369