1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * vsock test utilities 4 * 5 * Copyright (C) 2017 Red Hat, Inc. 6 * 7 * Author: Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 10 #include <errno.h> 11 #include <stdio.h> 12 #include <stdint.h> 13 #include <stdlib.h> 14 #include <string.h> 15 #include <signal.h> 16 #include <unistd.h> 17 #include <assert.h> 18 #include <sys/epoll.h> 19 #include <sys/mman.h> 20 21 #include "timeout.h" 22 #include "control.h" 23 #include "util.h" 24 25 /* Install signal handlers */ 26 void init_signals(void) 27 { 28 struct sigaction act = { 29 .sa_handler = sigalrm, 30 }; 31 32 sigaction(SIGALRM, &act, NULL); 33 signal(SIGPIPE, SIG_IGN); 34 } 35 36 static unsigned int parse_uint(const char *str, const char *err_str) 37 { 38 char *endptr = NULL; 39 unsigned long n; 40 41 errno = 0; 42 n = strtoul(str, &endptr, 10); 43 if (errno || *endptr != '\0') { 44 fprintf(stderr, "malformed %s \"%s\"\n", err_str, str); 45 exit(EXIT_FAILURE); 46 } 47 return n; 48 } 49 50 /* Parse a CID in string representation */ 51 unsigned int parse_cid(const char *str) 52 { 53 return parse_uint(str, "CID"); 54 } 55 56 /* Parse a port in string representation */ 57 unsigned int parse_port(const char *str) 58 { 59 return parse_uint(str, "port"); 60 } 61 62 /* Wait for the remote to close the connection */ 63 void vsock_wait_remote_close(int fd) 64 { 65 struct epoll_event ev; 66 int epollfd, nfds; 67 68 epollfd = epoll_create1(0); 69 if (epollfd == -1) { 70 perror("epoll_create1"); 71 exit(EXIT_FAILURE); 72 } 73 74 ev.events = EPOLLRDHUP | EPOLLHUP; 75 ev.data.fd = fd; 76 if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) { 77 perror("epoll_ctl"); 78 exit(EXIT_FAILURE); 79 } 80 81 nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000); 82 if (nfds == -1) { 83 perror("epoll_wait"); 84 exit(EXIT_FAILURE); 85 } 86 87 if (nfds == 0) { 88 fprintf(stderr, "epoll_wait timed out\n"); 89 exit(EXIT_FAILURE); 90 } 91 92 assert(nfds == 1); 93 assert(ev.events & (EPOLLRDHUP | EPOLLHUP)); 94 assert(ev.data.fd == fd); 95 96 close(epollfd); 97 } 98 99 /* Create socket <type>, bind to <cid, port> and return the file descriptor. */ 100 int vsock_bind(unsigned int cid, unsigned int port, int type) 101 { 102 struct sockaddr_vm sa = { 103 .svm_family = AF_VSOCK, 104 .svm_cid = cid, 105 .svm_port = port, 106 }; 107 int fd; 108 109 fd = socket(AF_VSOCK, type, 0); 110 if (fd < 0) { 111 perror("socket"); 112 exit(EXIT_FAILURE); 113 } 114 115 if (bind(fd, (struct sockaddr *)&sa, sizeof(sa))) { 116 perror("bind"); 117 exit(EXIT_FAILURE); 118 } 119 120 return fd; 121 } 122 123 int vsock_connect_fd(int fd, unsigned int cid, unsigned int port) 124 { 125 struct sockaddr_vm sa = { 126 .svm_family = AF_VSOCK, 127 .svm_cid = cid, 128 .svm_port = port, 129 }; 130 int ret; 131 132 timeout_begin(TIMEOUT); 133 do { 134 ret = connect(fd, (struct sockaddr *)&sa, sizeof(sa)); 135 timeout_check("connect"); 136 } while (ret < 0 && errno == EINTR); 137 timeout_end(); 138 139 return ret; 140 } 141 142 /* Bind to <bind_port>, connect to <cid, port> and return the file descriptor. */ 143 int vsock_bind_connect(unsigned int cid, unsigned int port, unsigned int bind_port, int type) 144 { 145 int client_fd; 146 147 client_fd = vsock_bind(VMADDR_CID_ANY, bind_port, type); 148 149 if (vsock_connect_fd(client_fd, cid, port)) { 150 perror("connect"); 151 exit(EXIT_FAILURE); 152 } 153 154 return client_fd; 155 } 156 157 /* Connect to <cid, port> and return the file descriptor. */ 158 int vsock_connect(unsigned int cid, unsigned int port, int type) 159 { 160 int fd; 161 162 control_expectln("LISTENING"); 163 164 fd = socket(AF_VSOCK, type, 0); 165 if (fd < 0) { 166 perror("socket"); 167 exit(EXIT_FAILURE); 168 } 169 170 if (vsock_connect_fd(fd, cid, port)) { 171 int old_errno = errno; 172 173 close(fd); 174 fd = -1; 175 errno = old_errno; 176 } 177 178 return fd; 179 } 180 181 int vsock_stream_connect(unsigned int cid, unsigned int port) 182 { 183 return vsock_connect(cid, port, SOCK_STREAM); 184 } 185 186 int vsock_seqpacket_connect(unsigned int cid, unsigned int port) 187 { 188 return vsock_connect(cid, port, SOCK_SEQPACKET); 189 } 190 191 /* Listen on <cid, port> and return the file descriptor. */ 192 static int vsock_listen(unsigned int cid, unsigned int port, int type) 193 { 194 int fd; 195 196 fd = vsock_bind(cid, port, type); 197 198 if (listen(fd, 1) < 0) { 199 perror("listen"); 200 exit(EXIT_FAILURE); 201 } 202 203 return fd; 204 } 205 206 /* Listen on <cid, port> and return the first incoming connection. The remote 207 * address is stored to clientaddrp. clientaddrp may be NULL. 208 */ 209 int vsock_accept(unsigned int cid, unsigned int port, 210 struct sockaddr_vm *clientaddrp, int type) 211 { 212 union { 213 struct sockaddr sa; 214 struct sockaddr_vm svm; 215 } clientaddr; 216 socklen_t clientaddr_len = sizeof(clientaddr.svm); 217 int fd, client_fd, old_errno; 218 219 fd = vsock_listen(cid, port, type); 220 221 control_writeln("LISTENING"); 222 223 timeout_begin(TIMEOUT); 224 do { 225 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len); 226 timeout_check("accept"); 227 } while (client_fd < 0 && errno == EINTR); 228 timeout_end(); 229 230 old_errno = errno; 231 close(fd); 232 errno = old_errno; 233 234 if (client_fd < 0) 235 return client_fd; 236 237 if (clientaddr_len != sizeof(clientaddr.svm)) { 238 fprintf(stderr, "unexpected addrlen from accept(2), %zu\n", 239 (size_t)clientaddr_len); 240 exit(EXIT_FAILURE); 241 } 242 if (clientaddr.sa.sa_family != AF_VSOCK) { 243 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n", 244 clientaddr.sa.sa_family); 245 exit(EXIT_FAILURE); 246 } 247 248 if (clientaddrp) 249 *clientaddrp = clientaddr.svm; 250 return client_fd; 251 } 252 253 int vsock_stream_accept(unsigned int cid, unsigned int port, 254 struct sockaddr_vm *clientaddrp) 255 { 256 return vsock_accept(cid, port, clientaddrp, SOCK_STREAM); 257 } 258 259 int vsock_stream_listen(unsigned int cid, unsigned int port) 260 { 261 return vsock_listen(cid, port, SOCK_STREAM); 262 } 263 264 int vsock_seqpacket_accept(unsigned int cid, unsigned int port, 265 struct sockaddr_vm *clientaddrp) 266 { 267 return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET); 268 } 269 270 /* Transmit bytes from a buffer and check the return value. 271 * 272 * expected_ret: 273 * <0 Negative errno (for testing errors) 274 * 0 End-of-file 275 * >0 Success (bytes successfully written) 276 */ 277 void send_buf(int fd, const void *buf, size_t len, int flags, 278 ssize_t expected_ret) 279 { 280 ssize_t nwritten = 0; 281 ssize_t ret; 282 283 timeout_begin(TIMEOUT); 284 do { 285 ret = send(fd, buf + nwritten, len - nwritten, flags); 286 timeout_check("send"); 287 288 if (ret == 0 || (ret < 0 && errno != EINTR)) 289 break; 290 291 nwritten += ret; 292 } while (nwritten < len); 293 timeout_end(); 294 295 if (expected_ret < 0) { 296 if (ret != -1) { 297 fprintf(stderr, "bogus send(2) return value %zd (expected %zd)\n", 298 ret, expected_ret); 299 exit(EXIT_FAILURE); 300 } 301 if (errno != -expected_ret) { 302 perror("send"); 303 exit(EXIT_FAILURE); 304 } 305 return; 306 } 307 308 if (ret < 0) { 309 perror("send"); 310 exit(EXIT_FAILURE); 311 } 312 313 if (nwritten != expected_ret) { 314 if (ret == 0) 315 fprintf(stderr, "unexpected EOF while sending bytes\n"); 316 317 fprintf(stderr, "bogus send(2) bytes written %zd (expected %zd)\n", 318 nwritten, expected_ret); 319 exit(EXIT_FAILURE); 320 } 321 } 322 323 /* Receive bytes in a buffer and check the return value. 324 * 325 * expected_ret: 326 * <0 Negative errno (for testing errors) 327 * 0 End-of-file 328 * >0 Success (bytes successfully read) 329 */ 330 void recv_buf(int fd, void *buf, size_t len, int flags, ssize_t expected_ret) 331 { 332 ssize_t nread = 0; 333 ssize_t ret; 334 335 timeout_begin(TIMEOUT); 336 do { 337 ret = recv(fd, buf + nread, len - nread, flags); 338 timeout_check("recv"); 339 340 if (ret == 0 || (ret < 0 && errno != EINTR)) 341 break; 342 343 nread += ret; 344 } while (nread < len); 345 timeout_end(); 346 347 if (expected_ret < 0) { 348 if (ret != -1) { 349 fprintf(stderr, "bogus recv(2) return value %zd (expected %zd)\n", 350 ret, expected_ret); 351 exit(EXIT_FAILURE); 352 } 353 if (errno != -expected_ret) { 354 perror("recv"); 355 exit(EXIT_FAILURE); 356 } 357 return; 358 } 359 360 if (ret < 0) { 361 perror("recv"); 362 exit(EXIT_FAILURE); 363 } 364 365 if (nread != expected_ret) { 366 if (ret == 0) 367 fprintf(stderr, "unexpected EOF while receiving bytes\n"); 368 369 fprintf(stderr, "bogus recv(2) bytes read %zd (expected %zd)\n", 370 nread, expected_ret); 371 exit(EXIT_FAILURE); 372 } 373 } 374 375 /* Transmit one byte and check the return value. 376 * 377 * expected_ret: 378 * <0 Negative errno (for testing errors) 379 * 0 End-of-file 380 * 1 Success 381 */ 382 void send_byte(int fd, int expected_ret, int flags) 383 { 384 static const uint8_t byte = 'A'; 385 386 send_buf(fd, &byte, sizeof(byte), flags, expected_ret); 387 } 388 389 /* Receive one byte and check the return value. 390 * 391 * expected_ret: 392 * <0 Negative errno (for testing errors) 393 * 0 End-of-file 394 * 1 Success 395 */ 396 void recv_byte(int fd, int expected_ret, int flags) 397 { 398 uint8_t byte; 399 400 recv_buf(fd, &byte, sizeof(byte), flags, expected_ret); 401 402 if (byte != 'A') { 403 fprintf(stderr, "unexpected byte read 0x%02x\n", byte); 404 exit(EXIT_FAILURE); 405 } 406 } 407 408 /* Run test cases. The program terminates if a failure occurs. */ 409 void run_tests(const struct test_case *test_cases, 410 const struct test_opts *opts) 411 { 412 int i; 413 414 for (i = 0; test_cases[i].name; i++) { 415 void (*run)(const struct test_opts *opts); 416 char *line; 417 418 printf("%d - %s...", i, test_cases[i].name); 419 fflush(stdout); 420 421 /* Full barrier before executing the next test. This 422 * ensures that client and server are executing the 423 * same test case. In particular, it means whoever is 424 * faster will not see the peer still executing the 425 * last test. This is important because port numbers 426 * can be used by multiple test cases. 427 */ 428 if (test_cases[i].skip) 429 control_writeln("SKIP"); 430 else 431 control_writeln("NEXT"); 432 433 line = control_readln(); 434 if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) { 435 436 printf("skipped\n"); 437 438 free(line); 439 continue; 440 } 441 442 control_cmpln(line, "NEXT", true); 443 free(line); 444 445 if (opts->mode == TEST_MODE_CLIENT) 446 run = test_cases[i].run_client; 447 else 448 run = test_cases[i].run_server; 449 450 if (run) 451 run(opts); 452 453 printf("ok\n"); 454 } 455 } 456 457 void list_tests(const struct test_case *test_cases) 458 { 459 int i; 460 461 printf("ID\tTest name\n"); 462 463 for (i = 0; test_cases[i].name; i++) 464 printf("%d\t%s\n", i, test_cases[i].name); 465 466 exit(EXIT_FAILURE); 467 } 468 469 static unsigned long parse_test_id(const char *test_id_str, size_t test_cases_len) 470 { 471 unsigned long test_id; 472 char *endptr = NULL; 473 474 errno = 0; 475 test_id = strtoul(test_id_str, &endptr, 10); 476 if (errno || *endptr != '\0') { 477 fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str); 478 exit(EXIT_FAILURE); 479 } 480 481 if (test_id >= test_cases_len) { 482 fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n", 483 test_id, test_cases_len - 1); 484 exit(EXIT_FAILURE); 485 } 486 487 return test_id; 488 } 489 490 void skip_test(struct test_case *test_cases, size_t test_cases_len, 491 const char *test_id_str) 492 { 493 unsigned long test_id = parse_test_id(test_id_str, test_cases_len); 494 test_cases[test_id].skip = true; 495 } 496 497 void pick_test(struct test_case *test_cases, size_t test_cases_len, 498 const char *test_id_str) 499 { 500 static bool skip_all = true; 501 unsigned long test_id; 502 503 if (skip_all) { 504 unsigned long i; 505 506 for (i = 0; i < test_cases_len; ++i) 507 test_cases[i].skip = true; 508 509 skip_all = false; 510 } 511 512 test_id = parse_test_id(test_id_str, test_cases_len); 513 test_cases[test_id].skip = false; 514 } 515 516 unsigned long hash_djb2(const void *data, size_t len) 517 { 518 unsigned long hash = 5381; 519 int i = 0; 520 521 while (i < len) { 522 hash = ((hash << 5) + hash) + ((unsigned char *)data)[i]; 523 i++; 524 } 525 526 return hash; 527 } 528 529 size_t iovec_bytes(const struct iovec *iov, size_t iovnum) 530 { 531 size_t bytes; 532 int i; 533 534 for (bytes = 0, i = 0; i < iovnum; i++) 535 bytes += iov[i].iov_len; 536 537 return bytes; 538 } 539 540 unsigned long iovec_hash_djb2(const struct iovec *iov, size_t iovnum) 541 { 542 unsigned long hash; 543 size_t iov_bytes; 544 size_t offs; 545 void *tmp; 546 int i; 547 548 iov_bytes = iovec_bytes(iov, iovnum); 549 550 tmp = malloc(iov_bytes); 551 if (!tmp) { 552 perror("malloc"); 553 exit(EXIT_FAILURE); 554 } 555 556 for (offs = 0, i = 0; i < iovnum; i++) { 557 memcpy(tmp + offs, iov[i].iov_base, iov[i].iov_len); 558 offs += iov[i].iov_len; 559 } 560 561 hash = hash_djb2(tmp, iov_bytes); 562 free(tmp); 563 564 return hash; 565 } 566 567 /* Allocates and returns new 'struct iovec *' according pattern 568 * in the 'test_iovec'. For each element in the 'test_iovec' it 569 * allocates new element in the resulting 'iovec'. 'iov_len' 570 * of the new element is copied from 'test_iovec'. 'iov_base' is 571 * allocated depending on the 'iov_base' of 'test_iovec': 572 * 573 * 'iov_base' == NULL -> valid buf: mmap('iov_len'). 574 * 575 * 'iov_base' == MAP_FAILED -> invalid buf: 576 * mmap('iov_len'), then munmap('iov_len'). 577 * 'iov_base' still contains result of 578 * mmap(). 579 * 580 * 'iov_base' == number -> unaligned valid buf: 581 * mmap('iov_len') + number. 582 * 583 * 'iovnum' is number of elements in 'test_iovec'. 584 * 585 * Returns new 'iovec' or calls 'exit()' on error. 586 */ 587 struct iovec *alloc_test_iovec(const struct iovec *test_iovec, int iovnum) 588 { 589 struct iovec *iovec; 590 int i; 591 592 iovec = malloc(sizeof(*iovec) * iovnum); 593 if (!iovec) { 594 perror("malloc"); 595 exit(EXIT_FAILURE); 596 } 597 598 for (i = 0; i < iovnum; i++) { 599 iovec[i].iov_len = test_iovec[i].iov_len; 600 601 iovec[i].iov_base = mmap(NULL, iovec[i].iov_len, 602 PROT_READ | PROT_WRITE, 603 MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE, 604 -1, 0); 605 if (iovec[i].iov_base == MAP_FAILED) { 606 perror("mmap"); 607 exit(EXIT_FAILURE); 608 } 609 610 if (test_iovec[i].iov_base != MAP_FAILED) 611 iovec[i].iov_base += (uintptr_t)test_iovec[i].iov_base; 612 } 613 614 /* Unmap "invalid" elements. */ 615 for (i = 0; i < iovnum; i++) { 616 if (test_iovec[i].iov_base == MAP_FAILED) { 617 if (munmap(iovec[i].iov_base, iovec[i].iov_len)) { 618 perror("munmap"); 619 exit(EXIT_FAILURE); 620 } 621 } 622 } 623 624 for (i = 0; i < iovnum; i++) { 625 int j; 626 627 if (test_iovec[i].iov_base == MAP_FAILED) 628 continue; 629 630 for (j = 0; j < iovec[i].iov_len; j++) 631 ((uint8_t *)iovec[i].iov_base)[j] = rand() & 0xff; 632 } 633 634 return iovec; 635 } 636 637 /* Frees 'iovec *', previously allocated by 'alloc_test_iovec()'. 638 * On error calls 'exit()'. 639 */ 640 void free_test_iovec(const struct iovec *test_iovec, 641 struct iovec *iovec, int iovnum) 642 { 643 int i; 644 645 for (i = 0; i < iovnum; i++) { 646 if (test_iovec[i].iov_base != MAP_FAILED) { 647 if (test_iovec[i].iov_base) 648 iovec[i].iov_base -= (uintptr_t)test_iovec[i].iov_base; 649 650 if (munmap(iovec[i].iov_base, iovec[i].iov_len)) { 651 perror("munmap"); 652 exit(EXIT_FAILURE); 653 } 654 } 655 } 656 657 free(iovec); 658 } 659 660 /* Set "unsigned long long" socket option and check that it's indeed set */ 661 void setsockopt_ull_check(int fd, int level, int optname, 662 unsigned long long val, char const *errmsg) 663 { 664 unsigned long long chkval; 665 socklen_t chklen; 666 int err; 667 668 err = setsockopt(fd, level, optname, &val, sizeof(val)); 669 if (err) { 670 fprintf(stderr, "setsockopt err: %s (%d)\n", 671 strerror(errno), errno); 672 goto fail; 673 } 674 675 chkval = ~val; /* just make storage != val */ 676 chklen = sizeof(chkval); 677 678 err = getsockopt(fd, level, optname, &chkval, &chklen); 679 if (err) { 680 fprintf(stderr, "getsockopt err: %s (%d)\n", 681 strerror(errno), errno); 682 goto fail; 683 } 684 685 if (chklen != sizeof(chkval)) { 686 fprintf(stderr, "size mismatch: set %zu got %d\n", sizeof(val), 687 chklen); 688 goto fail; 689 } 690 691 if (chkval != val) { 692 fprintf(stderr, "value mismatch: set %llu got %llu\n", val, 693 chkval); 694 goto fail; 695 } 696 return; 697 fail: 698 fprintf(stderr, "%s val %llu\n", errmsg, val); 699 exit(EXIT_FAILURE); 700 ; 701 } 702 703 /* Set "int" socket option and check that it's indeed set */ 704 void setsockopt_int_check(int fd, int level, int optname, int val, 705 char const *errmsg) 706 { 707 int chkval; 708 socklen_t chklen; 709 int err; 710 711 err = setsockopt(fd, level, optname, &val, sizeof(val)); 712 if (err) { 713 fprintf(stderr, "setsockopt err: %s (%d)\n", 714 strerror(errno), errno); 715 goto fail; 716 } 717 718 chkval = ~val; /* just make storage != val */ 719 chklen = sizeof(chkval); 720 721 err = getsockopt(fd, level, optname, &chkval, &chklen); 722 if (err) { 723 fprintf(stderr, "getsockopt err: %s (%d)\n", 724 strerror(errno), errno); 725 goto fail; 726 } 727 728 if (chklen != sizeof(chkval)) { 729 fprintf(stderr, "size mismatch: set %zu got %d\n", sizeof(val), 730 chklen); 731 goto fail; 732 } 733 734 if (chkval != val) { 735 fprintf(stderr, "value mismatch: set %d got %d\n", val, chkval); 736 goto fail; 737 } 738 return; 739 fail: 740 fprintf(stderr, "%s val %d\n", errmsg, val); 741 exit(EXIT_FAILURE); 742 } 743 744 static void mem_invert(unsigned char *mem, size_t size) 745 { 746 size_t i; 747 748 for (i = 0; i < size; i++) 749 mem[i] = ~mem[i]; 750 } 751 752 /* Set "timeval" socket option and check that it's indeed set */ 753 void setsockopt_timeval_check(int fd, int level, int optname, 754 struct timeval val, char const *errmsg) 755 { 756 struct timeval chkval; 757 socklen_t chklen; 758 int err; 759 760 err = setsockopt(fd, level, optname, &val, sizeof(val)); 761 if (err) { 762 fprintf(stderr, "setsockopt err: %s (%d)\n", 763 strerror(errno), errno); 764 goto fail; 765 } 766 767 /* just make storage != val */ 768 chkval = val; 769 mem_invert((unsigned char *)&chkval, sizeof(chkval)); 770 chklen = sizeof(chkval); 771 772 err = getsockopt(fd, level, optname, &chkval, &chklen); 773 if (err) { 774 fprintf(stderr, "getsockopt err: %s (%d)\n", 775 strerror(errno), errno); 776 goto fail; 777 } 778 779 if (chklen != sizeof(chkval)) { 780 fprintf(stderr, "size mismatch: set %zu got %d\n", sizeof(val), 781 chklen); 782 goto fail; 783 } 784 785 if (memcmp(&chkval, &val, sizeof(val)) != 0) { 786 fprintf(stderr, "value mismatch: set %ld:%ld got %ld:%ld\n", 787 val.tv_sec, val.tv_usec, chkval.tv_sec, chkval.tv_usec); 788 goto fail; 789 } 790 return; 791 fail: 792 fprintf(stderr, "%s val %ld:%ld\n", errmsg, val.tv_sec, val.tv_usec); 793 exit(EXIT_FAILURE); 794 } 795 796 void enable_so_zerocopy_check(int fd) 797 { 798 setsockopt_int_check(fd, SOL_SOCKET, SO_ZEROCOPY, 1, 799 "setsockopt SO_ZEROCOPY"); 800 } 801