1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <errno.h> 7 #include <fcntl.h> 8 #include <limits.h> 9 #include <string.h> 10 #include <stdarg.h> 11 #include <stdbool.h> 12 #include <stdint.h> 13 #include <inttypes.h> 14 #include <stdio.h> 15 #include <stdlib.h> 16 #include <strings.h> 17 #include <unistd.h> 18 #include <time.h> 19 20 #include <sys/ioctl.h> 21 #include <sys/random.h> 22 #include <sys/socket.h> 23 #include <sys/types.h> 24 #include <sys/wait.h> 25 26 #include <netdb.h> 27 #include <netinet/in.h> 28 29 #include <linux/tcp.h> 30 #include <linux/sockios.h> 31 32 #ifndef IPPROTO_MPTCP 33 #define IPPROTO_MPTCP 262 34 #endif 35 #ifndef SOL_MPTCP 36 #define SOL_MPTCP 284 37 #endif 38 39 static int pf = AF_INET; 40 static int proto_tx = IPPROTO_MPTCP; 41 static int proto_rx = IPPROTO_MPTCP; 42 43 static void die_perror(const char *msg) 44 { 45 perror(msg); 46 exit(1); 47 } 48 49 static void die_usage(int r) 50 { 51 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n"); 52 exit(r); 53 } 54 55 static void xerror(const char *fmt, ...) 56 { 57 va_list ap; 58 59 va_start(ap, fmt); 60 vfprintf(stderr, fmt, ap); 61 va_end(ap); 62 fputc('\n', stderr); 63 exit(1); 64 } 65 66 static const char *getxinfo_strerr(int err) 67 { 68 if (err == EAI_SYSTEM) 69 return strerror(errno); 70 71 return gai_strerror(err); 72 } 73 74 static void xgetaddrinfo(const char *node, const char *service, 75 struct addrinfo *hints, 76 struct addrinfo **res) 77 { 78 again: 79 int err = getaddrinfo(node, service, hints, res); 80 81 if (err) { 82 const char *errstr; 83 84 if (err == EAI_SOCKTYPE) { 85 hints->ai_protocol = IPPROTO_TCP; 86 goto again; 87 } 88 89 errstr = getxinfo_strerr(err); 90 91 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", 92 node ? node : "", service ? service : "", errstr); 93 exit(1); 94 } 95 } 96 97 static int sock_listen_mptcp(const char * const listenaddr, 98 const char * const port) 99 { 100 int sock = -1; 101 struct addrinfo hints = { 102 .ai_protocol = IPPROTO_MPTCP, 103 .ai_socktype = SOCK_STREAM, 104 .ai_flags = AI_PASSIVE | AI_NUMERICHOST 105 }; 106 107 hints.ai_family = pf; 108 109 struct addrinfo *a, *addr; 110 int one = 1; 111 112 xgetaddrinfo(listenaddr, port, &hints, &addr); 113 hints.ai_family = pf; 114 115 for (a = addr; a; a = a->ai_next) { 116 sock = socket(a->ai_family, a->ai_socktype, proto_rx); 117 if (sock < 0) 118 continue; 119 120 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, 121 sizeof(one))) 122 perror("setsockopt"); 123 124 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) 125 break; /* success */ 126 127 perror("bind"); 128 close(sock); 129 sock = -1; 130 } 131 132 freeaddrinfo(addr); 133 134 if (sock < 0) 135 xerror("could not create listen socket"); 136 137 if (listen(sock, 20)) 138 die_perror("listen"); 139 140 return sock; 141 } 142 143 static int sock_connect_mptcp(const char * const remoteaddr, 144 const char * const port, int proto) 145 { 146 struct addrinfo hints = { 147 .ai_protocol = IPPROTO_MPTCP, 148 .ai_socktype = SOCK_STREAM, 149 }; 150 struct addrinfo *a, *addr; 151 int sock = -1; 152 153 hints.ai_family = pf; 154 155 xgetaddrinfo(remoteaddr, port, &hints, &addr); 156 for (a = addr; a; a = a->ai_next) { 157 sock = socket(a->ai_family, a->ai_socktype, proto); 158 if (sock < 0) 159 continue; 160 161 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) 162 break; /* success */ 163 164 die_perror("connect"); 165 } 166 167 if (sock < 0) 168 xerror("could not create connect socket"); 169 170 freeaddrinfo(addr); 171 return sock; 172 } 173 174 static int protostr_to_num(const char *s) 175 { 176 if (strcasecmp(s, "tcp") == 0) 177 return IPPROTO_TCP; 178 if (strcasecmp(s, "mptcp") == 0) 179 return IPPROTO_MPTCP; 180 181 die_usage(1); 182 return 0; 183 } 184 185 static void parse_opts(int argc, char **argv) 186 { 187 int c; 188 189 while ((c = getopt(argc, argv, "h6t:r:")) != -1) { 190 switch (c) { 191 case 'h': 192 die_usage(0); 193 break; 194 case '6': 195 pf = AF_INET6; 196 break; 197 case 't': 198 proto_tx = protostr_to_num(optarg); 199 break; 200 case 'r': 201 proto_rx = protostr_to_num(optarg); 202 break; 203 default: 204 die_usage(1); 205 break; 206 } 207 } 208 } 209 210 /* wait up to timeout milliseconds */ 211 static void wait_for_ack(int fd, int timeout, size_t total) 212 { 213 int i; 214 215 for (i = 0; i < timeout; i++) { 216 int nsd, ret, queued = -1; 217 struct timespec req; 218 219 ret = ioctl(fd, TIOCOUTQ, &queued); 220 if (ret < 0) 221 die_perror("TIOCOUTQ"); 222 223 ret = ioctl(fd, SIOCOUTQNSD, &nsd); 224 if (ret < 0) 225 die_perror("SIOCOUTQNSD"); 226 227 if ((size_t)queued > total) 228 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total); 229 assert(nsd <= queued); 230 231 if (queued == 0) 232 return; 233 234 /* wait for peer to ack rx of all data */ 235 req.tv_sec = 0; 236 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */ 237 nanosleep(&req, NULL); 238 } 239 240 xerror("still tx data queued after %u ms\n", timeout); 241 } 242 243 static void connect_one_server(int fd, int unixfd) 244 { 245 size_t len, i, total, sent; 246 char buf[4096], buf2[4096]; 247 ssize_t ret; 248 249 len = rand() % (sizeof(buf) - 1); 250 251 if (len < 128) 252 len = 128; 253 254 for (i = 0; i < len ; i++) { 255 buf[i] = rand() % 26; 256 buf[i] += 'A'; 257 } 258 259 buf[i] = '\n'; 260 261 /* un-block server */ 262 ret = read(unixfd, buf2, 4); 263 assert(ret == 4); 264 265 assert(strncmp(buf2, "xmit", 4) == 0); 266 267 ret = write(unixfd, &len, sizeof(len)); 268 assert(ret == (ssize_t)sizeof(len)); 269 270 ret = write(fd, buf, len); 271 if (ret < 0) 272 die_perror("write"); 273 274 if (ret != (ssize_t)len) 275 xerror("short write"); 276 277 ret = read(unixfd, buf2, 4); 278 assert(strncmp(buf2, "huge", 4) == 0); 279 280 total = rand() % (16 * 1024 * 1024); 281 total += (1 * 1024 * 1024); 282 sent = total; 283 284 ret = write(unixfd, &total, sizeof(total)); 285 assert(ret == (ssize_t)sizeof(total)); 286 287 wait_for_ack(fd, 5000, len); 288 289 while (total > 0) { 290 if (total > sizeof(buf)) 291 len = sizeof(buf); 292 else 293 len = total; 294 295 ret = write(fd, buf, len); 296 if (ret < 0) 297 die_perror("write"); 298 total -= ret; 299 300 /* we don't have to care about buf content, only 301 * number of total bytes sent 302 */ 303 } 304 305 ret = read(unixfd, buf2, 4); 306 assert(ret == 4); 307 assert(strncmp(buf2, "shut", 4) == 0); 308 309 wait_for_ack(fd, 5000, sent); 310 311 ret = write(fd, buf, 1); 312 assert(ret == 1); 313 close(fd); 314 ret = write(unixfd, "closed", 6); 315 assert(ret == 6); 316 317 close(unixfd); 318 } 319 320 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv) 321 { 322 struct cmsghdr *cmsg; 323 324 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { 325 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { 326 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv)); 327 return; 328 } 329 } 330 331 xerror("could not find TCP_CM_INQ cmsg type"); 332 } 333 334 static void process_one_client(int fd, int unixfd) 335 { 336 unsigned int tcp_inq; 337 size_t expect_len; 338 char msg_buf[4096]; 339 char buf[4096]; 340 char tmp[16]; 341 struct iovec iov = { 342 .iov_base = buf, 343 .iov_len = 1, 344 }; 345 struct msghdr msg = { 346 .msg_iov = &iov, 347 .msg_iovlen = 1, 348 .msg_control = msg_buf, 349 .msg_controllen = sizeof(msg_buf), 350 }; 351 ssize_t ret, tot; 352 353 ret = write(unixfd, "xmit", 4); 354 assert(ret == 4); 355 356 ret = read(unixfd, &expect_len, sizeof(expect_len)); 357 assert(ret == (ssize_t)sizeof(expect_len)); 358 359 if (expect_len > sizeof(buf)) 360 xerror("expect len %zu exceeds buffer size", expect_len); 361 362 for (;;) { 363 struct timespec req; 364 unsigned int queued; 365 366 ret = ioctl(fd, FIONREAD, &queued); 367 if (ret < 0) 368 die_perror("FIONREAD"); 369 if (queued > expect_len) 370 xerror("FIONREAD returned %u, but only %zu expected\n", 371 queued, expect_len); 372 if (queued == expect_len) 373 break; 374 375 req.tv_sec = 0; 376 req.tv_nsec = 1000 * 1000ul; 377 nanosleep(&req, NULL); 378 } 379 380 /* read one byte, expect cmsg to return expected - 1 */ 381 ret = recvmsg(fd, &msg, 0); 382 if (ret < 0) 383 die_perror("recvmsg"); 384 385 if (msg.msg_controllen == 0) 386 xerror("msg_controllen is 0"); 387 388 get_tcp_inq(&msg, &tcp_inq); 389 390 assert((size_t)tcp_inq == (expect_len - 1)); 391 392 iov.iov_len = sizeof(buf); 393 ret = recvmsg(fd, &msg, 0); 394 if (ret < 0) 395 die_perror("recvmsg"); 396 397 /* should have gotten exact remainder of all pending data */ 398 assert(ret == (ssize_t)tcp_inq); 399 400 /* should be 0, all drained */ 401 get_tcp_inq(&msg, &tcp_inq); 402 assert(tcp_inq == 0); 403 404 /* request a large swath of data. */ 405 ret = write(unixfd, "huge", 4); 406 assert(ret == 4); 407 408 ret = read(unixfd, &expect_len, sizeof(expect_len)); 409 assert(ret == (ssize_t)sizeof(expect_len)); 410 411 /* peer should send us a few mb of data */ 412 if (expect_len <= sizeof(buf)) 413 xerror("expect len %zu too small\n", expect_len); 414 415 tot = 0; 416 do { 417 iov.iov_len = sizeof(buf); 418 ret = recvmsg(fd, &msg, 0); 419 if (ret < 0) 420 die_perror("recvmsg"); 421 422 tot += ret; 423 424 get_tcp_inq(&msg, &tcp_inq); 425 426 if (tcp_inq > expect_len - tot) 427 xerror("inq %d, remaining %d total_len %d\n", 428 tcp_inq, expect_len - tot, (int)expect_len); 429 430 assert(tcp_inq <= expect_len - tot); 431 } while ((size_t)tot < expect_len); 432 433 ret = write(unixfd, "shut", 4); 434 assert(ret == 4); 435 436 /* wait for hangup. Should have received one more byte of data. */ 437 ret = read(unixfd, tmp, sizeof(tmp)); 438 assert(ret == 6); 439 assert(strncmp(tmp, "closed", 6) == 0); 440 441 sleep(1); 442 443 iov.iov_len = 1; 444 ret = recvmsg(fd, &msg, 0); 445 if (ret < 0) 446 die_perror("recvmsg"); 447 assert(ret == 1); 448 449 get_tcp_inq(&msg, &tcp_inq); 450 451 /* tcp_inq should be 1 due to received fin. */ 452 assert(tcp_inq == 1); 453 454 iov.iov_len = 1; 455 ret = recvmsg(fd, &msg, 0); 456 if (ret < 0) 457 die_perror("recvmsg"); 458 459 /* expect EOF */ 460 assert(ret == 0); 461 get_tcp_inq(&msg, &tcp_inq); 462 assert(tcp_inq == 1); 463 464 close(fd); 465 } 466 467 static int xaccept(int s) 468 { 469 int fd = accept(s, NULL, 0); 470 471 if (fd < 0) 472 die_perror("accept"); 473 474 return fd; 475 } 476 477 static int server(int unixfd) 478 { 479 int fd = -1, r, on = 1; 480 481 switch (pf) { 482 case AF_INET: 483 fd = sock_listen_mptcp("127.0.0.1", "15432"); 484 break; 485 case AF_INET6: 486 fd = sock_listen_mptcp("::1", "15432"); 487 break; 488 default: 489 xerror("Unknown pf %d\n", pf); 490 break; 491 } 492 493 r = write(unixfd, "conn", 4); 494 assert(r == 4); 495 496 alarm(15); 497 r = xaccept(fd); 498 499 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on))) 500 die_perror("setsockopt"); 501 502 process_one_client(r, unixfd); 503 504 return 0; 505 } 506 507 static int client(int unixfd) 508 { 509 int fd = -1; 510 511 alarm(15); 512 513 switch (pf) { 514 case AF_INET: 515 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx); 516 break; 517 case AF_INET6: 518 fd = sock_connect_mptcp("::1", "15432", proto_tx); 519 break; 520 default: 521 xerror("Unknown pf %d\n", pf); 522 } 523 524 connect_one_server(fd, unixfd); 525 526 return 0; 527 } 528 529 static void init_rng(void) 530 { 531 unsigned int foo; 532 533 if (getrandom(&foo, sizeof(foo), 0) == -1) { 534 perror("getrandom"); 535 exit(1); 536 } 537 538 srand(foo); 539 } 540 541 static pid_t xfork(void) 542 { 543 pid_t p = fork(); 544 545 if (p < 0) 546 die_perror("fork"); 547 else if (p == 0) 548 init_rng(); 549 550 return p; 551 } 552 553 static int rcheck(int wstatus, const char *what) 554 { 555 if (WIFEXITED(wstatus)) { 556 if (WEXITSTATUS(wstatus) == 0) 557 return 0; 558 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus)); 559 return WEXITSTATUS(wstatus); 560 } else if (WIFSIGNALED(wstatus)) { 561 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus)); 562 } else if (WIFSTOPPED(wstatus)) { 563 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus)); 564 } 565 566 return 111; 567 } 568 569 int main(int argc, char *argv[]) 570 { 571 int e1, e2, wstatus; 572 pid_t s, c, ret; 573 int unixfds[2]; 574 575 parse_opts(argc, argv); 576 577 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds); 578 if (e1 < 0) 579 die_perror("pipe"); 580 581 s = xfork(); 582 if (s == 0) 583 return server(unixfds[1]); 584 585 close(unixfds[1]); 586 587 /* wait until server bound a socket */ 588 e1 = read(unixfds[0], &e1, 4); 589 assert(e1 == 4); 590 591 c = xfork(); 592 if (c == 0) 593 return client(unixfds[0]); 594 595 close(unixfds[0]); 596 597 ret = waitpid(s, &wstatus, 0); 598 if (ret == -1) 599 die_perror("waitpid"); 600 e1 = rcheck(wstatus, "server"); 601 ret = waitpid(c, &wstatus, 0); 602 if (ret == -1) 603 die_perror("waitpid"); 604 e2 = rcheck(wstatus, "client"); 605 606 return e1 ? e1 : e2; 607 } 608