1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <errno.h> 7 #include <fcntl.h> 8 #include <limits.h> 9 #include <string.h> 10 #include <stdarg.h> 11 #include <stdbool.h> 12 #include <stdint.h> 13 #include <inttypes.h> 14 #include <stdio.h> 15 #include <stdlib.h> 16 #include <strings.h> 17 #include <unistd.h> 18 #include <time.h> 19 20 #include <sys/ioctl.h> 21 #include <sys/random.h> 22 #include <sys/socket.h> 23 #include <sys/types.h> 24 #include <sys/wait.h> 25 26 #include <netdb.h> 27 #include <netinet/in.h> 28 29 #include <linux/tcp.h> 30 #include <linux/sockios.h> 31 #include <linux/compiler.h> 32 33 #ifndef IPPROTO_MPTCP 34 #define IPPROTO_MPTCP 262 35 #endif 36 #ifndef SOL_MPTCP 37 #define SOL_MPTCP 284 38 #endif 39 40 static int pf = AF_INET; 41 static int proto_tx = IPPROTO_MPTCP; 42 static int proto_rx = IPPROTO_MPTCP; 43 44 static void __noreturn die_perror(const char *msg) 45 { 46 perror(msg); 47 exit(1); 48 } 49 50 static void die_usage(int r) 51 { 52 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n"); 53 exit(r); 54 } 55 56 static void __noreturn xerror(const char *fmt, ...) 57 { 58 va_list ap; 59 60 va_start(ap, fmt); 61 vfprintf(stderr, fmt, ap); 62 va_end(ap); 63 fputc('\n', stderr); 64 exit(1); 65 } 66 67 static const char *getxinfo_strerr(int err) 68 { 69 if (err == EAI_SYSTEM) 70 return strerror(errno); 71 72 return gai_strerror(err); 73 } 74 75 static void xgetaddrinfo(const char *node, const char *service, 76 struct addrinfo *hints, 77 struct addrinfo **res) 78 { 79 int err; 80 81 again: 82 err = getaddrinfo(node, service, hints, res); 83 if (err) { 84 const char *errstr; 85 86 if (err == EAI_SOCKTYPE) { 87 hints->ai_protocol = IPPROTO_TCP; 88 goto again; 89 } 90 91 errstr = getxinfo_strerr(err); 92 93 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", 94 node ? node : "", service ? service : "", errstr); 95 exit(1); 96 } 97 } 98 99 static int sock_listen_mptcp(const char * const listenaddr, 100 const char * const port) 101 { 102 int sock = -1; 103 struct addrinfo hints = { 104 .ai_protocol = IPPROTO_MPTCP, 105 .ai_socktype = SOCK_STREAM, 106 .ai_flags = AI_PASSIVE | AI_NUMERICHOST 107 }; 108 109 hints.ai_family = pf; 110 111 struct addrinfo *a, *addr; 112 int one = 1; 113 114 xgetaddrinfo(listenaddr, port, &hints, &addr); 115 hints.ai_family = pf; 116 117 for (a = addr; a; a = a->ai_next) { 118 sock = socket(a->ai_family, a->ai_socktype, proto_rx); 119 if (sock < 0) 120 continue; 121 122 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, 123 sizeof(one))) 124 perror("setsockopt"); 125 126 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) 127 break; /* success */ 128 129 perror("bind"); 130 close(sock); 131 sock = -1; 132 } 133 134 freeaddrinfo(addr); 135 136 if (sock < 0) 137 xerror("could not create listen socket"); 138 139 if (listen(sock, 20)) 140 die_perror("listen"); 141 142 return sock; 143 } 144 145 static int sock_connect_mptcp(const char * const remoteaddr, 146 const char * const port, int proto) 147 { 148 struct addrinfo hints = { 149 .ai_protocol = IPPROTO_MPTCP, 150 .ai_socktype = SOCK_STREAM, 151 }; 152 struct addrinfo *a, *addr; 153 int sock = -1; 154 155 hints.ai_family = pf; 156 157 xgetaddrinfo(remoteaddr, port, &hints, &addr); 158 for (a = addr; a; a = a->ai_next) { 159 sock = socket(a->ai_family, a->ai_socktype, proto); 160 if (sock < 0) 161 continue; 162 163 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) 164 break; /* success */ 165 166 die_perror("connect"); 167 } 168 169 if (sock < 0) 170 xerror("could not create connect socket"); 171 172 freeaddrinfo(addr); 173 return sock; 174 } 175 176 static int protostr_to_num(const char *s) 177 { 178 if (strcasecmp(s, "tcp") == 0) 179 return IPPROTO_TCP; 180 if (strcasecmp(s, "mptcp") == 0) 181 return IPPROTO_MPTCP; 182 183 die_usage(1); 184 return 0; 185 } 186 187 static void parse_opts(int argc, char **argv) 188 { 189 int c; 190 191 while ((c = getopt(argc, argv, "h6t:r:")) != -1) { 192 switch (c) { 193 case 'h': 194 die_usage(0); 195 break; 196 case '6': 197 pf = AF_INET6; 198 break; 199 case 't': 200 proto_tx = protostr_to_num(optarg); 201 break; 202 case 'r': 203 proto_rx = protostr_to_num(optarg); 204 break; 205 default: 206 die_usage(1); 207 break; 208 } 209 } 210 } 211 212 /* wait up to timeout milliseconds */ 213 static void wait_for_ack(int fd, int timeout, size_t total) 214 { 215 int i; 216 217 for (i = 0; i < timeout; i++) { 218 int nsd, ret, queued = -1; 219 struct timespec req; 220 221 ret = ioctl(fd, TIOCOUTQ, &queued); 222 if (ret < 0) 223 die_perror("TIOCOUTQ"); 224 225 ret = ioctl(fd, SIOCOUTQNSD, &nsd); 226 if (ret < 0) 227 die_perror("SIOCOUTQNSD"); 228 229 if ((size_t)queued > total) 230 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total); 231 assert(nsd <= queued); 232 233 if (queued == 0) 234 return; 235 236 /* wait for peer to ack rx of all data */ 237 req.tv_sec = 0; 238 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */ 239 nanosleep(&req, NULL); 240 } 241 242 xerror("still tx data queued after %u ms\n", timeout); 243 } 244 245 static void connect_one_server(int fd, int unixfd) 246 { 247 size_t len, i, total, sent; 248 char buf[4096], buf2[4096]; 249 ssize_t ret; 250 251 len = rand() % (sizeof(buf) - 1); 252 253 if (len < 128) 254 len = 128; 255 256 for (i = 0; i < len ; i++) { 257 buf[i] = rand() % 26; 258 buf[i] += 'A'; 259 } 260 261 buf[i] = '\n'; 262 263 /* un-block server */ 264 ret = read(unixfd, buf2, 4); 265 assert(ret == 4); 266 267 assert(strncmp(buf2, "xmit", 4) == 0); 268 269 ret = write(unixfd, &len, sizeof(len)); 270 assert(ret == (ssize_t)sizeof(len)); 271 272 ret = write(fd, buf, len); 273 if (ret < 0) 274 die_perror("write"); 275 276 if (ret != (ssize_t)len) 277 xerror("short write"); 278 279 ret = read(unixfd, buf2, 4); 280 assert(strncmp(buf2, "huge", 4) == 0); 281 282 total = rand() % (16 * 1024 * 1024); 283 total += (1 * 1024 * 1024); 284 sent = total; 285 286 ret = write(unixfd, &total, sizeof(total)); 287 assert(ret == (ssize_t)sizeof(total)); 288 289 wait_for_ack(fd, 5000, len); 290 291 while (total > 0) { 292 if (total > sizeof(buf)) 293 len = sizeof(buf); 294 else 295 len = total; 296 297 ret = write(fd, buf, len); 298 if (ret < 0) 299 die_perror("write"); 300 total -= ret; 301 302 /* we don't have to care about buf content, only 303 * number of total bytes sent 304 */ 305 } 306 307 ret = read(unixfd, buf2, 4); 308 assert(ret == 4); 309 assert(strncmp(buf2, "shut", 4) == 0); 310 311 wait_for_ack(fd, 5000, sent); 312 313 ret = write(fd, buf, 1); 314 assert(ret == 1); 315 close(fd); 316 ret = write(unixfd, "closed", 6); 317 assert(ret == 6); 318 319 close(unixfd); 320 } 321 322 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv) 323 { 324 struct cmsghdr *cmsg; 325 326 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { 327 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { 328 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv)); 329 return; 330 } 331 } 332 333 xerror("could not find TCP_CM_INQ cmsg type"); 334 } 335 336 static void process_one_client(int fd, int unixfd) 337 { 338 unsigned int tcp_inq; 339 size_t expect_len; 340 char msg_buf[4096]; 341 char buf[4096]; 342 char tmp[16]; 343 struct iovec iov = { 344 .iov_base = buf, 345 .iov_len = 1, 346 }; 347 struct msghdr msg = { 348 .msg_iov = &iov, 349 .msg_iovlen = 1, 350 .msg_control = msg_buf, 351 .msg_controllen = sizeof(msg_buf), 352 }; 353 ssize_t ret, tot; 354 355 ret = write(unixfd, "xmit", 4); 356 assert(ret == 4); 357 358 ret = read(unixfd, &expect_len, sizeof(expect_len)); 359 assert(ret == (ssize_t)sizeof(expect_len)); 360 361 if (expect_len > sizeof(buf)) 362 xerror("expect len %zu exceeds buffer size", expect_len); 363 364 for (;;) { 365 struct timespec req; 366 unsigned int queued; 367 368 ret = ioctl(fd, FIONREAD, &queued); 369 if (ret < 0) 370 die_perror("FIONREAD"); 371 if (queued > expect_len) 372 xerror("FIONREAD returned %u, but only %zu expected\n", 373 queued, expect_len); 374 if (queued == expect_len) 375 break; 376 377 req.tv_sec = 0; 378 req.tv_nsec = 1000 * 1000ul; 379 nanosleep(&req, NULL); 380 } 381 382 /* read one byte, expect cmsg to return expected - 1 */ 383 ret = recvmsg(fd, &msg, 0); 384 if (ret < 0) 385 die_perror("recvmsg"); 386 387 if (msg.msg_controllen == 0) 388 xerror("msg_controllen is 0"); 389 390 get_tcp_inq(&msg, &tcp_inq); 391 392 assert((size_t)tcp_inq == (expect_len - 1)); 393 394 iov.iov_len = sizeof(buf); 395 ret = recvmsg(fd, &msg, 0); 396 if (ret < 0) 397 die_perror("recvmsg"); 398 399 /* should have gotten exact remainder of all pending data */ 400 assert(ret == (ssize_t)tcp_inq); 401 402 /* should be 0, all drained */ 403 get_tcp_inq(&msg, &tcp_inq); 404 assert(tcp_inq == 0); 405 406 /* request a large swath of data. */ 407 ret = write(unixfd, "huge", 4); 408 assert(ret == 4); 409 410 ret = read(unixfd, &expect_len, sizeof(expect_len)); 411 assert(ret == (ssize_t)sizeof(expect_len)); 412 413 /* peer should send us a few mb of data */ 414 if (expect_len <= sizeof(buf)) 415 xerror("expect len %zu too small\n", expect_len); 416 417 tot = 0; 418 do { 419 iov.iov_len = sizeof(buf); 420 ret = recvmsg(fd, &msg, 0); 421 if (ret < 0) 422 die_perror("recvmsg"); 423 424 tot += ret; 425 426 get_tcp_inq(&msg, &tcp_inq); 427 428 if (tcp_inq > expect_len - tot) 429 xerror("inq %d, remaining %d total_len %d\n", 430 tcp_inq, expect_len - tot, (int)expect_len); 431 432 assert(tcp_inq <= expect_len - tot); 433 } while ((size_t)tot < expect_len); 434 435 ret = write(unixfd, "shut", 4); 436 assert(ret == 4); 437 438 /* wait for hangup. Should have received one more byte of data. */ 439 ret = read(unixfd, tmp, sizeof(tmp)); 440 assert(ret == 6); 441 assert(strncmp(tmp, "closed", 6) == 0); 442 443 sleep(1); 444 445 iov.iov_len = 1; 446 ret = recvmsg(fd, &msg, 0); 447 if (ret < 0) 448 die_perror("recvmsg"); 449 assert(ret == 1); 450 451 get_tcp_inq(&msg, &tcp_inq); 452 453 /* tcp_inq should be 1 due to received fin. */ 454 assert(tcp_inq == 1); 455 456 iov.iov_len = 1; 457 ret = recvmsg(fd, &msg, 0); 458 if (ret < 0) 459 die_perror("recvmsg"); 460 461 /* expect EOF */ 462 assert(ret == 0); 463 get_tcp_inq(&msg, &tcp_inq); 464 assert(tcp_inq == 1); 465 466 close(fd); 467 } 468 469 static int xaccept(int s) 470 { 471 int fd = accept(s, NULL, 0); 472 473 if (fd < 0) 474 die_perror("accept"); 475 476 return fd; 477 } 478 479 static int server(int unixfd) 480 { 481 int fd = -1, r, on = 1; 482 483 switch (pf) { 484 case AF_INET: 485 fd = sock_listen_mptcp("127.0.0.1", "15432"); 486 break; 487 case AF_INET6: 488 fd = sock_listen_mptcp("::1", "15432"); 489 break; 490 default: 491 xerror("Unknown pf %d\n", pf); 492 break; 493 } 494 495 r = write(unixfd, "conn", 4); 496 assert(r == 4); 497 498 alarm(15); 499 r = xaccept(fd); 500 501 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on))) 502 die_perror("setsockopt"); 503 504 process_one_client(r, unixfd); 505 506 close(fd); 507 return 0; 508 } 509 510 static int client(int unixfd) 511 { 512 int fd = -1; 513 514 alarm(15); 515 516 switch (pf) { 517 case AF_INET: 518 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx); 519 break; 520 case AF_INET6: 521 fd = sock_connect_mptcp("::1", "15432", proto_tx); 522 break; 523 default: 524 xerror("Unknown pf %d\n", pf); 525 } 526 527 connect_one_server(fd, unixfd); 528 529 return 0; 530 } 531 532 static void init_rng(void) 533 { 534 unsigned int foo; 535 536 if (getrandom(&foo, sizeof(foo), 0) == -1) { 537 perror("getrandom"); 538 exit(1); 539 } 540 541 srand(foo); 542 } 543 544 static pid_t xfork(void) 545 { 546 pid_t p = fork(); 547 548 if (p < 0) 549 die_perror("fork"); 550 else if (p == 0) 551 init_rng(); 552 553 return p; 554 } 555 556 static int rcheck(int wstatus, const char *what) 557 { 558 if (WIFEXITED(wstatus)) { 559 if (WEXITSTATUS(wstatus) == 0) 560 return 0; 561 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus)); 562 return WEXITSTATUS(wstatus); 563 } else if (WIFSIGNALED(wstatus)) { 564 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus)); 565 } else if (WIFSTOPPED(wstatus)) { 566 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus)); 567 } 568 569 return 111; 570 } 571 572 int main(int argc, char *argv[]) 573 { 574 int e1, e2, wstatus; 575 pid_t s, c, ret; 576 int unixfds[2]; 577 578 parse_opts(argc, argv); 579 580 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds); 581 if (e1 < 0) 582 die_perror("pipe"); 583 584 s = xfork(); 585 if (s == 0) { 586 close(unixfds[0]); 587 ret = server(unixfds[1]); 588 close(unixfds[1]); 589 return ret; 590 } 591 592 close(unixfds[1]); 593 594 /* wait until server bound a socket */ 595 e1 = read(unixfds[0], &e1, 4); 596 assert(e1 == 4); 597 598 c = xfork(); 599 if (c == 0) 600 return client(unixfds[0]); 601 602 close(unixfds[0]); 603 604 ret = waitpid(s, &wstatus, 0); 605 if (ret == -1) 606 die_perror("waitpid"); 607 e1 = rcheck(wstatus, "server"); 608 ret = waitpid(c, &wstatus, 0); 609 if (ret == -1) 610 die_perror("waitpid"); 611 e2 = rcheck(wstatus, "client"); 612 613 return e1 ? e1 : e2; 614 } 615