1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <errno.h> 7 #include <fcntl.h> 8 #include <limits.h> 9 #include <string.h> 10 #include <stdarg.h> 11 #include <stdbool.h> 12 #include <stdint.h> 13 #include <inttypes.h> 14 #include <stdio.h> 15 #include <stdlib.h> 16 #include <strings.h> 17 #include <unistd.h> 18 #include <time.h> 19 20 #include <sys/ioctl.h> 21 #include <sys/random.h> 22 #include <sys/socket.h> 23 #include <sys/types.h> 24 #include <sys/wait.h> 25 26 #include <netdb.h> 27 #include <netinet/in.h> 28 29 #include <linux/tcp.h> 30 #include <linux/sockios.h> 31 32 #ifndef IPPROTO_MPTCP 33 #define IPPROTO_MPTCP 262 34 #endif 35 #ifndef SOL_MPTCP 36 #define SOL_MPTCP 284 37 #endif 38 39 static int pf = AF_INET; 40 static int proto_tx = IPPROTO_MPTCP; 41 static int proto_rx = IPPROTO_MPTCP; 42 43 static void die_perror(const char *msg) 44 { 45 perror(msg); 46 exit(1); 47 } 48 49 static void die_usage(int r) 50 { 51 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n"); 52 exit(r); 53 } 54 55 static void xerror(const char *fmt, ...) 56 { 57 va_list ap; 58 59 va_start(ap, fmt); 60 vfprintf(stderr, fmt, ap); 61 va_end(ap); 62 fputc('\n', stderr); 63 exit(1); 64 } 65 66 static const char *getxinfo_strerr(int err) 67 { 68 if (err == EAI_SYSTEM) 69 return strerror(errno); 70 71 return gai_strerror(err); 72 } 73 74 static void xgetaddrinfo(const char *node, const char *service, 75 struct addrinfo *hints, 76 struct addrinfo **res) 77 { 78 int err; 79 80 again: 81 err = getaddrinfo(node, service, hints, res); 82 if (err) { 83 const char *errstr; 84 85 if (err == EAI_SOCKTYPE) { 86 hints->ai_protocol = IPPROTO_TCP; 87 goto again; 88 } 89 90 errstr = getxinfo_strerr(err); 91 92 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", 93 node ? node : "", service ? service : "", errstr); 94 exit(1); 95 } 96 } 97 98 static int sock_listen_mptcp(const char * const listenaddr, 99 const char * const port) 100 { 101 int sock = -1; 102 struct addrinfo hints = { 103 .ai_protocol = IPPROTO_MPTCP, 104 .ai_socktype = SOCK_STREAM, 105 .ai_flags = AI_PASSIVE | AI_NUMERICHOST 106 }; 107 108 hints.ai_family = pf; 109 110 struct addrinfo *a, *addr; 111 int one = 1; 112 113 xgetaddrinfo(listenaddr, port, &hints, &addr); 114 hints.ai_family = pf; 115 116 for (a = addr; a; a = a->ai_next) { 117 sock = socket(a->ai_family, a->ai_socktype, proto_rx); 118 if (sock < 0) 119 continue; 120 121 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, 122 sizeof(one))) 123 perror("setsockopt"); 124 125 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) 126 break; /* success */ 127 128 perror("bind"); 129 close(sock); 130 sock = -1; 131 } 132 133 freeaddrinfo(addr); 134 135 if (sock < 0) 136 xerror("could not create listen socket"); 137 138 if (listen(sock, 20)) 139 die_perror("listen"); 140 141 return sock; 142 } 143 144 static int sock_connect_mptcp(const char * const remoteaddr, 145 const char * const port, int proto) 146 { 147 struct addrinfo hints = { 148 .ai_protocol = IPPROTO_MPTCP, 149 .ai_socktype = SOCK_STREAM, 150 }; 151 struct addrinfo *a, *addr; 152 int sock = -1; 153 154 hints.ai_family = pf; 155 156 xgetaddrinfo(remoteaddr, port, &hints, &addr); 157 for (a = addr; a; a = a->ai_next) { 158 sock = socket(a->ai_family, a->ai_socktype, proto); 159 if (sock < 0) 160 continue; 161 162 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) 163 break; /* success */ 164 165 die_perror("connect"); 166 } 167 168 if (sock < 0) 169 xerror("could not create connect socket"); 170 171 freeaddrinfo(addr); 172 return sock; 173 } 174 175 static int protostr_to_num(const char *s) 176 { 177 if (strcasecmp(s, "tcp") == 0) 178 return IPPROTO_TCP; 179 if (strcasecmp(s, "mptcp") == 0) 180 return IPPROTO_MPTCP; 181 182 die_usage(1); 183 return 0; 184 } 185 186 static void parse_opts(int argc, char **argv) 187 { 188 int c; 189 190 while ((c = getopt(argc, argv, "h6t:r:")) != -1) { 191 switch (c) { 192 case 'h': 193 die_usage(0); 194 break; 195 case '6': 196 pf = AF_INET6; 197 break; 198 case 't': 199 proto_tx = protostr_to_num(optarg); 200 break; 201 case 'r': 202 proto_rx = protostr_to_num(optarg); 203 break; 204 default: 205 die_usage(1); 206 break; 207 } 208 } 209 } 210 211 /* wait up to timeout milliseconds */ 212 static void wait_for_ack(int fd, int timeout, size_t total) 213 { 214 int i; 215 216 for (i = 0; i < timeout; i++) { 217 int nsd, ret, queued = -1; 218 struct timespec req; 219 220 ret = ioctl(fd, TIOCOUTQ, &queued); 221 if (ret < 0) 222 die_perror("TIOCOUTQ"); 223 224 ret = ioctl(fd, SIOCOUTQNSD, &nsd); 225 if (ret < 0) 226 die_perror("SIOCOUTQNSD"); 227 228 if ((size_t)queued > total) 229 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total); 230 assert(nsd <= queued); 231 232 if (queued == 0) 233 return; 234 235 /* wait for peer to ack rx of all data */ 236 req.tv_sec = 0; 237 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */ 238 nanosleep(&req, NULL); 239 } 240 241 xerror("still tx data queued after %u ms\n", timeout); 242 } 243 244 static void connect_one_server(int fd, int unixfd) 245 { 246 size_t len, i, total, sent; 247 char buf[4096], buf2[4096]; 248 ssize_t ret; 249 250 len = rand() % (sizeof(buf) - 1); 251 252 if (len < 128) 253 len = 128; 254 255 for (i = 0; i < len ; i++) { 256 buf[i] = rand() % 26; 257 buf[i] += 'A'; 258 } 259 260 buf[i] = '\n'; 261 262 /* un-block server */ 263 ret = read(unixfd, buf2, 4); 264 assert(ret == 4); 265 266 assert(strncmp(buf2, "xmit", 4) == 0); 267 268 ret = write(unixfd, &len, sizeof(len)); 269 assert(ret == (ssize_t)sizeof(len)); 270 271 ret = write(fd, buf, len); 272 if (ret < 0) 273 die_perror("write"); 274 275 if (ret != (ssize_t)len) 276 xerror("short write"); 277 278 ret = read(unixfd, buf2, 4); 279 assert(strncmp(buf2, "huge", 4) == 0); 280 281 total = rand() % (16 * 1024 * 1024); 282 total += (1 * 1024 * 1024); 283 sent = total; 284 285 ret = write(unixfd, &total, sizeof(total)); 286 assert(ret == (ssize_t)sizeof(total)); 287 288 wait_for_ack(fd, 5000, len); 289 290 while (total > 0) { 291 if (total > sizeof(buf)) 292 len = sizeof(buf); 293 else 294 len = total; 295 296 ret = write(fd, buf, len); 297 if (ret < 0) 298 die_perror("write"); 299 total -= ret; 300 301 /* we don't have to care about buf content, only 302 * number of total bytes sent 303 */ 304 } 305 306 ret = read(unixfd, buf2, 4); 307 assert(ret == 4); 308 assert(strncmp(buf2, "shut", 4) == 0); 309 310 wait_for_ack(fd, 5000, sent); 311 312 ret = write(fd, buf, 1); 313 assert(ret == 1); 314 close(fd); 315 ret = write(unixfd, "closed", 6); 316 assert(ret == 6); 317 318 close(unixfd); 319 } 320 321 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv) 322 { 323 struct cmsghdr *cmsg; 324 325 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { 326 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { 327 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv)); 328 return; 329 } 330 } 331 332 xerror("could not find TCP_CM_INQ cmsg type"); 333 } 334 335 static void process_one_client(int fd, int unixfd) 336 { 337 unsigned int tcp_inq; 338 size_t expect_len; 339 char msg_buf[4096]; 340 char buf[4096]; 341 char tmp[16]; 342 struct iovec iov = { 343 .iov_base = buf, 344 .iov_len = 1, 345 }; 346 struct msghdr msg = { 347 .msg_iov = &iov, 348 .msg_iovlen = 1, 349 .msg_control = msg_buf, 350 .msg_controllen = sizeof(msg_buf), 351 }; 352 ssize_t ret, tot; 353 354 ret = write(unixfd, "xmit", 4); 355 assert(ret == 4); 356 357 ret = read(unixfd, &expect_len, sizeof(expect_len)); 358 assert(ret == (ssize_t)sizeof(expect_len)); 359 360 if (expect_len > sizeof(buf)) 361 xerror("expect len %zu exceeds buffer size", expect_len); 362 363 for (;;) { 364 struct timespec req; 365 unsigned int queued; 366 367 ret = ioctl(fd, FIONREAD, &queued); 368 if (ret < 0) 369 die_perror("FIONREAD"); 370 if (queued > expect_len) 371 xerror("FIONREAD returned %u, but only %zu expected\n", 372 queued, expect_len); 373 if (queued == expect_len) 374 break; 375 376 req.tv_sec = 0; 377 req.tv_nsec = 1000 * 1000ul; 378 nanosleep(&req, NULL); 379 } 380 381 /* read one byte, expect cmsg to return expected - 1 */ 382 ret = recvmsg(fd, &msg, 0); 383 if (ret < 0) 384 die_perror("recvmsg"); 385 386 if (msg.msg_controllen == 0) 387 xerror("msg_controllen is 0"); 388 389 get_tcp_inq(&msg, &tcp_inq); 390 391 assert((size_t)tcp_inq == (expect_len - 1)); 392 393 iov.iov_len = sizeof(buf); 394 ret = recvmsg(fd, &msg, 0); 395 if (ret < 0) 396 die_perror("recvmsg"); 397 398 /* should have gotten exact remainder of all pending data */ 399 assert(ret == (ssize_t)tcp_inq); 400 401 /* should be 0, all drained */ 402 get_tcp_inq(&msg, &tcp_inq); 403 assert(tcp_inq == 0); 404 405 /* request a large swath of data. */ 406 ret = write(unixfd, "huge", 4); 407 assert(ret == 4); 408 409 ret = read(unixfd, &expect_len, sizeof(expect_len)); 410 assert(ret == (ssize_t)sizeof(expect_len)); 411 412 /* peer should send us a few mb of data */ 413 if (expect_len <= sizeof(buf)) 414 xerror("expect len %zu too small\n", expect_len); 415 416 tot = 0; 417 do { 418 iov.iov_len = sizeof(buf); 419 ret = recvmsg(fd, &msg, 0); 420 if (ret < 0) 421 die_perror("recvmsg"); 422 423 tot += ret; 424 425 get_tcp_inq(&msg, &tcp_inq); 426 427 if (tcp_inq > expect_len - tot) 428 xerror("inq %d, remaining %d total_len %d\n", 429 tcp_inq, expect_len - tot, (int)expect_len); 430 431 assert(tcp_inq <= expect_len - tot); 432 } while ((size_t)tot < expect_len); 433 434 ret = write(unixfd, "shut", 4); 435 assert(ret == 4); 436 437 /* wait for hangup. Should have received one more byte of data. */ 438 ret = read(unixfd, tmp, sizeof(tmp)); 439 assert(ret == 6); 440 assert(strncmp(tmp, "closed", 6) == 0); 441 442 sleep(1); 443 444 iov.iov_len = 1; 445 ret = recvmsg(fd, &msg, 0); 446 if (ret < 0) 447 die_perror("recvmsg"); 448 assert(ret == 1); 449 450 get_tcp_inq(&msg, &tcp_inq); 451 452 /* tcp_inq should be 1 due to received fin. */ 453 assert(tcp_inq == 1); 454 455 iov.iov_len = 1; 456 ret = recvmsg(fd, &msg, 0); 457 if (ret < 0) 458 die_perror("recvmsg"); 459 460 /* expect EOF */ 461 assert(ret == 0); 462 get_tcp_inq(&msg, &tcp_inq); 463 assert(tcp_inq == 1); 464 465 close(fd); 466 } 467 468 static int xaccept(int s) 469 { 470 int fd = accept(s, NULL, 0); 471 472 if (fd < 0) 473 die_perror("accept"); 474 475 return fd; 476 } 477 478 static int server(int unixfd) 479 { 480 int fd = -1, r, on = 1; 481 482 switch (pf) { 483 case AF_INET: 484 fd = sock_listen_mptcp("127.0.0.1", "15432"); 485 break; 486 case AF_INET6: 487 fd = sock_listen_mptcp("::1", "15432"); 488 break; 489 default: 490 xerror("Unknown pf %d\n", pf); 491 break; 492 } 493 494 r = write(unixfd, "conn", 4); 495 assert(r == 4); 496 497 alarm(15); 498 r = xaccept(fd); 499 500 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on))) 501 die_perror("setsockopt"); 502 503 process_one_client(r, unixfd); 504 505 return 0; 506 } 507 508 static int client(int unixfd) 509 { 510 int fd = -1; 511 512 alarm(15); 513 514 switch (pf) { 515 case AF_INET: 516 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx); 517 break; 518 case AF_INET6: 519 fd = sock_connect_mptcp("::1", "15432", proto_tx); 520 break; 521 default: 522 xerror("Unknown pf %d\n", pf); 523 } 524 525 connect_one_server(fd, unixfd); 526 527 return 0; 528 } 529 530 static void init_rng(void) 531 { 532 unsigned int foo; 533 534 if (getrandom(&foo, sizeof(foo), 0) == -1) { 535 perror("getrandom"); 536 exit(1); 537 } 538 539 srand(foo); 540 } 541 542 static pid_t xfork(void) 543 { 544 pid_t p = fork(); 545 546 if (p < 0) 547 die_perror("fork"); 548 else if (p == 0) 549 init_rng(); 550 551 return p; 552 } 553 554 static int rcheck(int wstatus, const char *what) 555 { 556 if (WIFEXITED(wstatus)) { 557 if (WEXITSTATUS(wstatus) == 0) 558 return 0; 559 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus)); 560 return WEXITSTATUS(wstatus); 561 } else if (WIFSIGNALED(wstatus)) { 562 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus)); 563 } else if (WIFSTOPPED(wstatus)) { 564 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus)); 565 } 566 567 return 111; 568 } 569 570 int main(int argc, char *argv[]) 571 { 572 int e1, e2, wstatus; 573 pid_t s, c, ret; 574 int unixfds[2]; 575 576 parse_opts(argc, argv); 577 578 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds); 579 if (e1 < 0) 580 die_perror("pipe"); 581 582 s = xfork(); 583 if (s == 0) 584 return server(unixfds[1]); 585 586 close(unixfds[1]); 587 588 /* wait until server bound a socket */ 589 e1 = read(unixfds[0], &e1, 4); 590 assert(e1 == 4); 591 592 c = xfork(); 593 if (c == 0) 594 return client(unixfds[0]); 595 596 close(unixfds[0]); 597 598 ret = waitpid(s, &wstatus, 0); 599 if (ret == -1) 600 die_perror("waitpid"); 601 e1 = rcheck(wstatus, "server"); 602 ret = waitpid(c, &wstatus, 0); 603 if (ret == -1) 604 die_perror("waitpid"); 605 e2 = rcheck(wstatus, "client"); 606 607 return e1 ? e1 : e2; 608 } 609