1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <errno.h> 7 #include <fcntl.h> 8 #include <limits.h> 9 #include <string.h> 10 #include <stdarg.h> 11 #include <stdbool.h> 12 #include <stdint.h> 13 #include <inttypes.h> 14 #include <stdio.h> 15 #include <stdlib.h> 16 #include <strings.h> 17 #include <unistd.h> 18 #include <time.h> 19 20 #include <sys/ioctl.h> 21 #include <sys/random.h> 22 #include <sys/socket.h> 23 #include <sys/types.h> 24 #include <sys/wait.h> 25 26 #include <netdb.h> 27 #include <netinet/in.h> 28 29 #include <linux/tcp.h> 30 #include <linux/sockios.h> 31 32 #ifndef IPPROTO_MPTCP 33 #define IPPROTO_MPTCP 262 34 #endif 35 #ifndef SOL_MPTCP 36 #define SOL_MPTCP 284 37 #endif 38 39 static int pf = AF_INET; 40 static int proto_tx = IPPROTO_MPTCP; 41 static int proto_rx = IPPROTO_MPTCP; 42 43 static void die_perror(const char *msg) 44 { 45 perror(msg); 46 exit(1); 47 } 48 49 static void die_usage(int r) 50 { 51 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n"); 52 exit(r); 53 } 54 55 static void xerror(const char *fmt, ...) 56 { 57 va_list ap; 58 59 va_start(ap, fmt); 60 vfprintf(stderr, fmt, ap); 61 va_end(ap); 62 fputc('\n', stderr); 63 exit(1); 64 } 65 66 static const char *getxinfo_strerr(int err) 67 { 68 if (err == EAI_SYSTEM) 69 return strerror(errno); 70 71 return gai_strerror(err); 72 } 73 74 static void xgetaddrinfo(const char *node, const char *service, 75 const struct addrinfo *hints, 76 struct addrinfo **res) 77 { 78 int err = getaddrinfo(node, service, hints, res); 79 80 if (err) { 81 const char *errstr = getxinfo_strerr(err); 82 83 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", 84 node ? node : "", service ? service : "", errstr); 85 exit(1); 86 } 87 } 88 89 static int sock_listen_mptcp(const char * const listenaddr, 90 const char * const port) 91 { 92 int sock = -1; 93 struct addrinfo hints = { 94 .ai_protocol = IPPROTO_TCP, 95 .ai_socktype = SOCK_STREAM, 96 .ai_flags = AI_PASSIVE | AI_NUMERICHOST 97 }; 98 99 hints.ai_family = pf; 100 101 struct addrinfo *a, *addr; 102 int one = 1; 103 104 xgetaddrinfo(listenaddr, port, &hints, &addr); 105 hints.ai_family = pf; 106 107 for (a = addr; a; a = a->ai_next) { 108 sock = socket(a->ai_family, a->ai_socktype, proto_rx); 109 if (sock < 0) 110 continue; 111 112 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, 113 sizeof(one))) 114 perror("setsockopt"); 115 116 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) 117 break; /* success */ 118 119 perror("bind"); 120 close(sock); 121 sock = -1; 122 } 123 124 freeaddrinfo(addr); 125 126 if (sock < 0) 127 xerror("could not create listen socket"); 128 129 if (listen(sock, 20)) 130 die_perror("listen"); 131 132 return sock; 133 } 134 135 static int sock_connect_mptcp(const char * const remoteaddr, 136 const char * const port, int proto) 137 { 138 struct addrinfo hints = { 139 .ai_protocol = IPPROTO_TCP, 140 .ai_socktype = SOCK_STREAM, 141 }; 142 struct addrinfo *a, *addr; 143 int sock = -1; 144 145 hints.ai_family = pf; 146 147 xgetaddrinfo(remoteaddr, port, &hints, &addr); 148 for (a = addr; a; a = a->ai_next) { 149 sock = socket(a->ai_family, a->ai_socktype, proto); 150 if (sock < 0) 151 continue; 152 153 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) 154 break; /* success */ 155 156 die_perror("connect"); 157 } 158 159 if (sock < 0) 160 xerror("could not create connect socket"); 161 162 freeaddrinfo(addr); 163 return sock; 164 } 165 166 static int protostr_to_num(const char *s) 167 { 168 if (strcasecmp(s, "tcp") == 0) 169 return IPPROTO_TCP; 170 if (strcasecmp(s, "mptcp") == 0) 171 return IPPROTO_MPTCP; 172 173 die_usage(1); 174 return 0; 175 } 176 177 static void parse_opts(int argc, char **argv) 178 { 179 int c; 180 181 while ((c = getopt(argc, argv, "h6t:r:")) != -1) { 182 switch (c) { 183 case 'h': 184 die_usage(0); 185 break; 186 case '6': 187 pf = AF_INET6; 188 break; 189 case 't': 190 proto_tx = protostr_to_num(optarg); 191 break; 192 case 'r': 193 proto_rx = protostr_to_num(optarg); 194 break; 195 default: 196 die_usage(1); 197 break; 198 } 199 } 200 } 201 202 /* wait up to timeout milliseconds */ 203 static void wait_for_ack(int fd, int timeout, size_t total) 204 { 205 int i; 206 207 for (i = 0; i < timeout; i++) { 208 int nsd, ret, queued = -1; 209 struct timespec req; 210 211 ret = ioctl(fd, TIOCOUTQ, &queued); 212 if (ret < 0) 213 die_perror("TIOCOUTQ"); 214 215 ret = ioctl(fd, SIOCOUTQNSD, &nsd); 216 if (ret < 0) 217 die_perror("SIOCOUTQNSD"); 218 219 if ((size_t)queued > total) 220 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total); 221 assert(nsd <= queued); 222 223 if (queued == 0) 224 return; 225 226 /* wait for peer to ack rx of all data */ 227 req.tv_sec = 0; 228 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */ 229 nanosleep(&req, NULL); 230 } 231 232 xerror("still tx data queued after %u ms\n", timeout); 233 } 234 235 static void connect_one_server(int fd, int unixfd) 236 { 237 size_t len, i, total, sent; 238 char buf[4096], buf2[4096]; 239 ssize_t ret; 240 241 len = rand() % (sizeof(buf) - 1); 242 243 if (len < 128) 244 len = 128; 245 246 for (i = 0; i < len ; i++) { 247 buf[i] = rand() % 26; 248 buf[i] += 'A'; 249 } 250 251 buf[i] = '\n'; 252 253 /* un-block server */ 254 ret = read(unixfd, buf2, 4); 255 assert(ret == 4); 256 257 assert(strncmp(buf2, "xmit", 4) == 0); 258 259 ret = write(unixfd, &len, sizeof(len)); 260 assert(ret == (ssize_t)sizeof(len)); 261 262 ret = write(fd, buf, len); 263 if (ret < 0) 264 die_perror("write"); 265 266 if (ret != (ssize_t)len) 267 xerror("short write"); 268 269 ret = read(unixfd, buf2, 4); 270 assert(strncmp(buf2, "huge", 4) == 0); 271 272 total = rand() % (16 * 1024 * 1024); 273 total += (1 * 1024 * 1024); 274 sent = total; 275 276 ret = write(unixfd, &total, sizeof(total)); 277 assert(ret == (ssize_t)sizeof(total)); 278 279 wait_for_ack(fd, 5000, len); 280 281 while (total > 0) { 282 if (total > sizeof(buf)) 283 len = sizeof(buf); 284 else 285 len = total; 286 287 ret = write(fd, buf, len); 288 if (ret < 0) 289 die_perror("write"); 290 total -= ret; 291 292 /* we don't have to care about buf content, only 293 * number of total bytes sent 294 */ 295 } 296 297 ret = read(unixfd, buf2, 4); 298 assert(ret == 4); 299 assert(strncmp(buf2, "shut", 4) == 0); 300 301 wait_for_ack(fd, 5000, sent); 302 303 ret = write(fd, buf, 1); 304 assert(ret == 1); 305 close(fd); 306 ret = write(unixfd, "closed", 6); 307 assert(ret == 6); 308 309 close(unixfd); 310 } 311 312 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv) 313 { 314 struct cmsghdr *cmsg; 315 316 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { 317 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { 318 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv)); 319 return; 320 } 321 } 322 323 xerror("could not find TCP_CM_INQ cmsg type"); 324 } 325 326 static void process_one_client(int fd, int unixfd) 327 { 328 unsigned int tcp_inq; 329 size_t expect_len; 330 char msg_buf[4096]; 331 char buf[4096]; 332 char tmp[16]; 333 struct iovec iov = { 334 .iov_base = buf, 335 .iov_len = 1, 336 }; 337 struct msghdr msg = { 338 .msg_iov = &iov, 339 .msg_iovlen = 1, 340 .msg_control = msg_buf, 341 .msg_controllen = sizeof(msg_buf), 342 }; 343 ssize_t ret, tot; 344 345 ret = write(unixfd, "xmit", 4); 346 assert(ret == 4); 347 348 ret = read(unixfd, &expect_len, sizeof(expect_len)); 349 assert(ret == (ssize_t)sizeof(expect_len)); 350 351 if (expect_len > sizeof(buf)) 352 xerror("expect len %zu exceeds buffer size", expect_len); 353 354 for (;;) { 355 struct timespec req; 356 unsigned int queued; 357 358 ret = ioctl(fd, FIONREAD, &queued); 359 if (ret < 0) 360 die_perror("FIONREAD"); 361 if (queued > expect_len) 362 xerror("FIONREAD returned %u, but only %zu expected\n", 363 queued, expect_len); 364 if (queued == expect_len) 365 break; 366 367 req.tv_sec = 0; 368 req.tv_nsec = 1000 * 1000ul; 369 nanosleep(&req, NULL); 370 } 371 372 /* read one byte, expect cmsg to return expected - 1 */ 373 ret = recvmsg(fd, &msg, 0); 374 if (ret < 0) 375 die_perror("recvmsg"); 376 377 if (msg.msg_controllen == 0) 378 xerror("msg_controllen is 0"); 379 380 get_tcp_inq(&msg, &tcp_inq); 381 382 assert((size_t)tcp_inq == (expect_len - 1)); 383 384 iov.iov_len = sizeof(buf); 385 ret = recvmsg(fd, &msg, 0); 386 if (ret < 0) 387 die_perror("recvmsg"); 388 389 /* should have gotten exact remainder of all pending data */ 390 assert(ret == (ssize_t)tcp_inq); 391 392 /* should be 0, all drained */ 393 get_tcp_inq(&msg, &tcp_inq); 394 assert(tcp_inq == 0); 395 396 /* request a large swath of data. */ 397 ret = write(unixfd, "huge", 4); 398 assert(ret == 4); 399 400 ret = read(unixfd, &expect_len, sizeof(expect_len)); 401 assert(ret == (ssize_t)sizeof(expect_len)); 402 403 /* peer should send us a few mb of data */ 404 if (expect_len <= sizeof(buf)) 405 xerror("expect len %zu too small\n", expect_len); 406 407 tot = 0; 408 do { 409 iov.iov_len = sizeof(buf); 410 ret = recvmsg(fd, &msg, 0); 411 if (ret < 0) 412 die_perror("recvmsg"); 413 414 tot += ret; 415 416 get_tcp_inq(&msg, &tcp_inq); 417 418 if (tcp_inq > expect_len - tot) 419 xerror("inq %d, remaining %d total_len %d\n", 420 tcp_inq, expect_len - tot, (int)expect_len); 421 422 assert(tcp_inq <= expect_len - tot); 423 } while ((size_t)tot < expect_len); 424 425 ret = write(unixfd, "shut", 4); 426 assert(ret == 4); 427 428 /* wait for hangup. Should have received one more byte of data. */ 429 ret = read(unixfd, tmp, sizeof(tmp)); 430 assert(ret == 6); 431 assert(strncmp(tmp, "closed", 6) == 0); 432 433 sleep(1); 434 435 iov.iov_len = 1; 436 ret = recvmsg(fd, &msg, 0); 437 if (ret < 0) 438 die_perror("recvmsg"); 439 assert(ret == 1); 440 441 get_tcp_inq(&msg, &tcp_inq); 442 443 /* tcp_inq should be 1 due to received fin. */ 444 assert(tcp_inq == 1); 445 446 iov.iov_len = 1; 447 ret = recvmsg(fd, &msg, 0); 448 if (ret < 0) 449 die_perror("recvmsg"); 450 451 /* expect EOF */ 452 assert(ret == 0); 453 get_tcp_inq(&msg, &tcp_inq); 454 assert(tcp_inq == 1); 455 456 close(fd); 457 } 458 459 static int xaccept(int s) 460 { 461 int fd = accept(s, NULL, 0); 462 463 if (fd < 0) 464 die_perror("accept"); 465 466 return fd; 467 } 468 469 static int server(int unixfd) 470 { 471 int fd = -1, r, on = 1; 472 473 switch (pf) { 474 case AF_INET: 475 fd = sock_listen_mptcp("127.0.0.1", "15432"); 476 break; 477 case AF_INET6: 478 fd = sock_listen_mptcp("::1", "15432"); 479 break; 480 default: 481 xerror("Unknown pf %d\n", pf); 482 break; 483 } 484 485 r = write(unixfd, "conn", 4); 486 assert(r == 4); 487 488 alarm(15); 489 r = xaccept(fd); 490 491 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on))) 492 die_perror("setsockopt"); 493 494 process_one_client(r, unixfd); 495 496 return 0; 497 } 498 499 static int client(int unixfd) 500 { 501 int fd = -1; 502 503 alarm(15); 504 505 switch (pf) { 506 case AF_INET: 507 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx); 508 break; 509 case AF_INET6: 510 fd = sock_connect_mptcp("::1", "15432", proto_tx); 511 break; 512 default: 513 xerror("Unknown pf %d\n", pf); 514 } 515 516 connect_one_server(fd, unixfd); 517 518 return 0; 519 } 520 521 static void init_rng(void) 522 { 523 unsigned int foo; 524 525 if (getrandom(&foo, sizeof(foo), 0) == -1) { 526 perror("getrandom"); 527 exit(1); 528 } 529 530 srand(foo); 531 } 532 533 static pid_t xfork(void) 534 { 535 pid_t p = fork(); 536 537 if (p < 0) 538 die_perror("fork"); 539 else if (p == 0) 540 init_rng(); 541 542 return p; 543 } 544 545 static int rcheck(int wstatus, const char *what) 546 { 547 if (WIFEXITED(wstatus)) { 548 if (WEXITSTATUS(wstatus) == 0) 549 return 0; 550 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus)); 551 return WEXITSTATUS(wstatus); 552 } else if (WIFSIGNALED(wstatus)) { 553 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus)); 554 } else if (WIFSTOPPED(wstatus)) { 555 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus)); 556 } 557 558 return 111; 559 } 560 561 int main(int argc, char *argv[]) 562 { 563 int e1, e2, wstatus; 564 pid_t s, c, ret; 565 int unixfds[2]; 566 567 parse_opts(argc, argv); 568 569 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds); 570 if (e1 < 0) 571 die_perror("pipe"); 572 573 s = xfork(); 574 if (s == 0) 575 return server(unixfds[1]); 576 577 close(unixfds[1]); 578 579 /* wait until server bound a socket */ 580 e1 = read(unixfds[0], &e1, 4); 581 assert(e1 == 4); 582 583 c = xfork(); 584 if (c == 0) 585 return client(unixfds[0]); 586 587 close(unixfds[0]); 588 589 ret = waitpid(s, &wstatus, 0); 590 if (ret == -1) 591 die_perror("waitpid"); 592 e1 = rcheck(wstatus, "server"); 593 ret = waitpid(c, &wstatus, 0); 594 if (ret == -1) 595 die_perror("waitpid"); 596 e2 = rcheck(wstatus, "client"); 597 598 return e1 ? e1 : e2; 599 } 600