1 // SPDX-License-Identifier: GPL-2.0 2 #include <fcntl.h> 3 #include <pthread.h> 4 #include <sched.h> 5 #include <signal.h> 6 #include "aolib.h" 7 8 /* 9 * Can't be included in the header: it defines static variables which 10 * will be unique to every object. Let's include it only once here. 11 */ 12 #include "../../../kselftest.h" 13 14 /* Prevent overriding of one thread's output by another */ 15 static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER; 16 17 void __test_msg(const char *buf) 18 { 19 pthread_mutex_lock(&ksft_print_lock); 20 ksft_print_msg("%s", buf); 21 pthread_mutex_unlock(&ksft_print_lock); 22 } 23 void __test_ok(const char *buf) 24 { 25 pthread_mutex_lock(&ksft_print_lock); 26 ksft_test_result_pass("%s", buf); 27 pthread_mutex_unlock(&ksft_print_lock); 28 } 29 void __test_fail(const char *buf) 30 { 31 pthread_mutex_lock(&ksft_print_lock); 32 ksft_test_result_fail("%s", buf); 33 pthread_mutex_unlock(&ksft_print_lock); 34 } 35 void __test_xfail(const char *buf) 36 { 37 pthread_mutex_lock(&ksft_print_lock); 38 ksft_test_result_xfail("%s", buf); 39 pthread_mutex_unlock(&ksft_print_lock); 40 } 41 void __test_error(const char *buf) 42 { 43 pthread_mutex_lock(&ksft_print_lock); 44 ksft_test_result_error("%s", buf); 45 pthread_mutex_unlock(&ksft_print_lock); 46 } 47 void __test_skip(const char *buf) 48 { 49 pthread_mutex_lock(&ksft_print_lock); 50 ksft_test_result_skip("%s", buf); 51 pthread_mutex_unlock(&ksft_print_lock); 52 } 53 54 static volatile int failed; 55 static volatile int skipped; 56 57 void test_failed(void) 58 { 59 failed = 1; 60 } 61 62 static void test_exit(void) 63 { 64 if (failed) { 65 ksft_exit_fail(); 66 } else if (skipped) { 67 /* ksft_exit_skip() is different from ksft_exit_*() */ 68 ksft_print_cnts(); 69 exit(KSFT_SKIP); 70 } else { 71 ksft_exit_pass(); 72 } 73 } 74 75 struct dlist_t { 76 void (*destruct)(void); 77 struct dlist_t *next; 78 }; 79 static struct dlist_t *destructors_list; 80 81 void test_add_destructor(void (*d)(void)) 82 { 83 struct dlist_t *p; 84 85 p = malloc(sizeof(struct dlist_t)); 86 if (p == NULL) 87 test_error("malloc() failed"); 88 89 p->next = destructors_list; 90 p->destruct = d; 91 destructors_list = p; 92 } 93 94 static void test_destructor(void) __attribute__((destructor)); 95 static void test_destructor(void) 96 { 97 while (destructors_list) { 98 struct dlist_t *p = destructors_list->next; 99 100 destructors_list->destruct(); 101 free(destructors_list); 102 destructors_list = p; 103 } 104 test_exit(); 105 } 106 107 static void sig_int(int signo) 108 { 109 test_error("Caught SIGINT - exiting"); 110 } 111 112 int open_netns(void) 113 { 114 const char *netns_path = "/proc/self/ns/net"; 115 int fd; 116 117 fd = open(netns_path, O_RDONLY); 118 if (fd < 0) 119 test_error("open(%s)", netns_path); 120 return fd; 121 } 122 123 int unshare_open_netns(void) 124 { 125 if (unshare(CLONE_NEWNET) != 0) 126 test_error("unshare()"); 127 128 return open_netns(); 129 } 130 131 void switch_ns(int fd) 132 { 133 if (setns(fd, CLONE_NEWNET)) 134 test_error("setns()"); 135 } 136 137 int switch_save_ns(int new_ns) 138 { 139 int ret = open_netns(); 140 141 switch_ns(new_ns); 142 return ret; 143 } 144 145 static int nsfd_outside = -1; 146 static int nsfd_parent = -1; 147 static int nsfd_child = -1; 148 const char veth_name[] = "ktst-veth"; 149 150 static void init_namespaces(void) 151 { 152 nsfd_outside = open_netns(); 153 nsfd_parent = unshare_open_netns(); 154 nsfd_child = unshare_open_netns(); 155 } 156 157 static void link_init(const char *veth, int family, uint8_t prefix, 158 union tcp_addr addr, union tcp_addr dest) 159 { 160 if (link_set_up(veth)) 161 test_error("Failed to set link up"); 162 if (ip_addr_add(veth, family, addr, prefix)) 163 test_error("Failed to add ip address"); 164 if (ip_route_add(veth, family, addr, dest)) 165 test_error("Failed to add route"); 166 } 167 168 static unsigned int nr_threads = 1; 169 170 static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER; 171 static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER; 172 static volatile unsigned int stage_threads[2]; 173 static volatile unsigned int stage_nr; 174 175 /* synchronize all threads in the same stage */ 176 void synchronize_threads(void) 177 { 178 unsigned int q = stage_nr; 179 180 pthread_mutex_lock(&sync_lock); 181 stage_threads[q]++; 182 if (stage_threads[q] == nr_threads) { 183 stage_nr ^= 1; 184 stage_threads[stage_nr] = 0; 185 pthread_cond_signal(&sync_cond); 186 } 187 while (stage_threads[q] < nr_threads) 188 pthread_cond_wait(&sync_cond, &sync_lock); 189 pthread_mutex_unlock(&sync_lock); 190 } 191 192 __thread union tcp_addr this_ip_addr; 193 __thread union tcp_addr this_ip_dest; 194 int test_family; 195 196 struct new_pthread_arg { 197 thread_fn func; 198 union tcp_addr my_ip; 199 union tcp_addr dest_ip; 200 }; 201 static void *new_pthread_entry(void *arg) 202 { 203 struct new_pthread_arg *p = arg; 204 205 this_ip_addr = p->my_ip; 206 this_ip_dest = p->dest_ip; 207 p->func(NULL); /* shouldn't return */ 208 exit(KSFT_FAIL); 209 } 210 211 static void __test_skip_all(const char *msg) 212 { 213 ksft_set_plan(1); 214 ksft_print_header(); 215 skipped = 1; 216 test_skip("%s", msg); 217 exit(KSFT_SKIP); 218 } 219 220 void __test_init(unsigned int ntests, int family, unsigned int prefix, 221 union tcp_addr addr1, union tcp_addr addr2, 222 thread_fn peer1, thread_fn peer2) 223 { 224 struct sigaction sa = { 225 .sa_handler = sig_int, 226 .sa_flags = SA_RESTART, 227 }; 228 time_t seed = time(NULL); 229 230 sigemptyset(&sa.sa_mask); 231 if (sigaction(SIGINT, &sa, NULL)) 232 test_error("Can't set SIGINT handler"); 233 234 test_family = family; 235 if (!kernel_config_has(KCONFIG_NET_NS)) 236 __test_skip_all(tests_skip_reason[KCONFIG_NET_NS]); 237 if (!kernel_config_has(KCONFIG_VETH)) 238 __test_skip_all(tests_skip_reason[KCONFIG_VETH]); 239 if (!kernel_config_has(KCONFIG_TCP_AO)) 240 __test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]); 241 242 ksft_set_plan(ntests); 243 test_print("rand seed %u", (unsigned int)seed); 244 srand(seed); 245 246 247 ksft_print_header(); 248 init_namespaces(); 249 250 if (add_veth(veth_name, nsfd_parent, nsfd_child)) 251 test_error("Failed to add veth"); 252 253 switch_ns(nsfd_child); 254 link_init(veth_name, family, prefix, addr2, addr1); 255 if (peer2) { 256 struct new_pthread_arg targ; 257 pthread_t t; 258 259 targ.my_ip = addr2; 260 targ.dest_ip = addr1; 261 targ.func = peer2; 262 nr_threads++; 263 if (pthread_create(&t, NULL, new_pthread_entry, &targ)) 264 test_error("Failed to create pthread"); 265 } 266 switch_ns(nsfd_parent); 267 link_init(veth_name, family, prefix, addr1, addr2); 268 269 this_ip_addr = addr1; 270 this_ip_dest = addr2; 271 peer1(NULL); 272 if (failed) 273 exit(KSFT_FAIL); 274 else 275 exit(KSFT_PASS); 276 } 277 278 /* /proc/sys/net/core/optmem_max artifically limits the amount of memory 279 * that can be allocated with sock_kmalloc() on each socket in the system. 280 * It is not virtualized in v6.7, so it has to written outside test 281 * namespaces. To be nice a test will revert optmem back to the old value. 282 * Keeping it simple without any file lock, which means the tests that 283 * need to set/increase optmem value shouldn't run in parallel. 284 * Also, not re-entrant. 285 * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max") 286 * it is per-namespace, keeping logic for non-virtualized optmem_max 287 * for v6.7, which supports TCP-AO. 288 */ 289 static const char *optmem_file = "/proc/sys/net/core/optmem_max"; 290 static size_t saved_optmem; 291 static int optmem_ns = -1; 292 293 static bool is_optmem_namespaced(void) 294 { 295 if (optmem_ns == -1) { 296 int old_ns = switch_save_ns(nsfd_child); 297 298 optmem_ns = !access(optmem_file, F_OK); 299 switch_ns(old_ns); 300 } 301 return !!optmem_ns; 302 } 303 304 size_t test_get_optmem(void) 305 { 306 int old_ns = 0; 307 FILE *foptmem; 308 size_t ret; 309 310 if (!is_optmem_namespaced()) 311 old_ns = switch_save_ns(nsfd_outside); 312 foptmem = fopen(optmem_file, "r"); 313 if (!foptmem) 314 test_error("failed to open %s", optmem_file); 315 316 if (fscanf(foptmem, "%zu", &ret) != 1) 317 test_error("can't read from %s", optmem_file); 318 fclose(foptmem); 319 if (!is_optmem_namespaced()) 320 switch_ns(old_ns); 321 return ret; 322 } 323 324 static void __test_set_optmem(size_t new, size_t *old) 325 { 326 int old_ns = 0; 327 FILE *foptmem; 328 329 if (old != NULL) 330 *old = test_get_optmem(); 331 332 if (!is_optmem_namespaced()) 333 old_ns = switch_save_ns(nsfd_outside); 334 foptmem = fopen(optmem_file, "w"); 335 if (!foptmem) 336 test_error("failed to open %s", optmem_file); 337 338 if (fprintf(foptmem, "%zu", new) <= 0) 339 test_error("can't write %zu to %s", new, optmem_file); 340 fclose(foptmem); 341 if (!is_optmem_namespaced()) 342 switch_ns(old_ns); 343 } 344 345 static void test_revert_optmem(void) 346 { 347 if (saved_optmem == 0) 348 return; 349 350 __test_set_optmem(saved_optmem, NULL); 351 } 352 353 void test_set_optmem(size_t value) 354 { 355 if (saved_optmem == 0) { 356 __test_set_optmem(value, &saved_optmem); 357 test_add_destructor(test_revert_optmem); 358 } else { 359 __test_set_optmem(value, NULL); 360 } 361 } 362