1 // SPDX-License-Identifier: GPL-2.0 2 #include <alloca.h> 3 #include <fcntl.h> 4 #include <inttypes.h> 5 #include <string.h> 6 #include "../../../../../include/linux/kernel.h" 7 #include "../../../../../include/linux/stringify.h" 8 #include "aolib.h" 9 10 const unsigned int test_server_port = 7010; 11 int __test_listen_socket(int backlog, void *addr, size_t addr_sz) 12 { 13 int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); 14 long flags; 15 16 if (sk < 0) 17 test_error("socket()"); 18 19 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name, 20 strlen(veth_name) + 1); 21 if (err < 0) 22 test_error("setsockopt(SO_BINDTODEVICE)"); 23 24 if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0) 25 test_error("bind()"); 26 27 flags = fcntl(sk, F_GETFL); 28 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) 29 test_error("fcntl()"); 30 31 if (listen(sk, backlog)) 32 test_error("listen()"); 33 34 return sk; 35 } 36 37 int test_wait_fd(int sk, time_t sec, bool write) 38 { 39 struct timeval tv = { .tv_sec = sec }; 40 struct timeval *ptv = NULL; 41 fd_set fds, efds; 42 int ret; 43 socklen_t slen = sizeof(ret); 44 45 FD_ZERO(&fds); 46 FD_SET(sk, &fds); 47 FD_ZERO(&efds); 48 FD_SET(sk, &efds); 49 50 if (sec) 51 ptv = &tv; 52 53 errno = 0; 54 if (write) 55 ret = select(sk + 1, NULL, &fds, &efds, ptv); 56 else 57 ret = select(sk + 1, &fds, NULL, &efds, ptv); 58 if (ret < 0) 59 return -errno; 60 if (ret == 0) { 61 errno = ETIMEDOUT; 62 return -ETIMEDOUT; 63 } 64 65 if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen) || ret) 66 return -ret; 67 return 0; 68 } 69 70 int __test_connect_socket(int sk, const char *device, 71 void *addr, size_t addr_sz, time_t timeout) 72 { 73 long flags; 74 int err; 75 76 if (device != NULL) { 77 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device, 78 strlen(device) + 1); 79 if (err < 0) 80 test_error("setsockopt(SO_BINDTODEVICE, %s)", device); 81 } 82 83 if (!timeout) { 84 err = connect(sk, addr, addr_sz); 85 if (err) { 86 err = -errno; 87 goto out; 88 } 89 return 0; 90 } 91 92 flags = fcntl(sk, F_GETFL); 93 if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) 94 test_error("fcntl()"); 95 96 if (connect(sk, addr, addr_sz) < 0) { 97 if (errno != EINPROGRESS) { 98 err = -errno; 99 goto out; 100 } 101 if (timeout < 0) 102 return sk; 103 err = test_wait_fd(sk, timeout, 1); 104 if (err) 105 goto out; 106 } 107 return sk; 108 109 out: 110 close(sk); 111 return err; 112 } 113 114 int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix, 115 int vrf, const char *password) 116 { 117 size_t pwd_len = strlen(password); 118 struct tcp_md5sig md5sig = {}; 119 120 md5sig.tcpm_keylen = pwd_len; 121 memcpy(md5sig.tcpm_key, password, pwd_len); 122 md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX; 123 md5sig.tcpm_prefixlen = prefix; 124 if (vrf >= 0) { 125 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX; 126 md5sig.tcpm_ifindex = (uint8_t)vrf; 127 } 128 memcpy(&md5sig.tcpm_addr, addr, addr_sz); 129 130 errno = 0; 131 return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT, 132 &md5sig, sizeof(md5sig)); 133 } 134 135 136 int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg, 137 void *addr, size_t addr_sz, bool set_current, bool set_rnext, 138 uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid, 139 uint8_t maclen, uint8_t keyflags, 140 uint8_t keylen, const char *key) 141 { 142 memset(ao, 0, sizeof(struct tcp_ao_add)); 143 144 ao->set_current = !!set_current; 145 ao->set_rnext = !!set_rnext; 146 ao->prefix = prefix; 147 ao->sndid = sndid; 148 ao->rcvid = rcvid; 149 ao->maclen = maclen; 150 ao->keyflags = keyflags; 151 ao->keylen = keylen; 152 ao->ifindex = vrf; 153 154 memcpy(&ao->addr, addr, addr_sz); 155 156 if (strlen(alg) > 64) 157 return -ENOBUFS; 158 strncpy(ao->alg_name, alg, 64); 159 160 memcpy(ao->key, key, 161 (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen); 162 return 0; 163 } 164 165 static int test_get_ao_keys_nr(int sk) 166 { 167 struct tcp_ao_getsockopt tmp = {}; 168 socklen_t tmp_sz = sizeof(tmp); 169 int ret; 170 171 tmp.nkeys = 1; 172 tmp.get_all = 1; 173 174 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); 175 if (ret) 176 return -errno; 177 return (int)tmp.nkeys; 178 } 179 180 int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out, 181 void *addr, size_t addr_sz, uint8_t prefix, 182 uint8_t sndid, uint8_t rcvid) 183 { 184 struct tcp_ao_getsockopt tmp = {}; 185 socklen_t tmp_sz = sizeof(tmp); 186 int ret; 187 188 memcpy(&tmp.addr, addr, addr_sz); 189 tmp.prefix = prefix; 190 tmp.sndid = sndid; 191 tmp.rcvid = rcvid; 192 tmp.nkeys = 1; 193 194 ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); 195 if (ret) 196 return ret; 197 if (tmp.nkeys != 1) 198 return -E2BIG; 199 *out = tmp; 200 return 0; 201 } 202 203 int test_get_ao_info(int sk, struct tcp_ao_info_opt *out) 204 { 205 socklen_t sz = sizeof(*out); 206 207 out->reserved = 0; 208 out->reserved2 = 0; 209 if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz)) 210 return -errno; 211 if (sz != sizeof(*out)) 212 return -EMSGSIZE; 213 return 0; 214 } 215 216 int test_set_ao_info(int sk, struct tcp_ao_info_opt *in) 217 { 218 socklen_t sz = sizeof(*in); 219 220 in->reserved = 0; 221 in->reserved2 = 0; 222 if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz)) 223 return -errno; 224 return 0; 225 } 226 227 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a, 228 const struct tcp_ao_getsockopt *b) 229 { 230 bool is_kdf_aes_128_cmac = false; 231 bool is_cmac_aes = false; 232 233 if (!strcmp("cmac(aes128)", a->alg_name)) { 234 is_kdf_aes_128_cmac = (a->keylen != 16); 235 is_cmac_aes = true; 236 } 237 238 #define __cmp_ao(member) \ 239 do { \ 240 if (b->member != a->member) { \ 241 test_fail("getsockopt(): " __stringify(member) " %u != %u", \ 242 b->member, a->member); \ 243 return -1; \ 244 } \ 245 } while(0) 246 __cmp_ao(sndid); 247 __cmp_ao(rcvid); 248 __cmp_ao(prefix); 249 __cmp_ao(keyflags); 250 __cmp_ao(ifindex); 251 if (a->maclen) { 252 __cmp_ao(maclen); 253 } else if (b->maclen != 12) { 254 test_fail("getsockopt(): expected default maclen 12, but it's %u", 255 b->maclen); 256 return -1; 257 } 258 if (!is_kdf_aes_128_cmac) { 259 __cmp_ao(keylen); 260 } else if (b->keylen != 16) { 261 test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u", 262 b->keylen); 263 return -1; 264 } 265 #undef __cmp_ao 266 if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) { 267 test_fail("getsockopt(): returned key is different `%s' != `%s'", 268 b->key, a->key); 269 return -1; 270 } 271 if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) { 272 test_fail("getsockopt(): returned address is different"); 273 return -1; 274 } 275 if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) { 276 test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name); 277 return -1; 278 } 279 if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) { 280 test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name); 281 return -1; 282 } 283 /* For a established key rotation test don't add a key with 284 * set_current = 1, as it's likely to change by peer's request; 285 * rather use setsockopt(TCP_AO_INFO) 286 */ 287 if (a->set_current != b->is_current) { 288 test_fail("getsockopt(): returned key is not Current_key"); 289 return -1; 290 } 291 if (a->set_rnext != b->is_rnext) { 292 test_fail("getsockopt(): returned key is not RNext_key"); 293 return -1; 294 } 295 296 return 0; 297 } 298 299 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a, 300 const struct tcp_ao_info_opt *b) 301 { 302 /* No check for ::current_key, as it may change by the peer */ 303 if (a->ao_required != b->ao_required) { 304 test_fail("getsockopt(): returned ao doesn't have ao_required"); 305 return -1; 306 } 307 if (a->accept_icmps != b->accept_icmps) { 308 test_fail("getsockopt(): returned ao doesn't accept ICMPs"); 309 return -1; 310 } 311 if (a->set_rnext && a->rnext != b->rnext) { 312 test_fail("getsockopt(): RNext KeyID has changed"); 313 return -1; 314 } 315 #define __cmp_cnt(member) \ 316 do { \ 317 if (b->member != a->member) { \ 318 test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \ 319 b->member, a->member); \ 320 return -1; \ 321 } \ 322 } while(0) 323 if (a->set_counters) { 324 __cmp_cnt(pkt_good); 325 __cmp_cnt(pkt_bad); 326 __cmp_cnt(pkt_key_not_found); 327 __cmp_cnt(pkt_ao_required); 328 __cmp_cnt(pkt_dropped_icmp); 329 } 330 #undef __cmp_cnt 331 return 0; 332 } 333 334 int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out) 335 { 336 struct tcp_ao_getsockopt *key_dump; 337 socklen_t key_dump_sz = sizeof(*key_dump); 338 struct tcp_ao_info_opt info = {}; 339 bool c1, c2, c3, c4, c5; 340 struct netstat *ns; 341 int err, nr_keys; 342 343 memset(out, 0, sizeof(*out)); 344 345 /* per-netns */ 346 ns = netstat_read(); 347 out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1); 348 out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2); 349 out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3); 350 out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4); 351 out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5); 352 netstat_free(ns); 353 if (c1 || c2 || c3 || c4 || c5) 354 return -EOPNOTSUPP; 355 356 err = test_get_ao_info(sk, &info); 357 if (err) 358 return err; 359 360 /* per-socket */ 361 out->ao_info_pkt_good = info.pkt_good; 362 out->ao_info_pkt_bad = info.pkt_bad; 363 out->ao_info_pkt_key_not_found = info.pkt_key_not_found; 364 out->ao_info_pkt_ao_required = info.pkt_ao_required; 365 out->ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp; 366 367 /* per-key */ 368 nr_keys = test_get_ao_keys_nr(sk); 369 if (nr_keys < 0) 370 return nr_keys; 371 if (nr_keys == 0) 372 test_error("test_get_ao_keys_nr() == 0"); 373 out->nr_keys = (size_t)nr_keys; 374 key_dump = calloc(nr_keys, key_dump_sz); 375 if (!key_dump) 376 return -errno; 377 378 key_dump[0].nkeys = nr_keys; 379 key_dump[0].get_all = 1; 380 key_dump[0].get_all = 1; 381 err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, 382 key_dump, &key_dump_sz); 383 if (err) { 384 free(key_dump); 385 return -errno; 386 } 387 388 out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0])); 389 if (!out->key_cnts) { 390 free(key_dump); 391 return -errno; 392 } 393 394 while (nr_keys--) { 395 out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid; 396 out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid; 397 out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good; 398 out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad; 399 } 400 free(key_dump); 401 402 return 0; 403 } 404 405 int __test_tcp_ao_counters_cmp(const char *tst_name, 406 struct tcp_ao_counters *before, 407 struct tcp_ao_counters *after, 408 test_cnt expected) 409 { 410 #define __cmp_ao(cnt, expecting_inc) \ 411 do { \ 412 if (before->cnt > after->cnt) { \ 413 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \ 414 tst_name ?: "", before->cnt, after->cnt); \ 415 return -1; \ 416 } \ 417 if ((before->cnt != after->cnt) != (expecting_inc)) { \ 418 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \ 419 tst_name ?: "", (expecting_inc) ? "" : "not ", \ 420 before->cnt, after->cnt); \ 421 return -1; \ 422 } \ 423 } while(0) 424 425 errno = 0; 426 /* per-netns */ 427 __cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD)); 428 __cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD)); 429 __cmp_ao(netns_ao_key_not_found, 430 !!(expected & TEST_CNT_NS_KEY_NOT_FOUND)); 431 __cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED)); 432 __cmp_ao(netns_ao_dropped_icmp, 433 !!(expected & TEST_CNT_NS_DROPPED_ICMP)); 434 /* per-socket */ 435 __cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD)); 436 __cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD)); 437 __cmp_ao(ao_info_pkt_key_not_found, 438 !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND)); 439 __cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED)); 440 __cmp_ao(ao_info_pkt_dropped_icmp, 441 !!(expected & TEST_CNT_SOCK_DROPPED_ICMP)); 442 return 0; 443 #undef __cmp_ao 444 } 445 446 int test_tcp_ao_key_counters_cmp(const char *tst_name, 447 struct tcp_ao_counters *before, 448 struct tcp_ao_counters *after, 449 test_cnt expected, 450 int sndid, int rcvid) 451 { 452 size_t i; 453 #define __cmp_ao(i, cnt, expecting_inc) \ 454 do { \ 455 if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \ 456 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \ 457 tst_name ?: "", before->key_cnts[i].cnt, \ 458 after->key_cnts[i].cnt, \ 459 before->key_cnts[i].sndid, \ 460 before->key_cnts[i].rcvid); \ 461 return -1; \ 462 } \ 463 if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) { \ 464 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \ 465 tst_name ?: "", (expecting_inc) ? "" : "not ",\ 466 before->key_cnts[i].cnt, \ 467 after->key_cnts[i].cnt, \ 468 before->key_cnts[i].sndid, \ 469 before->key_cnts[i].rcvid); \ 470 return -1; \ 471 } \ 472 } while(0) 473 474 if (before->nr_keys != after->nr_keys) { 475 test_fail("%s: Keys changed on the socket %zu != %zu", 476 tst_name, before->nr_keys, after->nr_keys); 477 return -1; 478 } 479 480 /* per-key */ 481 i = before->nr_keys; 482 while (i--) { 483 if (sndid >= 0 && before->key_cnts[i].sndid != sndid) 484 continue; 485 if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid) 486 continue; 487 __cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD)); 488 __cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD)); 489 } 490 return 0; 491 #undef __cmp_ao 492 } 493 494 void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts) 495 { 496 free(cnts->key_cnts); 497 } 498 499 #define TEST_BUF_SIZE 4096 500 ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec) 501 { 502 ssize_t total = 0; 503 504 do { 505 char buf[TEST_BUF_SIZE]; 506 ssize_t bytes, sent; 507 int ret; 508 509 ret = test_wait_fd(sk, timeout_sec, 0); 510 if (ret) 511 return ret; 512 513 bytes = recv(sk, buf, sizeof(buf), 0); 514 515 if (bytes < 0) 516 test_error("recv(): %zd", bytes); 517 if (bytes == 0) 518 break; 519 520 ret = test_wait_fd(sk, timeout_sec, 1); 521 if (ret) 522 return ret; 523 524 sent = send(sk, buf, bytes, 0); 525 if (sent == 0) 526 break; 527 if (sent != bytes) 528 test_error("send()"); 529 total += bytes; 530 } while (!quota || total < quota); 531 532 return total; 533 } 534 535 ssize_t test_client_loop(int sk, char *buf, size_t buf_sz, 536 const size_t msg_len, time_t timeout_sec) 537 { 538 char msg[msg_len]; 539 int nodelay = 1; 540 size_t i; 541 542 if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay))) 543 test_error("setsockopt(TCP_NODELAY)"); 544 545 for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) { 546 size_t sent, bytes = min(msg_len, buf_sz - i); 547 int ret; 548 549 ret = test_wait_fd(sk, timeout_sec, 1); 550 if (ret) 551 return ret; 552 553 sent = send(sk, buf + i, bytes, 0); 554 if (sent == 0) 555 break; 556 if (sent != bytes) 557 test_error("send()"); 558 559 bytes = 0; 560 do { 561 ssize_t got; 562 563 ret = test_wait_fd(sk, timeout_sec, 0); 564 if (ret) 565 return ret; 566 567 got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0); 568 if (got <= 0) 569 return i; 570 bytes += got; 571 } while (bytes < sent); 572 if (bytes > sent) 573 test_error("recv(): %zd > %zd", bytes, sent); 574 if (memcmp(buf + i, msg, bytes) != 0) { 575 test_fail("received message differs"); 576 return -1; 577 } 578 } 579 return i; 580 } 581 582 int test_client_verify(int sk, const size_t msg_len, const size_t nr, 583 time_t timeout_sec) 584 { 585 size_t buf_sz = msg_len * nr; 586 char *buf = alloca(buf_sz); 587 588 randomize_buffer(buf, buf_sz); 589 if (test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec) != buf_sz) 590 return -1; 591 return 0; 592 } 593