1 // SPDX-License-Identifier: GPL-2.0 2 #include <alloca.h> 3 #include <fcntl.h> 4 #include <inttypes.h> 5 #include <string.h> 6 #include "../../../../../include/linux/kernel.h" 7 #include "../../../../../include/linux/stringify.h" 8 #include "aolib.h" 9 10 const unsigned int test_server_port = 7010; 11 int __test_listen_socket(int backlog, void *addr, size_t addr_sz) 12 { 13 int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); 14 long flags; 15 16 if (sk < 0) 17 test_error("socket()"); 18 19 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name, 20 strlen(veth_name) + 1); 21 if (err < 0) 22 test_error("setsockopt(SO_BINDTODEVICE)"); 23 24 if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0) 25 test_error("bind()"); 26 27 flags = fcntl(sk, F_GETFL); 28 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) 29 test_error("fcntl()"); 30 31 if (listen(sk, backlog)) 32 test_error("listen()"); 33 34 return sk; 35 } 36 37 int test_wait_fd(int sk, time_t sec, bool write) 38 { 39 struct timeval tv = { .tv_sec = sec }; 40 struct timeval *ptv = NULL; 41 fd_set fds, efds; 42 int ret; 43 socklen_t slen = sizeof(ret); 44 45 FD_ZERO(&fds); 46 FD_SET(sk, &fds); 47 FD_ZERO(&efds); 48 FD_SET(sk, &efds); 49 50 if (sec) 51 ptv = &tv; 52 53 errno = 0; 54 if (write) 55 ret = select(sk + 1, NULL, &fds, &efds, ptv); 56 else 57 ret = select(sk + 1, &fds, NULL, &efds, ptv); 58 if (ret < 0) 59 return -errno; 60 if (ret == 0) { 61 errno = ETIMEDOUT; 62 return -ETIMEDOUT; 63 } 64 65 if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen)) 66 return -errno; 67 if (ret) 68 return -ret; 69 return 0; 70 } 71 72 int __test_connect_socket(int sk, const char *device, 73 void *addr, size_t addr_sz, time_t timeout) 74 { 75 long flags; 76 int err; 77 78 if (device != NULL) { 79 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device, 80 strlen(device) + 1); 81 if (err < 0) 82 test_error("setsockopt(SO_BINDTODEVICE, %s)", device); 83 } 84 85 if (!timeout) { 86 err = connect(sk, addr, addr_sz); 87 if (err) { 88 err = -errno; 89 goto out; 90 } 91 return 0; 92 } 93 94 flags = fcntl(sk, F_GETFL); 95 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) 96 test_error("fcntl()"); 97 98 if (connect(sk, addr, addr_sz) < 0) { 99 if (errno != EINPROGRESS) { 100 err = -errno; 101 goto out; 102 } 103 if (timeout < 0) 104 return sk; 105 err = test_wait_fd(sk, timeout, 1); 106 if (err) 107 goto out; 108 } 109 return sk; 110 111 out: 112 close(sk); 113 return err; 114 } 115 116 int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix, 117 int vrf, const char *password) 118 { 119 size_t pwd_len = strlen(password); 120 struct tcp_md5sig md5sig = {}; 121 122 md5sig.tcpm_keylen = pwd_len; 123 memcpy(md5sig.tcpm_key, password, pwd_len); 124 md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX; 125 md5sig.tcpm_prefixlen = prefix; 126 if (vrf >= 0) { 127 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX; 128 md5sig.tcpm_ifindex = (uint8_t)vrf; 129 } 130 memcpy(&md5sig.tcpm_addr, addr, addr_sz); 131 132 errno = 0; 133 return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT, 134 &md5sig, sizeof(md5sig)); 135 } 136 137 138 int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg, 139 void *addr, size_t addr_sz, bool set_current, bool set_rnext, 140 uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid, 141 uint8_t maclen, uint8_t keyflags, 142 uint8_t keylen, const char *key) 143 { 144 memset(ao, 0, sizeof(struct tcp_ao_add)); 145 146 ao->set_current = !!set_current; 147 ao->set_rnext = !!set_rnext; 148 ao->prefix = prefix; 149 ao->sndid = sndid; 150 ao->rcvid = rcvid; 151 ao->maclen = maclen; 152 ao->keyflags = keyflags; 153 ao->keylen = keylen; 154 ao->ifindex = vrf; 155 156 memcpy(&ao->addr, addr, addr_sz); 157 158 if (strlen(alg) > 64) 159 return -ENOBUFS; 160 strncpy(ao->alg_name, alg, 64); 161 162 memcpy(ao->key, key, 163 (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen); 164 return 0; 165 } 166 167 static int test_get_ao_keys_nr(int sk) 168 { 169 struct tcp_ao_getsockopt tmp = {}; 170 socklen_t tmp_sz = sizeof(tmp); 171 int ret; 172 173 tmp.nkeys = 1; 174 tmp.get_all = 1; 175 176 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); 177 if (ret) 178 return -errno; 179 return (int)tmp.nkeys; 180 } 181 182 int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out, 183 void *addr, size_t addr_sz, uint8_t prefix, 184 uint8_t sndid, uint8_t rcvid) 185 { 186 struct tcp_ao_getsockopt tmp = {}; 187 socklen_t tmp_sz = sizeof(tmp); 188 int ret; 189 190 memcpy(&tmp.addr, addr, addr_sz); 191 tmp.prefix = prefix; 192 tmp.sndid = sndid; 193 tmp.rcvid = rcvid; 194 tmp.nkeys = 1; 195 196 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); 197 if (ret) 198 return ret; 199 if (tmp.nkeys != 1) 200 return -E2BIG; 201 *out = tmp; 202 return 0; 203 } 204 205 int test_get_ao_info(int sk, struct tcp_ao_info_opt *out) 206 { 207 socklen_t sz = sizeof(*out); 208 209 out->reserved = 0; 210 out->reserved2 = 0; 211 if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz)) 212 return -errno; 213 if (sz != sizeof(*out)) 214 return -EMSGSIZE; 215 return 0; 216 } 217 218 int test_set_ao_info(int sk, struct tcp_ao_info_opt *in) 219 { 220 socklen_t sz = sizeof(*in); 221 222 in->reserved = 0; 223 in->reserved2 = 0; 224 if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz)) 225 return -errno; 226 return 0; 227 } 228 229 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a, 230 const struct tcp_ao_getsockopt *b) 231 { 232 bool is_kdf_aes_128_cmac = false; 233 bool is_cmac_aes = false; 234 235 if (!strcmp("cmac(aes128)", a->alg_name)) { 236 is_kdf_aes_128_cmac = (a->keylen != 16); 237 is_cmac_aes = true; 238 } 239 240 #define __cmp_ao(member) \ 241 do { \ 242 if (b->member != a->member) { \ 243 test_fail("getsockopt(): " __stringify(member) " %u != %u", \ 244 b->member, a->member); \ 245 return -1; \ 246 } \ 247 } while(0) 248 __cmp_ao(sndid); 249 __cmp_ao(rcvid); 250 __cmp_ao(prefix); 251 __cmp_ao(keyflags); 252 __cmp_ao(ifindex); 253 if (a->maclen) { 254 __cmp_ao(maclen); 255 } else if (b->maclen != 12) { 256 test_fail("getsockopt(): expected default maclen 12, but it's %u", 257 b->maclen); 258 return -1; 259 } 260 if (!is_kdf_aes_128_cmac) { 261 __cmp_ao(keylen); 262 } else if (b->keylen != 16) { 263 test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u", 264 b->keylen); 265 return -1; 266 } 267 #undef __cmp_ao 268 if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) { 269 test_fail("getsockopt(): returned key is different `%s' != `%s'", 270 b->key, a->key); 271 return -1; 272 } 273 if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) { 274 test_fail("getsockopt(): returned address is different"); 275 return -1; 276 } 277 if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) { 278 test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name); 279 return -1; 280 } 281 if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) { 282 test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name); 283 return -1; 284 } 285 /* For a established key rotation test don't add a key with 286 * set_current = 1, as it's likely to change by peer's request; 287 * rather use setsockopt(TCP_AO_INFO) 288 */ 289 if (a->set_current != b->is_current) { 290 test_fail("getsockopt(): returned key is not Current_key"); 291 return -1; 292 } 293 if (a->set_rnext != b->is_rnext) { 294 test_fail("getsockopt(): returned key is not RNext_key"); 295 return -1; 296 } 297 298 return 0; 299 } 300 301 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a, 302 const struct tcp_ao_info_opt *b) 303 { 304 /* No check for ::current_key, as it may change by the peer */ 305 if (a->ao_required != b->ao_required) { 306 test_fail("getsockopt(): returned ao doesn't have ao_required"); 307 return -1; 308 } 309 if (a->accept_icmps != b->accept_icmps) { 310 test_fail("getsockopt(): returned ao doesn't accept ICMPs"); 311 return -1; 312 } 313 if (a->set_rnext && a->rnext != b->rnext) { 314 test_fail("getsockopt(): RNext KeyID has changed"); 315 return -1; 316 } 317 #define __cmp_cnt(member) \ 318 do { \ 319 if (b->member != a->member) { \ 320 test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \ 321 b->member, a->member); \ 322 return -1; \ 323 } \ 324 } while(0) 325 if (a->set_counters) { 326 __cmp_cnt(pkt_good); 327 __cmp_cnt(pkt_bad); 328 __cmp_cnt(pkt_key_not_found); 329 __cmp_cnt(pkt_ao_required); 330 __cmp_cnt(pkt_dropped_icmp); 331 } 332 #undef __cmp_cnt 333 return 0; 334 } 335 336 int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out) 337 { 338 struct tcp_ao_getsockopt *key_dump; 339 socklen_t key_dump_sz = sizeof(*key_dump); 340 struct tcp_ao_info_opt info = {}; 341 bool c1, c2, c3, c4, c5; 342 struct netstat *ns; 343 int err, nr_keys; 344 345 memset(out, 0, sizeof(*out)); 346 347 /* per-netns */ 348 ns = netstat_read(); 349 out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1); 350 out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2); 351 out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3); 352 out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4); 353 out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5); 354 netstat_free(ns); 355 if (c1 || c2 || c3 || c4 || c5) 356 return -EOPNOTSUPP; 357 358 err = test_get_ao_info(sk, &info); 359 if (err) 360 return err; 361 362 /* per-socket */ 363 out->ao_info_pkt_good = info.pkt_good; 364 out->ao_info_pkt_bad = info.pkt_bad; 365 out->ao_info_pkt_key_not_found = info.pkt_key_not_found; 366 out->ao_info_pkt_ao_required = info.pkt_ao_required; 367 out->ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp; 368 369 /* per-key */ 370 nr_keys = test_get_ao_keys_nr(sk); 371 if (nr_keys < 0) 372 return nr_keys; 373 if (nr_keys == 0) 374 test_error("test_get_ao_keys_nr() == 0"); 375 out->nr_keys = (size_t)nr_keys; 376 key_dump = calloc(nr_keys, key_dump_sz); 377 if (!key_dump) 378 return -errno; 379 380 key_dump[0].nkeys = nr_keys; 381 key_dump[0].get_all = 1; 382 key_dump[0].get_all = 1; 383 err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, 384 key_dump, &key_dump_sz); 385 if (err) { 386 free(key_dump); 387 return -errno; 388 } 389 390 out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0])); 391 if (!out->key_cnts) { 392 free(key_dump); 393 return -errno; 394 } 395 396 while (nr_keys--) { 397 out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid; 398 out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid; 399 out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good; 400 out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad; 401 } 402 free(key_dump); 403 404 return 0; 405 } 406 407 int __test_tcp_ao_counters_cmp(const char *tst_name, 408 struct tcp_ao_counters *before, 409 struct tcp_ao_counters *after, 410 test_cnt expected) 411 { 412 #define __cmp_ao(cnt, expecting_inc) \ 413 do { \ 414 if (before->cnt > after->cnt) { \ 415 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \ 416 tst_name ?: "", before->cnt, after->cnt); \ 417 return -1; \ 418 } \ 419 if ((before->cnt != after->cnt) != (expecting_inc)) { \ 420 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \ 421 tst_name ?: "", (expecting_inc) ? "" : "not ", \ 422 before->cnt, after->cnt); \ 423 return -1; \ 424 } \ 425 } while(0) 426 427 errno = 0; 428 /* per-netns */ 429 __cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD)); 430 __cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD)); 431 __cmp_ao(netns_ao_key_not_found, 432 !!(expected & TEST_CNT_NS_KEY_NOT_FOUND)); 433 __cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED)); 434 __cmp_ao(netns_ao_dropped_icmp, 435 !!(expected & TEST_CNT_NS_DROPPED_ICMP)); 436 /* per-socket */ 437 __cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD)); 438 __cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD)); 439 __cmp_ao(ao_info_pkt_key_not_found, 440 !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND)); 441 __cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED)); 442 __cmp_ao(ao_info_pkt_dropped_icmp, 443 !!(expected & TEST_CNT_SOCK_DROPPED_ICMP)); 444 return 0; 445 #undef __cmp_ao 446 } 447 448 int test_tcp_ao_key_counters_cmp(const char *tst_name, 449 struct tcp_ao_counters *before, 450 struct tcp_ao_counters *after, 451 test_cnt expected, 452 int sndid, int rcvid) 453 { 454 size_t i; 455 #define __cmp_ao(i, cnt, expecting_inc) \ 456 do { \ 457 if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \ 458 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \ 459 tst_name ?: "", before->key_cnts[i].cnt, \ 460 after->key_cnts[i].cnt, \ 461 before->key_cnts[i].sndid, \ 462 before->key_cnts[i].rcvid); \ 463 return -1; \ 464 } \ 465 if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) { \ 466 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \ 467 tst_name ?: "", (expecting_inc) ? "" : "not ",\ 468 before->key_cnts[i].cnt, \ 469 after->key_cnts[i].cnt, \ 470 before->key_cnts[i].sndid, \ 471 before->key_cnts[i].rcvid); \ 472 return -1; \ 473 } \ 474 } while(0) 475 476 if (before->nr_keys != after->nr_keys) { 477 test_fail("%s: Keys changed on the socket %zu != %zu", 478 tst_name, before->nr_keys, after->nr_keys); 479 return -1; 480 } 481 482 /* per-key */ 483 i = before->nr_keys; 484 while (i--) { 485 if (sndid >= 0 && before->key_cnts[i].sndid != sndid) 486 continue; 487 if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid) 488 continue; 489 __cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD)); 490 __cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD)); 491 } 492 return 0; 493 #undef __cmp_ao 494 } 495 496 void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts) 497 { 498 free(cnts->key_cnts); 499 } 500 501 #define TEST_BUF_SIZE 4096 502 ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec) 503 { 504 ssize_t total = 0; 505 506 do { 507 char buf[TEST_BUF_SIZE]; 508 ssize_t bytes, sent; 509 int ret; 510 511 ret = test_wait_fd(sk, timeout_sec, 0); 512 if (ret) 513 return ret; 514 515 bytes = recv(sk, buf, sizeof(buf), 0); 516 517 if (bytes < 0) 518 test_error("recv(): %zd", bytes); 519 if (bytes == 0) 520 break; 521 522 ret = test_wait_fd(sk, timeout_sec, 1); 523 if (ret) 524 return ret; 525 526 sent = send(sk, buf, bytes, 0); 527 if (sent == 0) 528 break; 529 if (sent != bytes) 530 test_error("send()"); 531 total += bytes; 532 } while (!quota || total < quota); 533 534 return total; 535 } 536 537 ssize_t test_client_loop(int sk, char *buf, size_t buf_sz, 538 const size_t msg_len, time_t timeout_sec) 539 { 540 char msg[msg_len]; 541 int nodelay = 1; 542 size_t i; 543 544 if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay))) 545 test_error("setsockopt(TCP_NODELAY)"); 546 547 for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) { 548 size_t sent, bytes = min(msg_len, buf_sz - i); 549 int ret; 550 551 ret = test_wait_fd(sk, timeout_sec, 1); 552 if (ret) 553 return ret; 554 555 sent = send(sk, buf + i, bytes, 0); 556 if (sent == 0) 557 break; 558 if (sent != bytes) 559 test_error("send()"); 560 561 bytes = 0; 562 do { 563 ssize_t got; 564 565 ret = test_wait_fd(sk, timeout_sec, 0); 566 if (ret) 567 return ret; 568 569 got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0); 570 if (got <= 0) 571 return i; 572 bytes += got; 573 } while (bytes < sent); 574 if (bytes > sent) 575 test_error("recv(): %zd > %zd", bytes, sent); 576 if (memcmp(buf + i, msg, bytes) != 0) { 577 test_fail("received message differs"); 578 return -1; 579 } 580 } 581 return i; 582 } 583 584 int test_client_verify(int sk, const size_t msg_len, const size_t nr, 585 time_t timeout_sec) 586 { 587 size_t buf_sz = msg_len * nr; 588 char *buf = alloca(buf_sz); 589 ssize_t ret; 590 591 randomize_buffer(buf, buf_sz); 592 ret = test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec); 593 if (ret < 0) 594 return (int)ret; 595 return ret != buf_sz ? -1 : 0; 596 } 597