1 // SPDX-License-Identifier: GPL-2.0 2 /* Author: Dmitry Safonov <dima@arista.com> */ 3 #include <inttypes.h> 4 #include "../../../../include/linux/kernel.h" 5 #include "aolib.h" 6 7 const size_t quota = 1000; 8 /* 9 * Backlog == 0 means 1 connection in queue, see: 10 * commit 64a146513f8f ("[NET]: Revert incorrect accept queue...") 11 */ 12 const unsigned int backlog; 13 14 static void netstats_check(struct netstat *before, struct netstat *after, 15 char *msg) 16 { 17 uint64_t before_cnt, after_cnt; 18 19 before_cnt = netstat_get(before, "TCPAORequired", NULL); 20 after_cnt = netstat_get(after, "TCPAORequired", NULL); 21 if (after_cnt > before_cnt) 22 test_fail("Segments without AO sign (%s): %" PRIu64 " => %" PRIu64, 23 msg, before_cnt, after_cnt); 24 else 25 test_ok("No segments without AO sign (%s)", msg); 26 27 before_cnt = netstat_get(before, "TCPAOGood", NULL); 28 after_cnt = netstat_get(after, "TCPAOGood", NULL); 29 if (after_cnt <= before_cnt) 30 test_fail("Signed AO segments (%s): %" PRIu64 " => %" PRIu64, 31 msg, before_cnt, after_cnt); 32 else 33 test_ok("Signed AO segments (%s): %" PRIu64 " => %" PRIu64, 34 msg, before_cnt, after_cnt); 35 36 before_cnt = netstat_get(before, "TCPAOBad", NULL); 37 after_cnt = netstat_get(after, "TCPAOBad", NULL); 38 if (after_cnt > before_cnt) 39 test_fail("Segments with bad AO sign (%s): %" PRIu64 " => %" PRIu64, 40 msg, before_cnt, after_cnt); 41 else 42 test_ok("No segments with bad AO sign (%s)", msg); 43 } 44 45 /* 46 * Another way to send RST, but not through tcp_v{4,6}_send_reset() 47 * is tcp_send_active_reset(), that is not in reply to inbound segment, 48 * but rather active send. It uses tcp_transmit_skb(), so that should 49 * work, but as it also sends RST - nice that it can be covered as well. 50 */ 51 static void close_forced(int sk) 52 { 53 struct linger sl; 54 55 sl.l_onoff = 1; 56 sl.l_linger = 0; 57 if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &sl, sizeof(sl))) 58 test_error("setsockopt(SO_LINGER)"); 59 close(sk); 60 } 61 62 static int test_wait_for_exception(int sk, time_t sec) 63 { 64 struct timeval tv = { .tv_sec = sec }; 65 struct timeval *ptv = NULL; 66 fd_set efds; 67 int ret; 68 69 FD_ZERO(&efds); 70 FD_SET(sk, &efds); 71 72 if (sec) 73 ptv = &tv; 74 75 errno = 0; 76 ret = select(sk + 1, NULL, NULL, &efds, ptv); 77 if (ret < 0) 78 return -errno; 79 return ret ? sk : 0; 80 } 81 82 static void test_server_active_rst(unsigned int port) 83 { 84 struct tcp_ao_counters cnt1, cnt2; 85 ssize_t bytes; 86 int sk, lsk; 87 88 lsk = test_listen_socket(this_ip_addr, port, backlog); 89 if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) 90 test_error("setsockopt(TCP_AO_ADD_KEY)"); 91 if (test_get_tcp_ao_counters(lsk, &cnt1)) 92 test_error("test_get_tcp_ao_counters()"); 93 94 synchronize_threads(); /* 1: MKT added */ 95 if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) 96 test_error("test_wait_fd()"); 97 98 sk = accept(lsk, NULL, NULL); 99 if (sk < 0) 100 test_error("accept()"); 101 102 synchronize_threads(); /* 2: connection accept()ed, another queued */ 103 if (test_get_tcp_ao_counters(lsk, &cnt2)) 104 test_error("test_get_tcp_ao_counters()"); 105 106 synchronize_threads(); /* 3: close listen socket */ 107 close(lsk); 108 bytes = test_server_run(sk, quota, 0); 109 if (bytes != quota) 110 test_error("servered only %zd bytes", bytes); 111 else 112 test_ok("servered %zd bytes", bytes); 113 114 synchronize_threads(); /* 4: finishing up */ 115 close_forced(sk); 116 117 synchronize_threads(); /* 5: closed active sk */ 118 119 synchronize_threads(); /* 6: counters checks */ 120 if (test_tcp_ao_counters_cmp("active RST server", &cnt1, &cnt2, TEST_CNT_GOOD)) 121 test_fail("MKT counters (server) have not only good packets"); 122 else 123 test_ok("MKT counters are good on server"); 124 } 125 126 static void test_server_passive_rst(unsigned int port) 127 { 128 struct tcp_ao_counters ao1, ao2; 129 int sk, lsk; 130 ssize_t bytes; 131 132 lsk = test_listen_socket(this_ip_addr, port, 1); 133 134 if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) 135 test_error("setsockopt(TCP_AO_ADD_KEY)"); 136 137 synchronize_threads(); /* 1: MKT added => connect() */ 138 if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) 139 test_error("test_wait_fd()"); 140 141 sk = accept(lsk, NULL, NULL); 142 if (sk < 0) 143 test_error("accept()"); 144 145 synchronize_threads(); /* 2: accepted => send data */ 146 close(lsk); 147 if (test_get_tcp_ao_counters(sk, &ao1)) 148 test_error("test_get_tcp_ao_counters()"); 149 150 bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC); 151 if (bytes != quota) { 152 if (bytes > 0) 153 test_fail("server served: %zd", bytes); 154 else 155 test_fail("server returned %zd", bytes); 156 } 157 158 synchronize_threads(); /* 3: chekpoint/restore the connection */ 159 if (test_get_tcp_ao_counters(sk, &ao2)) 160 test_error("test_get_tcp_ao_counters()"); 161 162 synchronize_threads(); /* 4: terminate server + send more on client */ 163 bytes = test_server_run(sk, quota, TEST_RETRANSMIT_SEC); 164 close(sk); 165 test_tcp_ao_counters_cmp("passive RST server", &ao1, &ao2, TEST_CNT_GOOD); 166 167 synchronize_threads(); /* 5: verified => closed */ 168 close(sk); 169 } 170 171 static void *server_fn(void *arg) 172 { 173 struct netstat *ns_before, *ns_after; 174 unsigned int port = test_server_port; 175 176 ns_before = netstat_read(); 177 178 test_server_active_rst(port++); 179 test_server_passive_rst(port++); 180 181 ns_after = netstat_read(); 182 netstats_check(ns_before, ns_after, "server"); 183 netstat_free(ns_after); 184 netstat_free(ns_before); 185 synchronize_threads(); /* exit */ 186 187 synchronize_threads(); /* don't race to exit() - client exits */ 188 return NULL; 189 } 190 191 static int test_wait_fds(int sk[], size_t nr, bool is_writable[], 192 ssize_t wait_for, time_t sec) 193 { 194 struct timeval tv = { .tv_sec = sec }; 195 struct timeval *ptv = NULL; 196 fd_set left; 197 size_t i; 198 int ret; 199 200 FD_ZERO(&left); 201 for (i = 0; i < nr; i++) { 202 FD_SET(sk[i], &left); 203 if (is_writable) 204 is_writable[i] = false; 205 } 206 207 if (sec) 208 ptv = &tv; 209 210 do { 211 bool is_empty = true; 212 fd_set fds, efds; 213 int nfd = 0; 214 215 FD_ZERO(&fds); 216 FD_ZERO(&efds); 217 for (i = 0; i < nr; i++) { 218 if (!FD_ISSET(sk[i], &left)) 219 continue; 220 221 if (sk[i] > nfd) 222 nfd = sk[i]; 223 224 FD_SET(sk[i], &fds); 225 FD_SET(sk[i], &efds); 226 is_empty = false; 227 } 228 if (is_empty) 229 return -ENOENT; 230 231 errno = 0; 232 ret = select(nfd + 1, NULL, &fds, &efds, ptv); 233 if (ret < 0) 234 return -errno; 235 if (!ret) 236 return -ETIMEDOUT; 237 for (i = 0; i < nr; i++) { 238 if (FD_ISSET(sk[i], &fds)) { 239 if (is_writable) 240 is_writable[i] = true; 241 FD_CLR(sk[i], &left); 242 wait_for--; 243 continue; 244 } 245 if (FD_ISSET(sk[i], &efds)) { 246 FD_CLR(sk[i], &left); 247 wait_for--; 248 } 249 } 250 } while (wait_for > 0); 251 252 return 0; 253 } 254 255 static void test_client_active_rst(unsigned int port) 256 { 257 /* one in queue, another accept()ed */ 258 unsigned int wait_for = backlog + 2; 259 int i, sk[3], err; 260 bool is_writable[ARRAY_SIZE(sk)] = {false}; 261 unsigned int last = ARRAY_SIZE(sk) - 1; 262 263 for (i = 0; i < ARRAY_SIZE(sk); i++) { 264 sk[i] = socket(test_family, SOCK_STREAM, IPPROTO_TCP); 265 if (sk[i] < 0) 266 test_error("socket()"); 267 if (test_add_key(sk[i], DEFAULT_TEST_PASSWORD, 268 this_ip_dest, -1, 100, 100)) 269 test_error("setsockopt(TCP_AO_ADD_KEY)"); 270 } 271 272 synchronize_threads(); /* 1: MKT added */ 273 for (i = 0; i < last; i++) { 274 err = _test_connect_socket(sk[i], this_ip_dest, port, 275 (i == 0) ? TEST_TIMEOUT_SEC : -1); 276 277 if (err < 0) 278 test_error("failed to connect()"); 279 } 280 281 synchronize_threads(); /* 2: connection accept()ed, another queued */ 282 err = test_wait_fds(sk, last, is_writable, wait_for, TEST_TIMEOUT_SEC); 283 if (err < 0) 284 test_error("test_wait_fds(): %d", err); 285 286 synchronize_threads(); /* 3: close listen socket */ 287 if (test_client_verify(sk[0], 100, quota / 100, TEST_TIMEOUT_SEC)) 288 test_fail("Failed to send data on connected socket"); 289 else 290 test_ok("Verified established tcp connection"); 291 292 synchronize_threads(); /* 4: finishing up */ 293 err = _test_connect_socket(sk[last], this_ip_dest, port, -1); 294 if (err < 0) 295 test_error("failed to connect()"); 296 297 synchronize_threads(); /* 5: closed active sk */ 298 err = test_wait_fds(sk, ARRAY_SIZE(sk), NULL, 299 wait_for, TEST_TIMEOUT_SEC); 300 if (err < 0) 301 test_error("select(): %d", err); 302 303 for (i = 0; i < ARRAY_SIZE(sk); i++) { 304 socklen_t slen = sizeof(err); 305 306 if (getsockopt(sk[i], SOL_SOCKET, SO_ERROR, &err, &slen)) 307 test_error("getsockopt()"); 308 if (is_writable[i] && err != ECONNRESET) { 309 test_fail("sk[%d] = %d, err = %d, connection wasn't reset", 310 i, sk[i], err); 311 } else { 312 test_ok("sk[%d] = %d%s", i, sk[i], 313 is_writable[i] ? ", connection was reset" : ""); 314 } 315 } 316 synchronize_threads(); /* 6: counters checks */ 317 } 318 319 static void test_client_passive_rst(unsigned int port) 320 { 321 struct tcp_ao_counters ao1, ao2; 322 struct tcp_ao_repair ao_img; 323 struct tcp_sock_state img; 324 sockaddr_af saddr; 325 int sk, err; 326 socklen_t slen = sizeof(err); 327 328 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); 329 if (sk < 0) 330 test_error("socket()"); 331 332 if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) 333 test_error("setsockopt(TCP_AO_ADD_KEY)"); 334 335 synchronize_threads(); /* 1: MKT added => connect() */ 336 if (test_connect_socket(sk, this_ip_dest, port) <= 0) 337 test_error("failed to connect()"); 338 339 synchronize_threads(); /* 2: accepted => send data */ 340 if (test_client_verify(sk, 100, quota / 100, TEST_TIMEOUT_SEC)) 341 test_fail("Failed to send data on connected socket"); 342 else 343 test_ok("Verified established tcp connection"); 344 345 synchronize_threads(); /* 3: chekpoint/restore the connection */ 346 test_enable_repair(sk); 347 test_sock_checkpoint(sk, &img, &saddr); 348 test_ao_checkpoint(sk, &ao_img); 349 test_kill_sk(sk); 350 351 img.out.seq += quota; 352 353 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); 354 if (sk < 0) 355 test_error("socket()"); 356 357 test_enable_repair(sk); 358 test_sock_restore(sk, &img, &saddr, this_ip_dest, port); 359 if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100)) 360 test_error("setsockopt(TCP_AO_ADD_KEY)"); 361 test_ao_restore(sk, &ao_img); 362 363 if (test_get_tcp_ao_counters(sk, &ao1)) 364 test_error("test_get_tcp_ao_counters()"); 365 366 test_disable_repair(sk); 367 test_sock_state_free(&img); 368 369 synchronize_threads(); /* 4: terminate server + send more on client */ 370 if (test_client_verify(sk, 100, quota / 100, 2 * TEST_TIMEOUT_SEC)) 371 test_ok("client connection broken post-seq-adjust"); 372 else 373 test_fail("client connection still works post-seq-adjust"); 374 375 test_wait_for_exception(sk, TEST_TIMEOUT_SEC); 376 377 if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &err, &slen)) 378 test_error("getsockopt()"); 379 if (err != ECONNRESET && err != EPIPE) 380 test_fail("client connection was not reset: %d", err); 381 else 382 test_ok("client connection was reset"); 383 384 if (test_get_tcp_ao_counters(sk, &ao2)) 385 test_error("test_get_tcp_ao_counters()"); 386 387 synchronize_threads(); /* 5: verified => closed */ 388 close(sk); 389 test_tcp_ao_counters_cmp("client passive RST", &ao1, &ao2, TEST_CNT_GOOD); 390 } 391 392 static void *client_fn(void *arg) 393 { 394 struct netstat *ns_before, *ns_after; 395 unsigned int port = test_server_port; 396 397 ns_before = netstat_read(); 398 399 test_client_active_rst(port++); 400 test_client_passive_rst(port++); 401 402 ns_after = netstat_read(); 403 netstats_check(ns_before, ns_after, "client"); 404 netstat_free(ns_after); 405 netstat_free(ns_before); 406 407 synchronize_threads(); /* exit */ 408 return NULL; 409 } 410 411 int main(int argc, char *argv[]) 412 { 413 test_init(15, server_fn, client_fn); 414 return 0; 415 } 416